/**
***************************************************************************
* @file dlrComputerVision/threePointAlgorithm.h
*
* Header file declaring the threePointAlgorithm() function template.
*
* Copyright (C) 2009 David LaRose, dlr@cs.cmu.edu
* See accompanying file, LICENSE.TXT, for details.
*
* $Revision: $
* $Date: $
***************************************************************************
*/

#ifndef DLR_COMPUTERVISION_THREEPOINTALGORITHM_H
#define DLR_COMPUTERVISION_THREEPOINTALGORITHM_H

#include <dlrComputerVision/cameraIntrinsicsPinhole.h>
#include <dlrNumeric/transform3D.h>
#include <dlrNumeric/vector2D.h>
#include <dlrNumeric/vector3D.h>
#include <dlrRandom/pseudoRandom.h>

namespace dlr {

  namespace computerVision {

    /** 
     * This function implements the "three point perspective pose
     * estimation algorithm" of Grunert[1][2] for recovering the
     * camera-frame coordinates of the corners of a triangle of known
     * size, given the projections of those corners in the camera
     * image.  We follow the derivation in [2].  That is, given w_0,
     * w_1, and w_2, all 3D positions in world coordinates, u_0, u_1,
     * u_2, the projections of those coordinates in the image, and
     * pinhole camera parameters, the algorithm recovers p_0, p_1, and
     * p_2, the positions of those points in camera coordinates.
     *
     * [1] J. A. Grunert, "Das Pothenotische Problem in Erweiterter
     * Gestalt Nebst Uber Seine Anwendungen in der Geodisie," Grunerts
     * Archiv fur Mathematik und Physik, Band 1, 1841, pp. 238-248.
     *
     * [2] R. M. Haralick, C. Lee, K. Ottenberg, and M. Nolle, "Review
     * and Analysis of Solutions of the Three Point Perspective Pose
     * Estimation Problem," International Journal of Computer Vision,
     * 13, 3, 331-356 (1994).
     *
     * @param w0 This argument is the first of the three 3D points in
     * world coordinates.
     * 
     * @param w1 This argument is the second of the three 3D points in
     * world coordinates.
     * 
     * @param w2 This argument is the third  of the three 3D points in
     * world coordinates.
     * 
     * @param u0 This argument is the image location of the projection
     * of world point w0.
     * 
     * @param u1 This argument is the image location of the projection
     * of world point w1.
     * 
     * @param u2 This argument is the image location of the projection
     * of world point w2.
     * 
     * @param intrinsics This argument describes the intrinsic
     * calibration of the camera that generated the image from which
     * u0, u1, u2 were drawn.
     * 
     * @param p0OutputIter This argument is used to return estimates
     * of the point in camera coordinates corresponding to w0.  It is
     * an iterator pointing to the beginning of an output sequence,
     * and must be able to accept at least four values.
     * 
     * @param p1OutputIter This argument is used to return estimates
     * of the point in camera coordinates corresponding to w1.  It is
     * an iterator pointing to the beginning of an output sequence,
     * and must be able to accept at least four values.
     * 
     * @param p2OutputIter This argument is used to return estimates
     * of the point in camera coordinates corresponding to w2.  It is
     * an iterator pointing to the beginning of an output sequence,
     * and must be able to accept at least four values.
     * 
     * @param epsilon This argument sets some internal tolerances of
     * the algorithm, and should be left at its default value for now.
     * 
     * @return The return value indicates how many solutions for the
     * camera position and orientation were found.  If returnValue >=
     * 1, then the first elements of the three output sequences
     * correspond to the first solution found.  If return value >= 2,
     * then the second elements of the three output sequences
     * correspond to the second solution found, and so on.
     */
    template <class IterType>
    unsigned int
    threePointAlgorithm(dlr::numeric::Vector3D const& w0,
                        dlr::numeric::Vector3D const& w1,
                        dlr::numeric::Vector3D const& w2,
                        dlr::numeric::Vector2D const& u0,
                        dlr::numeric::Vector2D const& u1,
                        dlr::numeric::Vector2D const& u2,
                        CameraIntrinsicsPinhole const& intrinsics,
                        IterType p0OutputIter,
                        IterType p1OutputIter,
                        IterType p2OutputIter,
                        double epsilon = 1.0E-8);

    /** 
     * This function implements the "robust" version of
     * threePointAlgorithm().  Multiple solutions for camera pose are
     * computed using randomly selected sets of three input points, an
     * error value is computed (based on all of the input points) for
     * each of the potential solutions, and the solution having the
     * best error value is retained.
     * 
     * @param worldPointsBegin This argument is the beginning (in the
     * STL sense) of a sequence of 3D points expressed in world
     * coordinates, and represented as dlr::numeric::Vector3D
     * instances.
     * 
     * @param worldPointsEnd This argument is the end (in the STL
     * sense) of the sequence begun by worldPointsBegin.
     * 
     * @param imagePointsBegin This argument is the beginning (in the
     * STL sense) of a sequence of 2D points corresponding to the
     * elements of [worlPointsBegin, worldPointsEnd], and expressed in
     * image coordinates.
     * 
     * @param intrinsics This argument describes the intrinsic
     * calibration of the camera that generated the input points.
     * 
     * @param iterations This argument specifies how many random
     * samples of three input points should be processed to generate
     * solution hypotheses.
     * 
     * @param inlierProportion This argument specifies what proportion
     * of the input points are expected to be "inliers" and conform to
     * the correct solution (once we find it).  It is used to tune the
     * error value computation.
     * 
     * @param score This argument is a projection residual indicating
     * the goodness of the final solution.
     * 
     * @param pRandom This argument is a pseudorandom number generator
     * used by the algorithm to select sets of three input points.
     * 
     * @return The return value is a coordinate tranformation that
     * takes points in world coordinates and converts them to camera
     * coordinates.
     */
    template<class InIter3D, class InIter2D, class OutIter3D>
    dlr::numeric::Transform3D
    threePointAlgorithmRobust(InIter3D worldPointsBegin,
                              InIter3D worldPointsEnd,
                              InIter2D imagePointsBegin,
                              CameraIntrinsicsPinhole const& intrinsics,
                              size_t iterations,
                              double inlierProportion,
                              double& score,
                              dlr::random::PseudoRandom& pRandom
                              = dlr::random::PseudoRandom());


    template <class OutIter>
    unsigned int
    solveThreePointAlgorithmQuarticSystem(
      double cosAlpha, double cosBeta, double cosGamma,
      double a2, double b2, double c2, double epsilon,
      OutIter s0Iter, OutIter s1Iter, OutIter s2Iter,
      double& condition);
    
  } // namespace computerVision
    
} // namespace dlr


/* ============ Definitions of inline & template functions ============ */


#include <cmath>
#include <complex>
#include <limits>
#include <dlrComputerVision/registerPoints3D.h>
#include <dlrNumeric/solveQuartic.h>
#include <dlrNumeric/utilities.h>

namespace dlr {

  namespace computerVision {

    // This function implements the "three point perspective pose
    // estimation algorithm" of Grunert[1][2] for recovering the
    // camera-frame coordinates of the corners of a triangle of known
    // size, given the projections of those corners in the camera
    // image.
    template <class IterType>
    unsigned int
    threePointAlgorithm(dlr::numeric::Vector3D const& w0,
                        dlr::numeric::Vector3D const& w1,
                        dlr::numeric::Vector3D const& w2,
                        dlr::numeric::Vector2D const& u0,
                        dlr::numeric::Vector2D const& u1,
                        dlr::numeric::Vector2D const& u2,
                        CameraIntrinsicsPinhole const& intrinsics,
                        IterType p0OutputIter,
                        IterType p1OutputIter,
                        IterType p2OutputIter,
                        double epsilon)
    {
      unsigned int numberOfSolutions = 0;

      // Following the conventions of the paper, we define j0, j1, j2
      // to be unit vectors in the camera coordinate frame that point
      // toward the three world points w0, w1, and w2, respectively.
      dlr::numeric::Vector3D j0 =
        intrinsics.reverseProject(u0).getDirectionVector();
      dlr::numeric::Vector3D j1 =
        intrinsics.reverseProject(u1).getDirectionVector();
      dlr::numeric::Vector3D j2 =
        intrinsics.reverseProject(u2).getDirectionVector();

      // Define alpha to be the angle between j1 and j2, beta to be
      // the angle between j0 and j2, and gamma to be the angle
      // between j0 and j1.
      double cosAlpha = dlr::numeric::dot(j1, j2);
      double cosBeta = dlr::numeric::dot(j0, j2);
      double cosGamma = dlr::numeric::dot(j0, j1);

      // Similarly, define a, b, and c to be the distance between the
      // 3D points that define alpha, beta, and gamma, respectively.
      // We only need the squares of these distances, so we save a
      // sqrt() call for each by computing only the square.
      double a2 = dlr::numeric::magnitudeSquared(w1 - w2);
      double b2 = dlr::numeric::magnitudeSquared(w0 - w2);
      double c2 = dlr::numeric::magnitudeSquared(w0 - w1);
      
      // If we define s0, s1, and s2 to be the distances from the
      // camera focus to each of the three points, then we have:
      //
      //   p0 = s0 * j0
      //   p1 = s1 * j1
      //   p2 = s2 * j2
      //
      // Where p0, p1, and p2 are the positions of the three 3D points
      // in camera coordinates.  The law of cosines gives us:
      //
      //   s1^2 + s2^2 - 2*s1*s2*cos(alpha) = a^2
      //   s0^2 + s2^2 - 2*s0*s2*cos(beta) = b^2
      //   s0^2 + s1^2 - 2*s0*s1*cos(gamma) = c^2

      // If we choose k1, k2 so that s1 = k1*s0, and s2 = k2*s0 and
      // substitute into (and rearrange) each of the law of cosines
      // equations, we get:
      //
      //   s0^2 = a^2 / (k1^2 + k2^2 + 2*k1*k2*cos(alpha))
      //   s0^2 = b^2 / (1 + k2^2 + 2*k2*cos(beta))
      //   s0^2 = c^2 / (1 + k1^2 + 2*k1*cos(gamma))
      //
      // Combining these to eliminate k2, we obtain an expression for
      // k1 in terms of k2, and a quartic equation in k2.  These are
      // equations 8 and 9 in [1], and are not reproduced in this
      // comment (although the code below implements first the quartic
      // equation, and then the expression for k1).
      double s0Array0[4];
      double s1Array0[4];
      double s2Array0[4];
      double condition0;
      unsigned int newSolutions0 = solveThreePointAlgorithmQuarticSystem(
        cosAlpha, cosBeta, cosGamma, a2, b2, c2, epsilon,
        &(s0Array0[0]), &(s1Array0[0]), &(s2Array0[0]),
        condition0);
      double s0Array1[4];
      double s1Array1[4];
      double s2Array1[4];
      double condition1;
      unsigned int newSolutions1 = solveThreePointAlgorithmQuarticSystem(
        cosBeta, cosAlpha, cosGamma, b2, a2, c2, epsilon,
        &(s1Array1[0]), &(s0Array1[0]), &(s2Array1[0]),
        condition1);
      double s0Array2[4];
      double s1Array2[4];
      double s2Array2[4];
      double condition2;
      unsigned int newSolutions2 = solveThreePointAlgorithmQuarticSystem(
        cosBeta, cosGamma, cosAlpha, b2, c2, a2, epsilon,
        &(s1Array2[0]), &(s2Array2[0]), &(s0Array2[0]),
        condition2);

      double* s0Array;
      double* s1Array;
      double* s2Array;
      unsigned int newSolutions;
      if((condition0 >= condition1) && (condition0 >= condition2) ) {
        s0Array = s0Array0;
        s1Array = s1Array0;
        s2Array = s2Array0;
        newSolutions = newSolutions0;
      } else if((condition1 >= condition0) && (condition1 >= condition2) ) {
        s0Array = s0Array1;
        s1Array = s1Array1;
        s2Array = s2Array1;
        newSolutions = newSolutions1;
      } else {
        s0Array = s0Array2;
        s1Array = s1Array2;
        s2Array = s2Array2;
        newSolutions = newSolutions2;
      }

      for(unsigned int ii = 0; ii < newSolutions; ++ii) {
        *p0OutputIter = s0Array[ii] * j0;
        *p1OutputIter = s1Array[ii] * j1;
        *p2OutputIter = s2Array[ii] * j2;
        
        ++p0OutputIter;
        ++p1OutputIter;
        ++p2OutputIter;
        ++numberOfSolutions;
      }
      return numberOfSolutions;
    }


    template<class InIter3D, class InIter2D>
    dlr::numeric::Transform3D
    threePointAlgorithmRobust(InIter3D worldPointsBegin,
                              InIter3D worldPointsEnd,
                              InIter2D imagePointsBegin,
                              CameraIntrinsicsPinhole const& intrinsics,
                              size_t iterations,
                              double inlierProportion,
                              double& score,
                              dlr::random::PseudoRandom& pRandom)
    {
      // State variables so we'll remember the correct essential
      // matrix once we find it.
      double bestErrorSoFar = std::numeric_limits<double>::max();
      dlr::numeric::Transform3D selectedCandidate;

      // Sanity check arguments.
      size_t numberOfPoints = worldPointsEnd - worldPointsBegin;
      if(numberOfPoints < 4) {
        DLR_THROW(ValueException, "threePointAlgorithmRobust()",
                  "Input sequence must have at least four elements.");
      }
      
      // Copy input points into local buffers.
      std::vector<dlr::numeric::Vector3D> worldPoints(numberOfPoints);
      std::vector<dlr::numeric::Vector2D> imagePoints(numberOfPoints);
      std::copy(worldPointsBegin, worldPointsEnd, worldPoints.begin());
      std::copy(imagePointsBegin, imagePointsBegin + numberOfPoints,
                imagePoints.begin());

      // Make a buffer to hold points in camera space (and from which
      // to compute residual errors.
      std::vector<dlr::numeric::Vector3D> cameraPoints(numberOfPoints);

      // Start the algorithm!
      for(size_t ii = 0; ii < iterations; ++ii) {

        // Select three points.
        for(size_t jj = 0; jj < 3; ++jj) {
          int selectedIndex = pRandom.uniformInt(jj, numberOfPoints);
          if(selectedIndex != static_cast<int>(jj)) {
            std::swap(worldPoints[jj], worldPoints[selectedIndex]);
            std::swap(imagePoints[jj], imagePoints[selectedIndex]);
          }
        }

        // Get candidate cameraPoints.
        dlr::numeric::Vector3D testPoints0_cam[4];
        dlr::numeric::Vector3D testPoints1_cam[4];
        dlr::numeric::Vector3D testPoints2_cam[4];
        unsigned int numberOfSolutions = threePointAlgorithm(
          worldPoints[0], worldPoints[1], worldPoints[2],
          imagePoints[0], imagePoints[1], imagePoints[2], intrinsics,
          testPoints0_cam, testPoints1_cam, testPoints2_cam);

        // Test each candidate solution.
        for(size_t jj = 0; jj < numberOfSolutions; ++jj) {

          // Recover the camTworld transform corresponding to this
          // solution.
          cameraPoints[0] = testPoints0_cam[jj];
          cameraPoints[1] = testPoints1_cam[jj];
          cameraPoints[2] = testPoints2_cam[jj];
          dlr::numeric::Transform3D camTworld = registerPoints3D(
            worldPoints.begin(), worldPoints.begin() + 3, cameraPoints.begin());

          // Transform all world points into camera coordinates.
          std::transform(worldPoints.begin(), worldPoints.end(),
                         cameraPoints.begin(), camTworld.getFunctor());

          // Project all camera points into image coordinates and
          // compute residuals..
          std::vector<double> residualVector(numberOfPoints);
          for(size_t kk = 0; kk < cameraPoints.size(); ++kk) {
            dlr::numeric::Vector2D testPoint_image =
              intrinsics.project(cameraPoints[kk]);
            residualVector[kk] = dlr::numeric::magnitudeSquared(
              testPoint_image - imagePoints[kk]);
          }

          // Compute robust error statistic.
          //
          // Note(xxx): Better not to sort here, since it changes the
          // algorithm to O(NlogN).
          std::sort(residualVector.begin(), residualVector.end());
          int testIndex = static_cast<int>(
            inlierProportion * (residualVector.size() - 1) + 0.5);
          double errorValue = residualVector[testIndex];

          // Remember candidate if it's the best so far.
          if(errorValue < bestErrorSoFar) {
            selectedCandidate = camTworld;
            bestErrorSoFar = errorValue;
          }
        }
      }
      score = bestErrorSoFar;
      return selectedCandidate;
    }


    template <class OutIter>
    unsigned int
    solveThreePointAlgorithmQuarticSystem(
      double cosAlpha, double cosBeta, double cosGamma,
      double a2, double b2, double c2,
      double epsilon,
      OutIter s0Iter, OutIter s1Iter, OutIter s2Iter,
      double& condition)
    {
      unsigned int numberOfSolutions = 0;
      condition = std::numeric_limits<double>::max();
      
      double cos2Alpha = cosAlpha * cosAlpha;
      double cos2Beta = cosBeta * cosBeta;
      double cos2Gamma = cosGamma * cosGamma;
      double a2OverB2 = a2 / b2;
      double a2MinusC2OverB2 = (a2 - c2) / b2;
      double a2MinusC2OverB2Sq = a2MinusC2OverB2 * a2MinusC2OverB2;
      double a2MinusC2OverB2Minus1 = a2MinusC2OverB2 - 1.0;
      double a2PlusC2OverB2 = (a2 + c2) / b2;
      double b2MinusA2OverB2 = (b2 - a2) / b2;
      double b2MinusC2OverB2 = (b2 - c2) / b2;
      double c2OverB2 = c2 / b2;
      double oneMinusA2PlusC2OverB2 = (1.0 - a2PlusC2OverB2);
      double oneMinusA2MinusC2OverB2 = (1.0 - a2MinusC2OverB2);
      double onePlusA2MinusC2OverB2 = (1.0 + a2MinusC2OverB2);
      
      double A0 = (onePlusA2MinusC2OverB2 * onePlusA2MinusC2OverB2
                   - 4.0 * a2OverB2 * cos2Gamma);

      double A1 = 4.0 * ((-a2MinusC2OverB2 * onePlusA2MinusC2OverB2 * cosBeta)
                         + (2.0 * a2OverB2 * cos2Gamma * cosBeta)
                         - (oneMinusA2PlusC2OverB2 * cosAlpha * cosGamma));

      double A2 = 2.0 * (a2MinusC2OverB2Sq - 1.0
                         + 2.0 * a2MinusC2OverB2Sq * cos2Beta
                         + 2.0 * b2MinusC2OverB2 * cos2Alpha
                         - 4.0 * a2PlusC2OverB2 * cosAlpha * cosBeta * cosGamma
                         + 2.0 * b2MinusA2OverB2 * cos2Gamma);

      double A3 = 4.0 * (a2MinusC2OverB2 * oneMinusA2MinusC2OverB2 * cosBeta
                         - oneMinusA2PlusC2OverB2 * cosAlpha * cosGamma
                         + 2.0 * c2OverB2 * cos2Alpha * cosBeta);

      double A4 = (a2MinusC2OverB2Minus1 * a2MinusC2OverB2Minus1
                   - 4.0 * c2OverB2 * cos2Alpha);

      // Now we solve for the roots of the quartic, which tell us
      // valid values of k2.  Each root corresponds to a scale factor
      // that's consistent with the observed data.
      std::complex<double> k2Roots[4];
      dlr::numeric::solveQuartic(
        A3 / A4, A2 / A4, A1 / A4, A0 / A4,
        k2Roots[0], k2Roots[1], k2Roots[2], k2Roots[3]);

      // For real value of k1, there's a corresponding value of k2 (see
      // Eq. 8 in [1]).
      for(unsigned int ii = 0; ii < 4; ++ii) {
        bool isReal = std::fabs(k2Roots[ii].imag()) <= epsilon;
        if(isReal) {
          double k2 = k2Roots[ii].real();
          double numerator = (a2MinusC2OverB2Minus1 * k2 * k2
                              - 2.0 * a2MinusC2OverB2 * cosBeta * k2
                              + onePlusA2MinusC2OverB2);
          double denominator = 2.0 * (cosGamma - k2 * cosAlpha);
          double magDenominator = std::fabs(denominator);
          if(magDenominator < condition) {
            condition = magDenominator;
          }
          if(magDenominator < epsilon) {
            continue;
          }
          double k1 = numerator / denominator;

          // Now that we have k1 and k2, recover the distance from the
          // focus to each of the observed points.
          double s0Sq = c2 / (1.0 + k1 * k1 - 2.0 * k1 * cosGamma);
          if(s0Sq < 0.0) {
            continue;
          }
          *s0Iter = std::sqrt(s0Sq);
          *s1Iter = k1 * (*s0Iter);
          *s2Iter = k2 * (*s0Iter);

          ++s0Iter;
          ++s1Iter;
          ++s2Iter;
          ++numberOfSolutions;
        }
      }
      if(numberOfSolutions == 0) {
        condition = 0.0;
      }
      return numberOfSolutions;
    }

  } // namespace computerVision
    
} // namespace dlr

#endif /* #ifndef DLR_COMPUTERVISION_THREEPOINTALGORITHM_H */
