/**
***************************************************************************
* @file opticalFlowTest.cpp
*
* Source file defining tests for the OpticalFlow class template.
*
* Copyright (C) 2006 David LaRose, dlr@cs.cmu.edu
* See accompanying file, LICENSE.TXT, for details.
*
* $Revision: $
* $Date: $
***************************************************************************
**/

#include <dlrComputerVision/test/testImages.h>
#include <dlrComputerVision/opticalFlow.h>
#include <dlrComputerVision/imageIO.h>
#include <dlrNumeric/subArray2D.h>
#include <dlrTest/testFixture.h>

namespace dlr {

  namespace computerVision {
    
    class OpticalFlowTest
      : public TestFixture<OpticalFlowTest> {

    public:

      OpticalFlowTest();
      ~OpticalFlowTest() {}

      void setUp(const std::string& testName) {}
      void tearDown(const std::string& testName) {}

      // Tests.
      void testOpticalFlow0();
      void testOpticalFlow1();
      void testOpticalFlow2();

    private:

      void
      runTest(const Image<GRAY8>& inputImage,
              const Index2D& corner0,
              const Index2D& corner1,
              double tolerance,
              double sigma = 2.5,
              size_t ignoreBegin = 0,
              size_t ignoreEnd = 0);
      
      Float64 m_defaultTolerance0;
      Float64 m_defaultTolerance1;
      
    }; // class OpticalFlowTest


    /* ============== Member Function Definititions ============== */

    OpticalFlowTest::
    OpticalFlowTest()
      : TestFixture<OpticalFlowTest>("OpticalFlowTest"),
        m_defaultTolerance0(1.0E-8),
        m_defaultTolerance1(0.5)
    {
      DLR_TEST_REGISTER_MEMBER(testOpticalFlow0);
      DLR_TEST_REGISTER_MEMBER(testOpticalFlow1);
      DLR_TEST_REGISTER_MEMBER(testOpticalFlow2);
    }


    void
    OpticalFlowTest::
    testOpticalFlow0()
    {
      // Define an arbitrary shift to recover.
      const Index2D shift(-1, -2);

      // Set up input to optical flow algorithm.
      Image<GRAY8> inputImage(Array2D<UnsignedInt8>("[[0, 0], [0, 0]]"));
      Image<GRAY8> shiftedImage(Array2D<UnsignedInt8>("[[3, 4], [2, 1]]"));
      Array2D<Float64> dIdx("[[3.0, 2.0], [0.0, 0.0]]");
      Array2D<Float64> dIdy("[[-3.0, 0.0], [2.0, 1.0]]");

      // Verify that the OpticalFlow class recovers our shift with
      // reasonable accuracy.
      Index2D corner0(0, 0);
      Index2D corner1(2, 2);
      bool valid = false;
      OpticalFlow<GRAY8> opticalFlow(inputImage, shiftedImage, dIdx, dIdy);
      Vector2D recoveredShift = opticalFlow.getFlow(corner0, corner1, valid);

      DLR_TEST_ASSERT(valid == true);
      DLR_TEST_ASSERT(
	approximatelyEqual(static_cast<double>(shift.getColumn()),
                           recoveredShift.x(), m_defaultTolerance0));
      DLR_TEST_ASSERT(
	approximatelyEqual(static_cast<double>(shift.getRow()),
                           recoveredShift.y(), m_defaultTolerance0));
    }


    void
    OpticalFlowTest::
    testOpticalFlow1()
    {
      // Load input image.
      Image<GRAY8> inputImage(50, 75);
      for(int row = 0; row < static_cast<int>(inputImage.rows()); ++row) {
        for(int column = 0;
            column < static_cast<int>(inputImage.columns() / 2);
            ++column) {
          inputImage(row, column) =
            static_cast<UnsignedInt8>(2 * row + column);
        }
        for(int column = static_cast<int>(inputImage.columns() / 2);
            column < static_cast<int>(inputImage.columns());
            ++column) {
          inputImage(row, column) = static_cast<UnsignedInt8>(
            2 * row + column +
            (column -  static_cast<int>(inputImage.columns() / 2)));
        }
      }

      size_t ignoreBegin = inputImage.columns() / 2 - 3;
      size_t ignoreEnd = inputImage.columns() / 2 + 4;
      
      this->runTest(inputImage, Index2D(20, 25), Index2D(30, 50),
                    m_defaultTolerance0, 0.0, ignoreBegin, ignoreEnd);
    }


    void
    OpticalFlowTest::
    testOpticalFlow2()
    {
      // Load input image.
      Image<GRAY8> inputImage = readPGM8(getTestImageFileNamePGM0());
      if(inputImage.rows() < 500 || inputImage.columns() < 500) {
	DLR_THROW(LogicException, "OpticalFlowTest::testOpticalFlow()",
		  "Input image is too small.");
      }
      this->runTest(inputImage, Index2D(100, 100), Index2D(400, 400),
                    m_defaultTolerance1);
    }


    void
    OpticalFlowTest::
    runTest(const Image<GRAY8>& inputImage,
            const Index2D& corner0,
            const Index2D& corner1,
            double tolerance,
            double sigma,
            size_t ignoreBegin,
            size_t ignoreEnd)
    {
      Array2D<bool> ignoreFlags(inputImage.rows(), inputImage.columns());
      ignoreFlags = false;
      for(size_t column = ignoreBegin; column < ignoreEnd; ++column) {
        for(size_t row = 0; row < inputImage.rows(); ++row) {
          ignoreFlags(row, column) = true;
        }
      }

      for(int shiftX = 0; shiftX < 4; ++shiftX) {
        for(int shiftY = 0; shiftY < 4; ++shiftY) {

          // Synthesize an image which matches our chosen shift.
          Image<GRAY8> shiftedImage(inputImage.rows(), inputImage.columns());
          shiftedImage = 0;
          typedef dlr::numeric::Slice Slice;
          dlr::numeric::subArray(shiftedImage, Slice(shiftY, 0),
                                 Slice(shiftX, 0))
            = dlr::numeric::subArray(inputImage, Slice(0, -shiftY),
                                     Slice(0, -shiftX));
          bool valid = false;
          OpticalFlow<GRAY8> opticalFlow(inputImage, shiftedImage, sigma);
          opticalFlow.ignoreArea(ignoreFlags);
          Vector2D recoveredShift =
            opticalFlow.getFlow(corner0, corner1, valid);

          DLR_TEST_ASSERT(valid == true);
          DLR_TEST_ASSERT(
            approximatelyEqual(static_cast<double>(shiftX),
                               recoveredShift.x(), tolerance));
          DLR_TEST_ASSERT(
            approximatelyEqual(static_cast<double>(shiftY),
                               recoveredShift.y(), m_defaultTolerance1));
        }
      }

    }
    
  } // namespace computerVision

} // namespace dlr


#if 0

int main(int argc, char** argv)
{
  dlr::computerVision::OpticalFlowTest currentTest;
  bool result = currentTest.run();
  return (result ? 0 : 1);
}

#else

namespace {

  dlr::computerVision::OpticalFlowTest currentTest;

}

#endif
