from cozmo_fsm import *

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy import linalg
import sklearn
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture

class RGBClusterer(StateMachineProgram):

    class ClusterImage(StateNode):
        def plot_results(self,x,y,means,covariances):
            plt.figure(1)
            self.parent.im1 = [[means[i,0], means[i,1], means[i,2]] for i in y]
            self.parent.im2 = np.reshape(np.array(self.parent.im1)/256,(240,320,3))
            plt.imshow(self.parent.im2)
            plt.figure(2)
            plt.subplot(2,1,1)
            self.plot_result(0,1,x,y,means,covariances)
            plt.subplot(2,1,2)
            self.plot_result(2,1,x,y,means,covariances)
            plt.show()

        def plot_result(self,idx0,idx1,x,y,means,covariances):
            # Scatter plot of pixels in R-G or B-G color subspace.
            # idx0,idx1 determine whether we're plotting R-vs-G or B-vs-G
            for i in range(len(means)):
                color = np.array([[int(means[i,0]), int(means[i,1]), int(means[i,2])]])/256.0
                mean = means[i]
                #print('i=',i,'mean=',mean,'color=',color)
                plt.scatter(x[y==i,idx0], x[y==i,idx1], 1, color=color)
                c = covariances[i]
                covar = np.array([[c[idx0,idx0],c[idx0,idx1]],[c[idx1,idx0],c[idx1,idx1]]])
                v, w = linalg.eigh(covar)
                v = 2.0 * np.sqrt(2.) * np.sqrt(v)
                u = w[0] / linalg.norm(w[0])
                angle = np.arctan(u[1] / u[0])
                angle = angle * 180/pi  # convert to degrees
                ec = (color[0,1], color[0,0], 0)
                m = np.array([mean[idx0],mean[idx1]])
                ell = mpl.patches.Ellipse(m, v[0], v[1], 180. + angle, fc='None', edgecolor=ec)
                ell.set_clip_box(plt.gca().bbox)
                #ell.set_alpha(0.25)
                plt.gca().add_artist(ell)

        def start(self,event=None):
            super().start(event)
            im = self.robot.world.latest_color_image.raw_image
            data = np.reshape(im,(-1,3))
            gmm = BayesianGaussianMixture(n_components=7, covariance_type='full')
            gmm.fit(data)
            means = gmm.means_
            covariances = gmm.covariances_
            print('means:',means)
            y = gmm.predict(data)
            self.plot_results(data, y, means, covariances)

    def setup(self):
        """
            ColorImageEnabled(True) =C=> self.ClusterImage()
        """
        
        # Code generated by genfsm on Thu Mar 22 23:34:37 2018:
        
        colorimageenabled1 = ColorImageEnabled(True) .set_name("colorimageenabled1") .set_parent(self)
        clusterimage1 = self.ClusterImage() .set_name("clusterimage1") .set_parent(self)
        
        completiontrans1 = CompletionTrans() .set_name("completiontrans1")
        completiontrans1 .add_sources(colorimageenabled1) .add_destinations(clusterimage1)
        
        return self
