from cozmo_fsm import *

import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision import transforms
from torchvision.models import MobileNetV2
from labels import imagenet_labels

model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
model.eval()

preprocess = transforms.Compose([
    #transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class MobileNet(StateMachineProgram):
  def start(self):
    super().start()
    print("Type 'tm' to take a picture and classify it.")
    self.robot.camera.color_image_enabled = True

  def user_annotate(self,image):
    half_box = 224
    # Boost green to compensate for Cozmo camera idiosyncracies
    image[:,:,1] = np.minimum(231,image[:,:,1]) * 1.10
    cv2.rectangle(image, (320-half_box, 240-half_box), (320+half_box, 240+half_box), (255,255,0), 2)
    return image

  class MobileNetClassify(StateNode):
    def start(self,event=None):
      super().start(event)
      image = self.robot.world.latest_image.raw_image
      input_tensor = preprocess(image)
      im = input_tensor.detach().numpy().transpose((1,2,0))
      min = im.min()
      max = im.max()
      im = (im - min) / (max-min)
      plt.imshow(im)
      plt.axis('off')
      plt.pause(0.001)
      input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
      output = model(input_batch)
      probs = torch.nn.functional.softmax(output[0], dim=0).detach().numpy()
      z = sorted(zip(probs,range(1000)), key=lambda x: x[0])
      top5 = list(reversed(z[-5:]))
      for t in top5:
        print('%7.5f  %s' % (t[0], imagenet_labels[t[1]]))
      raw_label = imagenet_labels[top5[0][1]]
      sep = raw_label.find(',')
      label = raw_label[:sep] if sep > -1 else raw_label
      self.post_data(label)


  def setup(self):
      #     start: StateNode() =TM=> mobile
      # 
      #     mobile: self.MobileNetClassify() =SayData=> Say() =TM=> mobile
      
      # Code generated by genfsm on Mon Mar 13 21:53:24 2023:
      
      start = StateNode() .set_name("start") .set_parent(self)
      mobile = self.MobileNetClassify() .set_name("mobile") .set_parent(self)
      say1 = Say() .set_name("say1") .set_parent(self)
      
      textmsgtrans1 = TextMsgTrans() .set_name("textmsgtrans1")
      textmsgtrans1 .add_sources(start) .add_destinations(mobile)
      
      saydatatrans1 = SayDataTrans() .set_name("saydatatrans1")
      saydatatrans1 .add_sources(mobile) .add_destinations(say1)
      
      textmsgtrans2 = TextMsgTrans() .set_name("textmsgtrans2")
      textmsgtrans2 .add_sources(say1) .add_destinations(mobile)
      
      return self
