from cozmo_fsm import *
import numpy as np
import math

import cozmo
from cozmo.nav_memory_map import NodeContentTypes

class NextLocation(StateNode):
    def __init__(self, mode="Default"):
        self.mode = mode
        self.visitedLocs = []
        self.previous_x = None
        self.previous_y = None
        super().__init__()

    def start(self, event=None):
        if self.running: return
        super().start(event)
        if isinstance(event, DataEvent):
          if event.data == "GoalCollide":
            self.goal_collision()
        else:
          tree = self.robot.world.nav_memory_map
          possibleLocs = []
          self.tree_crawl(tree.root_node, possibleLocs)
          if self.mode == "Closest":
              possibleLocs.sort( key = self.closestHeuristic)
          else:
              possibleLocs.sort(key = self.defaultHeuristic)

          if len(possibleLocs) > 0:
              xLoc = possibleLocs[0].center.x
              yLoc = possibleLocs[0].center.y
          else:
              xLoc = robot.pose.position.x + 100
              yLoc = robot.pose.position.y

          self.setDest(xLoc, yLoc)

    def setDest(self, x, y):
        self.previous_x = x
        self.previous_y = y
        self.visitedLocs.append((x,y))

        self.robot.world.world_map.objects["Dest"] = (ChipObj(id = "dest", x = x, y = y, z = 20, radius = 20))

        dest = Pose(x = x, y = y, z = 0, angle_z = degrees(math.nan))
        self.post_data(dest)

    def defaultHeuristic(self, node):
        value = 0
        nodeLoc = np.array((node.center.x,node.center.y))

        objects = world.world_map.objects
        count = 0
        for key in objects:
            try:
              obj = objects[key].sdk_obj
              if obj.pose != None:
                count += 1
                objLoc = np.array((obj.pose.position.x, obj.pose.position.y))
                value += np.linalg.norm(nodeLoc-objLoc)
            except:
              obj = objects[key]
              count += 1
              objLoc = np.array((obj.x, obj.y))
              value += np.linalg.norm(nodeLoc-objLoc)

        if count > 0:
            value = value/count
        value -= (node.size/2)
        return value

    def closestHeuristic(self, node):
        value = 0
        nodeLoc = np.array((node.center.x, node.center.y))

        robotLoc = np.array((robot.pose.position.x, robot.pose.position.y))
        value += np.linalg.norm(nodeLoc - robotLoc)

        value -= node.size
        return value

    def goal_collision(self):
        nodeLoc = np.array((self.previous_x, self.previous_y))

        objects = world.world_map.objects
        objectLocs = []
        for key in objects:
          try:
            obj = objects[key].sdk_obj
            if obj.pose != None:
              objLoc = np.array((obj.pose.position.x, obj.pose.position.y))
              value = np.linalg.norm(nodeLoc-objLoc)
              objectLocs.append((objLoc,value))
          except:
            obj = objects[key]
            objLoc = np.array((obj.x, obj.y))
            value = np.linalg.norm(nodeLoc-objLoc)
            objectLocs.append((objLoc,value))

        objectLocs.sort(key = lambda x:x[1])

        xdiff = nodeLoc[0] - objectLocs[1][0][0]
        ydiff = nodeLoc[1] - objectLocs[1][0][1]
        print("X locations", nodeLoc[0], objectLocs[1][0][0], xdiff)
        print("Y locations", nodeLoc[1], objectLocs[1][0][1], ydiff)

        if xdiff < ydiff:
          nodeLoc[0] += xdiff*2
        else:
          nodeLoc[1] += ydiff*2

        print("New locations", nodeLoc[0], nodeLoc[1])
        self.setDest(nodeLoc[0], nodeLoc[1])

    def tree_crawl(self, node, locs):
        if node.children is not None:
            for child in node.children:
                self.tree_crawl(child, locs)
        else:
            if node.content == NodeContentTypes.Unknown and node.size > 20 and not((node.center.x,node.center.y) in self.visitedLocs):
                locs.append(node)

class GoToPose(PilotToPose):
    def start(self, event=None):
        if self.running: return
        if isinstance(event, DataEvent):
            self.target_pose = event.data
        super().start(event)

class StartCollision(StateNode):
    def start(self, event = None):
      if self.running: return
      super().start(event)
      self.start_collision()
      self.post_completion()

    def start_collision(self):
      nodeLoc = np.array((robot.pose.position.x, robot.pose.position.y))

      objects = world.world_map.objects
      objectLocs = []
      for key in objects:
        try:
          obj = objects[key].sdk_obj
          if obj.pose != None:
            objLoc = np.array((obj.pose.position.x, obj.pose.position.y))
            value = np.linalg.norm(nodeLoc-objLoc)
            objectLocs.append((objLoc,value))
        except:
          obj = objects[key]
          objLoc = np.array((obj.x, obj.y))
          value = np.linalg.norm(nodeLoc-objLoc)
          objectLocs.append((objLoc,value))

      objectLocs.sort(key = lambda x:x[1])

      xdiff = objectLocs[0][0][0] - nodeLoc[0]
      ydiff = objectLocs[0][0][1] - nodeLoc[1]
      bearing = math.degrees(math.atan2(ydiff,xdiff))
      heading = robot.pose.rotation.angle_z.degrees
      self.post_data(degrees(bearing-heading+180))

class GoalCollision(StateNode):
    def start(self, event = None):
      if self.running: return
      super().start(event)
      self.post_data("GoalCollide")

class Explore(StateMachineProgram):
    def __init__(self):
        super().__init__(worldmap_viewer=True, path_viewer=True)

    def setup(self):
        """
            SetHeadAngle(0) =C=> DriveTurn(360, 15) =C=> findNext
    
            findNext: NextLocation()
            findNext =D=> pose
            findNext =F=> findNext
    
            pose: GoToPose(max_iter = 4000)
            pose =C=> DriveTurn(360, 15) =C=> findNext
            pose =PILOT(StartCollides)=> StartCollision() =D=> DriveTurn() =C=> Forward(50) =C=> findNext
            pose =PILOT(GoalCollides)=> GoalCollision() =D=> findNext
            pose =F=> findNext
        """
        
        # Code generated by genfsm on Fri May  8 21:47:20 2020:
        
        setheadangle1 = SetHeadAngle(0) .set_name("setheadangle1") .set_parent(self)
        driveturn1 = DriveTurn(360, 15) .set_name("driveturn1") .set_parent(self)
        findNext = NextLocation() .set_name("findNext") .set_parent(self)
        pose = GoToPose(max_iter = 4000) .set_name("pose") .set_parent(self)
        driveturn2 = DriveTurn(360, 15) .set_name("driveturn2") .set_parent(self)
        startcollision1 = StartCollision() .set_name("startcollision1") .set_parent(self)
        driveturn3 = DriveTurn() .set_name("driveturn3") .set_parent(self)
        forward1 = Forward(50) .set_name("forward1") .set_parent(self)
        goalcollision1 = GoalCollision() .set_name("goalcollision1") .set_parent(self)
        
        completiontrans1 = CompletionTrans() .set_name("completiontrans1")
        completiontrans1 .add_sources(setheadangle1) .add_destinations(driveturn1)
        
        completiontrans2 = CompletionTrans() .set_name("completiontrans2")
        completiontrans2 .add_sources(driveturn1) .add_destinations(findNext)
        
        datatrans1 = DataTrans() .set_name("datatrans1")
        datatrans1 .add_sources(findNext) .add_destinations(pose)
        
        failuretrans1 = FailureTrans() .set_name("failuretrans1")
        failuretrans1 .add_sources(findNext) .add_destinations(findNext)
        
        completiontrans3 = CompletionTrans() .set_name("completiontrans3")
        completiontrans3 .add_sources(pose) .add_destinations(driveturn2)
        
        completiontrans4 = CompletionTrans() .set_name("completiontrans4")
        completiontrans4 .add_sources(driveturn2) .add_destinations(findNext)
        
        pilottrans1 = PilotTrans(StartCollides) .set_name("pilottrans1")
        pilottrans1 .add_sources(pose) .add_destinations(startcollision1)
        
        datatrans2 = DataTrans() .set_name("datatrans2")
        datatrans2 .add_sources(startcollision1) .add_destinations(driveturn3)
        
        completiontrans5 = CompletionTrans() .set_name("completiontrans5")
        completiontrans5 .add_sources(driveturn3) .add_destinations(forward1)
        
        completiontrans6 = CompletionTrans() .set_name("completiontrans6")
        completiontrans6 .add_sources(forward1) .add_destinations(findNext)
        
        pilottrans2 = PilotTrans(GoalCollides) .set_name("pilottrans2")
        pilottrans2 .add_sources(pose) .add_destinations(goalcollision1)
        
        datatrans3 = DataTrans() .set_name("datatrans3")
        datatrans3 .add_sources(goalcollision1) .add_destinations(findNext)
        
        failuretrans2 = FailureTrans() .set_name("failuretrans2")
        failuretrans2 .add_sources(pose) .add_destinations(findNext)
        
        return self
