/*
 * metropolis_algorithm.cpp
 *
 */

// Standard Include
#include "../cleax.h"

// My Include
#include "metropolis_algorithm.h"

// Includes
#include <algorithm>
#include <assert.h>
#include <cmath>
#include <fstream>
#include <iostream>
#include <limits>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include "../dmo/admixture_model.h"
#include "../dmo/genealogy.h"
#include "../dmo/partition.h"
#include "../utilities/entropy_utilities.h"
#include "../utilities/math_utilities.h"
#include "../utilities/random/randomc.h"
#include "../utilities/random/stocc.h"
#include "../utilities/permutation_utilities.h"
#include "../utilities/set_utilities.h"
#include "coalescence_algorithm.h"

//
const int ConstantSampleNewTime = 1;
const int ConstantSampleNewAdmixture = 2;
const int ConstantSampleNewTimeOrAdmixture = 3;
const int ConstantSampleNewTheta = 4;
const int ConstantSampleNewTimeOrAdmixtureOrTheta = 7;
const int ConstantSampleNewModel = 8;
const int ConstantSampleNewAll = 15;


//---------------------------------------------------------------------
// Constructors
metropolis_algorithm::metropolis_algorithm(admixture_model* pModel, int numIters, int numGenealogies, int burnins, int thetalock, double extern_theta, double initialTMax) :
	samples(std::vector<admixture_output*>(numIters)),
	numIters(numIters),
	numGenealogies(numGenealogies),
	numBurnIns(burnins),
	numThetaLocks(thetalock),
	extern_theta(extern_theta),
	initialTMax(initialTMax),
	pModel(pModel),
	//k(pModel->getData()->getNumberOfModel()),
	m(pModel->getData()->getPopAssignment().size())
{
	// Initialize cache
	std::vector<int> popsizes(pModel->getData()->getNumOfPopulation());
	for (int i = 0; i < pModel->getData()->getNumOfPopulation(); ++i) {
		popsizes[i] = pModel->getData()->getNumOfIndidivuals(i);
	}
	//pCache = new metropolis_cache(m, snpDiscoverySampleSize, variantCallErrorProb, pModel->getData()->getNumOfSequencedIndSet(), popsizes);

	int seed = time(0);
	pStoc = new StochasticLib1(seed);

	masked_model_partition_set* pSet = NULL;
	for (int i = 0; i < (int)pModel->getData()->getNumberOfModelSets(); ++i) {
		if (pModel->getData()->getModelSet(i)->getMask().size() == 0) {
			pSet = pModel->getData()->getModelSet(i);
			key_idx = i;
			break;
		}
	}

	printf("."); fflush(stdout);
	for (int i = 0; i < (int)pModel->getData()->getNumberOfModelSets(); ++i) {
		printf("."); fflush(stdout);
		std::vector<double> logfact;
		masked_model_partition_set* pPartSet = pModel->getData()->getModelSet(i);
		for (int j = 0; j < (int)pPartSet->size(); ++j) {
			logfact.push_back(logkfactorial(pPartSet->getPartition(j)->getWeight()));
		}
		logfact.push_back(logkfactorial((double)pPartSet->getNullWeight()));
		logkfact.push_back(logfact);
	}

	penalty = log2((int)pSet->size()+2);
	oneBitEntropy = (double)-m*(plogp((double)1/(double)m)+plogp((double)(m-1)/(double)m))/log(2.0);
	printf("done\n"); fflush(stdout);
	printf("OneBitEntropy: %.4f\n", oneBitEntropy);
}

metropolis_algorithm::~metropolis_algorithm() {
	delete pModel;
	for (int i = 0; i < (int)samples.size(); ++i) {
		delete samples[i];
	}
	samples.clear();

	//delete pCache;
	delete pStoc;
}

//---------------------------------------------------------------------
// Methods

void metropolis_algorithm::run(admixture_parameters* pParamsIni, std::ofstream* pOutput) {

	//--------------------------------------
	// Initializing parameters

	// Create hyperparameter
	admixture_hyperparameters* pHyperParam = hyperparameters_create();
	if (this->extern_theta != 0) { pHyperParam->theta.sigma = 0; }
	if ( pParamsIni != NULL && extern_theta != 0) { pParamsIni->setTheta(extern_theta); }

	// Create parameters
	admixture_parameters* pParamsCur = (pParamsIni == NULL) ? parameters_create(pHyperParam) : pParamsIni;
	admixture_parameters* pParamsNew = new admixture_parameters(*pParamsCur);

	printf("%s\n", pParamsCur->toString().data());

	// Create a instance coalescence algorithm to generate genealogies
	printf("Creating coalescence model\n"); fflush(stdout);
	coalescence_algorithm coalescence(pModel->getData()->getPopAssignment(), pModel->getData()->getNumOfPopulation());

	// Create longa and short hyper parameters
	admixture_hyperparameters* pHyperParamLong = NULL;
	admixture_hyperparameters* pHyperParamShort = NULL;
	admixture_hyperparameters* pHyperParamCurrent = pHyperParam;
	//--------------------------------------
	// Performs MCMC
	int lockTheta = this->numThetaLocks;
	int burnins = numBurnIns;
	std::vector<std::vector<double> > lambdas_cur(estimate_lambda(&coalescence, pParamsCur));
	for (int i = 0; i < burnins+numIters; ++i) {

		// PRINTF BEGIN
		printf("(%d) Theta: %.1f->%.1f ", i, pParamsCur->getTheta(), pParamsNew->getTheta());
		printf("OLD:%s ", pParamsCur->toString().data());
		printf("NEW:%s ", pParamsNew->toString().data());
		fflush(stdout);

		// PRINTF END

		std::vector<std::vector<double> > lambdas_new(estimate_lambda(&coalescence, pParamsNew));
		double llr = this->likelihood_ratio(lambdas_cur, lambdas_new, pParamsCur, pParamsNew, pHyperParam);

		// DISPLAYING
		char buffer[50];
		std::string snew, scur, sobs;
		for (int r = 0; r < (int)lambdas_new.size(); ++r) {
			int c;
			//int c = sprintf(buffer, "{%d:", r);
			//snew.append(buffer, c);
			//scur.append(buffer, c);
			//sobs.append(buffer, c);

			for (int j = 0; j < (int)lambdas_new[r].size(); ++j) {
				c = sprintf(buffer, "%.0f ", lambdas_new[r][j]);
				snew.append(buffer, c);
				c = sprintf(buffer, "%.0f ", lambdas_cur[r][j]);
				scur.append(buffer, c);
				if (j+1 < (int)lambdas_new[r].size()) {
					c = sprintf(buffer, "%d ", pModel->getData()->getModelSet(r)->getPartition(j)->getWeight());
				} else {
					c = sprintf(buffer, "%d ", pModel->getData()->getModelSet(r)->getNullWeight());
				}
				sobs.append(buffer, c);
			}
			snew.append(" ");
			scur.append(" ");
			sobs.append(" ");
		}
		printf("llr:%.1f]\n  OBS:%s\n  NEW:%s\n  OLD:%s\n", llr, sobs.data(), snew.data(), scur.data());

		// Update params
		admixture_parameters* pParamsFrom = pParamsCur;
		admixture_parameters* pParamsTo = pParamsNew;
		if (llr >= 0 || pStoc->Random() <= exp(llr)) {
			pParamsFrom = pParamsNew;
			pParamsTo = pParamsCur;
			lambdas_cur = lambdas_new;
		}
		pParamsTo->copy(pParamsFrom);


		// Change sigma of the hyperparameter only during burn ins
		if (i < burnins) {
			this->hyperparameters_update(pHyperParam, lambdas_cur[key_idx], (double)i/(double)burnins, 0, key_idx);
		} else {
			if (i == burnins) {
				pHyperParamLong = new admixture_hyperparameters(*pHyperParam);
				pHyperParamShort = new admixture_hyperparameters(*pHyperParam);

				for (int j = 0; j < (int)pHyperParamLong->time.size(); ++j) {
					pHyperParamLong->time[j]->sigma = pHyperParamLong->time[j]->sigma_max;
					pHyperParamShort->time[j]->sigma = pHyperParamShort->time[j]->sigma_min;
				}
			}

			pHyperParamCurrent = (pStoc->Random() < 0.2) ? pHyperParamLong : pHyperParamShort;
			for (int j = 0; j < (int)pHyperParamCurrent->time.size(); ++j) {
				printf("(%.1f %.1f) ", pHyperParamCurrent->time[j]->sigma, pHyperParamCurrent->alpha[j]->sigma);
			}
			printf("Theta: %.1f\n", pHyperParamCurrent->theta.sigma);
			fflush(stdout);
		}

		// Store output
		if (i >= burnins) {
			this->samples[i-burnins] = new admixture_output(pParamsCur);

			if (pOutput != NULL) {

				*pOutput << pParamsCur->getTheta() << " ";
				for (int j = 0; j < (int)pParamsCur->getNumberOfEvent(); ++j) {
					*pOutput << "[ ";
					*pOutput << pParamsCur->getEvent(j)->t << " ";
					*pOutput << pParamsCur->getEvent(j)->alpha << " ";
					*pOutput << pParamsCur->getEvent(j)->p1 << " ";
					*pOutput << pParamsCur->getEvent(j)->p2 << " ";
					*pOutput << pParamsCur->getEvent(j)->pa << " ";
					*pOutput << "] ";
				}
				*pOutput << "( ";
				for (int j = 0; j < (int)lambdas_cur[key_idx].size(); ++j) {
					*pOutput << lambdas_cur[key_idx][j] << " ";
				}
				*pOutput << ")" << std::endl;
				pOutput->flush();
			}
		}

		// Choose
		double rval = pStoc->Random();
		if (rval < 0.30) {
			printf("(A)"); fflush(stdout);
			this->parameters_sample(pParamsNew, pHyperParamCurrent, ConstantSampleNewModel | ConstantSampleNewAdmixture);
		} else if (rval <= /*0.67*/ 0.5) {
			printf("(B)"); fflush(stdout);
			this->parameters_sample(pParamsNew, pHyperParamCurrent, ConstantSampleNewTimeOrAdmixtureOrTheta);
		} else {
			printf("(C)"); fflush(stdout);
			this->parameters_sample(pParamsNew, pHyperParamCurrent, ConstantSampleNewAll);
		}
		/*
		// Sample next parameters set
		if (i < burnins) {
			double theta_std = 0.0;
			if (i >= lockTheta && this->extern_theta != 0) {
				theta_total += pParamsNew->getTheta();
				thetas[i-lockTheta] = pParamsNew->getTheta();
				double theta_mu = theta_total / (double)(i-lockTheta+1);
				for (int j = 0; j < (i-lockTheta); ++j) {
					theta_std += (thetas[j] - theta_mu)*(thetas[j] - theta_mu);
				}
				theta_std /= (double)(i-lockTheta+1);
				theta_std = sqrt(theta_std);
			}

			printf("(Theta-Std: %.1f)", theta_std);

			this->hyperparameters_update(pHyperParam, lambdas_cur[key_idx], (double)i/(double)burnins, theta_std, key_idx);
			if (i % 2 == 0) {
				this->parameters_sample(pParamsNew, pHyperParam, (i > lockTheta) ? ConstantSampleNewAll: (ConstantSampleNewTimeOrAdmixture | ConstantSampleNewModel));
			} else {
				this->parameters_sample(pParamsNew, pHyperParam, (i > lockTheta) ? ConstantSampleNewTimeOrAdmixtureOrTheta : ConstantSampleNewTimeOrAdmixture);
			}
		} else {
			if (i % 2 == 0) {
				this->parameters_sample(pParamsNew, pHyperParam, ConstantSampleNewAll);
			} else {
				this->parameters_sample(pParamsNew, pHyperParam, ConstantSampleNewTimeOrAdmixtureOrTheta);
			}
			this->samples[i-burnins] = new admixture_output(pParamsCur);
		}
		*/
	}


	delete pParamsNew;
	delete pParamsCur;
	delete pHyperParam;
	delete pHyperParamLong;
	delete pHyperParamShort;
}

// Get sample output
const std::vector<admixture_output*>& metropolis_algorithm::getSamples() {
	return this->samples;
}


//---------------------------------------------------------------------
// Helpers

// Enumerate through the genealogy and computes the weights for each
// model bipartitions
const std::vector<std::vector<double> > metropolis_algorithm::estimate_lambda(coalescence_algorithm* pCoalescence, admixture_parameters* pParams) {

	std::vector<int> popassign(pModel->getData()->getPopAssignment());
	std::vector<std::vector<double> > weights_total(this->pModel->getData()->getNumberOfModelSets());
	std::vector<std::vector<double> > lambdas(this->pModel->getData()->getNumberOfModelSets());
	for (int k = 0; k < (this->pModel->getData()->getNumberOfModelSets()); ++k) {
		weights_total[k] = std::vector<double>(this->pModel->getData()->getModelSet(k)->size()+1, 0.0);
		lambdas[k] = std::vector<double>(this->pModel->getData()->getModelSet(k)->size()+1, 0.0);
	}

	// determine a skip map
	std::vector<std::vector<bool> > skip_map;
	for (int k = 0; k < (int)this->pModel->getData()->getNumberOfModelSets(); ++k) {
		masked_model_partition_set* pSet = this->pModel->getData()->getModelSet(k);
		skip_map.push_back(std::vector<bool>((int)pSet->size(), 0));

		// first check if the model bipartition is informative at all
		for (int j = 0; j < (int) pSet->size(); ++j) {
			if (!is_informative(pSet->getPartition(j), pSet->getMask())) {
				skip_map[k][j] = 1;
			}
		}

		for (int j = 0; j < (int) pSet->size(); ++j) {
			if (!skip_map[k][j]) {
				for (int i = j+1; i < (int)pSet->size(); ++i) {
					if (!skip_map[k][i]) {
						if (is_identical_or_mirror(pSet->getPartition(j), pSet->getPartition(i), pSet->getMask())) {
							skip_map[k][i] = 1;
						}
					}
				}
			}
		}
	}


	for (int r = 0; r < numGenealogies; ++r) {
		genealogy* pGenealogy = pCoalescence->generate_genalogy_fixed(pParams);

		for (int k = 0; k < (this->pModel->getData()->getNumberOfModelSets()); ++k) {
			masked_model_partition_set* pSet = this->pModel->getData()->getModelSet(k);
			std::vector<bool> mask = pSet->getMask();

			// Assumes that the last branch is a root that we do not care about
			for (int i = 0; i < (pGenealogy->getBranchSize()-1); ++i) {
				genealogy_branch* pBranch = pGenealogy->getBranch(i);

				int bestModelIdx = -1;
				// This is an one bit partition
				if (pBranch->getChild1() == NULL && pBranch->getChild2() == NULL) {
					if (infoThres < oneBitEntropy) {
						bestModelIdx = (int)pSet->size();
					}
				} else {
					double bestScore = m*compute_entropy(pBranch->getPartition(), mask);
					bestModelIdx = (int)pSet->size();

					for (int j = 0; j < (int)pSet->size(); ++j) {
						if (!skip_map[k][j]) {
							double score = m*compute_conditional_entropy(pBranch->getPartition(), pSet->getPartition(j), mask) + this->penalty;
							if (score < bestScore) {
								bestScore = score;
								bestModelIdx = j;
							}
						}
					}

					// Filter out those partitions that are less than infoThres
					if (infoThres >= bestScore && bestModelIdx == (int)pSet->size()) {
						bestModelIdx = -1;
					}
				}

				if (bestModelIdx >= 0) {
//					double discoveryProb = 1.0;
//					if (pModel->getData()->hasAscertainmentData()) {
//						int nLeaf = pBranch->getNumberOfLeaf();
//						std::vector<bool> popvec(pModel->getData()->getNumOfPopulation(), 0);
//						popvec[pBranch->getPopulationId()] = 1;
//						for (int j = pBranch->getEventId()-1; j >= 0; --j) {
//							if ((pParams->getEvent(j)->p1 >= 0 && popvec[pParams->getEvent(j)->p1]) ||
//								(pParams->getEvent(j)->p2 >= 0 && popvec[pParams->getEvent(j)->p2])) {
//								popvec[pParams->getEvent(j)->pa] = 1;
//							}
//						}
//						discoveryProb = this->pCache->compute_snp_discovery_prob_by_pop(popvec, nLeaf);
//					} else {
//						int nLeaf = pBranch->getNumberOfLeaf();
//						discoveryProb = this->pCache->compute_snp_discovery_prob(nLeaf);
//					}


					double t = pBranch->getEndTime() - pBranch->getStartTime();
					if (t <= 0) {
						printf("%d %d (%f-%f) has negative or zero time\n", i, pBranch->getUID(), pBranch->getEndTime(), pBranch->getStartTime());
						fflush(stdout);
					}
					assert(t > 0);
					weights_total[k][bestModelIdx] += (t/**discoveryProb*/);
				}
			}
		}

		delete pGenealogy;
	}


	for (int k = 0; k < (this->pModel->getData()->getNumberOfModelSets()); ++k) {
		for (int i = 0; i < (int)weights_total[k].size(); ++i) {
			lambdas[k][i] = (weights_total[k][i] / (double)numGenealogies) * pParams->getTheta();
		}
	}
	return lambdas;
}


double metropolis_algorithm::squared_error(const std::vector<std::vector<double> >& lambdas, admixture_parameters* pParam) {
	double sqerr = 0.0;

	for (int k = 0; k < (this->pModel->getData()->getNumberOfModelSets()); ++k) {
		masked_model_partition_set* pSet = this->pModel->getData()->getModelSet(k);
		int n = pSet->size();
		for (int i = 0; i < n; ++i) {
			double err = (double)pSet->getPartition(i)->getWeight() - lambdas[k][i];
			sqerr += err*err;
		}
		double err = (double)pSet->getNullWeight() - lambdas[k][n];
		sqerr += err*err;
	}

	return sqerr;
}

//Compute the likelihood ratio
double metropolis_algorithm::likelihood_ratio(const std::vector<std::vector<double> >& lambdas_old, const std::vector<std::vector<double> >& lambdas_new, admixture_parameters* pOldParam, admixture_parameters* pNewParam, admixture_hyperparameters* pHyperParams) {

	double logratio = 0.0;

	for (int k = 0; k < (this->pModel->getData()->getNumberOfModelSets()); ++k) {
		masked_model_partition_set* pSet = this->pModel->getData()->getModelSet(k);
		int n = pSet->size();
		printf("[");
		for (int i = 0; i < n; ++i) {
			double wt_old = (lambdas_old[k][i] == 0) ? 0.01 : lambdas_old[k][i];
			double wt_new = (lambdas_new[k][i] == 0) ? 0.01 : lambdas_new[k][i];
			double lratio = (((double)pSet->getPartition(i)->getWeight() == 0.0) ? 0.0 : (log(wt_new) - log(wt_old))*(double)pSet->getPartition(i)->getWeight()) - (wt_new - wt_old);
			printf("%.2f ", lratio);
			logratio += lratio;
		}
		double wt_old = (lambdas_old[k][n] == 0) ? 0.01 : lambdas_old[k][n];
		double wt_new = (lambdas_new[k][n] == 0) ? 0.01 : lambdas_new[k][n];
		double lratio = (((double)pSet->getNullWeight() == 0.0) ? 0.0 : (log(wt_new) - log(wt_old))*(double)pSet->getNullWeight()) - (wt_new - wt_old);
		printf("%.2f ", lratio);
		logratio += 2*lratio;
		printf("]");
	}


	// Enumerate through number of events
	double logtransition = 0.0;
	for (int i = 0; i < (int)pOldParam->getNumberOfEvent(); ++i) {
		double mult = 100.0;
		if (pOldParam->getEvent(i)->p2 == -1 && pNewParam->getEvent(i)->p2 != -1) {
			logtransition -= (pOldParam->getNumberOfEvent()-i-1)*log_normal_pdf(pNewParam->getEvent(i)->alpha*mult, (pNewParam->getEvent(i)->alpha < 0.5) ? 0.0 : mult, pHyperParams->alpha[i]->sigma*mult);
			//logtransition += 25;
		} else if (pOldParam->getEvent(i)->p2 != -1 && pNewParam->getEvent(i)->p2 == -1) {
			logtransition += (pOldParam->getNumberOfEvent()-i-1)*log_normal_pdf((pOldParam->getEvent(i)->alpha < 0.5) ? 0.0 : mult, pOldParam->getEvent(i)->alpha*mult, pHyperParams->alpha[i]->sigma*mult);
			//logtransition -= 25;
		}
	}
	printf("(Q:%.2f)", logtransition);
	logratio += logtransition;
	printf("(L:%4f)", logratio);

	return logratio;
}

// Sample new parameters
void metropolis_algorithm::parameters_sample(admixture_parameters* pParams, admixture_hyperparameters* pHyperParams, int commands, int events, int direction) {

	if (commands & ConstantSampleNewModel) {
		parameters_events_sample(pParams);
	}

	if (commands & ConstantSampleNewTimeOrAdmixture) {

		// If events is -1, then we will randomly pick a event
		if (events == -1) {
			events = pStoc->IRandom(0, (int)pParams->getNumberOfEvent());
		}

		printf(" {");
		for (int i = 0; i < (int)pParams->getNumberOfEvent(); ++i) {

			// Update time
			if (commands & ConstantSampleNewTime) {
				if (events < 0 || events == i) {
					double toffset = pStoc->Normal(0.0, pHyperParams->time[i]->sigma);
					printf("t%d:%.2f,%.2f ", i, toffset, pHyperParams->time[i]->sigma);
					double t = pParams->getEvent(i)->t + ((direction == 0) ? toffset : (direction > 0) ? abs(toffset) : -abs(toffset));
					if (((i > 0 && t >= pParams->getEvent(i-1)->t) || (i == 0 && t > 0)) &&
						((i+1 < (int)pParams->getNumberOfEvent() && t <= pParams->getEvent(i+1)->t) || (i+1 == (int)pParams->getNumberOfEvent()))) {
						pParams->getEvent(i)->t = t;
					}
				}
			}

			if (commands & ConstantSampleNewAdmixture) {
				// Update alpha
				if (events < 0 || events == i) {
					// This means that there is admixture in this event
					if (pParams->getEvent(i)->p2 >= 0) {

						double alpha = pStoc->Normal(pParams->getEvent(i)->alpha, pHyperParams->alpha[i]->sigma);

						// Discreting the alpha
						alpha = alpha - fmod(alpha, 0.01);

						// Loop around
						if (alpha < 0.0 || alpha > 1.0) {
							double rem = fmod(alpha, 1.0);
							if (rem < 0) {
								rem += 1.0;
							}
							alpha = rem;
						}
						pParams->getEvent(i)->alpha = alpha;
							/*else if (alpha < 0.0) {
						}
							alpha = 0.0;
						}
						else {
							alpha = 1.0;
						}*/
					} else {
						pParams->getEvent(i)->alpha = 1.0;
					}
				}
			}
		}
		printf("} ");
	}

	if (commands & ConstantSampleNewTheta) {
		// Update theta
		double sigma = pHyperParams->theta.sigma;
		double theta = pStoc->Normal(0, sigma);
		if (direction > 0) {
			theta = abs(theta) + pParams->getTheta();
		} else if (direction < 0) {
			theta = pParams->getTheta() - abs(theta);
		} else {
			theta = pParams->getTheta() + theta;
		}
		if (theta >= 0) {
			pParams->setTheta(theta);
		}
	}


}

// create parameter for the chain
admixture_parameters* metropolis_algorithm::parameters_create(admixture_hyperparameters* pHyperParams) {

	admixture_parameters* pParams = new admixture_parameters((int)pModel->getNumberOfEvents());
	parameters_events_sample(pParams);

	// first create a random
	std::vector<double> times((int)pModel->getNumberOfEvents(), 0.0);
	for (int i = 0; i < (int)pModel->getNumberOfEvents(); ++i) {
		double t = pStoc->Random()*this->initialTMax;
		printf("t%d=%.3f ", i, t);
		for (int j = 0; j < i; ++j) {
			if (times[j] > t) {
				double tmp = times[j];
				times[j] = t;
				t = tmp;
			}
		}
		times[i] = t;
		for (int i = 0; i < (int)times.size(); ++i) {
			printf("%.3f ", times[i]);
		}
		printf("\n");

	}

	printf("\n");

	for (int i = 0; i < (int)pModel->getNumberOfEvents(); ++i) {
		double alpha = 0.0;
		if (pParams->getEvent(i)->p1 < 0) {
			alpha = 0;
		} else if (pParams->getEvent(i)->p2 < 0)  {
			alpha = 1;
		} else {
			alpha = pStoc->Random();
		}
		pParams->getEvent(i)->alpha = alpha;
		pParams->getEvent(i)->t = times[i];
	}

	if (pHyperParams == NULL) {
		pParams->setTheta(1.0);
	} else {
		if (this->extern_theta != 0) {
			pParams->setTheta(this->extern_theta);
		} else {
			pParams->setTheta((pHyperParams->theta.sigma_min + pHyperParams->theta.sigma_max)/2.0);
		}
	}

	// Now if we do not have fixed theta, we will try to optimize it first
	if (this->extern_theta == 0) {

		printf("Creating starting point for the MCMC\n");
		// create coalescence algorithm
		coalescence_algorithm coalescence(pModel->getData()->getPopAssignment(), pModel->getData()->getNumOfPopulation());

		admixture_parameters* pParamsOld = pParams;
		admixture_parameters* pParamsNew = new admixture_parameters(*pParamsOld);

		std::vector<std::vector<double> > lambdas_cur = estimate_lambda(&coalescence, pParams);
		double total_lambdas = 0.0;
		for (int i = 0; i < (int)lambdas_cur[this->key_idx].size(); ++i) {
			total_lambdas += lambdas_cur[this->key_idx][i];
		}
		double total_weight = this->pModel->getData()->getModelSet(this->key_idx)->getNullWeight();
		for (int i = 0; i < (int)this->pModel->getData()->getModelSet(this->key_idx)->size(); ++i) {
			total_weight += this->pModel->getData()->getModelSet(this->key_idx)->getPartition(i)->getWeight();
		}

		printf("%s\n", pParams->toString().data());
		std::vector<double> pvec(5, 0.0);
		double prob = 0.4;
		double multiplier = std::min(1.0, fabs((double)(total_lambdas - total_weight)/(total_weight)));
		pvec[0] = 0.05+multiplier*prob;
		pvec[1] = 0.5;
		pvec[2] = 0.6;
		pvec[3] = 0.67;
		pvec[4] = 0.90;

		for (int i = 0; i < 1000; ++i) {
			double rval = pStoc->Random();
			if (rval < pvec[0]) {
				// Sample new theta or sample new time
				if (pStoc->Random() < 0.25) {
					printf("1a: ");
					this->parameters_sample(pParamsNew, pHyperParams, ConstantSampleNewTime, -2, total_weight - total_lambdas);
				} else {
					printf("1b: ");
					this->parameters_sample(pParamsNew, pHyperParams, ConstantSampleNewTheta, -2, total_weight - total_lambdas);
				}
			} else if (rval < pvec[1]) {
				// Sample new time, admixture, and theta
				printf("2: ");
				double direction = pStoc->Random() - 0.5;
				this->parameters_sample(pParamsNew, pHyperParams, ConstantSampleNewTheta, -2, direction);
				this->parameters_sample(pParamsNew, pHyperParams, ConstantSampleNewTimeOrAdmixture, -2, -direction);
			} else if (rval < pvec[2]) {
				// Sample new time and admixture only, on one of the event
				printf("3: ");
				int index = pStoc->IRandom(0, this->pModel->getNumberOfEvents()-1);
				this->parameters_sample(pParamsNew, pHyperParams, ConstantSampleNewTimeOrAdmixture, index, total_weight - total_lambdas);
			} else if (rval < pvec[3]) {
				// Sample new time and admixture on all events
				printf("4: ");
				this->parameters_sample(pParamsNew, pHyperParams, ConstantSampleNewTimeOrAdmixture, -2, total_weight - total_lambdas);
			} else if (rval < pvec[4]) {
				// Sample new admixture model, but keep theta constant
				printf("5: ");
				this->parameters_sample(pParamsNew, pHyperParams, ConstantSampleNewModel | ConstantSampleNewAdmixture);
			} else {
				// Sample new time and new model, but keep theta constant
				printf("6: ");
				this->parameters_sample(pParamsNew, pHyperParams, ConstantSampleNewModel | ConstantSampleNewTimeOrAdmixture);
			}

			std::vector<std::vector<double> > lambdas_new = estimate_lambda(&coalescence, pParamsNew);
			double total_lambdas_new = 0.0;
			for (int i = 0; i < (int)lambdas_cur[this->key_idx].size(); ++i) {
				total_lambdas_new += lambdas_cur[this->key_idx][i];
			}
			//double numer = this->squared_error(lambdas_cur, pParamsOld);
			//double denom = this->squared_error(lambdas_new, pParamsNew);
			//double llr = log(numer)-log(denom);
			double llr = likelihood_ratio(lambdas_cur, lambdas_new, pParamsOld, pParamsNew, pHyperParams);
			//printf("Model: %s, llr: (%.1f, %1f) %.1f, lambda: %.1f, observed:%.1f (pvec: [%.2f %.2f %.2f])\n", pParamsNew->toString().data(), numer, denom, llr, total_lambdas_new, total_weight, pvec[0], pvec[1], pvec[2]);
			printf("Model: %s, llr: %.1f, lambda: %.1f, observed:%.1f (pvec: [%.2f %.2f %.2f])\n", pParamsNew->toString().data(), llr, total_lambdas_new, total_weight, pvec[0], pvec[1], pvec[2]);

			// Update params
			admixture_parameters* pParamsFrom = pParamsOld;
			admixture_parameters* pParamsTo = pParamsNew;
			if (llr >= 0 || pStoc->Random() <= exp(llr)) {
				pParamsFrom = pParamsNew;
				pParamsTo = pParamsOld;
				lambdas_cur = lambdas_new;
				total_lambdas = total_lambdas_new;

				// update probability vector
				multiplier = std::min(1.0, fabs((total_lambdas - total_weight)/(total_weight)));
				pvec[0] = 0.05+multiplier*prob;
			}
			pParamsTo->copy(pParamsFrom);
		}

		pParams = pParamsOld;

		for (int i = 0; i < (int)lambdas_cur[this->key_idx].size(); ++i) {
			printf("%.0f ", lambdas_cur[this->key_idx][i]);
		}
		printf("\n");
	}

	return pParams;
}

// Sample new events
void metropolis_algorithm::parameters_events_sample(admixture_parameters* pParams /* OUTPUT */) {

	// Create population candidates
	std::vector<int> pop_cands;
	for (int i = 0; i < (int)pModel->getData()->getNumOfPopulation(); ++i) {
		pop_cands.push_back(i);
	}

	// Now iterate from current time backward, and generate a plausible evolutionary model
	//for (int n = (int)pModel->getNumberOfEvents()-1; n >= 0; --n) {
	for (int n = 0; n < (int)pModel->getNumberOfEvents(); ++n) {
		std::vector<int> p1_cands(set_intersection(pModel->getModelEvent(n)->p1_candidates, pop_cands));
		std::vector<int> p2_cands(set_intersection(pModel->getModelEvent(n)->p2_candidates, set_append(pop_cands, -1)));
		std::vector<int> pa_cands(set_intersection(pModel->getModelEvent(n)->pa_candidates, pop_cands));

		if (p1_cands.size() == 0) {
			p1_cands = pop_cands;
		}

		if (p2_cands.size() == 0) {
			p2_cands = pop_cands;
		}

		if (pa_cands.size() == 0) {
			pa_cands = pop_cands;
		}

//		printf("uniform\n"); fflush(stdout);
//		printf("P1: %s\n", set_to_string(p1_cands).data());
//		printf("P2: %s\n", set_to_string(p2_cands).data());
//		printf("PA: %s\n", set_to_string(pa_cands).data());
//		fflush(stdout);
		std::vector<int> sample = uniform_sample_without_repeats_3set(p1_cands, pa_cands, p2_cands);
//		printf("Sample: %s\n", set_to_string(sample).data());
//		printf("1234\n"); fflush(stdout);
		pParams->getEvent(n)->p1 = sample[0];
		pParams->getEvent(n)->p2 = sample[2];
		pParams->getEvent(n)->pa = sample[1];

		// remove pa from pop cands
		for (int i = 0; i < (int)pop_cands.size(); ++i) {
			if (pop_cands[i] == pParams->getEvent(n)->pa) {
				pop_cands.erase(pop_cands.begin()+i);
				break;
			}
		}
	}
}

// update hyper parameter for the chain given current lambda
void metropolis_algorithm::hyperparameters_update(admixture_hyperparameters* pHyperParams, const std::vector<double>& lambdas, double iter_frac, double theta_std, int key_model_idx) {

	// Time and Alpha
	double dev = 0.0;
	for (int i = 0; i < (int)this->pModel->getData()->getModelSet(key_model_idx)->size(); ++i) {
		dev += fabs(lambdas[i] - (double)this->pModel->getData()->getModelSet(key_model_idx)->getPartition(i)->getWeight());
	}
	dev += abs(lambdas[lambdas.size()-1] - (double)this->pModel->getData()->getModelSet(key_model_idx)->getNullWeight());

	double multiplier = dev / pHyperParams->error_width;
	if (multiplier > 1.0) {
		multiplier = 1.0;
	}

	for (int i = 0; i < this->pModel->getNumberOfEvents(); ++i) {
		double sigma = multiplier*pHyperParams->time[i]->sigma_max;
		pHyperParams->time[i]->sigma = (sigma < pHyperParams->time[i]->sigma_min) ? pHyperParams->time[i]->sigma_min : sigma;
		pHyperParams->alpha[i]->sigma =pHyperParams->alpha[i]->sigma_min + (pHyperParams->alpha[i]->sigma_max-pHyperParams->alpha[i]->sigma_min)*iter_frac;
	}

	// Theta
	if (theta_std > 0) {
		pHyperParams->theta.sigma = theta_std;
	}
}


// create a hyper parameter for the chain
admixture_hyperparameters* metropolis_algorithm::hyperparameters_create() {

	// create a coalescence algorithm
	coalescence_algorithm coalescence(pModel->getData()->getPopAssignment(), pModel->getData()->getNumOfPopulation());

	// create a scenario where all coalescence occurred at the end
	admixture_parameters parameters1;
	parameters1.setTheta(1.0);
	for (int i = 0; i < this->pModel->getNumberOfEvents(); ++i) {
		parameters1.addEvent(event(0, 0.5, 0, 0, 0));
	}
	parameters_events_sample(&parameters1);
	for (int i = 0; i < this->pModel->getNumberOfEvents(); ++i) {
		if (parameters1.getEvent(i)->p2 < 0) {
			parameters1.getEvent(i)->alpha = 1;
		}
	}
	std::vector<double> branch_length1(estimate_lambda(&coalescence, &parameters1)[key_idx]);
	printf("[ ");
	for (int i = 0; i < (int)branch_length1.size(); ++i) { printf("%.3f ", branch_length1[i]); }
	printf("]\n");

	// create a scenario where all coalescence occurred at the beginning
	int n = (int)pModel->getData()->getPopAssignment().size();
	double t = 0.0;
	for (int i = 0; i < (n-1); ++i) {
		t += 1.0/((double)((n-i)*(n-i-1)));
	}
	admixture_parameters parameters2;
	parameters2.setTheta(1.0);
	for (int i = 0; i < this->pModel->getNumberOfEvents(); ++i) {
		parameters2.addEvent(event(t, 0.5, 0, 0, 0));
	}
	parameters_events_sample(&parameters2);
	for (int i = 0; i < this->pModel->getNumberOfEvents(); ++i) {
		if (parameters2.getEvent(i)->p2 < 0) {
			parameters2.getEvent(i)->alpha = 1;
		}
	}
	std::vector<double> branch_length2(estimate_lambda(&coalescence, &parameters2)[key_idx]);
	printf("[ ");
	for (int i = 0; i < (int)branch_length2.size(); ++i) { printf("%.3f ", branch_length2[i]); }
	printf("]\n");

	// now create hyperparameters
	admixture_hyperparameters* pHyperParams = new admixture_hyperparameters();
	for (int i = 0; i < this->pModel->getNumberOfEvents(); ++i) {
		//pHyperParams->time.push_back(new admixture_hyperparameter(i, admixture_hyperparameter::Time, 0.15*(double)(i+1), 0.015*(double)(i+1), 0.1*(double)(i+1)));
		pHyperParams->time.push_back(new admixture_hyperparameter(i, admixture_hyperparameter::Time, 0.05*this->initialTMax*(double)(i+1), 0.001*this->initialTMax*(double)(i+1), 0.15*this->initialTMax*(double)(i+1)));
		pHyperParams->alpha.push_back(new admixture_hyperparameter(i, admixture_hyperparameter::Admixture, 0.07, 0.02, 0.1));
	}

	// Calculate theta max and min
	double b1 = 0.0;
	double b2 = 0.0;
	double w = 0.0;
	double width1 = 0.0;
	double width2 = 0.0;
	for (int i = 0; i < (int)branch_length1.size(); ++i) {
		b1 += branch_length1[i];
		b2 += branch_length2[i];
		w += (i+1 < (int)branch_length1.size()) ? (double)pModel->getData()->getModelSet(key_idx)->getPartition(i)->getWeight() : pModel->getData()->getModelSet(key_idx)->getNullWeight();
		width1 += abs(branch_length1[i]);
		width2 += abs(branch_length2[i]);
	}
	pHyperParams->theta.eventIdx = -1;
	pHyperParams->theta.eventType = admixture_hyperparameter::Theta;
	pHyperParams->theta.sigma_min = w / b2;
	pHyperParams->theta.sigma_max = w / b1;
	pHyperParams->theta.sigma = (pHyperParams->theta.sigma_max - pHyperParams->theta.sigma_min)*0.01;
	pHyperParams->error_width = std::max(width1, width2)*pHyperParams->theta.sigma;

	return pHyperParams;
}

//------------------------------------------------------------------------------------------------------
// MetropolisCache
// Return the probability that a branch with k leaves will be discover by n sequenced individuals

//
//metropolis_cache::metropolis_cache(int numOfSamps, int numOfSeqSamps, double variantCallErrorProb, const std::vector<int>& ascertainment, const std::vector<int>& popsizes) :
//	n(numOfSamps),
//	pop_sizes(popsizes),
//	discovery_sizes(ascertainment),
//	snp_discovery_prob_cache(std::vector<double>(numOfSamps, 1.0)) {
//
//
//
//	double prob = 1.0;
//	for (int i = 0; i < (int)(numOfSamps+1)/2; ++i) {
//		prob *= variantCallErrorProb;
//		//printf("%f, %f\n", snp_discovery_prob_cache[i], prob);
//		snp_discovery_prob_cache[i] *= (1.0-prob);
//		snp_discovery_prob_cache[numOfSamps-i-1] *=  snp_discovery_prob_cache[i];
//	}
//
//	/* Case 1 model (ascertainment sample in genotyped sample, fixed d)
//	printf("Initializing cache: %d %d\n", numOfSamps, numOfSeqSamps);
//	double prob = 1.0;
//	for (int i = 0; i < (numOfSamps-numOfSeqSamps); ++i) {
//		prob *= (double)(numOfSamps-numOfSeqSamps-i)/(double)(numOfSamps-i);
//		snp_discovery_prob_cache[i] = 1-prob;
//	} */
//	/* Case 3 model (ascertainment sample unknown) */
//	if (numOfSeqSamps > 0 && numOfSamps != numOfSeqSamps) {
//		for (int i = 0; i < numOfSamps; ++i) {
//			printf("%d ", i); fflush(stdout);
//			double prob = 0.0;
//			for (int j = 1; j < numOfSeqSamps; ++j) {
//				prob += compute_ascertainment_prob3(i+j+1, numOfSamps, numOfSeqSamps);
//			}
//			prob /= (double)(numOfSeqSamps+1);
//			snp_discovery_prob_cache[i] = prob;
//		}
//	}
//
//	// Do it only if we need to do ascertainment
//	if (numOfSeqSamps > 0 && numOfSeqSamps < numOfSamps) {
//		int m = (int)((double)(numOfSamps/2)*0.05);
//		for (int i = 0; i <= m; ++i) {
//			snp_discovery_prob_cache[i] = snp_discovery_prob_cache[i] * 0.3;
//			snp_discovery_prob_cache[numOfSamps-i-1] = snp_discovery_prob_cache[numOfSamps-i-1] * 0.3;
//		}
//	}
//
//	/*
//	prob = 1.0;
//	for (int i = 0; i < numOfSeqSamps; ++i) {
//		prob *= (double)(numOfSeqSamps-i)/(double)(numOfSamps-i);
//		snp_discovery_prob_cache[i] -= prob;
//	} */
//
//	for (int i = 0; i < (int)snp_discovery_prob_cache.size(); ++i) {
//		printf("%.2f ", snp_discovery_prob_cache[i]);
//		if (i > 0 && i % 30 == 0) {
//			printf("\n");
//		}
//	}
//	fflush(stdout);
//}
//
//double metropolis_cache::compute_snp_discovery_prob(int k) {
//	return this->snp_discovery_prob_cache[k-1];
//}
//
//double metropolis_cache::compute_snp_discovery_prob_by_pop(const std::vector<bool>& popvec, int nleaf) {
//	double prob = 0;
//	int d = 0;
//	//int n = 0;
//	for (int i = 0; i < (int)popvec.size(); ++i) {
//		if (popvec[i] > 0) {
//			d += this->discovery_sizes[i];
//			//n += this->pop_sizes[i];
//		}
//	}
//
//	for (int j = nleaf+1; j < nleaf+d; ++j) {
//		prob += compute_ascertainment_prob3(j, n, d);
//	}
//	prob /= (d+1);
//
//	return prob;
//}
//
//double metropolis_cache::compute_ascertainment_prob3(int j, int n, int d) {
//
//	if (j >= n + d - 1 || j < 2) {
//		return 0.0;
//	}
//
//	double p = 0;
//	double ldenom = log_nchoosek(n+d, n);
//	if (j >= n) {
//		p += exp(log_nchoosek(j, n)-ldenom);
//	}
//
//	if (n+d-j >= n) {
//		p += exp(log_nchoosek(n+d-j, n)-ldenom);
//	}
//
//	if (j >= d) {
//		p += exp(log_nchoosek(j, d)-ldenom);
//	}
//
//	if (n+d-j >= d) {
//		p += exp(log_nchoosek(n+d-j, d)-ldenom);
//	}
//
//	if (d == j) {
//		p = p + exp(-ldenom);
//	}
//
//	if (n == j) {
//		p = p + exp(-ldenom);
//	}
//
//	p = 1 - p;
//	return p;
//}

