/**
***************************************************************************
* @file dataQueueTest.cpp
* Source file defining tests for dlr::thread::DataQueue.
*
* Copyright (C) 2006 David LaRose, dlr@cs.cmu.edu
* See accompanying file, LICENSE.TXT, for details.
*
* $Revision: 929 $
* $Date: 2007-05-21 18:57:00 -0400 (Mon, 21 May 2007) $
***************************************************************************
**/

#include <string>

#include <dlrPortability/timeUtilities.h>
#include <dlrRandom/pseudoRandom.h>
#include <dlrTest/testFixture.h>
#include <dlrThread/dataQueue.h>
#include <dlrThread/threadFunctor.h>

using namespace dlr::thread;
using dlr::portability::getCurrentTime;
using dlr::portability::portableSleep;

namespace dlr {


  class DataQueueTest : public TestFixture<DataQueueTest> {

  public:

    DataQueueTest();
    ~DataQueueTest() {}

    void setUp(const std::string& testName) {}
    void tearDown(const std::string& testName) {}

    // Tests of member functions.
    void testNormalFunction0();
    void testNormalFunction1();
    void testGreedyConsumer();
    void testSlowConsumer();
    void testSynchronize();
    void testSynchronizeTimeout();
    void testSynchronizeQuorum();

  }; // class DataQueueTest


  enum ThreadExitStatus {DQT_NOSTATUS,
                         DQT_NORMAL,
                         DQT_OVERFLOWEXCEPTION,
                         DQT_STATEEXCEPTION};
                         
  
  class TestThreadFunctor
    : public ThreadFunctor
  {
  public:
    TestThreadFunctor(const std::string& threadName,
                      const DataQueue< std::pair<size_t, double> >& dataQueue0,
                      size_t numberOfIterations,
                      double minSleep,
                      double maxSleep,
                      int verbosity)
      : ThreadFunctor(threadName),
        m_dataQueue(dataQueue0),
        m_maxSleep(maxSleep),
        m_minSleep(minSleep),
        m_numberOfIterations(numberOfIterations),
        m_pseudoRandom(),
        m_verbosity(verbosity),
        m_exitStatus(DQT_NOSTATUS) {
      // Sleep for a millisecond so that subsequent TestThreadFunctor
      // instances will get a different seed for m_pseudoRandom.
      portableSleep(0.001);
    }

    
    ThreadExitStatus
    getExitStatus() {return m_exitStatus;}

    virtual void
    main() {
      try {
        this->testMain();
        this->m_exitStatus = DQT_NORMAL;
      } catch(const dlr::thread::OverflowException& caughtException) {
        if(m_verbosity) {
          std::cout << "[" << this->getThreadName() << "] "
		    << "Caught an exception:\n"
                    << caughtException.what() << "\n" << std::endl;
        }
        this->m_exitStatus = DQT_OVERFLOWEXCEPTION;
      } catch(const StateException& caughtException) {
        if(m_verbosity) {
          std::cout << "[" << this->getThreadName() << "] "
		    << "Caught an exception:\n"
                    << caughtException.what() << "\n" << std::endl;
        }
        this->m_exitStatus = DQT_STATEEXCEPTION;
      }
    }
    
    virtual void
    testMain() = 0;
    
  protected:
    DataQueue< std::pair<size_t, double> > m_dataQueue;
    double m_maxSleep;
    double m_minSleep;
    size_t m_numberOfIterations;
    PseudoRandom m_pseudoRandom;
    int m_verbosity;
    
    ThreadExitStatus m_exitStatus;
  };


  
  class ProducerThreadFunctor
    : public TestThreadFunctor
  {
  public:
    ProducerThreadFunctor(
      const std::string& threadName,
      const DataQueue< std::pair<size_t, double> >& dataQueue0,
      size_t numberOfIterations,
      double minSleep,
      double maxSleep,
      int verbosity=0)
      : TestThreadFunctor(threadName, dataQueue0, numberOfIterations,
                          minSleep, maxSleep, verbosity) {}
        
    virtual void
    testMain() {
      DataQueue< std::pair<size_t, double> >::ClientID clientID;
      m_dataQueue.registerClient(clientID, m_threadName);
      for(size_t index0 = 0; index0 < m_numberOfIterations; ++index0) {
        m_dataQueue.pushFront(
          clientID, std::make_pair(index0, getCurrentTime()));
        if(m_verbosity >= 2) {
          std::cout << m_threadName << ": pushFront(" << index0 << ")"
                    << std::endl;
        }

        double sleepTime = m_pseudoRandom.uniform(m_minSleep, m_maxSleep);
        if(m_verbosity >= 2) {
          std::cout << m_threadName << ": sleep(" << sleepTime << ")"
                    << std::endl;
        }
        portableSleep(sleepTime);
      }
    }
  };


  class ConsumerThreadFunctor
    : public TestThreadFunctor
  {
  public:
    ConsumerThreadFunctor(
      const std::string& threadName,
      const DataQueue< std::pair<size_t, double> >& dataQueue0,
      size_t numberOfIterations,
      double minSleep,
      double maxSleep,
      int verbosity=0,
      bool sleepWhileLocking=false)
      : TestThreadFunctor(threadName, dataQueue0, numberOfIterations,
                          minSleep, maxSleep, verbosity),
        m_sleepWhileLocking(sleepWhileLocking) {}
    
    virtual void
    testMain() {
      DataQueue< std::pair<size_t, double> >::ClientID clientID;
      m_dataQueue.registerClient(clientID, m_threadName);
      for(size_t index0 = 0; index0 < m_numberOfIterations; ++index0) {
        std::pair<size_t, double> elementValue;

        if(m_sleepWhileLocking) {
          m_dataQueue.lockBack(clientID, 3 * m_maxSleep);
          if(m_verbosity >= 2) {
            std::cout << m_threadName << ": lock()" << std::endl;
          }
        }

        try {
          if(m_dataQueue.copyBack(clientID, elementValue, 30 * m_maxSleep)
             == false) {
            DLR_THROW(StateException, "ConsumerThreadFunctor::testMain()",
                      "Producer appears to have died.");
          }
          if(m_verbosity >= 2) {
            std::cout << m_threadName << ": copyBack(" << elementValue.first
                      << ")" << std::endl;
          }
          
          // Check that we got the right packet.
          if(elementValue.first != index0) {
            DLR_THROW(LogicException, "ConsumerThreadFunctor::testMain()",
                      "Bad data from m_dataQueue.");
          }
          
          this->adaptiveSleep(elementValue.second);
        } catch(...) {
          if(m_sleepWhileLocking) {
            m_dataQueue.unlockBack(clientID);
          }
          throw;
        }

        if(m_sleepWhileLocking) {
          m_dataQueue.unlockBack(clientID);
          if(m_verbosity >= 2) {
            std::cout << m_threadName << ": unlock()" << std::endl;
          }
        }
        
        m_dataQueue.popBack(clientID);        
      }
    }

    
    void
    adaptiveSleep(double elementTimestamp) {
      // Adjust sleep times to make sure we don't fall too far
      // behind in the queue.
      double elementAge = getCurrentTime() - elementTimestamp;
      double queueLengthBound = elementAge / m_minSleep;
      double slackTime =
        (m_dataQueue.getMaximumLength() - queueLengthBound - 1) * m_minSleep;
      if(slackTime < m_minSleep) {
        slackTime = m_minSleep;
      }
      double safeMaxSleepTime = std::min(m_maxSleep, slackTime);
      double safeMinSleepTime = m_minSleep;

      // Sleep.
      double sleepTime =
        m_pseudoRandom.uniform(safeMinSleepTime, safeMaxSleepTime);
      if(m_verbosity >= 2) {
        std::cout << m_threadName << ": sleep(" << sleepTime << ")"
                  << std::endl;
      }
      portableSleep(sleepTime);
    }

  private:
    bool m_sleepWhileLocking;
    
  };



  class SyncProducerThreadFunctor
    : public TestThreadFunctor
  {
  public:
    SyncProducerThreadFunctor(
      const std::string& threadName,
      const DataQueue< std::pair<size_t, double> >& dataQueue0,
      size_t numberOfIterations,
      double minSleep,
      double maxSleep,
      int verbosity=0,
      int numberToSync=-1)
      : TestThreadFunctor(threadName, dataQueue0, numberOfIterations,
                          minSleep, maxSleep, verbosity),
        m_numberToSync(numberToSync) {}
        
    virtual void
    testMain() {
      DataQueue< std::pair<size_t, double> >::ClientID clientID;
      m_dataQueue.registerClient(clientID, m_threadName);
      for(size_t index0 = 0; index0 < m_numberOfIterations; ++index0) {
        double syncIndexDbl =
          m_pseudoRandom.uniform(1.0, m_numberOfIterations);
        size_t syncIndex = static_cast<size_t>(syncIndexDbl);
        if(m_verbosity >= 2) {
          std::cout << m_threadName << ": sync at " << syncIndex << std::endl;
        }
        for(size_t index1 = 0; index1 < syncIndex; ++index1) {
          m_dataQueue.pushFront(
            clientID, std::make_pair(index1, syncIndexDbl));
          if(m_verbosity >= 3) {
            std::cout << m_threadName << ": pushFront(" << index1 << ")"
                      << std::endl;
          }
          portableSleep(m_minSleep);
        }
        if(m_verbosity >= 3) {
          std::cout << m_threadName << ": synchronize()..." << std::endl;
        }
        if(m_numberToSync == -1) {
          if(m_dataQueue.synchronize(clientID, 1.0) == false) {
            std::ostringstream message;
            message << m_threadName << ": couldn't sync.";
            DLR_THROW(StateException, "SyncProducerThreadFunctor::testMain()",
                      message.str().c_str());
          }
        } else {
          if(m_dataQueue.synchronizeQuorum(
               clientID, static_cast<size_t>(m_numberToSync), 1.0)
             == false) {
            std::ostringstream message;
            message << m_threadName << ": couldn't sync.";
            DLR_THROW(StateException, "SyncProducerThreadFunctor::testMain()",
                      message.str().c_str());
          }
        }
        for(size_t index1 = syncIndex; index1 < m_numberOfIterations;
            ++index1) {
          m_dataQueue.pushFront(
            clientID, std::make_pair(index1, syncIndexDbl));
          if(m_verbosity >= 3) {
            std::cout << m_threadName << ": pushFront(" << index1 << ")"
                      << std::endl;
          }
          portableSleep(m_minSleep);
        }
      }
    }

  private:
    int m_numberToSync;
  };


  class SyncConsumerThreadFunctor
    : public TestThreadFunctor
  {
  public:
    SyncConsumerThreadFunctor(
      const std::string& threadName,
      const DataQueue< std::pair<size_t, double> >& dataQueue0,
      size_t numberOfIterations,
      double minSleep,
      double maxSleep,
      int verbosity=0,
      bool doSync=true,
      int numberToSync=-1)
      : TestThreadFunctor(threadName, dataQueue0, numberOfIterations,
                          minSleep, maxSleep, verbosity),
        m_doSync(doSync),
        m_numberToSync(numberToSync) {}

    
    virtual void
    testMain() {
      DataQueue< std::pair<size_t, double> >::ClientID clientID;
      m_dataQueue.registerClient(clientID, m_threadName);

      for(size_t index0 = 0; index0 < m_numberOfIterations; ++index0) {
        std::pair<size_t, double> elementValue;
        if(m_dataQueue.copyBack(clientID, elementValue, 30 * m_maxSleep)
           == false) {
          DLR_THROW(StateException, "SyncConsumerThreadFunctor::testMain()",
                    "Producer appears to have died.");
        }
        m_dataQueue.popBack(clientID);

        double syncIndexDbl = elementValue.second;
        size_t syncIndex = static_cast<size_t>(syncIndexDbl);        
        if(m_doSync) {
          if(m_verbosity >= 2) {
            std::cout << m_threadName << ": synchronize()..." << std::endl;
          }
          if(m_numberToSync == -1) {
            if(m_dataQueue.synchronize(clientID, 1.0) == false) {
              std::ostringstream message;
              message << m_threadName << ": couldn't sync.";
              DLR_THROW(StateException,
                        "SyncConsumerThreadFunctor::testMain()",
                        message.str().c_str());
            }
          } else {
            if(m_dataQueue.synchronizeQuorum(
                 clientID, static_cast<size_t>(m_numberToSync), 1.0)
               == false) {
              std::ostringstream message;
              message << m_threadName << ": couldn't sync.";
              DLR_THROW(StateException,
                        "SyncConsumerThreadFunctor::testMain()",
                        message.str().c_str());
            }
          }
        } else {
          syncIndex = 1;
        }
          
        for(size_t index1 = syncIndex; index1 < m_numberOfIterations;
            ++index1) {
          if(m_dataQueue.copyBack(clientID, elementValue, 30 * m_maxSleep)
             == false) {
            DLR_THROW(StateException, "SyncConsumerThreadFunctor::testMain()",
                      "Producer appears to have died.");
          }
          if(m_verbosity >= 3) {
            std::cout << m_threadName << ": copyBack(" << index1 << ")"
                      << std::endl;
          }
          if(elementValue.first != index1) {
            std::ostringstream message;
            message << "Syncing since " << syncIndex << "; "
                    << "expecting " << index1 << "; "
                    << "got " << elementValue.first << ".  "
                    << "Sync appears not to have worked correctly.";
            DLR_THROW(ValueException, "SyncConsumerThreadFunctor::testMain()",
                      message.str().c_str());
          }
          m_dataQueue.popBack(clientID);
          portableSleep(m_minSleep);
        }
      }
    }

  private:
    bool m_doSync;
    int m_numberToSync;
  };

  

  /* ============== Member Function Definititions ============== */
    
  DataQueueTest::
  DataQueueTest()
    : TestFixture<DataQueueTest>("DataQueueTest")
  {
    // DLR_TEST_REGISTER_MEMBER(testNormalFunction0);
    // DLR_TEST_REGISTER_MEMBER(testNormalFunction1);
    // DLR_TEST_REGISTER_MEMBER(testGreedyConsumer);
    // DLR_TEST_REGISTER_MEMBER(testSlowConsumer);
    // DLR_TEST_REGISTER_MEMBER(testSynchronize);
    DLR_TEST_REGISTER_MEMBER(testSynchronizeQuorum);
  }


  void
  DataQueueTest::
  testNormalFunction0()
  {
    DataQueue< std::pair<size_t, double> > dataQueue(10, 2);
    ConsumerThreadFunctor consumerThread(
      "Consumer0", dataQueue, 100, 0.01, 0.1, 1);
    ProducerThreadFunctor producerThread(
      "Producer0", dataQueue, 100, 0.01, 0.1, 1);

    consumerThread.run();
    producerThread.run();
    producerThread.join();
    consumerThread.join();

    DLR_TEST_ASSERT(consumerThread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(producerThread.getExitStatus() == DQT_NORMAL);
  }


  void
  DataQueueTest::
  testNormalFunction1()
  {
    DataQueue< std::pair<size_t, double> > dataQueue(10, 2);
    ConsumerThreadFunctor consumer0Thread(
      "Consumer0", dataQueue, 1000, 0.001, 0.01, 1);
    ConsumerThreadFunctor consumer1Thread(
      "Consumer1", dataQueue, 1000, 0.001, 0.01, 1);
    ConsumerThreadFunctor consumer2Thread(
      "Consumer2", dataQueue, 1000, 0.001, 0.01, 1);
    ConsumerThreadFunctor consumer3Thread(
      "Consumer3", dataQueue, 1000, 0.001, 0.01, 1);
    ProducerThreadFunctor producerThread(
      "Producer0", dataQueue, 1000, 0.001, 0.01, 1);

    consumer0Thread.run();
    consumer1Thread.run();
    consumer2Thread.run();
    consumer3Thread.run();
    producerThread.run();
    producerThread.join();
    consumer3Thread.join();
    consumer2Thread.join();
    consumer1Thread.join();
    consumer0Thread.join();

    DLR_TEST_ASSERT(consumer0Thread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(consumer1Thread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(consumer2Thread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(consumer3Thread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(producerThread.getExitStatus() == DQT_NORMAL);
  }


  void
  DataQueueTest::
  testGreedyConsumer()
  {
    DataQueue< std::pair<size_t, double> > dataQueue(10, 2);
    ConsumerThreadFunctor consumer0Thread(
      "Consumer0", dataQueue, 100, 0.05, 0.05, 0, true);
    ProducerThreadFunctor producerThread(
      "Producer0", dataQueue, 100, 0.001, 0.005, 0);

    consumer0Thread.run();
    producerThread.run();
    producerThread.join();
    consumer0Thread.join();

    if(producerThread.getExitStatus() == DQT_NORMAL) {
      DLR_TEST_ASSERT(consumer0Thread.getExitStatus()
                      == DQT_OVERFLOWEXCEPTION);
    } else {
      DLR_TEST_ASSERT(producerThread.getExitStatus() == DQT_OVERFLOWEXCEPTION);
      DLR_TEST_ASSERT(consumer0Thread.getExitStatus() == DQT_STATEEXCEPTION);
    }
  }


  void
  DataQueueTest::
  testSlowConsumer()
  {
    DataQueue< std::pair<size_t, double> > dataQueue(10, 2);
    ConsumerThreadFunctor consumer0Thread(
      "Consumer0", dataQueue, 100, 0.1, 0.1, 0);
    ProducerThreadFunctor producerThread(
      "Producer0", dataQueue, 100, 0.001, 0.005, 0);

    consumer0Thread.run();
    producerThread.run();
    producerThread.join();
    consumer0Thread.join();

    if(producerThread.getExitStatus() == DQT_NORMAL) {
      DLR_TEST_ASSERT(consumer0Thread.getExitStatus()
                      == DQT_OVERFLOWEXCEPTION);
    } else {
      DLR_TEST_ASSERT(producerThread.getExitStatus() == DQT_OVERFLOWEXCEPTION);
      DLR_TEST_ASSERT(consumer0Thread.getExitStatus() == DQT_STATEEXCEPTION);
    }
  }


  void
  DataQueueTest::
  testSynchronize()
  {
    DataQueue< std::pair<size_t, double> > dataQueue(50, 2);
    SyncConsumerThreadFunctor consumer0Thread(
      "Consumer0", dataQueue, 50, 0.001, 0.001, 1);
    SyncConsumerThreadFunctor consumer1Thread(
      "Consumer1", dataQueue, 50, 0.001, 0.001, 1);
    SyncConsumerThreadFunctor consumer2Thread(
      "Consumer2", dataQueue, 50, 0.001, 0.001, 1);
    SyncProducerThreadFunctor producerThread(
      "Producer0", dataQueue, 50, 0.001, 0.001, 1);

    consumer0Thread.run();
    consumer1Thread.run();
    consumer2Thread.run();
    producerThread.run();
    producerThread.join();
    consumer2Thread.join();
    consumer1Thread.join();
    consumer0Thread.join();

    DLR_TEST_ASSERT(producerThread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(consumer0Thread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(consumer1Thread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(consumer2Thread.getExitStatus() == DQT_NORMAL);
  }


  void
  DataQueueTest::
  testSynchronizeTimeout()
  {
    // Empty.
  }


  void
  DataQueueTest::
  testSynchronizeQuorum()
  {
    DataQueue< std::pair<size_t, double> > dataQueue(50, 2);
    SyncConsumerThreadFunctor consumer0Thread(
      "Consumer0", dataQueue, 50, 0.001, 0.001, 1, true, 3);
    SyncConsumerThreadFunctor consumer1Thread(
      "Consumer1", dataQueue, 50, 0.001, 0.001, 1, true, 3);
    SyncConsumerThreadFunctor consumer2Thread(
      "Consumer2", dataQueue, 50, 0.001, 0.001, 1, false);
    SyncProducerThreadFunctor producerThread(
      "Producer0", dataQueue, 50, 0.001, 0.001, 1, 3);

    consumer0Thread.run();
    consumer1Thread.run();
    consumer2Thread.run();
    producerThread.run();
    producerThread.join();
    consumer2Thread.join();
    consumer1Thread.join();
    consumer0Thread.join();

    DLR_TEST_ASSERT(producerThread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(consumer0Thread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(consumer1Thread.getExitStatus() == DQT_NORMAL);
    DLR_TEST_ASSERT(consumer2Thread.getExitStatus() == DQT_NORMAL);
  }


  
} // namespace dlr


#if 0

int main(int argc, char** argv)
{
  dlr::DataQueueTest currentTest;
  bool result = currentTest.run();
  return (result ? 0 : 1);
}

#else

namespace {

  dlr::DataQueueTest currentTest;

}

#endif
