from cozmo_fsm import *

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy import linalg
import sklearn
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture

class RGBClusterer(StateMachineProgram):

    class ClusterImage(StateNode):

        def plot_results(self,x,y,means,covariances):
            plt.figure(1)
            self.parent.im1 = [[means[i,0], means[i,1], means[i,2]] for i in y]
            self.parent.im2 = np.reshape(np.array(self.parent.im1)/256,(240,320,3))
            plt.imshow(self.parent.im2)
            plt.figure(2)
            plt.subplot(3,1,1)
            self.plot_result(0,1,x,y,means,covariances)
            plt.subplot(3,1,2)
            self.plot_result(2,1,x,y,means,covariances)
            plt.subplot(3,1,3)
            self.plot_result(0,2,x,y,means,covariances)
            plt.show()

        def plot_result(self,idx0,idx1,x,y,means,covariances):
            # Scatter plot of pixels in R-G or B-G color subspace.
            # idx0,idx1 determine whether we're plotting R-vs-G or B-vs-G
            axis = plt.gca()
            for i in range(len(means)):
                color = np.array([[int(means[i,0]), int(means[i,1]), int(means[i,2])]])/256.0
                mean = means[i]
                #print('i=',i,'mean=',mean,'color=',color)
                plt.scatter(x[y==i,idx0], x[y==i,idx1], 1, color=color)
                c = covariances[i]
                covar = np.array([[c[idx0,idx0],c[idx0,idx1]],[c[idx1,idx0],c[idx1,idx1]]])
                v, w = linalg.eigh(covar)
                v = 2.0 * np.sqrt(2.) * np.sqrt(v)
                u = w[0] / linalg.norm(w[0])
                angle = np.arctan(u[1] / u[0])
                angle = angle * 180/pi  # convert to degrees
                ec = (color[0,1], color[0,0], 0)
                ec = (1, 1, 0)
                m = np.array([mean[idx0],mean[idx1]])
                ell = mpl.patches.Ellipse(m, v[0], v[1], 180. + angle, fc='None', edgecolor=ec)
                ell.set_clip_box(axis.bbox)
                #ell.set_alpha(0.25)
                axis.add_artist(ell)
            components = ['red', 'green', 'blue']
            plt.title(components[idx0] + ' vs. ' + components[idx1])

        def start(self,event=None):
            super().start(event)
            N = 7  # max number of gaussians
            im = self.robot.world.latest_color_image.raw_image
            data = np.reshape(im,(-1,3))
            gmm = BayesianGaussianMixture(n_components=N, covariance_type='full')
            print('Got image.  Fitting data...')
            gmm.fit(data)
            means = gmm.means_
            covariances = gmm.covariances_
            print('Means:\n',means)
            y = gmm.predict(data)
            counts = [sum(y==i) for i in range(0,N)]
            print('Counts =',counts)
            self.plot_results(data, y, means, covariances)

    def setup(self):
        """
            ColorImageEnabled(True) =C=> self.ClusterImage()
        """
        
        # Code generated by genfsm on Fri Mar 23 05:18:41 2018:
        
        colorimageenabled1 = ColorImageEnabled(True) .set_name("colorimageenabled1") .set_parent(self)
        clusterimage1 = self.ClusterImage() .set_name("clusterimage1") .set_parent(self)
        
        completiontrans1 = CompletionTrans() .set_name("completiontrans1")
        completiontrans1 .add_sources(colorimageenabled1) .add_destinations(clusterimage1)
        
        return self
