/*
 * likelihood_test.cpp
 *
 */

// Standard Include
#include "likelihood_test.h"

// Includes
#include <math.h>
#include <stdio.h>
#include <fstream>
#include "../algorithm/coalescence_algorithm.h"
#include "../dmo/admixture_model.h"
#include "../dmo/genealogy.h"
#include "../dmo/partition.h"
#include "../utilities/entropy_utilities.h"

likelihood_test::likelihood_test() {
	// TODO Auto-generated constructor stub

}

likelihood_test::~likelihood_test() {
	// TODO Auto-generated destructor stub
}

void likelihood_test::test_likelihood(std::ofstream& output) {

	// initialize location
	std::vector<partition*> model_p;
	std::vector<bool> b1(400, 0);
	std::vector<bool> b2(400, 0);
	std::vector<bool> b3(400, 0);
	std::vector<bool> b4(400, 0);
	std::vector<bool> b5(400, 0);
	std::vector<bool> b6(400, 0);

	std::vector<int> pop_assign(400, 0);
	for (int i = 0; i < 400; i++) {
		if (i < 100) {
			b1[i] = 1;
			b5[i] = 1;
		} else if (i < 200) {
			b2[i] = 1;
			b5[i] = 1;
			b6[i] = 1;
			pop_assign[i] = 1;
		} else if (i < 300) {
			b3[i] = 1;
			pop_assign[i] = 2;
		} else {
			b4[i] = 1;
			pop_assign[i] = 3;
			//if (i != 327 && i != 354 && i != 370) {
				b6[i] = 1;
			//}
		}
	}
	coalescence_algorithm coalescence(pop_assign, 4);

	model_p.push_back(new partition(b1));
	model_p.push_back(new partition(b2));
	model_p.push_back(new partition(b3));
	model_p.push_back(new partition(b4));
	model_p.push_back(new partition(b5));
	model_p.push_back(new partition(b6));

	std::vector<masked_model_partition_set*> model_sets;

	std::vector<bool> mask0;
	masked_model_partition_set* pSet = new masked_model_partition_set(mask0);
	pSet->add_model_partition(new partition(b1), 4956);
	pSet->add_model_partition(new partition(b2), 3891);
	pSet->add_model_partition(new partition(b3), 2009);
	pSet->add_model_partition(new partition(b4), 1749);
	pSet->add_model_partition(new partition(b5), 3799);
	pSet->add_model_partition(new partition(b6), 839);
	pSet->setNullWeight(5213);
	model_sets.push_back(pSet);

	std::vector<bool> mask(400, 1);
	for (int i = 0; i < 100; ++i) {mask[i] = 0; }
	masked_model_partition_set* pSet1 = new masked_model_partition_set(mask);
	pSet1->add_model_partition(new partition(b1), 0);
	pSet1->add_model_partition(new partition(b2), 7929);
	pSet1->add_model_partition(new partition(b3), 3014);
	pSet1->add_model_partition(new partition(b4), 1957);
	pSet1->add_model_partition(new partition(b5), 0);
	pSet1->add_model_partition(new partition(b6), 0);
	pSet1->setNullWeight(9556);
	model_sets.push_back(pSet1);

	std::vector<bool> mask2(400, 1);
	for (int i = 100; i < 200; ++i) {mask2[i] = 0; }
	masked_model_partition_set* pSet2 = new masked_model_partition_set(mask2);
	pSet2->add_model_partition(new partition(b1), 8621);
	pSet2->add_model_partition(new partition(b2), 0);
	pSet2->add_model_partition(new partition(b3), 2374);
	pSet2->add_model_partition(new partition(b4), 3199);
	pSet2->add_model_partition(new partition(b5), 0);
	pSet2->add_model_partition(new partition(b6), 0);
	pSet2->setNullWeight(8262);
	model_sets.push_back(pSet2);


	std::vector<bool> mask3(400, 1);
	for (int i = 200; i < 300; ++i) {mask3[i] = 0; }
	masked_model_partition_set* pSet3 = new masked_model_partition_set(mask3);
	pSet3->add_model_partition(new partition(b1), 5763);
	pSet3->add_model_partition(new partition(b2), 4127);
	pSet3->add_model_partition(new partition(b3), 0);
	pSet3->add_model_partition(new partition(b4), 5557);
	pSet3->add_model_partition(new partition(b5), 0);
	pSet3->add_model_partition(new partition(b6), 0);
	pSet3->setNullWeight(7009);
	model_sets.push_back(pSet3);


	std::vector<bool> mask4(400, 1);
	for (int i = 300; i <400; ++i) {mask4[i] = 0; }
	masked_model_partition_set* pSet4 = new masked_model_partition_set(mask4);
	pSet4->add_model_partition(new partition(b1), 5012);
	pSet4->add_model_partition(new partition(b2), 4672);
	pSet4->add_model_partition(new partition(b3), 5846);
	pSet4->add_model_partition(new partition(b4), 0);
	pSet4->add_model_partition(new partition(b5), 0);
	pSet4->add_model_partition(new partition(b6), 0);
	pSet4->setNullWeight(6926);
	model_sets.push_back(pSet4);

	admixture_parameters* pParams1 = new admixture_parameters();
	pParams1->addEvent(event(0.025, 0.2, 1, 2, 3));
	pParams1->addEvent(event(0.25, 0.7, 0, 2, 1));
	pParams1->addEvent(event(0.5, 1, 0, -1, 2));
	pParams1->setTheta(1400.0);
	printf("%s\n", pParams1->toString().data());
	std::vector<std::vector<double> > lambda_bench = estimate_lambda(&coalescence, pParams1, model_sets);

	for (int a = 0; a < (int)lambda_bench.size(); ++a) {
		printf("[ ");
		for (int b = 0; b < (int)lambda_bench[a].size(); ++b) {
			printf("%.0f ", lambda_bench[a][b]);
		}
		printf("] ");
	}
	printf("\n");
	for (int a = 0; a < (int)model_sets.size(); ++a) {
		printf("[ ");
		for (int b = 0; b < (int)model_sets[a]->size(); ++b) {
			printf("%.0f ", (double)model_sets[a]->getPartition(b)->getWeight());
		}
		printf("%.0f ", (double)model_sets[a]->getNullWeight());
		printf("] ");
	}
	printf("\n");

	double t1[5] = { 0.025, 0.03, 0.04, 0.06, 0.20 };
	double t2[5] = { 0.25, 0.20, 0.30, 0.60, 0.74 };
	double t3[5] = { 0.5, 0.4, 1.0, 1.7, 3.16 };
	double a1[5] = { 0.2, 0.25, 0.30, 0.35, 0.40 };
	double a2[5] = { 0.7, 0.65, 0.60, 0.85, 0.8 };
	double theta[5] = { 1400.0, 1300.0, 1200.0, 1100.0, 1000.0 };
	for (int i = 0; i < 5; ++i) {
		for (int j = 0; j < 5; ++j) {
			for (int k = 0; k < 5; ++k) {
				for (int l = 0; l < 5; ++ l) {
					for (int m = 0; m < 5; ++m) {
						for (int n = 0; n < 5; ++n) {

							if (t1[i] <= t2[j] && t2[j] <= t3[k]) {
								event e1(t1[i], a1[l], 1, 2, 3);
								event e2(t2[j], a2[m], 0, 2, 1);
								event e3(t3[k], 1, 0, -1, 2);
								admixture_parameters* pParams = new admixture_parameters();
								pParams->addEvent(e1);
								pParams->addEvent(e2);
								pParams->addEvent(e3);
								pParams->setTheta(theta[n]);


								std::vector<std::vector<double> > lambda = estimate_lambda(&coalescence, pParams, model_sets);

								double llh = likelihood_ratio(lambda, lambda_bench, pParams, model_sets);


								for (int a = 0; a < (int)lambda.size(); ++a) {
									printf("[ ");
									for (int b = 0; b < (int)lambda[a].size(); ++b) {
										printf("%.0f ", lambda[a][b]);
									}
									printf("] ");
								}
								printf(" (THETA: %.0f %s LLH: %.1f)\n", pParams->getTheta(), pParams->toString().data(), llh);
								output << t1[i] << " " << t2[j] << " " << t3[k] << " " << a1[l] << " ";
								output << a2[m] << " " << theta[n] << " " << llh << std::endl;
							}
						}
					}
				}
			}
		}
	}
}

double likelihood_test::likelihood_ratio(const std::vector<std::vector<double> >& lambdas, const std::vector<std::vector<double> >& lambdas_bench, admixture_parameters* pParam, std::vector<masked_model_partition_set*> model_sets) {

	double llh = 0.0;

	for (int k = 0; k < ((int)model_sets.size()); ++k) {
		masked_model_partition_set* pSet = model_sets[k];
		int n = pSet->size();
		for (int i = 0; i < n; ++i) {
			double llhi = (pSet->getPartition(i)->getWeight() == 0) ? 0 :
				((log(lambdas[k][i])-log(lambdas_bench[k][i]))*(double)pSet->getPartition(i)->getWeight() - (lambdas[k][i]-lambdas_bench[k][i]));
			llh += llhi;
			//printf("%.0f ", llhi);
		}
		double llhi = ((double)pSet->getNullWeight() == 0.0) ? 0 : (log(lambdas[k][n]) - log(lambdas_bench[k][n]))*(double)pSet->getNullWeight() - (lambdas[k][n] - lambdas_bench[k][n]);
		//printf("%.0f ", llhi);
		llh += llhi;
	}

	return llh;
}

// Enumerate through the genealogy and computes the weights for each
// model bipartitions
const std::vector<std::vector<double> > likelihood_test::estimate_lambda(coalescence_algorithm* pCoalescence, admixture_parameters* pParams, const std::vector<masked_model_partition_set*>& model_sets) {

	std::vector<std::vector<double> > weights_total(model_sets.size());
	std::vector<std::vector<double> > lambdas(model_sets.size());
	for (int k = 0; k < ((int)model_sets.size()); ++k) {
		weights_total[k] = std::vector<double>(model_sets[k]->size()+1, 0.0);
		lambdas[k] = std::vector<double>(model_sets[k]->size()+1, 0.0);
	}

	// determine a skip map
	std::vector<std::vector<bool> > skip_map;
	for (int k = 0; k < (int)model_sets.size(); ++k) {
		masked_model_partition_set* pSet = model_sets[k];
		skip_map.push_back(std::vector<bool>((int)pSet->size(), 0));


		// first check if the model bipartition is informative at all
		for (int j = 0; j < (int) pSet->size(); ++j) {
			if (!is_informative(pSet->getPartition(j), pSet->getMask())) {
				skip_map[k][j] = 1;
			}
		}

		for (int j = 0; j < (int) pSet->size(); ++j) {
			if (!skip_map[k][j]) {
				for (int i = j+1; i < (int)pSet->size(); ++i) {
					if (!skip_map[k][i]) {
						if (is_identical_or_mirror(pSet->getPartition(j), pSet->getPartition(i), pSet->getMask())) {
							skip_map[k][i] = 1;
						}
					}
				}
			}
		}
	}


	double penalty = log2(8);
	int m = 400;
	int numGenealogies = 30;
	for (int r = 0; r < numGenealogies; ++r) {
		genealogy* pGenealogy = pCoalescence->generate_genalogy_fixed(pParams);
		for (int k = 0; k < ((int)model_sets.size()); ++k) {
			masked_model_partition_set* pSet = model_sets[k];
			std::vector<bool> mask = pSet->getMask();

			// Assumes that the last branch is a root that we do not care about
			for (int i = 0; i < (pGenealogy->getBranchSize()-1); ++i) {
				genealogy_branch* pBranch = pGenealogy->getBranch(i);

				std::vector<double> scores(pSet->size()+1, 0.0);
				double bestScore = m*compute_entropy(pBranch->getPartition(), mask);
				double bestModelIdx = (int)pSet->size();
				scores[bestModelIdx] = bestScore;

				for (int j = 0; j < (int)pSet->size(); ++j) {
					if (!skip_map[k][j]) {
						double score = m*compute_conditional_entropy(pBranch->getPartition(), pSet->getPartition(j), mask) + penalty;
						if (score < bestScore) {
							bestScore = score;
							bestModelIdx = j;
						}
						scores[j] = score;
					}
				}


				double t = pBranch->getEndTime() - pBranch->getStartTime();
				if (t <= 0) {
					printf("%d %d (%f-%f) has negative or zero time\n", i, pBranch->getUID(), pBranch->getEndTime(), pBranch->getStartTime());
					fflush(stdout);
				}
				weights_total[k][bestModelIdx] += t;
			}
		}

		delete pGenealogy;
	}

	for (int k = 0; k < ((int)model_sets.size()); ++k) {
		for (int i = 0; i < (int)weights_total[k].size(); ++i) {
			lambdas[k][i] = (weights_total[k][i] / (double)numGenealogies) * pParams->getTheta();
		}
	}
	return lambdas;
}
