/**
***************************************************************************
* @file transform3D.cpp
*
* Source file defining Transform3D class.
*
* Copyright (C) 2001-2007 David LaRose, dlr@cs.cmu.edu
* See accompanying file, LICENSE.TXT, for details.
*
* $Revision: 880 $
* $Date: 2007-05-04 00:33:49 -0400 (Fri, 04 May 2007) $
***************************************************************************
**/

#include <dlrNumeric/transform3D.h>

namespace dlr {

  namespace numeric {
    
    // Build a Transform3D from a homogeneous 4x4 matrix.
    Transform3D::
    Transform3D(const Array2D<double>& source)
    {
      if((source.rows() != 4) || (source.columns() != 4)) {
        std::ostringstream message;
        message << "Can't create a Transform3D from a " << source.rows()
                << " x " << source.columns() << "Array2D<double> instance.";
        DLR_THROW(ValueException, "Transform3D::Transform3D()",
                  message.str().c_str());
      }
      m_00 = source(0); m_01 = source(1); m_02 = source(2); m_03 = source(3);
      m_10 = source(4); m_11 = source(5); m_12 = source(6); m_13 = source(7);
      m_20 = source(8); m_21 = source(9); m_22 = source(10); m_23 = source(11);
      m_30 = source(12); m_31 = source(13); m_32 = source(14);
      this->normalize(source(15));
    }


    // This member function returns a functor which makes it easier to
    // transform arrays of points using algorithms such as
    // std::transform().
    Transform3DFunctor
    Transform3D::
    getFunctor() const {
      return Transform3DFunctor(*this);
    }    
    
  
    // This member function returns the inverse of *this.
    Transform3D
    Transform3D::
    invert() const
    {
      // We use the cofactor method for now, since it's easier to code
      // than Gauss-Jordan elimination.  We suspect that it's less
      // efficient, however.
    
      // Notation for determinant values is detRRRCCC, where the
      // Rs indicate the involved rows, from top to bottom, and the Cs
      // indicate the involved columns, from left to right.

      double det1201 = m_10 * m_21 - m_11 * m_20;
      double det1202 = m_10 * m_22 - m_12 * m_20;
      double det1203 = m_10 * m_23 - m_13 * m_20;
      double det1212 = m_11 * m_22 - m_12 * m_21;
      double det1213 = m_11 * m_23 - m_13 * m_21;
      double det1223 = m_12 * m_23 - m_13 * m_22;

      double det1301 = m_10 * m_31 - m_11 * m_30;
      double det1302 = m_10 * m_32 - m_12 * m_30;
      double det1303 = m_10 - m_13 * m_30;
      double det1312 = m_11 * m_32 - m_12 * m_31;
      double det1313 = m_11 - m_13 * m_31;
      double det1323 = m_12 - m_13 * m_32;

      double det2301 = m_20 * m_31 - m_21 * m_30;
      double det2302 = m_20 * m_32 - m_22 * m_30;
      double det2303 = m_20 - m_23 * m_30;
      double det2312 = m_21 * m_32 - m_22 * m_31;
      double det2313 = m_21 - m_23 * m_31;
      double det2323 = m_22 - m_23 * m_32;
    
      double det012012 = (m_00 * det1212 - m_01 * det1202 + m_02 * det1201);
      double det012013 = (m_00 * det1213 - m_01 * det1203 + m_03 * det1201);
      double det012023 = (m_00 * det1223 - m_02 * det1203 + m_03 * det1202);
      double det012123 = (m_01 * det1223 - m_02 * det1213 + m_03 * det1212);

      double det013012 = (m_00 * det1312 - m_01 * det1302 + m_02 * det1301);
      double det013013 = (m_00 * det1313 - m_01 * det1303 + m_03 * det1301);
      double det013023 = (m_00 * det1323 - m_02 * det1303 + m_03 * det1302);
      double det013123 = (m_01 * det1323 - m_02 * det1313 + m_03 * det1312);

      double det023012 = (m_00 * det2312 - m_01 * det2302 + m_02 * det2301);
      double det023013 = (m_00 * det2313 - m_01 * det2303 + m_03 * det2301);
      double det023023 = (m_00 * det2323 - m_02 * det2303 + m_03 * det2302);
      double det023123 = (m_01 * det2323 - m_02 * det2313 + m_03 * det2312);
    
      double det123012 = (m_10 * det2312 - m_11 * det2302 + m_12 * det2301);
      double det123013 = (m_10 * det2313 - m_11 * det2303 + m_13 * det2301);
      double det123023 = (m_10 * det2323 - m_12 * det2303 + m_13 * det2302);
      double det123123 = (m_11 * det2323 - m_12 * det2313 + m_13 * det2312);

      double det01230123 = (
        m_00 * det123123 - m_01 * det123023
        + m_02 * det123013 - m_03 * det123012
        - m_10 * det023123 + m_11 * det023023
        - m_12 * det023013 + m_13 * det023012
        + m_20 * det013123 - m_21 * det013023
        + m_22 * det013013 - m_23 * det013012
        - m_30 * det012123 + m_31 * det012023
        - m_32 * det012013 + det012012);

      // Note that in general, roundoff error will make us pass these
      // tests, even for singular matrices.
      if(det01230123 == 0.0) {
        DLR_THROW(ValueException, "Transform3D::invert()",
                  "Transform is not invertible.");
      }
      if(det012012 == 0.0) {
        DLR_THROW(LogicException, "Transform3D::invert()",
                  "Illegal value for projective scale.");
      }
    
      return Transform3D(
        det123123 / det01230123, -det023123 / det01230123,
        det013123 / det01230123, -det012123 / det01230123,
        -det123023 / det01230123, det023023 / det01230123,
        -det013023 / det01230123, det012023 / det01230123,
        det123013 / det01230123, -det023013 / det01230123,
        det013013 / det01230123, -det012013 / det01230123,
        -det123012 / det01230123, det023012 / det01230123,
        -det013012 / det01230123, det012012 / det01230123);
    }

  
    // Change the Transform3D value by explicitly setting element values
    // as if setting the elements of a 4x4 transformation matrix:
    //    [[a00, a01, a02, a03],
    //     [a10, a11, a12, a13],
    //     [a20, a21, a22, a23],
    //     [a30, a31, a32, a33]]
    void
    Transform3D::
    setTransform(double a00, double a01, double a02, double a03,
                 double a10, double a11, double a12, double a13,
                 double a20, double a21, double a22, double a23,
                 double a30, double a31, double a32, double a33)
    {
      m_00 = a00; m_01 = a01; m_02 = a02; m_03 = a03;
      m_10 = a10; m_11 = a11; m_12 = a12; m_13 = a13;
      m_20 = a20; m_21 = a21; m_22 = a22; m_23 = a23;
      m_30 = a30; m_31 = a31; m_32 = a32;
      this->normalize(a33);
    }


    // This member function sets one element from the matrix
    // representation of the coordinate transform to the specified
    // value.
    void
    Transform3D::
    setValue(size_t row, size_t column, double value)
    {
      switch(row) {
      case 0:
        switch(column) {
        case 0: m_00 = value; return; break;
        case 1: m_01 = value; return; break;
        case 2: m_02 = value; return; break;
        case 3: m_03 = value; return; break;
        default: break;
        }
        break;
      case 1:
        switch(column) {
        case 0: m_10 = value; return; break;
        case 1: m_11 = value; return; break;
        case 2: m_12 = value; return; break;
        case 3: m_13 = value; return; break;
        default: break;
        }
        break;
      case 2:
        switch(column) {
        case 0: m_20 = value; return; break;
        case 1: m_21 = value; return; break;
        case 2: m_22 = value; return; break;
        case 3: m_23 = value; return; break;
        default: break;
        }
        break;
      case 3:
        switch(column) {
        case 0: m_30 = value; return; break;
        case 1: m_31 = value; return; break;
        case 2: m_32 = value; return; break;
        default: break;
        }
        break;
      default:
        break;
      }
      std::ostringstream message;
      message << "Indices (" << row << ", " << column << ") are out of bounds.";
      DLR_THROW(IndexException, "Transform3D::operator()(size_t, size_t)",
                message.str().c_str());
    }

      
    // This operator returns one element from the matrix
    // representation of the coordinate transform by value.
    double
    Transform3D::
    operator()(size_t row, size_t column) const
    {
      // // Avoid ugly duplication of code using ugly const_cast.
      // return const_cast<Transform3D*>(this)->operator()(row, column);
      switch(row) {
      case 0:
        switch(column) {
        case 0: return m_00; break;
        case 1: return m_01; break;
        case 2: return m_02; break;
        case 3: return m_03; break;
        default: break;
        }
        break;
      case 1:
        switch(column) {
        case 0: return m_10; break;
        case 1: return m_11; break;
        case 2: return m_12; break;
        case 3: return m_13; break;
        default: break;
        }
        break;
      case 2:
        switch(column) {
        case 0: return m_20; break;
        case 1: return m_21; break;
        case 2: return m_22; break;
        case 3: return m_23; break;
        default: break;
        }
        break;
      case 3:
        switch(column) {
        case 0: return m_30; break;
        case 1: return m_31; break;
        case 2: return m_32; break;
        case 3: return 1.0; break;
        default: break;
        }
        break;
      default:
        break;
      }
      std::ostringstream message;
      message << "Indices (" << row << ", " << column << ") are out of bounds.";
      DLR_THROW(IndexException, "Transform3D::operator()(size_t, size_t)",
                message.str().c_str());
      return 0.0; // Dummy return to keep the compiler happy.
    }
  
  
    // This operator takes a point and applies the coordinate
    // transform, returning the result.
    Vector3D
    Transform3D::
    operator*(const Vector3D& vector0) const
    {
      return Vector3D(
        m_00 * vector0.x() + m_01 * vector0.y() + m_02 * vector0.z() + m_03,
        m_10 * vector0.x() + m_11 * vector0.y() + m_12 * vector0.z() + m_13,
        m_20 * vector0.x() + m_21 * vector0.y() + m_22 * vector0.z() + m_23,
        m_30 * vector0.x() + m_31 * vector0.y() + m_32 * vector0.z() + 1.0);
    }

  
    // The assignment operator simply duplicates its argument.
    Transform3D&
    Transform3D::
    operator=(const Transform3D& source)
    {
      m_00 = source.m_00; m_01 = source.m_01;
      m_02 = source.m_02; m_03 = source.m_03;
      m_10 = source.m_10; m_11 = source.m_11;
      m_12 = source.m_12; m_13 = source.m_13;
      m_20 = source.m_20; m_21 = source.m_21;
      m_22 = source.m_22; m_23 = source.m_23;
      m_30 = source.m_30; m_31 = source.m_31;
      m_32 = source.m_32;
      return *this;
    }

  
    void
    Transform3D::
    normalize(double scaleFactor)
    {
      if(scaleFactor == 0.0) {
        DLR_THROW(ValueException, "Transform3D::normalize()",
                  "Invalid normalization constant. "
                  "The bottom right element of a homogeneous transformation "
                  "cannot be equal to 0.0.");
      }
      if(scaleFactor != 1.0) {
        m_00 /= scaleFactor;
        m_01 /= scaleFactor;
        m_02 /= scaleFactor;
        m_03 /= scaleFactor;
        m_10 /= scaleFactor;
        m_11 /= scaleFactor;
        m_12 /= scaleFactor;
        m_13 /= scaleFactor;
        m_20 /= scaleFactor;
        m_21 /= scaleFactor;
        m_22 /= scaleFactor;
        m_23 /= scaleFactor;
        m_30 /= scaleFactor;
        m_31 /= scaleFactor;
        m_32 /= scaleFactor;
      }
    }

    /* ============== Non-member functions which should ============== */
    /* ============== probably live in a different file ============== */
  
  
    // This operator composes two Transform3D instances.  The resulting
    // transform satisfies the equation:
    //   (transform0 * transform1) * v0 = transform0 * (transform1 * v0),
    // where v0 is a Vector3D instance.
    Transform3D
    operator*(const Transform3D& transform0, const Transform3D& transform1)
    {
      double a00 = (transform0.value<0, 0>() * transform1.value<0, 0>()
                    + transform0.value<0, 1>() * transform1.value<1, 0>()
                    + transform0.value<0, 2>() * transform1.value<2, 0>()
                    + transform0.value<0, 3>() * transform1.value<3, 0>());
      double a01 = (transform0.value<0, 0>() * transform1.value<0, 1>()
                    + transform0.value<0, 1>() * transform1.value<1, 1>()
                    + transform0.value<0, 2>() * transform1.value<2, 1>()
                    + transform0.value<0, 3>() * transform1.value<3, 1>());
      double a02 = (transform0.value<0, 0>() * transform1.value<0, 2>()
                    + transform0.value<0, 1>() * transform1.value<1, 2>()
                    + transform0.value<0, 2>() * transform1.value<2, 2>()
                    + transform0.value<0, 3>() * transform1.value<3, 2>());
      double a03 = (transform0.value<0, 0>() * transform1.value<0, 3>()
                    + transform0.value<0, 1>() * transform1.value<1, 3>()
                    + transform0.value<0, 2>() * transform1.value<2, 3>()
                    + transform0.value<0, 3>() * transform1.value<3, 3>());
      double a10 = (transform0.value<1, 0>() * transform1.value<0, 0>()
                    + transform0.value<1, 1>() * transform1.value<1, 0>()
                    + transform0.value<1, 2>() * transform1.value<2, 0>()
                    + transform0.value<1, 3>() * transform1.value<3, 0>());
      double a11 = (transform0.value<1, 0>() * transform1.value<0, 1>()
                    + transform0.value<1, 1>() * transform1.value<1, 1>()
                    + transform0.value<1, 2>() * transform1.value<2, 1>()
                    + transform0.value<1, 3>() * transform1.value<3, 1>());
      double a12 = (transform0.value<1, 0>() * transform1.value<0, 2>()
                    + transform0.value<1, 1>() * transform1.value<1, 2>()
                    + transform0.value<1, 2>() * transform1.value<2, 2>()
                    + transform0.value<1, 3>() * transform1.value<3, 2>());
      double a13 = (transform0.value<1, 0>() * transform1.value<0, 3>()
                    + transform0.value<1, 1>() * transform1.value<1, 3>()
                    + transform0.value<1, 2>() * transform1.value<2, 3>()
                    + transform0.value<1, 3>() * transform1.value<3, 3>());
      double a20 = (transform0.value<2, 0>() * transform1.value<0, 0>()
                    + transform0.value<2, 1>() * transform1.value<1, 0>()
                    + transform0.value<2, 2>() * transform1.value<2, 0>()
                    + transform0.value<2, 3>() * transform1.value<3, 0>());
      double a21 = (transform0.value<2, 0>() * transform1.value<0, 1>()
                    + transform0.value<2, 1>() * transform1.value<1, 1>()
                    + transform0.value<2, 2>() * transform1.value<2, 1>()
                    + transform0.value<2, 3>() * transform1.value<3, 1>());
      double a22 = (transform0.value<2, 0>() * transform1.value<0, 2>()
                    + transform0.value<2, 1>() * transform1.value<1, 2>()
                    + transform0.value<2, 2>() * transform1.value<2, 2>()
                    + transform0.value<2, 3>() * transform1.value<3, 2>());
      double a23 = (transform0.value<2, 0>() * transform1.value<0, 3>()
                    + transform0.value<2, 1>() * transform1.value<1, 3>()
                    + transform0.value<2, 2>() * transform1.value<2, 3>()
                    + transform0.value<2, 3>() * transform1.value<3, 3>());
      double a30 = (transform0.value<3, 0>() * transform1.value<0, 0>()
                    + transform0.value<3, 1>() * transform1.value<1, 0>()
                    + transform0.value<3, 2>() * transform1.value<2, 0>()
                    + transform0.value<3, 3>() * transform1.value<3, 0>());
      double a31 = (transform0.value<3, 0>() * transform1.value<0, 1>()
                    + transform0.value<3, 1>() * transform1.value<1, 1>()
                    + transform0.value<3, 2>() * transform1.value<2, 1>()
                    + transform0.value<3, 3>() * transform1.value<3, 1>());
      double a32 = (transform0.value<3, 0>() * transform1.value<0, 2>()
                    + transform0.value<3, 1>() * transform1.value<1, 2>()
                    + transform0.value<3, 2>() * transform1.value<2, 2>()
                    + transform0.value<3, 3>() * transform1.value<3, 2>());
      double a33 = (transform0.value<3, 0>() * transform1.value<0, 3>()
                    + transform0.value<3, 1>() * transform1.value<1, 3>()
                    + transform0.value<3, 2>() * transform1.value<2, 3>()
                    + transform0.value<3, 3>() * transform1.value<3, 3>());
      return Transform3D(a00, a01, a02, a03,
                         a10, a11, a12, a13,
                         a20, a21, a22, a23,
                         a30, a31, a32, a33);
    }


    std::ostream&
    operator<<(std::ostream& stream, const Transform3D& transform0)
    {
      stream << "Transform3D("
             << transform0.value<0, 0>() << ", "
             << transform0.value<0, 1>() << ", "
             << transform0.value<0, 2>() << ", "
             << transform0.value<0, 3>() << ",\n"
             << transform0.value<1, 0>() << ", "
             << transform0.value<1, 1>() << ", "
             << transform0.value<1, 2>() << ", "
             << transform0.value<1, 3>() << ",\n"
             << transform0.value<2, 0>() << ", "
             << transform0.value<2, 1>() << ", "
             << transform0.value<2, 2>() << ", "
             << transform0.value<2, 3>() << ",\n"
             << transform0.value<3, 0>() << ", "
             << transform0.value<3, 1>() << ", "
             << transform0.value<3, 2>() << ", "
             << transform0.value<3, 3>() << ")";
      return stream;
    }

    std::istream&
    operator>>(std::istream& stream, Transform3D& transform0)
    {
      // If stream is in a bad state, we can't read from it.
      if (!stream){
        return stream;
      }
    
      // It's a lot easier to use a try block than to be constantly
      // testing whether the IO has succeeded, so we tell stream to
      // complain if anything goes wrong.
      std::ios_base::iostate oldExceptionState = stream.exceptions();
      stream.exceptions(
        std::ios_base::badbit | std::ios_base::failbit | std::ios_base::eofbit);

      // Now on with the show.
      try{
        // Construct an InputStream instance so we can use our
        // convenience functions.
        InputStream inputStream(stream);

        // Advance to the next relevant character.
        inputStream.skipWhiteSpace();
      
        // Read the "Transform3D(" part.
        inputStream.expect("Transform3D(");

        // Read all the data except the last element.
        std::vector<double> inputValues(16);
        for(size_t index = 0; index < (inputValues.size() - 1); ++index) {
          // Read the value.
          inputStream >> inputValues[index];

          // Read punctuation before the next value.
          inputStream.expect(",");
        }

        // Read the final value.
        inputStream >> inputValues[inputValues.size() - 1];

        // Read the closing parenthesis.
        inputStream.expect(")");

        // And update the transform.
        transform0.setTransform(inputValues[0], inputValues[1],
                                inputValues[2], inputValues[3], 
                                inputValues[4], inputValues[5], 
                                inputValues[6], inputValues[7], 
                                inputValues[8], inputValues[9], 
                                inputValues[10], inputValues[11], 
                                inputValues[12], inputValues[13], 
                                inputValues[14], inputValues[15]);
      } catch(std::ios_base::failure) {
        // Empty
      }
      stream.exceptions(oldExceptionState);
      return stream;
    }

  } // namespace numeric

} // namespace dlr
