/*@@@**************************************************************************
* \file  pyrDPStereo
* \author Hernan Badino
* \date  Mon Sep 28 10:48:18 EDT 2009
* \notes 
*******************************************************************************
*****          (C) COPYRIGHT Hernan Badino - All Rights Reserved          *****
******************************************************************************/

/* INCLUDES */
#include "uintParam.h"
#include "pyrDPStereo.h"
#include "parameterSet.h"
#include "paramBaseConnector.h"
#include "int2DParam.h"
#include "floatParam.h"
#include "intParam.h"
#include "boolParam.h"
#include "paramGroup.h"
#include "paramGroupEnd.h"
#include "uint2DParam.h"

using namespace VIC;

//const static  int CPyramidalDynProgStereo::m_maxLevels_si = 6;

CPyramidalDynProgStereo::CPyramidalDynProgStereo( unsigned int f_imageWidth_ui,
                                                  unsigned int f_imageHeight_ui,
                                                  unsigned int f_levels_ui )
        : m_pyrLeft (          f_imageWidth_ui, f_imageHeight_ui, f_levels_ui),
          m_pyrRight (         f_imageWidth_ui, f_imageHeight_ui, f_levels_ui),
          m_pyrDisp (   f_imageWidth_ui, f_imageHeight_ui, f_levels_ui, true ),
          m_dsi_p (                                                     NULL ),
          m_dynProg_p (                                                 NULL ),
          m_levels_ui (                                          f_levels_ui ),
          m_dispRange (                                               -1, 20 ),
          m_kernelSize (                                                3, 3 )
{
    setPyramidParams ( f_imageWidth_ui,
                       f_imageHeight_ui,
                       f_levels_ui );
}

CPyramidalDynProgStereo::~CPyramidalDynProgStereo( )
{
    if ( m_dsi_p )
        delete [] m_dsi_p;
    
    if ( m_dynProg_p )
        delete [] m_dynProg_p;
}

bool
CPyramidalDynProgStereo::compute ( const CFloatImage &f_leftImg,
                                   const CFloatImage &f_rightImg )
{
    int level_ui = m_levels_ui-1;
    
    /// 1.- Compute Gaussian Pyramids.
    m_pyrLeft.compute  ( f_leftImg );
    m_pyrRight.compute ( f_rightImg );
    
    /// 2.- Compute first result.
    /// 2a.- DSI computation.
    if ( m_kernelSize.width == m_kernelSize.height && 
         m_kernelSize.width == 1 )
        m_dsi_p[level_ui].compute ( * m_pyrLeft.getLevelImage (level_ui),
                                    * m_pyrRight.getLevelImage(level_ui) );
    else
        m_dsi_p[level_ui].computeZSSD ( * m_pyrLeft.getLevelImage (level_ui),
                                        * m_pyrRight.getLevelImage(level_ui),
                                        m_kernelSize.width,
                                        m_kernelSize.height,
                                        true );
    /// 2b.- Dynamic Programming
    computeDynProg ( level_ui );

    for (--level_ui; level_ui >= 0; --level_ui)
    {
        /// 3. - Transfer result from previous level to this level.
        transferLevel ( level_ui + 1);
        
        /// 4.- Compute next level.
        /// 4a.- DSI computation.
        if ( m_kernelSize.width == m_kernelSize.height && 
             m_kernelSize.width == 1 )
            m_dsi_p[level_ui].compute ( * m_pyrLeft.getLevelImage (level_ui),
                                        * m_pyrRight.getLevelImage(level_ui) );
        else
            m_dsi_p[level_ui].computeZSSD ( * m_pyrLeft.getLevelImage (level_ui),
                                            * m_pyrRight.getLevelImage(level_ui),
                                            m_kernelSize.width,
                                            m_kernelSize.height,
                                            true );
        /// 4b.- Dynamic Programming
        computeDynProg ( level_ui );
    }
    return true;
}

bool
CPyramidalDynProgStereo::transferLevel ( int f_level_ui )
{
    if ( f_level_ui == 0 ) return false;
    
    CFloatImage * srcImg_p = m_pyrDisp.getLevelImage ( f_level_ui );
    CFloatImage * dstImg_p = m_pyrDisp.getLevelImage ( f_level_ui - 1);

    int w_i, h_i;
    
    h_i = dstImg_p -> getHeight() / 2;
    w_i = dstImg_p -> getWidth()  / 2;
    
#if defined ( _OPENMP )
    const int numThreads_i = std::min(omp_get_max_threads(), 32);
#pragma omp parallel for num_threads(numThreads_i) schedule(static)
#endif
    for (int i = 0; i < h_i; ++i)
    {
        float *dst1_p = dstImg_p -> getScanline ( i*2   );
        float *dst2_p = dstImg_p -> getScanline ( i*2+1 );
        float *src_p  = srcImg_p -> getScanline ( i );
        
        for (int j = 0; j < w_i; ++j, ++src_p)
        {
            float disp_f = *src_p * 2;
    
            *dst1_p++ = disp_f;
            *dst1_p++ = disp_f;
            *dst2_p++ = disp_f;
            *dst2_p++ = disp_f;
        }

        /// Check for uneven width of the destination image.
        if ( w_i < (int)srcImg_p -> getWidth() )
        {
            float disp_f = *src_p * 2;
    
            *dst1_p   = disp_f;
            *dst2_p   = disp_f;
        } 
    }

    /// Check for uneven height of the destination image.
    if ( h_i < (int)srcImg_p -> getHeight() )
    {
        for (int i = 0; i < h_i; ++i)
        {
            float *dst1_p = dstImg_p -> getScanline ( i*2 );
            float *src_p  = srcImg_p -> getScanline ( i );
            
            for (int j = 0; j < w_i; ++j, ++src_p)
            {
                float disp_f = *src_p * 2;
                
                *dst1_p++ = disp_f;
                *dst1_p++ = disp_f;
            }

            if ( w_i < (int)dstImg_p -> getWidth() )
            {
                float disp_f = *src_p * 2;
    
                *dst1_p   = disp_f;
            }
        }
    }

    return true;
}


bool
CPyramidalDynProgStereo::computeDynProg ( int f_level_i )
{
    CTestFloatDSI_t::SDispSpaceImage dsi = m_dsi_p[f_level_i].getDisparitySpaceImage( );
    S2D<int> halfMaskSize ( m_kernelSize.width/2, m_kernelSize.height/2 );
    CFloatImage *   dispImg_p = m_pyrDisp.getLevelImage ( f_level_i );

    int h_i = m_pyrLeft.getLevelImage (f_level_i) -> getHeight();
    int w_i = m_pyrLeft.getLevelImage (f_level_i) -> getWidth();

    std::vector<int> resVec ( dsi.width_ui, 0 );

    for (int i = halfMaskSize.y; i < h_i-halfMaskSize.y; ++i)
    {
        /// Obtain disparities for this row using Dynamic programming.
        /// First let set an image pointing to the column/disp slice.
        CFloatImage img;
        img.setWidth  ( dsi.dispRange_ui );
        img.setHeight ( dsi.width_ui );
        img.setData   ( dsi.getDispColumnSlice(i) );
        
        /// Compute dynamic programming. Check if folloPath vector must be
        /// build.
        printf("computing result for row %i\n", i);
        
        if ( f_level_i == (int)(m_levels_ui - 1) )
        {
            m_dynProg_p[f_level_i].compute ( img, resVec );
        }
        else
        {
            std::vector<int> predVec( dsi.width_ui, 0 );
            float *disp_p = dispImg_p -> getScanline ( i );
            for (int j = 0; j < w_i; ++j)
            {
                predVec[j] = disp_p[j] - dsi.minDisp_i;
                predVec[j] = std::max(predVec[j], dsi.minDisp_i);
                predVec[j] = std::min(predVec[j], dsi.maxDisp_i - dsi.minDisp_i);
            }

            m_dynProg_p[f_level_i].compute ( img, resVec, predVec );
        }
        
        float *  result_p = dispImg_p->getScanline(i);
        for ( int j = 0 ; j < w_i ; ++j, ++result_p )
        {
            /// Sum dsi.minDisp_i to the result because the result refers to 
            /// the column image position (which is 0 for minDisp_i).
            *result_p = resVec[j] + dsi.minDisp_i;

            //if ( *result_p < 0 ||
            //     j - *result_p < 0 ) *result_p = 0;
        }
    }

    return true;
}
        
CFloatImage *
CPyramidalDynProgStereo::getDisparityImage () const
{
    return m_pyrDisp.getLevelImage ( 0 );
}
        
bool
CPyramidalDynProgStereo::setPyramidParams ( unsigned int f_imageWidth_ui,
                                            unsigned int f_imageHeight_ui,
                                            unsigned int f_levels_ui )
{
    bool imageSizeChanged_b = ( f_imageWidth_ui  != m_pyrLeft.getWidth() ||
                                f_imageHeight_ui != m_pyrLeft.getHeight() );
    
    f_levels_ui = std::min(f_levels_ui, m_maxLevels_si);
    
    bool levelChanged_b = m_levels_ui != f_levels_ui;
    
    if ( imageSizeChanged_b || 
         levelChanged_b )
    {
        m_pyrLeft.setPyramidParams  (  f_imageWidth_ui,
                                       f_imageHeight_ui,
                                       f_levels_ui );

        m_levels_ui = m_pyrLeft.getLevels();

        m_pyrRight.setPyramidParams  ( f_imageWidth_ui,
                                       f_imageHeight_ui,
                                       m_levels_ui );

        m_pyrDisp.setPyramidParams  (  f_imageWidth_ui,
                                       f_imageHeight_ui,
                                       m_levels_ui );
    }

    if ( levelChanged_b || !m_dsi_p )
    {
        if ( m_dsi_p )
            delete [] m_dsi_p;
        
        m_dsi_p = new CTestFloatDSI_t[m_levels_ui];
        updateParamDsi();
    }
    else
        if ( imageSizeChanged_b )
        {
            updateParamDsi();
        }

    if ( !m_dynProg_p )
    {
        m_dynProg_p  = new CDynamicProgrammingOp[m_maxLevels_si];
        for (unsigned int i = 0; i < m_levels_ui; ++i)
        {
            m_dpParams_p[i].dp_p = m_dynProg_p + i;
            m_dpParams_p[i].applyParametersToObject();
        }
    }

    if ( levelChanged_b || 
         f_imageWidth_ui  != m_pyrLeft.getWidth() )
    {
        updateParamDynProgOp();
    }
    
    return true;
}

void 
CPyramidalDynProgStereo::updateParamDsi()
{
    S2D<int> dispRange = m_dispRange;
    
    dispRange.min--;
    dispRange.max++;    

    for (unsigned int i = 0 ; i < m_levels_ui; ++i)
    {
        const CFloatImage *temp_p = m_pyrDisp.getLevelImage ( i );
        
        m_dsi_p[i].setDSIType( CTestFloatDSI_t::CT_DUV );
        m_dsi_p[i].setImageSizes ( temp_p -> getWidth(), 
                                   temp_p -> getHeight(),
                                   dispRange.min - m_kernelSize.width/2,
                                   dispRange.max + m_kernelSize.width/2 );
        
        printf("For pyramid %i width: %i height: %i minDisp = %i maxDisp = %i\n",
               i,
               temp_p -> getWidth(), 
               temp_p -> getHeight(),
               dispRange.min - m_kernelSize.width/2,
               dispRange.max + m_kernelSize.width/2 );
        
        dispRange.min/=2;
        dispRange.max/=2;            
    }
}


void 
CPyramidalDynProgStereo::updateParamDynProgOp()
{    
    for (unsigned int i = 0 ; i < m_levels_ui; ++i)
    {
        /// The range is obtained from the dsi object since the image cost image size
        /// might differ from the disparity range stored in this class (because of 
        /// mask size).
        m_dynProg_p[i].setCostImageSize( m_dsi_p[i].getDisparitySpaceImage().dispRange_ui,
                                         m_dsi_p[i].getDisparitySpaceImage().width_ui );
    }        
}

bool
CPyramidalDynProgStereo::setDisparityRange ( S2D<int> f_dispRange )
{
    if ( !(m_dispRange == f_dispRange) )
    {
        printf("Setting disparity range to %i %i\n", f_dispRange.min, f_dispRange.max);
        
        m_dispRange = f_dispRange;

        updateParamDsi();
        updateParamDynProgOp();
    }

    return true;
}

S2D<int>
CPyramidalDynProgStereo::getDisparityRange ( ) const
{
    return m_dispRange;
}

bool
CPyramidalDynProgStereo::setKernelSize ( S2D<unsigned int> f_kernelSize )
{
    if ( m_kernelSize != f_kernelSize )
    {        
        m_kernelSize = f_kernelSize;
        
        updateParamDsi();
        updateParamDynProgOp();
    }    
    
    return true;
}


S2D<unsigned int> 
CPyramidalDynProgStereo::getKernelSize ( ) const
{
    return m_kernelSize;
}

CParameterSet *   
CPyramidalDynProgStereo::getParameterSet ( )
{
    static CParameterSet * m_paramSet_p;
    
    if ( !m_paramSet_p )
    {
        m_paramSet_p = new CParameterSet ( NULL );
        m_paramSet_p -> setName ( "Pyramidal Dyn Prog Stereo" );

        m_paramSet_p -> addParameter ( new CParameterGroup ( "Base parameters" ) );

        m_paramSet_p -> addParameter (
                new CUIntParameter ( "Pyramid Levels", 
                                     "Number of pyramid levels", 
                                     m_levels_ui,
                                     new CParameterConnector< CPyramidalDynProgStereo, unsigned int, CUIntParameter>
                                     ( this,
                                       &CPyramidalDynProgStereo::getLevels,
                                       &CPyramidalDynProgStereo::setLevels ) ) );
        
        m_paramSet_p -> addParameter (
                new CInt2DParameter ( "Disparity Range", "Min and max disparities [px].", 
                                      m_dispRange, 
                                      "Min", "Max",
                                      new CParameterConnector< CPyramidalDynProgStereo, S2D<int>, CInt2DParameter>
                                      ( this,
                                        &CPyramidalDynProgStereo::getDisparityRange,
                                        &CPyramidalDynProgStereo::setDisparityRange ) ) );

        m_paramSet_p -> addParameter (
                new CUInt2DParameter ( "Agg. Kernel Size", "Agregation kernel size.", 
                                      m_kernelSize, 
                                      "W", "H",
                                      new CParameterConnector< CPyramidalDynProgStereo, S2D<unsigned int>, CUInt2DParameter>
                                      ( this,
                                        &CPyramidalDynProgStereo::getKernelSize,
                                        &CPyramidalDynProgStereo::setKernelSize ) ) );

        m_paramSet_p -> addParameter ( new CParameterGroupEnd ( ) );
        
        //////////////////////

        m_paramSet_p -> addParameter ( new CParameterGroup ( "Dynamic Programming" ) );

        char str[256];
        for ( unsigned int i = 0; i < m_maxLevels_si; ++i )
        {
            snprintf(str, 256, "Level %u", i );
            m_paramSet_p -> addParameter ( new CParameterGroup ( str ) );

            snprintf(str, 256, "L%u Distance Cost", i ); 
            m_paramSet_p -> addParameter (
                    new CFloatParameter ( str,
                                          "Distance Cost for the DP [cost/px].",
                                          m_dpParams_p[i].getDistanceCost(),
                                          new CParameterConnector< CPyramidalDynProgStereo::SDynProgLevelParam, float, CFloatParameter>
                                          ( m_dpParams_p+i,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::getDistanceCost,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::setDistanceCost ) ) );

            snprintf(str, 256, "L%u Max Distance", i ); 
            m_paramSet_p -> addParameter (
                    new CFloatParameter ( str,
                                          "Max Distance for bounding jump cost [px].",
                                          m_dpParams_p[i].getMaxCostDist(),
                                          new CParameterConnector< CPyramidalDynProgStereo::SDynProgLevelParam, float, CFloatParameter>
                                          ( m_dpParams_p+i,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::getMaxCostDist,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::setMaxCostDist ) ) );

            snprintf(str, 256, "L%u Prediction Cost", i ); 
            m_paramSet_p -> addParameter (
                    new CFloatParameter ( str,
                                          "Prediction Cost [cost/px].",
                                          m_dpParams_p[i].getPredictionCost(),
                                          new CParameterConnector< CPyramidalDynProgStereo::SDynProgLevelParam, float, CFloatParameter>
                                          ( m_dpParams_p+i,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::getPredictionCost,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::setPredictionCost ) ) );

            snprintf(str, 256, "L%u Max Prediction Dist", i ); 
            m_paramSet_p -> addParameter (
                    new CFloatParameter ( str,
                                          "Max Distance for bounding prediction cost [px].",
                                          m_dpParams_p[i].getMaxPredictionDist(),
                                          new CParameterConnector< CPyramidalDynProgStereo::SDynProgLevelParam, float, CFloatParameter>
                                          ( m_dpParams_p+i,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::getMaxPredictionDist,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::setMaxPredictionDist ) ) );

            snprintf(str, 256, "L%u Initial Cost", i ); 
            m_paramSet_p -> addParameter (
                    new CFloatParameter ( str,
                                          "Initial Cost [cost/px].",
                                          m_dpParams_p[i].getInitialCost(),
                                          new CParameterConnector< CPyramidalDynProgStereo::SDynProgLevelParam, float, CFloatParameter>
                                          ( m_dpParams_p+i,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::getInitialCost,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::setInitialCost ) ) );

            snprintf(str, 256, "L%u Min Cost", i ); 
            m_paramSet_p -> addParameter (
                    new CFloatParameter ( str,
                                          "Min cost to consider. Cells with smaller cost are ignored [cost].",
                                          m_dpParams_p[i].getMinCost(),
                                          new CParameterConnector< CPyramidalDynProgStereo::SDynProgLevelParam, float, CFloatParameter>
                                          ( m_dpParams_p+i,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::getMinCost,
                                            &CPyramidalDynProgStereo::SDynProgLevelParam::setMinCost ) ) );

            snprintf(str, 256, "L%u Apply Median Filter", i ); 
            m_paramSet_p -> addParameter (
                    new CBoolParameter ( str,
                                         "Apply median filter to the resulting path?",
                                         m_dpParams_p[i].getApplyMedianFilter(),
                                         new CParameterConnector< CPyramidalDynProgStereo::SDynProgLevelParam, bool, CBoolParameter>
                                         ( m_dpParams_p+i,
                                           &CPyramidalDynProgStereo::SDynProgLevelParam::getApplyMedianFilter,
                                           &CPyramidalDynProgStereo::SDynProgLevelParam::setApplyMedianFilter ) ) );

            snprintf(str, 256, "L%u Follow Path Tolerance", i ); 
            m_paramSet_p -> addParameter (
                    new CIntParameter ( str,
                                         "Delta number of pixels to consider around the prediction [px].",
                                         m_dpParams_p[i].getApplyMedianFilter(),
                                         new CParameterConnector< CPyramidalDynProgStereo::SDynProgLevelParam, int, CIntParameter>
                                         ( m_dpParams_p+i,
                                           &CPyramidalDynProgStereo::SDynProgLevelParam::getFollowPathTolerance,
                                           &CPyramidalDynProgStereo::SDynProgLevelParam::setFollowPathTolerance ) ) );

            m_paramSet_p -> addParameter ( new CParameterGroupEnd ( ) );
        }
        
        m_paramSet_p -> addParameter ( new CParameterGroupEnd ( ) );
    }

    return m_paramSet_p;
}

bool
CPyramidalDynProgStereo::setLevels ( unsigned int f_levels_ui )
{
    return setPyramidParams ( m_pyrLeft.getWidth(),
                              m_pyrLeft.getHeight(),
                              f_levels_ui );
}

unsigned int
CPyramidalDynProgStereo::getLevels ( ) const
{
    return m_levels_ui;
}

