/**
***************************************************************************
* @file dlrNumeric/solveCubic.h
*
* Header file declaring a function for solving cubic polynomial
* equations of a single variable.
*
* Copyright (C) 2001-2009 David LaRose, dlr@cs.cmu.edu
* See accompanying file, LICENSE.TXT, for details.
*
* $Revision: $
* $Date: $
***************************************************************************
**/

#ifndef DLR_NUMERIC_SOLVECUBIC_H
#define DLR_NUMERIC_SOLVECUBIC_H

#include <complex>

namespace dlr {

  namespace numeric {
    
    /** 
     * This function computes the real roots of the cubic polynomial
     * x^3 + c0*x^2 + c1*x + c2 = 0.
     *
     * @param c0 This argument is the quadratic coefficient of the
     * polynomial.
     * 
     * @param c1 This argument is the linear coefficient of the
     * polynomial.
     * 
     * @param c2 This argument is the constant coefficient of the
     * polynomial.
     * 
     * @param root0 This reference argument is used to return the
     * first real root of the polynomial.
     * 
     * @param root1 If the polynomial has three real roots, this
     * reference argument is used to return the second root.
     *
     * @param root2 If the polynomial has three real roots, this
     * argument is used to return the third root.
     *
     * @return If the polynomial has three real roots, the return
     * value is true.  If the polynomial has only one real root, the
     * return value is false, and arguments root1 and root2 are not
     * changed.
     */
    template <class Type>
    bool
    solveCubic(Type c0, Type c1, Type c2,
               Type& root0, Type& root1, Type& root2);

  
    /** 
     * This function computes the (possibly complex) roots of the
     * cubic polynomial x^3 + c0*x^2 + c1*x + c2 = 0.
     * 
     * @param c0 This argument is the cubic coefficient of the
     * polynomial.
     * 
     * @param c1 This argument is the linear coefficient of the
     * polynomial.
     * 
     * @param c2 This argument is the constant coefficient of the
     * polynomial.
     * 
     * @param root0 This reference argument is used to return the
     * first root of the polynomial.
     * 
     * @param root1 This reference argument is used to return the
     * second root of the polynomial.
     *
     * @param root2 This reference argument is used to return the
     * third root of the polynomial.
     */
    template <class Type>
    void
    solveCubic(Type c0, Type c1, Type c2,
               std::complex<Type>& root0, std::complex<Type>& root1,
               std::complex<Type>& root2);

  
  } // namespace numeric

} // namespace dlr


/* ======== Inline and template definitions below. ======== */

#include <cmath>
#include <dlrCommon/constants.h>

namespace dlr {

  namespace numeric {

    // This function computes the real roots of the cubic polynomial
    // x^3 + c0*x^2 + c1*x + c2 = 0.
    template <class Type>
    bool
    solveCubic(Type c0, Type c1, Type c2,
               Type& root0, Type& root1, Type& root2)
    {
      // We follow the formulation in Press et al, "Numerical Recipes,
      // The Art of Scientific Computing," third edition, Cambridge
      // University Press, 2007.

      bool returnValue = true;
      
      Type c0Squared = c0 * c0;
      Type qq = ((c0Squared - (Type(3.0) * c1)) / Type(9.0));
      Type rr = ((Type(2.0) * c0Squared * c0
                  - Type(9.0) * c0 * c1
                  + Type(27.0) * c2)
                 / Type(54.0));

      Type rrSquared = rr * rr;
      Type qqCubed = qq * qq * qq;
      Type c0OverThree = c0 / Type(3.0);
      if(rrSquared < qqCubed) {
        // Looks like we have three real roots.
        Type theta = std::acos(rr / std::sqrt(qqCubed));
        Type minusTwoRootQq = Type(-2.0) * Type(std::sqrt(qq));
        Type twoPi = Type(2.0 * dlr::common::constants::pi);

        root0 = (minusTwoRootQq * std::cos(theta / Type(3.0))
                 - c0OverThree);
        root1 = (minusTwoRootQq * std::cos((theta + twoPi) / Type(3.0))
                 - c0OverThree);
        root2 = (minusTwoRootQq * std::cos((theta - twoPi) / Type(3.0))
                 - c0OverThree);
      } else {
        // Looks like we have some complex roots.
        bool signRr = rr > Type(0.0);
        Type absRr = signRr ? rr : -rr;
        Type aa = std::pow(absRr + std::sqrt(rrSquared - qqCubed), 1.0 / 3.0);
        if(signRr) {
          aa = -aa;
        }

        Type bb = (aa == Type(0.0)) ? Type(0.0) : (qq / aa);

        root0 = (aa + bb) - c0OverThree;
        returnValue = false;
      }
      return returnValue;
    }



    // This function computes the (possibly complex) roots of the
    // cubic polynomial x^3 + c0*x^2 + c1*x + c2 = 0.
    template <class Type>
    void
    solveCubic(Type c0, Type c1, Type c2,
               std::complex<Type>& root0, std::complex<Type>& root1,
               std::complex<Type>& root2)
    {
      // We follow the formulation in Press et al, "Numerical Recipes,
      // The Art of Scientific Computing," third edition, Cambridge
      // University Press, 2007.
      Type c0Squared = c0 * c0;
      Type qq = ((c0Squared - (Type(3.0) * c1)) / Type(9.0));
      Type rr = ((Type(2.0) * c0Squared * c0
                  - Type(9.0) * c0 * c1
                  + Type(27.0) * c2)
                 / Type(54.0));

      Type rrSquared = rr * rr;
      Type qqCubed = qq * qq * qq;
      Type c0OverThree = c0 / Type(3.0);
      if(rrSquared < qqCubed) {
        // Looks like we have three real roots.
        Type theta = std::acos(rr / std::sqrt(qqCubed));
        Type minusTwoRootQq = Type(-2.0) * Type(std::sqrt(qq));
        Type twoPi = Type(2.0 * dlr::common::constants::pi);

        root0.real() = (minusTwoRootQq * std::cos(theta / Type(3.0))
                        - c0OverThree);
        root0.imag() = 0.0;
        root1.real() = (minusTwoRootQq * std::cos((theta + twoPi) / Type(3.0))
                        - c0OverThree);
        root1.imag() = 0.0;
        root2.real() = (minusTwoRootQq * std::cos((theta - twoPi) / Type(3.0))
                        - c0OverThree);
        root2.imag() = 0.0;
      } else {
        // Looks like we have some complex roots.
        bool signRr = rr > Type(0.0);
        Type absRr = signRr ? rr : -rr;
        Type aa = std::pow(absRr + std::sqrt(rrSquared - qqCubed), 1.0 / 3.0);
        if(signRr) {
          aa = -aa;
        }

        Type bb = (aa == Type(0.0)) ? Type(0.0) : (qq / aa);

        root0.real() = (aa + bb) - c0OverThree;
        root0.imag() = 0.0;
        root1.real() = Type(-0.5) * (aa + bb) - c0OverThree;
        root1.imag() = (Type(std::sqrt(3.0)) / Type(2.0)) * (aa - bb);
        root2.real() = root1.real();
        root2.imag() = -root1.imag();
      }
    }

  } // namespace numeric

} // namespace dlr

#endif /* #ifndef DLR_NUMERIC_UTILITIES_H */
