#include "Behaviors/StateMachine.h"
#include "DualCoding/DualCoding.h"

using namespace DualCoding;
/*
Sketch<yuv> leftPic;
Sketch<yuv> rightPic;
Sketch<uchar> segLeft;
*/
class TakeLeftPicture : public VisualRoutinesStateNode {
    public:
    TakeLeftPicture() : VisualRoutinesStateNode("TakeLeftPicture") {}
       
    void DoStart() {
        printf("taking left picture\n");
        NEW_SKETCH(leftPic, yuv, sketchFromYUV());
        NEW_SKETCH(segLeft, uchar, sketchFromSeg());
        leftPic->retain();
        segLeft->retain();
        /*
        leftPic = sketchFromYUV();
        segLeft = sketchFromSeg();
        leftPic->setViewable();
        segLeft->setViewable();
        */
    }
};

class TakeRightPicture : public VisualRoutinesStateNode {
    public:
    TakeRightPicture() : VisualRoutinesStateNode("TakeRightPicture") {}
 
    void DoStart() { 
        printf("taking right picture\n");
        NEW_SKETCH(rightPic, yuv, sketchFromYUV());
        rightPic->retain();
        //rightPic = sketchFromYUV();
        //rightPic->setViewable();
    }
};

class TwistLeft : public XWalkNode {
    public:
    TwistLeft() : XWalkNode("TwistLeft") {}

    void DoStart() {
        printf("twisting left\n");
        getMC()->setTargetDisplacement(0,0,1); // for some reason +-.45 doesn't work
    }
};

class TwistRight : public XWalkNode {
    public:
    TwistRight() : XWalkNode("TwistRight") {}

    void DoStart() {
        printf("twisting right\n");
        getMC()->setTargetDisplacement(0,0,-1.4); // -2 makes it go too far...
    }
};

class WalkRight : public XWalkNode {
    public:
    WalkRight() : XWalkNode("WalkRight") {}

    void DoStart() {
        printf("walking right\n");
        getMC()->setTargetDisplacement(0,-300,0); 
    }
};

class LookAhead : public HeadPointerNode {
    public:
    LookAhead() : HeadPointerNode("LookAhead") {}

    void DoStart() {
        printf("looking straight ahead\n");
        getMC()->setJoints(0,0,0);    
    }
};

class LookLeft : public HeadPointerNode {
    public:
    LookLeft() : HeadPointerNode("LookLeft") {}

    void DoStart() {
        printf("looking left\n");
        getMC()->setJoints(0,.5,0);    
    }
};

class LookRight : public HeadPointerNode {
    public:
    LookRight() : HeadPointerNode("LookRight") {}

    void DoStart() {
        printf("looking right\n");
        getMC()->setJoints(0,-.5,0);
    }
};

class StandUp : public XWalkNode {
    public:
    StandUp() : XWalkNode("StandUp") {}

    void DoStart() {
        printf("standing up\n");
        // take a step to stabilize self
        getMC()->setTargetDisplacement(0,1,0); 
    }
};

class AnalyzePictures : public VisualRoutinesStateNode {
    public:
    AnalyzePictures() : VisualRoutinesStateNode("AnalyzePictures") {}
  
    void DoStart() {
        printf("Getting pictures\n");
                
        GET_SKETCH(leftPic, yuv, camSkS);
        GET_SKETCH(rightPic, yuv, camSkS);
        GET_SKETCH(segLeft, uchar, camSkS);
        
        printf("Masking colors\n");
        // TODO have user specify these colors
        /*NEW_SKETCH(orangeStuff, bool, visops::colormask(segLeft, "orange"));
        NEW_SKETCH(blueStuff, bool, visops::colormask(segLeft, "blue"));
        NEW_SKETCH(allStuff, bool, visops::minArea(orangeStuff | blueStuff, 15));*/
        NEW_SKETCH(allStuff, bool, visops::minArea(visops::colormask(segLeft, "green"),15));

        printf("Calculating disparities\n");
        NEW_SKETCH(result, uint, getDisparities(leftPic, rightPic, allStuff, 11, 30)); 
        NEW_SKETCH(scaledResult, uint, scaleDisparityMap(result));   

        leftPic->retain(false);
        segLeft->retain(false);
        rightPic->retain(false);

        printf("Done.\n");
        
    }

    Sketch<uint> scaleDisparityMap(Sketch<uint> dispMap)
    {
        uint max = dispMap->max();
        float coeff = 511.0/max;
        Sketch<uint> result (dispMap * coeff, "scaled", false);
        return result;
    }

    // Generates the disparity map between the two given sketches
    // Uses w x w windows to determine minimum error, and starts buf pixels from
    // the left side
    Sketch<uint> getDisparities(Sketch<yuv> sk_l, Sketch<yuv> sk_r, Sketch<bool> mask, int w, int buf)
    {
        NEW_SKETCH(result, uint, visops::zeros(sk_l));
        int ht = sk_l.height;
        int wd = sk_l.width;
        int sz = (w-1)/2;

        // Loop over every pixel of the lefthand image (except the buffer space)
        for(int i = sz; i < ht - sz; i++)
        {
            int rowInd = i * wd;
            for(int j = buf; j < wd - sz; j++)
            {
                if(mask[rowInd + j])
                {
                    // Pull out the window we want to look at in the lefthand image
                    yuv* lwin = getWindow(sk_l, i, j, w);

                    // Initialize minimum error to the maximum positive int value
                    int minErr = 1 << 20;
                    int minK = 0;
    
                    // Find the corresponding pixel in the righthand image by minimizing
                    // the error of its window and that in the lefthand image
                    // (Only look at pixels in the same row as the lefthand one, and 
                    // finish when we reach the lefthand pixel)
                    for(int k = sz; k <= j; k++)
                    {
                        int err = ssdError(lwin, sk_r, i, k, w);
                        if(err < minErr + 25)
                        {
                            minErr = err;
                            minK = k;
                        }
                    }

                    //printf("MinErr: %d k: %d j:%d\n", minErr, minK, j);
                    result[rowInd + j] = j - minK;
                }    
                else
                {
                    result[rowInd + j] = 0;
                }
            }

            printf("Finished row %d\n", i);
        }
        return result;
    }
      
    // Computes the sum of squared differences error between two yuv windows
    // of size w
    int ssdError(yuv *yuv1, Sketch<yuv> sk, int r, int c, int w)
    {
        int err = 0;
        int sz = (w - 1)/2;
        int width = sk.width;
        for(int i = 0; i < w; i++)
        {
            int rowInd = i * w;
            int skRowInd = (r + i - sz) * width;
            for(int j = 0; j < w; j++)
            {
                yuv cur1 = yuv1[rowInd + j];
                yuv cur2 = sk[skRowInd + c + j - sz];
                int yErr = cur1.y - cur2.y;
                yErr *= yErr;
                int uErr = cur1.u - cur2.u;
                uErr *= uErr;
                int vErr = cur1.v - cur2.v;
                vErr *= vErr;
                err += yErr + uErr + vErr;
            }
        }
        return err;
    }

    // Grabs a window of size w around the pixel in row r and column c in sk
    yuv* getWindow(Sketch<yuv> sk, int r, int c, int w)
    {
        yuv* result = new yuv[w * w];
        int sz = (w - 1)/2;    
        int width = sk.width;

        for(int i = 0; i < w; i++)
        {
            int resRowInd = i*w;
            int skRowInd = (r + i - sz) * width;
            for(int j = 0; j < w; j++)
            {
                result[resRowInd + j] = sk[skRowInd + c + j - sz];
            }
        }
        return result;
    }
};

class StereoVisionBehavior : public VisualRoutinesStateNode {
    public:
    StereoVisionBehavior() : VisualRoutinesStateNode("StereoVisionBehavior") {}

    virtual void setup() {
        #statemachine

        startnode: StandUp() =C=> LookAhead() =T(1000)=> TakeLeftPicture()
        =T(500)=> WalkRight() =C=> LookAhead() =T(1000)=> TakeRightPicture()
        =T(500)=> AnalyzePictures()

        #endstatemachine
        /*
        startnode: TwistLeft() =C=> LookRight() =T(1000)=> TakeLeftPicture() 
        =T(500)=> TwistRight() =C=> LookLeft() =T(1000)=> TakeRightPicture() 
        =T(500)=> TwistLeft() =C=> AnalyzePictures()
        */

    }
};

// notes: displacement angles and joint angles seem to be wrong for xwalk and headpointer nodes.
// had to experiment to get values that actually work. picture nodes don't post completions.
// master state machine needs to be a visual routines state node for sketches to not be erased.
// timer transitions before taking pictures prevent pictures from being blurred by the head motion.
// Should twist left after taking right picture to prevent the right middle leg from overheating.

