#include "Behaviors/StateMachine.h"
#include "Behaviors/Services/CameraBehavior.h"
#include "Events/TextMsgEvent.h"
#include <wait.h>


$nodeclass SurfTest : VisualRoutinesStateNode {

  enum command {
    testing,
    cont
  };

  // Global variables stating relevant filename paths
  $provide std::string tempSavePath;
  $provide std::string tempReturnPath;
  $provide std::string keypointDataPath;
  $provide std::string trainDataPath;
  $provide std::string objectDataPath;
  $provide int picWidth;
  $provide int picHeight;
  $provide int factor;

  $provide std::string currentObject;
  $provide bool findMatches;
  
  // initialize all the global file paths
  $nodeclass Initialize : StateNode : doStart {
    $reference SurfTest::tempSavePath;
    $reference SurfTest::tempReturnPath;
    $reference SurfTest::keypointDataPath;
    $reference SurfTest::trainDataPath;
    $reference SurfTest::objectDataPath;
    $reference SurfTest::picWidth;
    $reference SurfTest::picHeight;
    $reference SurfTest::factor;

    $reference SurfTest::currentObject;
    $reference SurfTest::findMatches;

    picWidth = 640;
    picHeight = 480;
    factor = 2;
		
    tempSavePath = "/tmp/image.raw";
    tempReturnPath = "surf_data/";
    trainDataPath = "surf_data/trainData.dat";
    keypointDataPath = "surf_data/keypointData.dat";
    objectDataPath = "surf_data/objectDatabase.dat";
	
    currentObject = "-";

    findMatches = false;

    postStateCompletion();
  }

  // node that saves the image into a temporary file
  $nodeclass SaveImage : VisualRoutinesStateNode {
		
    virtual void doStart()
    {

      $reference SurfTest::tempSavePath;

      NEW_SKETCH(camY, uchar, sketchFromRawY());
      int width = camY->getWidth();
      int height = camY->getHeight();
      size_t buffsize = sizeof(char) * height * width;
      char *imgBuf = (char *) malloc(buffsize);
      camY->savePixels(imgBuf,buffsize);
      int factor = 2;
      char *destBuf = resample(imgBuf, width, height, factor);
      ofstream myfile;
      myfile.open(tempSavePath.c_str(), fstream::out | ios::binary);
      myfile.write(destBuf, buffsize * factor * factor);
      myfile.close();
      cout << "saved to " << tempSavePath << ". w: " << width << ", h: " 
	   << height  << endl;
      postStateCompletion();
    }
		
    char * resample(char * imgBuf, int width, int height, int factor) {
      size_t buffsize = sizeof(char) * height * factor * width * factor;
      char *destBuf = (char *) malloc(buffsize);
      for (int i = 0; i < height; i++) {
	for (int j = 0; j < width; j++) {
	  int pos = xy2pos(i, j, width);
	  char val = imgBuf[pos];
	  int destWidth = (width * factor);
	  int destPos = 
	    xy2pos((i * factor),(j * factor), destWidth);
	  for (int k = 0; k < factor; k++) {
	    for (int l = 0; l < factor; l++) {
	      int resampleRowOffset = k * destWidth;
	      destBuf[destPos + (resampleRowOffset + l)] = val;
	    }					
	  }
	}
      }
      return destBuf;
    }

    char *interpolate(char *imgBuf, int width, int height, int factor) {
      size_t buffsize = sizeof(char) * height * width;
      char *destBuf = (char *) malloc(buffsize);
			
      for(int row = 0; row < height; row++)
	{
	  bool interpolateRight = (row == height - 1 || row % 2 == 0);
	  int rowVal = row * width;
	  for(int col = 0; col < width; col++)
	    {
	      if(interpolateRight)
		{
		  if(col % 2 == 1)
		    destBuf[rowVal + col] = .5 * (imgBuf[rowVal + col - 1] 
						  + imgBuf[rowVal + col + 1]);
		  else
		    destBuf[rowVal + col] = imgBuf[rowVal + col];
		}
	      else
		{
		  if(col % 2 == 0)
		    destBuf[rowVal + col] = .5 * (imgBuf[rowVal - width + col] 
					      + imgBuf[rowVal + width + col]);
		  else
		    destBuf[rowVal + col] = .5 * (imgBuf[rowVal - width + col 
				      - 1] + imgBuf[rowVal + width + col + 1]);
		}
	    }
	}
      return destBuf;
    }

    int xy2pos(int x, int y, int width) {
      return x * width + y;
    }
		
  }

  // node that calls the test function in Find-Object
  $nodeclass Fork : VisualRoutinesStateNode  {
	  
    virtual void doStart() {
      $reference SurfTest::tempSavePath;
      $reference SurfTest::tempReturnPath;
      $reference SurfTest::trainDataPath;
      $reference SurfTest::keypointDataPath;
      $reference SurfTest::objectDataPath;

      $reference SurfTest::currentObject;

			
      cout << "Launching SURF engine..." << endl;
			
      pid_t child_id = fork();

      // if in the child, execute the test program
      if ( child_id == 0 )
	{
	  char* tekrootval = getenv("TEKKOTSU_ROOT");
	  //    std::string const tekkotsuRoot = tekrootval==NULL 
	  //	? "/usr/local/Tekkotsu" : std::string(tekrootval);
	  std::string const tekkotsuRoot =  "";
	  std::string const testProgram = "find_object-test";
	  std::string const testProgramPath = tekkotsuRoot + testProgram;
	  execl(testProgramPath.c_str(), testProgram.c_str(), 
		tempSavePath.c_str(),
		objectDataPath.c_str(),
		trainDataPath.c_str(),
		keypointDataPath.c_str(), 
		"1280", "960",
		tempReturnPath.c_str(),
		currentObject.c_str(), NULL);

	  // If we get here, the execlp() failed
	  std::cerr << "ERROR: failed to launch Mary server from " 
		    << testProgramPath << std::endl 
		    << "Check that TEKKOTSU_ROOT is set properly." << std::endl;
	  _exit(0);
	}

      // if in the parent, wait for the child to finish, then display the results
      else {
	int status;
	pid_t done = wait(&status);
	readObjects(objectDataPath, tempReturnPath, currentObject);
	postStateCompletion();
      }
    }

    // read the relevant objects from the object database
    void readObjects(std::string objectDataPath, 
		     std::string tempReturnPath,
		     std::string object) {
      $reference SurfTest::findMatches;

      ifstream objectDB;
      objectDB.open(objectDataPath.c_str());
      while (!objectDB.eof()) {
	std::string row = "";
	getline(objectDB, row);
	std::istringstream iss(row);
	std::string values[3];
	std::string token;
	int i = 0;
	while(getline(iss, token, '\t'))	{
	  values[i++] = token;
	}
	string prefix = object.substr(0, object.size() - 1);
	if(object.compare("-") == 0 
	   || (object.compare(values[0]) == 0)
	   || (findMatches && strncmp(prefix.c_str(), values[0].c_str(), 
				      prefix.size()) == 0)) {
	  std::string objectId = values[1];
	  if (objectId.size() != 0) {
	    cout << "Object id: " << objectId << endl;
	    char tmp2[1000];
	    strcpy(tmp2, tempReturnPath.c_str());
	    strcat(tmp2, "/");
	    strcat(tmp2, objectId.c_str());
	    strcat(tmp2, ".txt");
	    readObjectFile(tmp2, values[0]);
	  }
	}
      }
      objectDB.close();
    }

    bool checkPoints(std::vector<float> points)
    {
      $reference SurfTest::picWidth;
      $reference SurfTest::picHeight;
      $reference SurfTest::factor;
			
      for(int count = 0; count < points.size(); count++)
	{
	  if(count % 2 == 0)
	    {
	      if(points.at(count) < - picWidth * factor/2 
		 || points.at(count) > 3* picWidth * factor /2)
		return false;
	    }
	  else
	    if(points.at(count) < -picHeight * factor/2 
	       || points.at(count) > 3*picHeight * factor/2)
	      return false;
	}
      if(!checkIntersections(points))
	return false;
      return true;
    }

    bool checkIntersections(std::vector<float> points)
    {
      Point *pt1 = new Point(points.at(0), points.at(1));
      Point *pt2 = new Point(points.at(2), points.at(3));
      Point *pt3 = new Point(points.at(4), points.at(5));
      Point *pt4 = new Point(points.at(6), points.at(7));

      if(!ptsOnSameSide(*pt3, *pt4, points.at(0), points.at(2), points.at(1),
			points.at(3)))
	return false;
      if(!ptsOnSameSide(*pt4, *pt1, points.at(2), points.at(4), points.at(3),
			points.at(5)))
	return false;
      if(!ptsOnSameSide(*pt1, *pt2, points.at(4), points.at(6), points.at(5),
			points.at(7)))
	return false;
      if(!ptsOnSameSide(*pt2, *pt3, points.at(6), points.at(0), points.at(7),
			points.at(1)))
	return false;
      return true;
    }

    bool ptsOnSameSide(const Point& p1, const Point& p2, float x1, float x2, 
		       float y1, float y2)
    {
      float dx = x2 - x1;
      float dy = y2 - y1;

      float p1val = (p1.coordY() - y1)*dx - (p1.coordX() - x1)*dy;
      float p2val = (p2.coordY() - y1)*dx - (p2.coordX() - x1)*dy;

      return (p1val>0) == (p2val>0);
    }

    void readObjectFile(char * fname, std::string name) {
      int factor = 2;
      rgb octaves[] = {rgb(127, 127, 255), rgb(255, 127, 127),
		       rgb(127, 255, 127), rgb(255, 255, 127)};
      cout << "Reading file and creating box." << endl;
      NEW_SHAPE(objects, GraphicsData, new GraphicsData(camShS));
      objects->setName(name);
      bool inBBox = false;
      ifstream pointsFile;
      pointsFile.open(fname);
      while (pointsFile.is_open() && !pointsFile.eof()) {
	std::string row = "";
	getline(pointsFile, row);
	if (row.find("##") != string::npos) {
	  inBBox = true;
	  continue;
	} 
	if (inBBox) {
	  std::istringstream iss(row);
	  std::string values[8];
	  std::string token;
	  int i = 0;
	  while(getline(iss, token, ','))	{
	    values[i++] = token;
	  }
	  float topLeftX = atof(values[0].c_str());
	  float topLeftY = atof(values[1].c_str());
	  float topRightX = atof(values[2].c_str());
	  float topRightY = atof(values[3].c_str());
	  float botRightX = atof(values[4].c_str());
	  float botRightY = atof(values[5].c_str());
	  float botLeftX = atof(values[6].c_str());
	  float botLeftY = atof(values[7].c_str());

	  std::vector<float> points;
	  points.push_back(topLeftX);
	  points.push_back(topLeftY);
	  points.push_back(topRightX);
	  points.push_back(topRightY);
	  points.push_back(botRightX);
	  points.push_back(botRightY);
	  points.push_back(botLeftX);
	  points.push_back(botLeftY);
					
	  if(checkPoints(points))	{
	    std::vector<std::pair<float,float> > rectangle;
	    rectangle.push_back(std::pair<float,float>(topLeftX/factor, 
						       topLeftY/factor));
	    rectangle.push_back(std::pair<float,float>(topRightX/factor, 
						       topRightY/factor));
	    rectangle.push_back(std::pair<float,float>(botRightX/factor, 
						       botRightY/factor));
	    rectangle.push_back(std::pair<float,float>(botLeftX/factor, 
						       botLeftY/factor));
	    objects->add(new GraphicsData::PolygonElement(rectangle, true, 
							  rgb(0, 255, 0)));
	  }
	  else
	    cout << "Not creating impossible bounding box." << endl;
	} else {
	  std::istringstream iss(row);
	  std::string values[8];
	  std::string token;
	  int i = 0;
	  while(getline(iss, token, ','))	{
	    values[i++] = token;
	  }
	  if ((values[0].compare("in") != 0) 
	      && (values[0].compare("out") != 0)) {
	    camShS.deleteShape(objects);
	    continue;
	  }
	  boolean inlier = (values[0].compare("in") == 0);
	  float x = atof(values[1].c_str());
	  float y = atof(values[2].c_str());
	  float size = atof(values[3].c_str())*1.2/9.*2;
	  float angle = atof(values[4].c_str());
	  float response = atof(values[5].c_str());
	  int octave = atoi(values[6].c_str());
	  float classId = atof(values[7].c_str());
	  std::pair<float,float> p(x/factor,y/factor);
	  objects->add(new GraphicsData::CircleElement(p, size, false, 
						       octaves[octave - 1]));
	  float r = size;
	  float xp = r * cos(deg2rad(angle)) + x;
	  float yp = r * sin(deg2rad(angle)) + y;
	  std::pair<float,float> t(xp/factor,yp/factor);
	  if (inlier) {
	    objects->add(new GraphicsData::LineElement(p, t, rgb(255, 0, 0)));
	  } else  {
	    objects->add(new GraphicsData::LineElement(p, t, rgb(0, 255, 255)));
	  } 	
	}
      }
      pointsFile.close();
    }
  }

  // node that waits for the user to input what class the image was
  $nodeclass WaitForCommand : StateNode {
		
    virtual void doStart() {
      cout << "Enter command (list|test [class_name])>" << endl;
      erouter->addListener(this, EventBase::textmsgEGID);
    }
      
    // upon getting a text message, enter this function
    void doEvent() {
      $reference SurfTest::currentObject;
      $reference SurfTest::findMatches;
      $reference SurfTest::objectDataPath;

      switch ( event->getGeneratorID() ) {
      case EventBase::textmsgEGID: {
	const TextMsgEvent *txtenv = 
	  dynamic_cast<const TextMsgEvent*>(event);
	std::string text = txtenv->getText();
	if (text.compare("list") == 0) {
	  ifstream myfile;
	  myfile.open(objectDataPath.c_str());
	  if(myfile.is_open())
	    {
	      int count = 0;
	      string line;
	      int pos = 0;
	      while(myfile.good())
		{
		  getline(myfile, line);
		  pos = line.find("\t");
		  string subs = line.substr(0, pos);
		  cout << subs << endl;
		  count++;
		}
	      myfile.close();
	      if(count == 1 && pos == string::npos)
		cout << "No objects in database." << endl;
	    }
	  else
	    cout << "No objects in database." << endl;
	} else if (text.compare("test") == 0) {
	  currentObject = "-";
	  cout << "Testing against all objects" << endl;
	  findMatches = false;
	  postStateSignal<command>(testing);
	} else if (text.find("test") != string::npos) {
	  int pos = text.find(" ");
	  string subs = text.substr(pos + 1);
	  cout << "Testing against " << subs << endl;
	  currentObject = subs;
	  findMatches = (currentObject.at(currentObject.size() - 1) == '*');
	  postStateSignal<command>(testing);
	} else {
	  cout << "Unexpected command: " << text << endl;
	  findMatches = false;
	  postStateSignal<command>(cont);
	}
	break;
      }
      default: {
	cout << "Unexpected event: " << event->getGeneratorID() << endl;
      }         
      }
    }
  }

  $nodeclass Clear : VisualRoutinesStateNode : doStart{
    camSkS.clear(); camShS.clear();
    postStateCompletion();
  }


  virtual void doStart() {
    cout << "Type test to find objects. " << endl;
  }

  // save an image, wait for user input of what class the image is,
  // run the train function (which saves the features to a file),
  // then ask the user whether he wants to finish or take more images
  $setupmachine{
    Initialize =C=> wait
      wait: WaitForCommand
      wait =S<command>(testing)=> Clear =C=> SaveImage =C=> Fork =C=> wait
      wait =S<command>(cont)=> wait
      }
}

REGISTER_BEHAVIOR(SurfTest);
