#include <stdlib.h>
#include <iostream>
#include <iomanip>
#include <math.h>

#include "Geometry.h"
#include "OppModelSet.h"
#include "CoachParam.h"
#include "Logger.h"
using namespace spades;
//using namespace std;

#define DEBUG(x)
#define DEBUG2(x)

//DEB: Anywhere you want to use (*pBlah).elem you can use pBlah->elem
// which is much easier

OppModelSet::OppModelSet(void) {
  numModels = 0;
  probs = new LogDouble[CoachParam::instance()->getSppOmNumModels()];
  model = new OppModel*[CoachParam::instance()->getSppOmNumModels()];
  if (probs == NULL || model == NULL)
    errorlog << "OppModelSet did not allocate correctly" << ende;
  
  int i = 0;
  for (i = 0; i < CoachParam::instance()->getSppOmNumModels(); ++i) {
    model[i] = NULL;
    probs[i] = LogDouble(-1.0);
  }
  stored_best_OM = -1;
}

OppModelSet::~OppModelSet(void) {
  delete [] probs;
  int i = 0;
  for (i = 0; i < numModels; ++i) {
    if (model[i] != NULL) {
      delete model[i];
    }
  }
  delete [] model;
}

double OppModelSet::weight(const double distance) {
  return Max(0.0, (1 - distance/CoachParam::instance()->getSppOmMaxPlayerDist()));
}

void OppModelSet::setProb(int index, double newValue) {
  actionlog(100) << "Updating probability of model " << index << " to " << newValue << ende;
  if (index >= numModels) errorlog << "model index given out of bound" << ende;
  else if (model[index] == NULL) errorlog << "model index does not point to a model" << ende;
  else if (model[index] != NULL) {
    probs[index] = LogDouble(newValue);
//      if (newValue <= 1) {
//        actionlog(120) << "multiplying everything by 1000" << ende;
//        for (int i = 0; i < numModels; ++i) 
//  	probs[i] = 1000*probs[i];
//        printProbs();
//      }
  }
}


void OppModelSet::multProb(int index, double newValue) {
  actionlog(100) << "Multiplying  probability of model " << index << " by " << newValue << ende;
  if (index >= numModels) errorlog << "model index given out of bound" << ende;
  else if (model[index] == NULL) errorlog << "model index does not point to a model" << ende;
  else if (model[index] != NULL) {
    probs[index] *= newValue;
//      if (newValue <= 1) {
//        actionlog(120) << "multiplying everything by 1000" << ende;
//        for (int i = 0; i < numModels; ++i) 
//  	probs[i] = 1000*probs[i];
//        printProbs();
//        actionlog(120) << "the probability of model %d is %f", index, probs[index] << ende;
//      }
  }
}

void OppModelSet::normalize(void) {
  actionlog(100) << "Normalizing distribution" << ende;
  int bestOM = -1;
  LogDouble bestOMprob(-HUGE);
  int counter = 0;
  LogDouble total(0.0);
  for(counter = 0; counter < numModels; ++counter) {
    //cout << probs[counter] << "\n";
    total += probs[counter];
    if (probs[counter] > bestOMprob) {
      bestOM = counter;
      bestOMprob = probs[counter];
    }
  }
  if (total.isZero()){
    errorlog << "all probabilities of opp models are zero" << ende;
    bestOM = 0;
    for (int i = 0; i < numModels; ++i) {
      probs[i] = LogDouble(1.0)/LogDouble((double)numModels);
      /*actionlog(120) << "probability of model %d is %f", i,
	(double)(1.0/(double)numModels) << ende;*/
    }
  } else if (!total.isInf()) {
    for(counter = 0; counter < numModels; ++counter) 
      probs[counter] /= total;
  } else {
    errorlog << "sum of probs is not finite, resetting to uniform" << ende;
    bestOM = 0;
    for (int i = 0; i < numModels; ++i) {
      probs[i] = LogDouble(1.0)/LogDouble((double)numModels);
    }
  }

  stored_best_OM = bestOM;
}

void OppModelSet::setToUniform(void) {
  for (int i=0; i<numModels; i++)
    probs[i] = LogDouble(1.0);
  normalize();
}



//DEB: You shouldn't need before and after here, if really think that you 
// you do, let me know
void OppModelSet::calculate(MovementObservation* o, 
			    PlayerDistribution* before,
			    PlayerDistribution* after,
			    ostream* pDistOut) {
  actionlog(100) << "Calculating probability of an observation" << ende;
  //we used to pass time into predictMovement, but I don't think we actually need it anymore;
  //int time = o->getTime();
  BallMovement* bm = o->getBallMovement();
  
  VecPosition* playersStart = o->getStartLocs();
  before->setInitial(playersStart);
  int k = 0;
  actionlog(120) << "Start Positions of players in this observation" << ende;
  for (k = 0; k < ServerParam::instance()->getSPTeamSize(); ++k) {
    actionlog(140) << "Starting position of player " << k+1 << playersStart[k] <<  ende;
  }
  
  VecPosition finPos = bm->getFinalPosition();
  actionlog(120) <<  "Final Ball Position: " << finPos << ende;

  VecPosition* players = o->getEndLocs();
  for (int i = 0; i < ServerParam::instance()->getSPTeamSize(); ++i) {
    actionlog(140) << "Player " << i+1 << "'s end location: " << players[i] << ende;
  }

#ifdef NEVER
  VecPosition relLocs[ServerParam::instance()->getSPTeamSize()];
  actionlog(120) << "Finding locations of the player relative to the ball" << ende;
  for (int i = 0; i < ServerParam::instance()->getSPTeamSize(); ++i) {
    relLocs[i].x = fabs(players[i].x - finPos.x);
    relLocs[i].y = fabs(players[i].y - finPos.y);
    actionlog(140) << "Player " << i+1 << "'s end location: " << players[i] << ende;
    actionlog(140) << "Player " << i+1 << "'s relative location: " << relLocs[i] << ende;
  }
#endif
       
  double distance = HUGE;
  double distances[ServerParam::instance()->getSPTeamSize()];
  actionlog(120) << "Finding distances to ball and shortest distance to ball" << ende;
  for (int i = 0; i < ServerParam::instance()->getSPTeamSize(); i++) {
    double dist = players[i].getDistanceTo(finPos);
    distances[i] = dist;
    if (dist < distance) distance = dist;
    actionlog(140) << "Distance of player " << i+1 << ": " << distances[i]
		   << ", shortest distance: " << distance << ende;
  }

  actionlog(120) << "Subracting shortest distance from all distances" << ende;
  for (int i = 0; i < ServerParam::instance()->getSPTeamSize(); ++i) {
    distances[i] = distances[i] - distance;
    actionlog(140) << "Updated distance of player " << i+1 << ": " << distances[i] << ende;
  }

  actionlog(140) << "Calculating weights" << ende;
  double weights[ServerParam::instance()->getSPTeamSize()];
  for (int i = 0; i < ServerParam::instance()->getSPTeamSize(); ++i) {
    weights[i] = weight(distances[i]);
    actionlog(140) << "Weight of player " << i+1 << ": " << weights[i] << ende;
  }

  actionlog(120) << "Total time: " << bm->getTotalTime() << ende;

  LogDouble probObs[numModels];
  for (int k = 0; k < numModels; ++k)
    probObs[k] = LogDouble(1.0);

  for (int j = 0; j < ServerParam::instance()->getSPTeamSize(); ++j) {

    for (int i = 0; i < numModels; ++i) {
      OppModel* current = getOM(i);

      //we used to pass time into here, but I don't think we actually need it anymore;
      current->predictMovement(before, after, bm);
      if (pDistOut) {
	*pDistOut << "model " << i
		  << "\tplayer " << j
		  << "\ntime: " << bm->getTotalTime()
		  << endl;
	after->getPlayDist(j+1)->PrintSimplifiedBins(*pDistOut, before->getPlayDist(j+1)->GetMean());
      }
      
      double w;
      if (CoachParam::instance()->getSppOmUseWeights())
	w = weights[j];
      else
	w = 1.0;
      LogDouble p = LogDouble(after->getValue(j+1, players[j]));
      if (p.isZero()) actionlog(120) << "player " << j+1 << " probability was exactly zero" << ende;
      VecPosition mean = after->getPlayDist(j+1)->GetMean();
      actionlog(140) << "unweighted p for player " << j+1 << ": " << p.toDouble()
		     << " (mean: " << mean << ")" << ende;
      //I use this form because it memorizes distortion	in the LogDouble class
      probObs[i] *= p*w + (1-w); 
      if (probObs[i].isZero()) 
	actionlog(120) << "The probability of the observation for model " << i
		       << " is now exactly zero" << ende;
      if (probObs[i] <= 1e-16) {
	bool multAll = true;
	for (int x = 0; x < numModels; ++x) {
	  if (probObs[x] > 1e16) {
	    actionlog(120) << "would multiply all probs, but prob model " << x
			   << " too big " << probObs[i].toDouble() << ende;
	    multAll = false;
	    break;
	  }
	}
	
	if (multAll) {
	  actionlog(120) << "multiplying all probs by 1000" << ende;
	  for (int x = 0; x < numModels; ++x) 
	    probObs[x] *= 1000;
	}
	
      }
      actionlog(140) << "Probability of observation for model " << i << " thus far: "
		     << probObs[i].toDouble() << ende;
      after->clear();
    }
  }

  for(int y = 0; y < numModels; ++y) 
    probs[y] *= probObs[y];
  normalize();
  for (int sp = 0; sp < numModels; ++sp) 
    probs[sp] += CoachParam::instance()->getSppOmWeightBoost();
  normalize();
  before->clear();
  for (int z = 0; z < numModels; z++) 
    actionlog(120) << "the probability of model " << z << "  is "
		   << probs[z].toDouble() << ende;
}


void OppModelSet::printProbs(ostream& out) {
  int i = 0;
  for (i = 0; i < numModels; ++i) {
    out << "The probablility of model " << i << " (" << model[i]->GetName() << ")"
	<< " is " << probs[i] << "\n";
  }
}
void OppModelSet::printCompactProbs(ostream& out) {
  int i = 0;
  for (i = 0; i < numModels; ++i) {
    out << "Model " << i << " (" << model[i]->GetName() << "): "
	<< probs[i] << endl;
  }
}

void OppModelSet::logProbs(int level) {
  int i = 0;
  for (i = 0; i < numModels; ++i) {
    actionlog(level) << "The prob of model " << i << " (" << model[i]->GetName() << ") is "
		     << std::setprecision(6) << probs[i].toDouble() << ende;
  }
}


void OppModelSet::setPriorProbs(double* priors) {
  int i = 0;
  for (i = 0; i < numModels; ++i) {
    probs[i] = LogDouble(priors[i]);
  }
  normalize();
}

void OppModelSet::addOM(OppModel* om, double prob) {
  if (prob < 0) 
    errorlog << "prob given is less than zero" << ende;
  else if (numModels == CoachParam::instance()->getSppOmNumModels())
    errorlog << "I have too many models already! " << numModels << " " 
	     << CoachParam::instance()->getSppOmNumModels() << ende;
  else {
    model[numModels] = om;
    probs[numModels] = LogDouble(prob);
    ++numModels;
  }
}

OppModel* OppModelSet::getOM(int index) {
  if (model[index] != NULL) return model[index];
  else {
    //errorlog << "%d does not point to a model", index << ende;
    return NULL;
  }
}


OppModel* OppModelSet::getBestOM()
{
  if (stored_best_OM < 0) 
    return NULL;
  else {
    actionlog(50) << "Best OM is model " << stored_best_OM << " ("
		  << model[stored_best_OM]->GetName() << ")" << ende;
    return model[stored_best_OM]; 
  }
}

int OppModelSet::getProbs(double* storage) {
  for (int i = 0; i < numModels; ++i) {
    storage[i] = probs[i].toDouble();
  }
  return numModels;
}


MaxMeanSummary OppModelSet:: calculateProbCorrect(int num_samples,
						  int correct_model,
						  MovementObservation* o)
{
  MaxMeanSummary data;
  OppModel* pOM = getOM(correct_model);
  if (pOM == NULL)
    errorlog << "calculateProbCorrect: NULL opp model" << ende;

  PlayerDistribution distInit, distTemp;
  distInit.setInitial(o->getStartLocs());
  PlayerDistribution distFinal;
  pOM->predictMovement(&distInit, &distFinal, o->getBallMovement());

  for (int i= CoachParam::instance()->getSppOmProbCorrectNumReps(); i>0; i--) {
    setToUniform();
    for (int s=0; s<num_samples; s++) {
      o->GenerateFakeEndPos(&distFinal);
      //o->GenerateFakeEndPos(pOM);
      /* I don't actually care about the distributions, so I just pass in a
	 dsitribution I don't need anymore */
      calculate(o, &distInit, &distTemp);
      DEBUG2(cout << "model " << correct_model \
	     << ":Reps left " << i << " of " << Mem->CP_OM_prob_correct_num_reps \
	     << "; sample " << s << " of " << num_samples << endl);
      DEBUG2(printProbs());
    }
    data.addPoint((correct_model == stored_best_OM) ? 1.0 : 0.0);
    DEBUG(cout << "Data point; correct " \
	  << correct_model << "(" << probs[correct_model] << ")\t" \
	  << "best: " << stored_best_OM << "(" << probs[stored_best_OM] << ")" \
	  << endl);
    actionlog(150) << "Added data point " << correct_model << " " << stored_best_OM << ende;
  }

  return data;
}
