#include "Behaviors/StateMachine.h"
#include "DualCoding/DualCoding.h"
#include <vector>

#define BLUE 0
#define GREEN 1
#define PINK 0
#define ORANGE 0

#define NOISE_THRESH 50

using namespace DualCoding;

class TakeLeftPicture : public VisualRoutinesStateNode {
    public:
    TakeLeftPicture() : VisualRoutinesStateNode("TakeLeftPicture") {}
       
    void DoStart() {
        NEW_SKETCH(leftPic, yuv, sketchFromYUV());
        NEW_SKETCH(segLeft, uchar, sketchFromSeg());
        leftPic->retain();
        segLeft->retain();
    }
};

class TakeRightPicture : public VisualRoutinesStateNode {
    public:
    TakeRightPicture() : VisualRoutinesStateNode("TakeRightPicture") {}
 
    void DoStart() { 
        NEW_SKETCH(rightPic, yuv, sketchFromYUV());
        rightPic->retain();
    }
};

class LookAhead : public HeadPointerNode {
    public:
    LookAhead() : HeadPointerNode("LookAhead") {}

    void DoStart() {
        getMC()->setJoints(0,0,0);    
    }
};

class LegsUp : public XWalkNode {
    public:
    LegsUp(float const groundOffset=60) : XWalkNode("LegsUp"), _groundOffset(groundOffset) {}
    virtual void DoStart() {
        getMC()->groundPlane[3] = _groundOffset;
    }
    float _groundOffset;
};

class LegsDown : public XWalkNode {
    public:
    LegsDown(float const groundOffset=-10) : XWalkNode("LegsDown"), _groundOffset(groundOffset) {}
    virtual void DoStart() {
        getMC()->groundPlane[3] = _groundOffset;
    }
    float _groundOffset;
};

class AnalyzePictures : public VisualRoutinesStateNode {
    public:
    AnalyzePictures() : VisualRoutinesStateNode("AnalyzePictures") {}
  
    void DoStart() {
        GET_SKETCH(leftPic, yuv, camSkS);
        GET_SKETCH(rightPic, yuv, camSkS);
        GET_SKETCH(segLeft, uchar, camSkS);
        
        std::vector< Sketch<uchar> > segmentedPics;
        
        if(GREEN) {
            NEW_SKETCH(greenStuff, bool, visops::colormask(segLeft, "green"));
            segmentedPics.push_back(greenStuff);
        }
        if(BLUE) {
            NEW_SKETCH(blueStuff, bool, visops::colormask(segLeft, "blue"));
            segmentedPics.push_back(blueStuff);
        }
        if(PINK) {
            NEW_SKETCH(pinkStuff, bool, visops::colormask(segLeft, "pink"));
            segmentedPics.push_back(pinkStuff);
        }
        if(ORANGE) {
            NEW_SKETCH(orangeStuff, bool, visops::colormask(segLeft, "orange"));
            segmentedPics.push_back(orangeStuff);
        }
        
        Sketch<bool> allStuff(camSkS, "allStuff");
				allStuff = visops::zeros(allStuff);
        allStuff->setViewable();

        for(int i = 0; i < (int)segmentedPics.size(); i++) {
            allStuff |= visops::minArea(segmentedPics[i], NOISE_THRESH);
        }

        NEW_SKETCH(result, uint, getDisparities(leftPic, rightPic, allStuff, 11, 30)); 
        NEW_SKETCH(scaledResult, uint, scaleDisparityMap(result));   

        leftPic->retain(false);
        segLeft->retain(false);
        rightPic->retain(false);

        printf("Done.\n");
    }

    Sketch<uint> scaleDisparityMap(Sketch<uint> dispMap)
    {
        uint max = dispMap->max();
        float coeff = 511.0/max;
        Sketch<uint> result (dispMap * coeff, "scaled", false);
        return result;
    }

    // Generates the disparity map between the two given sketches
    // Uses w x w windows to determine minimum error, and starts buf pixels from
    // the left side
    Sketch<uint> getDisparities(Sketch<yuv> sk_l, Sketch<yuv> sk_r, Sketch<bool> mask, int w, int buf)
    {
        NEW_SKETCH(result, uint, visops::zeros(sk_l));
        int ht = sk_l.height;
        int wd = sk_l.width;
        int sz = (w-1)/2;

        // Loop over every pixel of the lefthand image (except the buffer space)
        for(int i = sz; i < ht - sz; i++)
        {
            int rowInd = i * wd;
            for(int j = buf; j < wd - sz; j++)
            {
                if(mask[rowInd + j])
                {
                    // Pull out the window we want to look at in the lefthand image
                    yuv* lwin = getWindow(sk_l, i, j, w);

                    // Initialize minimum error to the maximum positive int value
                    int minErr = 1 << 20;
                    int minK = 0;
    
                    // Find the corresponding pixel in the righthand image by minimizing
                    // the error of its window and that in the lefthand image
                    // (Only look at pixels in the same row as the lefthand one, and 
                    // finish when we reach the lefthand pixel)
                    for(int k = sz; k <= j; k++)
                    {
                        int err = ssdError(lwin, sk_r, i, k, w);
                        if(err < minErr + 25)
                        {
                            minErr = err;
                            minK = k;
                        }
                    }

                    //printf("MinErr: %d k: %d j:%d\n", minErr, minK, j);
                    result[rowInd + j] = j - minK;
                }    
                else
                {
                    result[rowInd + j] = 0;
                }
            }

            if(i % 20 == 0)
                printf("Finished row %d\n", i);
        }
        return result;
    }
      
    // Computes the sum of squared differences error between two yuv windows
    // of size w
    int ssdError(yuv *yuv1, Sketch<yuv> sk, int r, int c, int w)
    {
        int err = 0;
        int sz = (w - 1)/2;
        int width = sk.width;
        for(int i = 0; i < w; i++)
        {
            int rowInd = i * w;
            int skRowInd = (r + i - sz) * width;
            for(int j = 0; j < w; j++)
            {
                yuv cur1 = yuv1[rowInd + j];
                yuv cur2 = sk[skRowInd + c + j - sz];
                int yErr = cur1.y - cur2.y;
                yErr *= yErr;
                int uErr = cur1.u - cur2.u;
                uErr *= uErr;
                int vErr = cur1.v - cur2.v;
                vErr *= vErr;
                err += yErr + uErr + vErr;
            }
        }
        return err;
    }

    // Grabs a window of size w around the pixel in row r and column c in sk
    yuv* getWindow(Sketch<yuv> sk, int r, int c, int w)
    {
        yuv* result = new yuv[w * w];
        int sz = (w - 1)/2;    
        int width = sk.width;

        for(int i = 0; i < w; i++)
        {
            int resRowInd = i*w;
            int skRowInd = (r + i - sz) * width;
            for(int j = 0; j < w; j++)
            {
                result[resRowInd + j] = sk[skRowInd + c + j - sz];
            }
        }
        return result;
    }
};

class StereoVisionBehavior : public VisualRoutinesStateNode {
    public:
    StereoVisionBehavior() : VisualRoutinesStateNode("StereoVisionBehavior") {}

    virtual void setup() {
        #statemachine
        
        startnode: LegsUp() =T(2000)=> LookAhead() =T(1000)=> TakeLeftPicture()
            =T(500)=> LegsDown() =T(2000)=> XWalkNode(0,0,10,-300,0,0) =C=> LegsUp() =T(2000)=> LookAhead() =T(1000)=> TakeRightPicture()
            =T(500)=> LegsDown() =T(2000)=> AnalyzePictures()

        #endstatemachine
    }
};


