from cozmo_fsm import *
from CubeDetector1 import CubeDetector1
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import cv2
from math import *

class PartialCube(StateMachineProgram):
        def __init__(self):
            super().__init__(cam_viewer=False)
        def user_image(self,image,gray):
            self.robot.myimage = gray
        class TurnL(Turn):
            def start(self,event):
                if self.running: return
                if isinstance(event,DataEvent):
                    print("event" + str(event.data))
                self.angle = Angle(math.pi*(0.07))
                super().start(event)
        class TurnR(Turn):
            def start(self,event = None):
                if self.running: return
                if isinstance(event,DataEvent):
                    print("event" + str(event.data))
                self.angle = Angle(math.pi*(-0.07))
                super().start(event)
        class PartialImage(Say):
            def left_image(self,image):
                image = image[:235,:155]
                leftimage = np.fliplr(image).copy()
                #uncomment to see the left imagecv2.imwrite("img1.jpg" , image1)
                in_dim = (155,235)
                nkernels1 = 32
                nkernels2 = 12
                pool1 = 4
                pool2 = 4
                device = torch.device('cpu') 
                model = CubeDetector1(in_dim,nkernels1,nkernels2,pool1,pool2).to(device)
                model.load_state_dict(torch.load('./cuberec-saved.pt'))
                result = model.forward(torch.from_numpy(leftimage).reshape(1,1,155,235).float().to(device))
                #the print statement shows the no cube, cube prediction
                print("Left image", result)
                return (str(result.argmax().item()))
            def right_image(self, image):
                rightimage = image[5:,165:]
                #uncomment to see the right image cv2.imwrite('orig.jpg', image)
                in_dim = (155,235)
                nkernels1 = 32
                nkernels2 = 12
                pool1 = 4
                pool2 = 4
                device = torch.device('cpu') 
                model = CubeDetector1(in_dim,nkernels1,nkernels2,pool1,pool2).to(device)
                model.load_state_dict(torch.load('./cuberec-saved.pt'))
                result = model.forward(torch.from_numpy(rightimage).reshape(1,1,155,235).float().to(device))
                print("right image", result)
                return (str(result.argmax().item()))
            def start(self,event):
                num = self.left_image(self.robot.myimage)
                speech = self.right_image(self.robot.myimage)
                if num == '0':
                    #left half has cube
                    self.text = "cube"
                    data = 0
                else: 
                    if speech == '0':
                        #right half has cube
                        self.text = "cube"
                        data = 1
                    else:
                        #no cube
                        self.text = "no cube"
                        data = 2
                super().start(event)
                self.post_data(int(data))
        def setup(self):
            """
               StateNode() =T(0.5)=> showImg
               showImg: self.PartialImage()
               showImg =D(0)=> StateNode() =T(1)=> self.TurnL()
               showImg =D(1)=>  StateNode() =T(1)=> self.TurnR()
               showImg =D(2) => StateNode()
            """
            
            # Code generated by genfsm on Thu May  7 13:49:56 2020:
            
            statenode1 = StateNode() .set_name("statenode1") .set_parent(self)
            showImg = self.PartialImage() .set_name("showImg") .set_parent(self)
            statenode2 = StateNode() .set_name("statenode2") .set_parent(self)
            turnl1 = self.TurnL() .set_name("turnl1") .set_parent(self)
            statenode3 = StateNode() .set_name("statenode3") .set_parent(self)
            turnr1 = self.TurnR() .set_name("turnr1") .set_parent(self)
            statenode4 = StateNode() .set_name("statenode4") .set_parent(self)
            
            timertrans1 = TimerTrans(0.5) .set_name("timertrans1")
            timertrans1 .add_sources(statenode1) .add_destinations(showImg)
            
            datatrans1 = DataTrans(0) .set_name("datatrans1")
            datatrans1 .add_sources(showImg) .add_destinations(statenode2)
            
            timertrans2 = TimerTrans(1) .set_name("timertrans2")
            timertrans2 .add_sources(statenode2) .add_destinations(turnl1)
            
            datatrans2 = DataTrans(1) .set_name("datatrans2")
            datatrans2 .add_sources(showImg) .add_destinations(statenode3)
            
            timertrans3 = TimerTrans(1) .set_name("timertrans3")
            timertrans3 .add_sources(statenode3) .add_destinations(turnr1)
            
            datatrans3 = DataTrans(2) .set_name("datatrans3")
            datatrans3 .add_sources(showImg) .add_destinations(statenode4)
            
            return self
