/** @file StereoVision.h.fsm 
 * 
 *  @brief Defines a depth mapping algorithm and behavior for the Chiara.
 *  To use the behavior defined in this file, include StereoVisionBehavior as
 *  a menu item.
 *
 *  @author Ian Lenz (ilenz)
 *  @author David Klionsky (dklionsk)
 *  @author Chandrasekhar Bhagavatula (cbhagava)
 */

#include "Behaviors/StateMachine.h"
#include "DualCoding/DualCoding.h"
#include <vector>

// Set the colors to be included in the mask here.
// All other colors will be ignored by the depth mapper.
#define BLUE 0
#define GREEN 1
#define PINK 0
#define ORANGE 0

// Minimum size for a connected component in the mask.
// Anything smaller is considered noise and ignored. 
#define NOISE_THRESH 50

using namespace DualCoding;

/**
 * @class TakeLeftPicture
 * @brief Saves images for the left camera position.
 */
class TakeLeftPicture : public VisualRoutinesStateNode {
    public:
    TakeLeftPicture() : VisualRoutinesStateNode("TakeLeftPicture") {}
       
    void DoStart() {
        NEW_SKETCH(leftPic, yuv, sketchFromYUV());
        NEW_SKETCH(segLeft, uchar, sketchFromSeg());
        leftPic->retain();
        segLeft->retain();
    }
};

/**
 * @class TakeRightPicture
 * @brief Saves images for the right camera position.
 */
class TakeRightPicture : public VisualRoutinesStateNode {
    public:
    TakeRightPicture() : VisualRoutinesStateNode("TakeRightPicture") {}
 
    void DoStart() { 
        NEW_SKETCH(rightPic, yuv, sketchFromYUV());
        rightPic->retain();
    }
};

/**
 * @class LookAhead
 * @brief Points the head straight ahead.
 */
class LookAhead : public HeadPointerNode {
    public:
    LookAhead() : HeadPointerNode("LookAhead") {}

    void DoStart() {
        getMC()->setJoints(0,0,0);    
    }
};

/**
 * @class LegsUp
 * @brief Rests the robot on its belly.
 */
class LegsUp : public XWalkNode {
    public:
    LegsUp(float const groundOffset=60) : XWalkNode("LegsUp"), 
        _groundOffset(groundOffset) {}
    virtual void DoStart() {
        getMC()->groundPlane[3] = _groundOffset;
    }
    float _groundOffset;
};

/**
 * @class LegsDown
 * @brief Raises the robot back on its legs after a call to LegsUp.
 */
class LegsDown : public XWalkNode {
    public:
    LegsDown(float const groundOffset=-10) : XWalkNode("LegsDown"), 
        _groundOffset(groundOffset) {}
    virtual void DoStart() {
        getMC()->groundPlane[3] = _groundOffset;
    }
    float _groundOffset;
};

/**
 * @class AnalyzePictures
 * @brief Calculates disparities and generates a depth map.
 * TakeLeftPicture and TakeRightPicture are assumed to have run before this
 * class, since they create the sketches used by this class to calculate
 * disparities.
 */
class AnalyzePictures : public VisualRoutinesStateNode {
    public:
    AnalyzePictures() : VisualRoutinesStateNode("AnalyzePictures") {}
  
    void DoStart() {

        // retrieve sketches made by TakeLeftPicture and TakeRightPicture
        GET_SKETCH(leftPic, yuv, camSkS);
        GET_SKETCH(rightPic, yuv, camSkS);
        GET_SKETCH(segLeft, uchar, camSkS);
        
        // use only user-specified colors to make the mask
        std::vector< Sketch<uchar> > segmentedPics;
        if(GREEN) {
            NEW_SKETCH(greenStuff, bool, visops::colormask(segLeft, "green"));
            segmentedPics.push_back(greenStuff);
        }
        if(BLUE) {
            NEW_SKETCH(blueStuff, bool, visops::colormask(segLeft, "blue"));
            segmentedPics.push_back(blueStuff);
        }
        if(PINK) {
            NEW_SKETCH(pinkStuff, bool, visops::colormask(segLeft, "pink"));
            segmentedPics.push_back(pinkStuff);
        }
        if(ORANGE) {
            NEW_SKETCH(orangeStuff, bool, visops::colormask(segLeft, "orange"));
            segmentedPics.push_back(orangeStuff);
        }
        
        Sketch<bool> allStuff(camSkS, "allStuff");
				allStuff = visops::zeros(allStuff);
        allStuff->setViewable();

        // make the mask by combining the segmented color images
        for(int i = 0; i < (int)segmentedPics.size(); i++) {
            allStuff |= visops::minArea(segmentedPics[i], NOISE_THRESH);
        }

        // generate a depth map and a scaled depth map
        NEW_SKETCH(result, uint, getDisparities(leftPic,rightPic,allStuff,11,30));
        NEW_SKETCH(scaledResult, uint, scaleDisparityMap(result));

        // allow garbage collector to erase these sketches
        leftPic->retain(false);
        segLeft->retain(false);
        rightPic->retain(false);

        printf("Done.\n");
    }

    /** 
     * @brief Scales a disparity map to span the full color range.
     * 
     * @param dispMap YUV sketch of disparities.
     * @return The input sketch with the range of color increased.
     */
    Sketch<uint> scaleDisparityMap(Sketch<uint> dispMap)
    {
        uint max = dispMap->max();
        float coeff = 511.0/max;
        Sketch<uint> result (dispMap * coeff, "scaled", false);
        return result;
    }

    /** 
     * @brief Generates the disparity map between the two input sketches.
     * Algorithm finds a matching pixel in the right image for each pixel in the
     * left image masked by the mask sketch, and uses the disparity between the 
     * two pixel positions to color the output image. 
     *
     * @param sk_l YUV sketch of the left image.
     * @param sk_r YUV sketch of the right image.
     * @param mask Boolean mask indicating which pixels in the left image to consider.
     * @param w Width of the window used to calculate SSD minimum error.
     * @param buf Number of pixels from the left side of the right image to start searching.
     * @return YUV sketch of the disparity for every pixel.
     */
    Sketch<uint> getDisparities(Sketch<yuv> sk_l, Sketch<yuv> sk_r, Sketch<bool> mask, int w, int buf)
    {
        NEW_SKETCH(result, uint, visops::zeros(sk_l));
        int ht = sk_l.height;
        int wd = sk_l.width;
        int sz = (w-1)/2;

        // Loop over every pixel of the lefthand image (except the buffer space)
        for(int i = sz; i < ht - sz; i++)
        {
            int rowInd = i * wd;
            for(int j = buf; j < wd - sz; j++)
            {
                if(mask[rowInd + j])
                {
                    // Pull out the window we want to look at in the lefthand image
                    yuv* lwin = getWindow(sk_l, i, j, w);

                    // Initialize minimum error to the maximum positive int value
                    int minErr = 1 << 20;
                    int minK = 0;
    
                    // Find the corresponding pixel in the righthand image by minimizing
                    // the error of its window and that in the lefthand image
                    // (Only look at pixels in the same row as the lefthand one, and 
                    // finish when we reach the lefthand pixel)
                    for(int k = sz; k <= j; k++)
                    {
                        int err = ssdError(lwin, sk_r, i, k, w);
                        if(err < minErr + 25)
                        {
                            minErr = err;
                            minK = k;
                        }
                    }

                    result[rowInd + j] = j - minK;
                }    
                else
                {
                    result[rowInd + j] = 0;
                }
            }

            if(i % 20 == 0)
                printf("Finished row %d\n", i);
        }
        return result;
    }
      
    /** 
     * @brief Computes the sum of squared differences error between two yuv windows
     *
     * @param yuv1 Pointer to an array of pixels corresponding to a window in the left image
     * @param sk YUV sketch of the right image.
     * @param r Row containing the center pixel in the window
     * @param c Column containing the center pixel in the window
     * @param w Width and height of the window.
     * @return The error between the two windows.
     */
    int ssdError(yuv *yuv1, Sketch<yuv> sk, int r, int c, int w)
    {
        int err = 0;
        int sz = (w - 1)/2;
        int width = sk.width;
        for(int i = 0; i < w; i++)
        {
            int rowInd = i * w;
            int skRowInd = (r + i - sz) * width;
            for(int j = 0; j < w; j++)
            {
                yuv cur1 = yuv1[rowInd + j];
                yuv cur2 = sk[skRowInd + c + j - sz];
                int yErr = cur1.y - cur2.y;
                yErr *= yErr;
                int uErr = cur1.u - cur2.u;
                uErr *= uErr;
                int vErr = cur1.v - cur2.v;
                vErr *= vErr;
                err += yErr + uErr + vErr;
            }
        }
        return err;
    }

    /** 
     * @brief Grabs a window around a pixel.
     *
     * @param sk YUV sketch.
     * @param r Row containing the center pixel in the window
     * @param c Column containing the center pixel in the window
     * @param w Width and height of the window.
     * @return Pointer to the window.
     */
    yuv* getWindow(Sketch<yuv> sk, int r, int c, int w)
    {
        yuv* result = new yuv[w * w];
        int sz = (w - 1)/2;    
        int width = sk.width;

        for(int i = 0; i < w; i++)
        {
            int resRowInd = i*w;
            int skRowInd = (r + i - sz) * width;
            for(int j = 0; j < w; j++)
            {
                result[resRowInd + j] = sk[skRowInd + c + j - sz];
            }
        }
        return result;
    }
};

/**
 * @class StereoVisionBehavior
 * @brief Defines the state machine for the behavior that generates a depth map.
 * The behavior rests the robot on its belly, takes the left picture, walks the 
 * robot a short distance to the right, rests it on its belly again, takes the
 * right picture, and then generates a depth map.
 */
class StereoVisionBehavior : public VisualRoutinesStateNode {
    public:
    StereoVisionBehavior() : VisualRoutinesStateNode("StereoVisionBehavior") {}

    virtual void setup() {
        #statemachine
        
        startnode: LegsUp() =T(2000)=> LookAhead() =T(1000)=> TakeLeftPicture()
            =T(500)=> LegsDown() =T(2000)=> XWalkNode(0,0,10,-300,0,0) 
            =C=> LegsUp() =T(2000)=> LookAhead() =T(1000)=> TakeRightPicture()
            =T(500)=> LegsDown() =T(2000)=> AnalyzePictures()

        #endstatemachine
    }
};


