#ifndef FINDAGENTS_H_
#define FINDAGENTS_H_

#include <cmath>
#include "Behaviors/StateMachine.h"
#include "DualCoding/VRmixin.h"

#include "Shared/fmatSpatial.h"

#include "project/ColorHist.h"

const char *agent_names[] = {
    "blue", //0
    "green",
    "yellow",
    "black",
};

enum RobotSide {
    FRONT,
    BACK,
    LEFT,
    RIGHT
};

static float sign(float x) {
    return copysignf(1.0f, x);
}

static bool compareCircleSize(const Shape<EllipseData> &a, const Shape<EllipseData> &b) {
    float radius_diff = a->getSemimajor() - b->getSemimajor();
    return fabs(radius_diff) < 2;
}

static const int DISTANCE_LIMIT = 200;
bool compareCircleDist(const Point &a, const Point &b) {
    return fabs(a.coordX() - b.coordX()) < DISTANCE_LIMIT;
}

static const int VLEVEL_THRESHOLD = 3;
static const float HLEVEL_THRESHOLD = 0.6f;
static bool circlesVerticallyAligned(const Shape<EllipseData> &a, const Shape<EllipseData> &b) {
    return fabs((a->getCentroid() - b->getCentroid()).atanYX()) < HLEVEL_THRESHOLD;
}

static bool circlesHorizontallyAligned(const Shape<EllipseData> &a, const Shape<EllipseData> &b) {
    return fabs(a->getCentroid().coordY() - b->getCentroid().coordY()) < HLEVEL_THRESHOLD;
}

$nodeclass FindAgents {
    $provide std::vector<int> neighbor_density;
    $provide std::vector<Point> searchPoints;

    $nodeclass FindCirclesManual(bool snap_circles=true) : doStart {
        $reference FindAgents::neighbor_density;
        neighbor_density.clear();
        camShS.clear();
        NEW_SKETCH(cam_frame, uchar, sketchFromSeg());
        NEW_SKETCH(red_stuff, bool, visops::colormask(cam_frame, "red"));
        NEW_SKETCH(blobs, usint, visops::labelcc(red_stuff));
        NEW_SKETCH(labeli, bool, visops::zeros(cam_frame));
        NEW_SHAPEVEC(circles, EllipseData, 0);

        for(unsigned int i = 1; i < blobs->max(); ++i) {
            labeli = blobs == i;
            int count = 0;
            int x_sum = 0;
            int y_sum = 0;
            int min_index = labeli->findMinPlus();
            int xstart = min_index % labeli.width;
            int ystart = min_index / labeli.width;
            int lbound = labeli.width;
            int rbound = xstart;
            int tbound = ystart;
            int bbound = ystart;
            for(int y = ystart; y < labeli.height; ++y) {
                for(int x = 0; x < labeli.width; ++x) {
                    if(labeli(x, y)) {
                        if(x < lbound) {
                            lbound = x;
                        } else if(x > rbound) {
                            rbound = x;
                        }
                        if(y > bbound) {
                            bbound = y;
                        }
                        x_sum += x;
                        y_sum += y;
                        ++count;
                    }
                }
            }
            if(count == 0) {
                continue;
            }

            float major;
            float minor;
            if(bbound - tbound > rbound - lbound) {
                major = bbound - tbound;
                minor = rbound - lbound;
            } else {
                major = rbound - lbound;
                minor = bbound - tbound;
            }
            ++major;
            ++minor;

            float center_x = x_sum/((float)count);
            float center_y = y_sum/((float)count);
            //aspect ratio filter
            if(major/minor > 3) {
                cout << i << ": aspect ratio " << major/minor << endl;
                continue;
            }
            //area filter
            if((major*minor)/count > 2) {
                cout << i << ": area " << (major*minor)/count << endl;
                continue;
            }
            if(snap_circles) {
                minor = major;
            }
            NEW_SHAPE(circle, EllipseData, new EllipseData(camShS,
                                                Point(center_x, center_y, 0),
                                                major/2, minor/2));
            circle->setColor("red");
            circles.push_back(circle);
        }

        const int deltaX = 60;
        const int deltaY = 60;

        SHAPEVEC_ITERATE(circles, EllipseData, circle) {
            neighbor_density.push_back(0);
            Point centroid = circle->getCentroid();
            float center_x = centroid.coordX();
            float center_y = centroid.coordY();
            float left = max(center_x - deltaX, 0.0f);
            float right = min(center_x + deltaX, (float)blobs.width - 1);
            float top = max(center_y - deltaY, 0.0f);
            float bottom = min(center_y + deltaY, (float)blobs.height - 1);

            SHAPEVEC_ITERATE(circles, EllipseData, neighbor) {
                Point neighbor_centroid = neighbor->getCentroid();
                float neighbor_x = neighbor_centroid.coordX();
                float neighbor_y = neighbor_centroid.coordY();
                if(neighbor_x >= left && neighbor_x <= right && neighbor_y >= top && neighbor_y <= bottom) {
                    ++neighbor_density.back();
                }
            } END_ITERATE;
        } END_ITERATE;
        postStateCompletion();
    }

    $nodeclass FindAgentsBelt : doStart {
        $reference FindAgents::neighbor_density;

        const float robot_radius = VRmixin::theAgent->getBoundingBoxHalfDims()[1];
        NEW_SHAPEVEC(circles, EllipseData, select_type<EllipseData>(camShS));
        cout << "num circles: " << circles.size() << endl;

        enum RobotSide robot_side;
        vector<Shape<AgentData> > possible_agents;

        PlaneEquation above_ground(0, 0, 1, 65); //when on ground
        //PlaneEquation above_ground(0, 0, 1, 65 - 737); //when on table
        const fmat::Transform cam_transform = kine->linkToBase(CameraFrameOffset);
        NEW_SKETCH(yuv_frame, yuv, VRmixin::sketchFromYUV());

        for(unsigned int i = 0; i < neighbor_density.size(); ++i) {
            cout << "density: " << circles[i]->getId() << " - " << neighbor_density[i] << endl;
        }

        const int DENSITY_THRESHOLD = 6;
        for(unsigned int i = 0; i < circles.size(); ++i) {
            if(neighbor_density[i] > DENSITY_THRESHOLD) {
                continue;
            }
            for(unsigned int j = i + 1; j < circles.size(); ++j) {
                if(!circlesHorizontallyAligned(circles[i], circles[j])) {
                    //cout << circles[i]->getId() << " and " << circles[j]->getId() <<
                        //" not horizontally aligned." << endl;
                    continue;
                }
                if(!compareCircleSize(circles[i], circles[j])) {
                    //cout << circles[i]->getId() << " and " << circles[j]->getId() <<
                        //" not similar size." << endl;
                    continue;
                }
                for(unsigned int k = 0; k < circles.size(); ++k) {
                    if(k == i || k == j) {
                        continue;
                    }
                    unsigned int v, h;
                    if(circlesVerticallyAligned(circles[i], circles[k])) {
                        v = i;
                        h = j;
                    } else if(circlesVerticallyAligned(circles[j], circles[k])) {
                        v = j;
                        h = i;
                    } else {
                        //cout << "neither " << circles[i]->getId() << " nor " << circles[j]->getId() <<
                            //" is vertically aligned with " << circles[k]->getId() << "." << endl;
                        continue;
                    }
                    if(!compareCircleSize(circles[i], circles[k]) &&
                        !compareCircleSize(circles[j], circles[k])) {
                        //cout << circles[i]->getId() << ", " << circles[j]->getId() << " and " << 
                            //circles[k]->getId() << " not similar size." << endl;
                        continue;
                    }
                    float radius = max(circles[i]->getSemimajor(),
                                       max(circles[j]->getSemimajor(),
                                           circles[k]->getSemimajor()));

                    if(fabs(circles[i]->getCentroid().coordX() - circles[j]->getCentroid().coordX()) >= 7*radius) {
                        //cout << circles[i]->getId() << " too far in x from " << circles[j]->getId() << endl;
                        continue;
                    }
                    if(fabs(circles[v]->getCentroid().coordY() - circles[k]->getCentroid().coordY()) >= 7*radius) {
                        //cout << circles[v]->getId() << " too far in y from " << circles[k]->getId() << endl;
                        continue;
                    }

                    Point pi = circles[i]->getCentroid();
                    Point pj = circles[j]->getCentroid();
                    Point pk = circles[k]->getCentroid();

                    if(!pi.projectToGround(cam_transform, above_ground)) {
                        cout << "project to ground failed." << endl;
                        break;
                    }
                    if(!pj.projectToGround(cam_transform, above_ground)) {
                        cout << "project to ground failed." << endl;
                        break;
                    }
                    if(!pk.projectToGround(cam_transform, above_ground)) {
                        cout << "project to ground failed." << endl;
                        continue;
                    }
                    Point pv = circles[v]->getCentroid();
                    Point ph = circles[h]->getCentroid();

                    Shape<PointData> pishape(localShS, pi);
                    Shape<PointData> pjshape(localShS, pj);
                    localShS.importShape(pishape);
                    localShS.importShape(pjshape);

                    if(!compareCircleDist(pi, pj) || !compareCircleDist(pj, pk)) {
                        cout << circles[i]->getCentroid() << endl;;
                        cout << pi << endl;
                        cout << circles[j]->getCentroid() << endl;;
                        cout << pj << endl;
                        continue;
                    }

                    circles[i]->setLandmark();
                    circles[j]->setLandmark();
                    circles[k]->setLandmark();

                    cout << "v: " << circles[v]->getId() << endl;
                    cout << "h: " << circles[h]->getId() << endl;
                    cout << "k: " << circles[k]->getId() << endl;

                    /*
                    doColorHist(circles[i]->getCentroid(),
                                circles[j]->getCentroid(),
                                circles[k]->getCentroid());
                    printIdColor(circles[i]->getCentroid(), circles[j]->getCentroid());
                    */
                    int robot_id = getIdFromCorner(yuv_frame,
                                                   circles[k]->getCentroid(),
                                                   circles[v]->getCentroid(),
                                                   circles[h]->getCentroid());
                    cout << "robot id: " << robot_id << endl;

                    if(circles[k]->getCentroid().coordY() < pv.coordY()) {
                        if(pv.coordX() < ph.coordX()) { //top right missing
                            robot_side = FRONT;
                        } else { //top left missing
                            robot_side = LEFT;
                        }
                    } else {
                        if(pv.coordX() < ph.coordX()) { //bottom right missing
                            robot_side = BACK;
                        } else { //bottom left missing
                            robot_side = RIGHT;
                        }
                    }

                    Point diff = (pi - pj);
                    Point mid = pi + diff/2;
                    fmat::Matrix<3,3,float> rotate = fmat::rotationZ(-sign(diff.coordY())*M_PI/2);
                    fmat::Column<3,float> robot_center = robot_radius*rotate*diff.unitVector().getCoords();
                    Point center(0, 0, 0, egocentric);
                    center.setCoords(robot_center);
                    AngTwoPi orient_to_dots((float)center.atanYX());
                    center += mid;

                    AngTwoPi orient;
                    switch(robot_side) {
                    case FRONT:
                        cout << "saw front: ";
                        orient = (float)orient_to_dots + M_PI;
                        cout << orient_to_dots << ", " << orient << endl;
                        break;
                    case BACK:
                        cout << "saw back" << endl;
                        orient = (float)orient_to_dots;
                        break;
                    case LEFT:
                        cout << "saw left: ";
                        orient = (float)orient_to_dots + M_PI/2;
                        cout << orient_to_dots << ", " << orient << endl;
                        break;
                    case RIGHT:
                        cout << "saw right" << endl;
                        orient = (float)orient_to_dots + 3*M_PI/2;
                        break;
                    }

                    Shape<AgentData> new_agent(localShS, center);
                    new_agent->setOrientation(orient);
                    new_agent->setName(agent_names[robot_id]);
                    localShS.importShape(new_agent);
                }
            }
        }
    }

    $nodeclass FindAgentsHat : doStart {
        PlaneEquation hat_from_table(0, 0, 1, -1320); //when on table
        //const fmat::Transform cam_transform = kine->linkToBase(CameraFrameOffset);
        const fmat::Transform cam_transform = fmat::Transform::aboutY(-0.6981317007977318);
        //TODO check cam_transform of aboutY vs linkToBase
        cout << "cam_transform" << endl;
        cout << cam_transform << endl << endl;
        NEW_SKETCH(cam_frame, yuv, VRmixin::sketchFromYUV());

        NEW_SHAPEVEC(ellipses, EllipseData, select_type<EllipseData>(camShS));
        NEW_SHAPEVEC(hat_ellipses, EllipseData, 0);
        SHAPEVEC_ITERATE(ellipses, EllipseData, hat_ellipse) {
            Shape<EllipseData> hat_ellipse_copy(new EllipseData(*hat_ellipse));
            hat_ellipse_copy->projectToGround(cam_transform, hat_from_table);
            cout << hat_ellipse->getId() << ": " << hat_ellipse_copy->getCentroid() << endl;
            hat_ellipses.push_back(hat_ellipse_copy);
        } END_ITERATE;

        const float DIST_RATIO_THRESHOLD = 0.50f;
        const float AREA_RATIO_THRESHOLD = 0.50f;
        const float MID_RATIO_THRESHOLD = 0.20f;
        for(unsigned int i = 0; i < hat_ellipses.size(); ++i) {
            for(unsigned int j = i + 1; j < hat_ellipses.size(); ++j) {
                Point i_centroid = hat_ellipses[i]->getCentroid();
                Point j_centroid = hat_ellipses[j]->getCentroid();
                float ij_dist = (i_centroid - j_centroid).xyzNorm();
                float area_i = ellipses[i]->getArea();
                float area_j = ellipses[i]->getArea();
                float area_max = max(area_i, area_j);
                if(area_i/area_max < AREA_RATIO_THRESHOLD || area_j/area_max < AREA_RATIO_THRESHOLD) {
                        cout << ellipses[i]->getId() << ", " << ellipses[j]->getId() <<
                             " passed on size check" << endl;
                    continue;
                }
                for(unsigned int k = 0; k < hat_ellipses.size(); ++k) {
                    if(k == i || k == j) {
                        continue;
                    }
                    Point k_centroid = hat_ellipses[k]->getCentroid();
                    float jk_dist = (k_centroid - j_centroid).xyzNorm();
                    float ik_dist = (k_centroid - i_centroid).xyzNorm();

                    //check distances between all three pairs
                    //TODO maybe also check ratio to robot size
                    float max_dist = max(max(ij_dist, jk_dist), ik_dist);
                    if(ij_dist/max_dist < DIST_RATIO_THRESHOLD ||
                        jk_dist/max_dist < DIST_RATIO_THRESHOLD ||
                        ik_dist/max_dist < DIST_RATIO_THRESHOLD) {
                        cout << ellipses[i]->getId() << ", " << ellipses[j]->getId() << ", " <<
                            ellipses[k]->getId() << " passed on equilateral check" << endl;
                        continue;
                    }
                    //check that areas are similar
                    float area_k = ellipses[k]->getArea();
                    if(area_k/area_max < AREA_RATIO_THRESHOLD) {
                        cout << ellipses[i]->getId() << ", " << ellipses[j]->getId() << ", " <<
                            ellipses[k]->getId() << " passed on size check" << endl;
                        continue;
                    }
                    //check for midpoint circle
                    //TODO stricter midpoint checks
                    Point midpoint = (i_centroid + j_centroid)/2;
                    unsigned int mid;
                    for(mid = 0; mid < hat_ellipses.size(); ++mid) {
                        if(mid == i || mid == j || mid == k) {
                            continue;
                        }
                        Point mid_centroid = hat_ellipses[mid]->getCentroid();
                        float mid_dist = (mid_centroid - midpoint).xyzNorm();
                        if(mid_dist/ij_dist < MID_RATIO_THRESHOLD &&
                            ellipses[mid]->getArea()/area_max >= AREA_RATIO_THRESHOLD) {
                            break;
                        }
                    }
                    if(mid == hat_ellipses.size()) {
                        /*
                        cout << ellipses[i]->getId() << ", " << ellipses[j]->getId() << ", " <<
                            ellipses[k]->getId() << " passed on midpoint check" << endl;
                            */
                        continue;
                    }

                    NEW_SHAPE(ipoint_shape, PointData, new PointData(localShS, hat_ellipses[i]->getCentroid()));
                    NEW_SHAPE(jpoint_shape, PointData, new PointData(localShS, hat_ellipses[j]->getCentroid()));
                    NEW_SHAPE(kpoint_shape, PointData, new PointData(localShS, hat_ellipses[k]->getCentroid()));
                    NEW_SHAPE(midpoint_shape, PointData, new PointData(localShS, hat_ellipses[mid]->getCentroid()));

                    localShS.importShape(ipoint_shape);
                    localShS.importShape(jpoint_shape);
                    localShS.importShape(kpoint_shape);
                    localShS.importShape(midpoint_shape);

                    cout << ellipses[i]->getId() << " " << ellipses[j]->getId() << " " <<
                            ellipses[k]->getId() << " " << ellipses[mid]->getId() << endl;
                    //check id in center
                    int robot_id = getIdFromCenter(cam_frame,
                                                   ellipses[i]->getCentroid(),
                                                   ellipses[j]->getCentroid(),
                                                   ellipses[k]->getCentroid());
                    cout << "robot id: " << robot_id << endl;

                    Point center(midpoint.coordX(), midpoint.coordY(), 0);
                    Point dir = hat_ellipses[k]->getCentroid() - hat_ellipses[mid]->getCentroid();
                    AngTwoPi orientation = (float)dir.atanYX();

                    Shape<AgentData> new_agent(localShS, center);
                    new_agent->setOrientation(orientation);
                    new_agent->setName(agent_names[robot_id]);
                    localShS.importShape(new_agent);
                }
            }
        }
    }

    $nodeclass FilterAgentOverlap : doStart {
        const int CENTER_DIST = 100;
        const int ORIENT_DIST = 1.0f;
        NEW_SHAPEVEC(possible_agents, AgentData, select_type<AgentData>(localShS));
        NEW_SHAPEVEC(existing_agents, AgentData, select_type<AgentData>(worldShS));
        vector<ShapeRoot> to_delete;
        for(unsigned int i = 0; i < possible_agents.size(); ++i) {
            bool is_unique = true;
            unsigned int k;
            const std::string &name = possible_agents[i]->getName();
            const Point &centroid_i = possible_agents[i]->getCentroid();
            const float orient_i = (float)possible_agents[i]->getOrientation();
            for(unsigned int j = 0; j < possible_agents.size(); ++j) {
                if(i == j) {
                    continue;
                }
                float center_dist = (centroid_i - possible_agents[j]->getCentroid()).xyzNorm();
                AngTwoPi orient_dist = orient_i - (float)possible_agents[j]->getOrientation();
                if((center_dist <= CENTER_DIST && orient_dist <= ORIENT_DIST) ||
                    name == possible_agents[j]->getName()) {
                    is_unique = false;
                    k = j;
                    break;
                }
            }
            if(!is_unique && k < i) {
                to_delete.push_back(possible_agents[i]);
                continue;
            }
            cout << "adding agent... " << name << endl;

            SHAPEVEC_ITERATE(existing_agents, AgentData, agent) {
                if(agent->getName() == name) {
                    worldShS.deleteShape(agent);
                }
            } END_ITERATE;

            VRmixin::mapBuilder->importLocalShapeToWorld(possible_agents[i]);
        }
        localShS.deleteShapes(to_delete);
    }

    $nodeclass ClearLocal : doStart {
        localShS.clear();
    }

    $nodeclass CheckSearchPoints : doStart {
        $reference FindAgents::searchPoints;
        if(searchPoints.empty()) {
            postStateFailure();
            return;
        } else {
            postStateSuccess();
            return;
        }
    }

    $nodeclass LookOptional : HeadPointerNode : doStart {
        $reference FindAgents::searchPoints;
        if(searchPoints.empty()) {
            postStateCompletion();
            return;
        }
        getMC()->lookAtPoint(searchPoints[0]);
        searchPoints.erase(searchPoints.begin());
    }

    $nodeclass Done : doStart {
        postStateCompletion();
    }

    $setupmachine {
        s0: ClearLocal =N=> loop
        //timer transition to avoid motion blur from gaze points
        loop: LookOptional =C=> StateNode =T(500)=> FindCirclesManual =C=>
              FindAgentsBelt =N=> FilterAgentOverlap =N=> check_points
        check_points: CheckSearchPoints
        check_points =S=> loop
        check_points =F=> Done
        //s0: FindCirclesManual(false) =C=> FindAgentsHat =N=> FilterAgentOverlap =B(PlayButOffset)=> s0;
    }
}

$nodeclass FindAgentsSub : FindAgents : doStart {
    searchPoints.push_back(Point(1000, -1000, 0));
    searchPoints.push_back(Point(1000, 1000, 0));
}

REGISTER_BEHAVIOR(FindAgents);
REGISTER_BEHAVIOR(FindAgentsSub);
#endif
