#include "Behaviors/StateMachine.h"
#include "Behaviors/Services/CameraBehavior.h"
#include "Events/TextMsgEvent.h"
#include <wait.h>


$nodeclass SurfTest : VisualRoutinesStateNode {

   enum command {
		testing,
        cont
    };

	// Global variables stating relevant filename paths
	$provide std::string tempSavePath;
	$provide std::string tempReturnPath;
	$provide std::string keypointDataPath;
	$provide std::string trainDataPath;
	$provide std::string objectDataPath;
	$provide int picWidth;
	$provide int picHeight;
	$provide int factor;

	$provide std::string currentObject;
	$provide bool findMatches;
  
	// initialize all the global file paths
	$nodeclass Initialize : StateNode : doStart {
		$reference SurfTest::tempSavePath;
		$reference SurfTest::tempReturnPath;
		$reference SurfTest::keypointDataPath;
		$reference SurfTest::trainDataPath;
		$reference SurfTest::objectDataPath;
		$reference SurfTest::picWidth;
		$reference SurfTest::picHeight;
		$reference SurfTest::factor;

		$reference SurfTest::currentObject;
		$reference SurfTest::findMatches;

		picWidth = 640;
		picHeight = 480;
		factor = 2;
		
		tempSavePath = "/tmp/image.raw";
		tempReturnPath = "surf_data/";
		trainDataPath = "surf_data/trainData.dat";
		keypointDataPath = "surf_data/keypointData.dat";
		objectDataPath = "surf_data/objectDatabase.dat";
	
		currentObject = "-";

		findMatches = false;

		postStateCompletion();
	}

	// node that saves the image into a temporary file
	$nodeclass SaveImage : VisualRoutinesStateNode {
		
		virtual void doStart()
		{

			$reference SurfTest::tempSavePath;

			NEW_SKETCH(camY, uchar, sketchFromRawY());
			int width = camY->getWidth();
			int height = camY->getHeight();
			size_t buffsize = sizeof(char) * height * width;
			char *imgBuf = (char *) malloc(buffsize);
			camY->savePixels(imgBuf,buffsize);
			int factor = 2;
			char *destBuf = resample(imgBuf, width, height, factor);
			ofstream myfile;
			myfile.open(tempSavePath.c_str(), fstream::out | ios::binary);
			myfile.write(destBuf, buffsize * factor * factor);
			myfile.close();
			cout << "saved to " << tempSavePath << ". w: " << width << ", h: " << height  << endl;
			postStateCompletion();
		}
		
		char * resample(char * imgBuf, int width, int height, int factor) {
			size_t buffsize = sizeof(char) * height * factor * width * factor;
			char *destBuf = (char *) malloc(buffsize);
			for (int i = 0; i < height; i++) {
				for (int j = 0; j < width; j++) {
					int pos = xy2pos(i, j, width);
					char val = imgBuf[pos];
					int destWidth = (width * factor);
					int destPos = 
						xy2pos((i * factor),(j * factor), destWidth);
					for (int k = 0; k < factor; k++) {
						for (int l = 0; l < factor; l++) {
							int resampleRowOffset = k * destWidth;
							destBuf[destPos + (resampleRowOffset + l)] = val;
						}					
					}
				}
			}
			return destBuf;
		}

		char *interpolate(char *imgBuf, int width, int height, int factor) {
			size_t buffsize = sizeof(char) * height * width;
			char *destBuf = (char *) malloc(buffsize);
			
			for(int row = 0; row < height; row++)
			{
				bool interpolateRight = (row == height - 1 || row % 2 == 0);
				int rowVal = row * width;
				for(int col = 0; col < width; col++)
				{
					if(interpolateRight)
					{
						if(col % 2 == 1)
							destBuf[rowVal + col] = .5 * (imgBuf[rowVal + col - 1] + imgBuf[rowVal + col + 1]);
						else
							destBuf[rowVal + col] = imgBuf[rowVal + col];
					}
					else
					{
						if(col % 2 == 0)
							destBuf[rowVal + col] = .5 * (imgBuf[rowVal - width + col] + imgBuf[rowVal + width + col]);
						else
							destBuf[rowVal + col] = .5 * (imgBuf[rowVal - width + col - 1] + imgBuf[rowVal + width + col + 1]);
					}
				}
			}
			return destBuf;
		}

		int xy2pos(int x, int y, int width) {
			return x * width + y;
		}
		
	}

	// node that calls the test function in Find-Object
	$nodeclass Fork : VisualRoutinesStateNode  {
	  
		virtual void doStart() {
			$reference SurfTest::tempSavePath;
			$reference SurfTest::tempReturnPath;
			$reference SurfTest::trainDataPath;
			$reference SurfTest::keypointDataPath;
			$reference SurfTest::objectDataPath;

			$reference SurfTest::currentObject;

			
			cout << "Launching SURF engine..." << endl;
			
			pid_t child_id = fork();

			// if in the child, execute the test program
			if ( child_id == 0 )
			{
				char* tekrootval = getenv("TEKKOTSU_ROOT");
				//    std::string const tekkotsuRoot = tekrootval==NULL 
				//	? "/usr/local/Tekkotsu" : std::string(tekrootval);
				std::string const tekkotsuRoot =  "";
				std::string const testProgram = "find_object-test";
				std::string const testProgramPath = tekkotsuRoot + testProgram;
				execl(testProgramPath.c_str(), testProgram.c_str(), 
					  tempSavePath.c_str(),
					  objectDataPath.c_str(),
					  trainDataPath.c_str(),
					  keypointDataPath.c_str(), 
					  "1280", "960",
					  tempReturnPath.c_str(),
					  currentObject.c_str(), NULL);

				// If we get here, the execlp() failed
				std::cerr << "ERROR: failed to launch Mary server from " 
						  << testProgramPath << std::endl << "Check that TEKKOTSU_ROOT is set properly." << std::endl;
				_exit(0);
			}

			// if in the parent, wait for the child to finish, then display the results
			else {
				int status;
				pid_t done = wait(&status);
//				NEW_SHAPE(scene, GraphicsData, new GraphicsData(camShS));
//				char tmp[1000];			
//				ifstream pointsFile;
//				strcpy(tmp, tempReturnPath.c_str());
//				strcat(tmp,"/scene.txt");
//				pointsFile.open(tmp);
//				while (!pointsFile.eof()) {
//					std::string row = "";
//					getline(pointsFile, row);
//					std::istringstream iss(row);
//					std::string values[4];
//					std::string token;
//					int i = 0;
//					while(getline(iss, token, ','))	{
//						values[i++] = token;
//					}
//					float x = atof(values[0].c_str());
//					float y = atof(values[1].c_str());	
//					float size = atof(values[2].c_str())*1.2/9.*2;
//					float angle = atof(values[3].c_str());
//					std::pair<float,float> p(x,y);
//					scene->add(new GraphicsData::CircleElement(p, size, false, rgb(127,127,127)));
//					float r = size;
//					float xp = r * cos(deg2rad(angle)) + x;
//					float yp = r * sin(deg2rad(angle)) + y;
//					std::pair<float,float> t(xp,yp);
//					scene->add(new GraphicsData::LineElement(p, t, rgb(127,127,127)));
//				}
//				pointsFile.close();
				readObjects(objectDataPath, tempReturnPath, currentObject);
				postStateCompletion();
			}
		}

		// read the relevant objects from the object database
		void readObjects(std::string objectDataPath, 
						 std::string tempReturnPath,
						 std::string object) {
			$reference SurfTest::findMatches;

			ifstream objectDB;
			objectDB.open(objectDataPath.c_str());
			while (!objectDB.eof()) {
				std::string row = "";
				getline(objectDB, row);
				std::istringstream iss(row);
				std::string values[3];
				std::string token;
				int i = 0;
				while(getline(iss, token, '\t'))	{
					values[i++] = token;
				}
                    		string prefix = object.substr(0, object.size() - 1);
				if(object.compare("-") == 0 
				   || (object.compare(values[0]) == 0)
				   || (findMatches && strncmp(prefix.c_str(), values[0].c_str(), prefix.size()) == 0)) {
					std::string objectId = values[1];
					if (objectId.size() != 0) {
						cout << "Object id: " << objectId << endl;
						char tmp2[1000];
						strcpy(tmp2, tempReturnPath.c_str());
						strcat(tmp2, "/");
						strcat(tmp2, objectId.c_str());
						strcat(tmp2, ".txt");
						readObjectFile(tmp2, values[0]);
					}
				}
			}
			objectDB.close();
		}

		bool checkPoints(std::vector<float> points)
		{
			$reference SurfTest::picWidth;
			$reference SurfTest::picHeight;
			$reference SurfTest::factor;
			
			for(int count = 0; count < points.size(); count++)
			{
				if(count % 2 == 0)
				{
					if(points.at(count) < - picWidth * factor/2 
					   || points.at(count) > 3* picWidth * factor /2)
						return false;
				}
				else
					if(points.at(count) < -picHeight * factor/2 
					   || points.at(count) > 3*picHeight * factor/2)
						return false;
			}
			if(!checkIntersections(points))
				return false;
			return true;
		}

		bool checkIntersections(std::vector<float> points)
		{
//			EndPoint *oneEndPoint = new EndPoint(points.at(0), points.at(1));
//			EndPoint *twoEndPoint = new EndPoint(points.at(2), points.at(3));
//			EndPoint *threeEndPoint = new EndPoint(points.at(4), points.at(5));
//			EndPoint *fourEndPoint = new EndPoint(points.at(6), points.at(7));
//			NEW_SHAPE(lineOne, LineData, new LineData(camSkS, *oneEndPoint, *twoEndPoint));
//			NEW_SHAPE(lineTwo, LineData, new LineData(camSkS, *twoEndPoint, *threeEndPoint));
//			NEW_SHAPE(lineThree, LineData, new LineData(camSkS, *threeEndPoint, *fourEndPoint));
//			NEW_SHAPE(lineFour, LineData, new LineData(camSkS, *fourEndPoint, *oneEndPoint));
			Point *pt1 = new Point(points.at(0), points.at(1));
			Point *pt2 = new Point(points.at(2), points.at(3));
			Point *pt3 = new Point(points.at(4), points.at(5));
			Point *pt4 = new Point(points.at(6), points.at(7));

			if(!ptsOnSameSide(*pt3, *pt4, points.at(0), points.at(2), points.at(1), points.at(3)))
				return false;
			if(!ptsOnSameSide(*pt4, *pt1, points.at(2), points.at(4), points.at(3), points.at(5)))
				return false;
			if(!ptsOnSameSide(*pt1, *pt2, points.at(4), points.at(6), points.at(5), points.at(7)))
				return false;
			if(!ptsOnSameSide(*pt2, *pt3, points.at(6), points.at(0), points.at(7), points.at(1)))
				return false;
			return true;
		}

		bool ptsOnSameSide(const Point& p1, const Point& p2, float x1, float x2, float y1, float y2)
		{
			float dx = x2 - x1;
			float dy = y2 - y1;

			float p1val = (p1.coordY() - y1)*dx - (p1.coordX() - x1)*dy;
			float p2val = (p2.coordY() - y1)*dx - (p2.coordX() - x1)*dy;

			return (p1val>0) == (p2val>0);
		}

		void readObjectFile(char * fname, std::string name) {
			int factor = 2;
			rgb octaves[] = {rgb(127, 127, 255), rgb(255, 127, 127),
							 rgb(127, 255, 127), rgb(255, 255, 127)};
			cout << "Reading file and creating box." << endl;
			NEW_SHAPE(objects, GraphicsData, new GraphicsData(camShS));
			objects->setName(name);
			bool inBBox = false;
			ifstream pointsFile;
			pointsFile.open(fname);
			while (pointsFile.is_open() && !pointsFile.eof()) {
				std::string row = "";
				getline(pointsFile, row);
				if (row.find("##") != string::npos) {
					inBBox = true;
					continue;
				} 
				if (inBBox) {
					std::istringstream iss(row);
					std::string values[8];
					std::string token;
					int i = 0;
					while(getline(iss, token, ','))	{
						values[i++] = token;
					}
					float topLeftX = atof(values[0].c_str());
					float topLeftY = atof(values[1].c_str());
					float topRightX = atof(values[2].c_str());
					float topRightY = atof(values[3].c_str());
					float botRightX = atof(values[4].c_str());
					float botRightY = atof(values[5].c_str());
					float botLeftX = atof(values[6].c_str());
					float botLeftY = atof(values[7].c_str());

					std::vector<float> points;
					points.push_back(topLeftX);
					points.push_back(topLeftY);
					points.push_back(topRightX);
					points.push_back(topRightY);
					points.push_back(botRightX);
					points.push_back(botRightY);
					points.push_back(botLeftX);
					points.push_back(botLeftY);
					
					if(checkPoints(points))	{
						std::vector<std::pair<float,float> > rectangle;
						rectangle.push_back(std::pair<float,float>(topLeftX/factor, topLeftY/factor));
						rectangle.push_back(std::pair<float,float>(topRightX/factor, topRightY/factor));
						rectangle.push_back(std::pair<float,float>(botRightX/factor, botRightY/factor));
						rectangle.push_back(std::pair<float,float>(botLeftX/factor, botLeftY/factor));
						objects->add(new GraphicsData::PolygonElement(rectangle, true, rgb(0, 255, 0)));
					}
					else
						cout << "Not creating impossible bounding box." << endl;
				} else {
					std::istringstream iss(row);
					std::string values[8];
					std::string token;
					int i = 0;
					while(getline(iss, token, ','))	{
						values[i++] = token;
					}
					if ((values[0].compare("in") != 0) 
						&& (values[0].compare("out") != 0)) {
						camShS.deleteShape(objects);
						continue;
					}
					boolean inlier = (values[0].compare("in") == 0);
					float x = atof(values[1].c_str());
					float y = atof(values[2].c_str());
					float size = atof(values[3].c_str())*1.2/9.*2;
					float angle = atof(values[4].c_str());
					float response = atof(values[5].c_str());
					int octave = atoi(values[6].c_str());
					float classId = atof(values[7].c_str());
					std::pair<float,float> p(x/factor,y/factor);
					objects->add(new GraphicsData::CircleElement(p, size, false, octaves[octave - 1]));		
					float r = size;
					float xp = r * cos(deg2rad(angle)) + x;
					float yp = r * sin(deg2rad(angle)) + y;
					std::pair<float,float> t(xp/factor,yp/factor);
					if (inlier) {
						objects->add(new GraphicsData::LineElement(p, t, rgb(255, 0, 0)));
					} else  {
						objects->add(new GraphicsData::LineElement(p, t, rgb(0, 255, 255)));
					} 	
				}
			}
			pointsFile.close();
		}
	}

    // node that waits for the user to input what class the image was
    $nodeclass WaitForCommand : StateNode {
		
        virtual void doStart() {
            cout << "Enter command (test [class_name])>" << endl;
            erouter->addListener(this, EventBase::textmsgEGID);
        }
      
        // upon getting a text message, enter this function
        void doEvent() {
			$reference SurfTest::currentObject;
			$reference SurfTest::findMatches;

            switch ( event->getGeneratorID() ) {
            case EventBase::textmsgEGID: {
                const TextMsgEvent *txtenv = 
                    dynamic_cast<const TextMsgEvent*>(event);
                std::string text = txtenv->getText();
                if (text.compare("test") == 0) {
					currentObject = "-";
					cout << "Testing against all objects" << endl;
                    findMatches = false;
                    postStateSignal<command>(testing);
                } else if (text.find("test") != string::npos) {
                    int pos = text.find(" ");
                    string subs = text.substr(pos + 1);
                    cout << "Testing against " << subs << endl;
                    currentObject = subs;
                    findMatches = (currentObject.at(currentObject.size() - 1) == '*');
                    postStateSignal<command>(testing);
                } else {
                    cout << "Unexpected command: " << text << endl;
                    findMatches = false;
                    postStateSignal<command>(cont);
                }
                break;
            }
            default: {
                cout << "Unexpected event: " << event->getGeneratorID() << endl;
            }         
            }
        }
    }

	$nodeclass Clear : VisualRoutinesStateNode : doStart{
		camSkS.clear(); camShS.clear();
		postStateCompletion();
	}


	virtual void doStart() {
		cout << "Type test to find objects. " << endl;
	}

	// save an image, wait for user input of what class the image is,
	// run the train function (which saves the features to a file),
	// then ask the user whether he wants to finish or take more images
	$setupmachine{
		Initialize =C=> wait
			wait: WaitForCommand
			wait =S<command>(testing)=> Clear =C=> SaveImage =C=> Fork =C=> wait
			wait =S<command>(cont)=> wait
			}
}

	REGISTER_BEHAVIOR(SurfTest);
