/*
 * consensus_algorithm.cpp
 *
 */

// My include
#include "consensus_algorithm.h"

// Standard Includes
#include <algorithm>
#include <limits.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>

// Includes
#include "../dmo/partition.h"
#include "../utilities/entropy_utilities.h"
#include "../utilities/partition_utilities.h"

consensus_algorithm::consensus_algorithm(const std::vector<partition*> parts, int maxParts, int maxIters, double penalty) :
	partitions(parts),
	maxIters(maxIters),
	maxParts(maxParts),
	penalty(penalty)
{
	// Initialize
	this->score.score = 0;
	this->n = (int)this->partitions.size();
	this->m = (int)this->partitions[0]->size();


	// Initialize entropies
	for (int i = 0; i < (int)parts.size(); ++i) {
		double h = ((double)this->m)*compute_entropy(parts[i]);
		this->entropies.push_back(h);
	}
	printf("Finished Initializing Observed Entropies\n");
	fflush(stdout);

}

consensus_algorithm::consensus_algorithm(const std::vector<partition*> parts, const std::vector<partition*> models) :
	partitions(parts),
	models(models),
	maxIters(0),
	maxParts(0),
	penalty(0.0)
{
	this->n = (int)this->partitions.size();
	this->m = (int)this->partitions[0]->size();
	this->score = Score(this->n, (int)models.size()+1);

	// Initialize entropies
	for (int i = 0; i < (int)parts.size(); ++i) {
		double h = ((double)this->m)*compute_entropy(parts[i]);
		this->entropies.push_back(h);
	}
	printf("Finished Initializing Observed Entropies\n");
	fflush(stdout);
	compute(models, score);
	printf("Finished Initializing Scores\n");
	fflush(stdout);
}

consensus_algorithm::~consensus_algorithm() {
	for (int i = 0; i < (int)this->models.size(); ++i) {
		delete this->models[i];
	}
	this->models.clear();
}

//------------------------------------------------------------------------------------
// Methods

// Run the algorithm
void consensus_algorithm::run() {
	//std::vector<partition*> models;

	// Compute the initial score
	this->score = Score(this->n, 1);
	this->compute(models, this->score);

	int iterations = 0;
	while (iterations < this->maxIters) {
		time_t rawtime;
		struct tm * timeinfo;
		time ( &rawtime );
		timeinfo = localtime ( &rawtime );

		printf("%s (%d/%d):", asctime(timeinfo), iterations, this->maxIters);
		if ((this->maxParts == 0 || (int)models.size() < this->maxParts) && (models.size() == 0 || (rand() % 100) < 90)) {
			int pidx = this->sampleObserved(score);

			partition* pModel = new partition(*(this->partitions[pidx]));
			std::pair<partition*, Score> output(insertModelPartition(models, pModel));
			printf("Inserting partition: %s\n", pModel->getBase64String().data());

			if (output.second.score < score.score) {
				models.push_back(pModel);
				score = output.second;
			} else {
				delete pModel;
			}
		} else {

			fflush(stdout);
			int pidx = this->sampleObserved(score);
			int midx = this->sampleModel();
			partition* pModel = new partition(*(this->partitions[pidx]));
			std::pair<partition*, Score> output(replaceModelPartition(models, pModel, midx));

			printf("(%d) Replacing %d-th partition with: %s\n", (int)this->models.size(), midx, pModel->getBase64String().data());

			if (output.second.score < score.score) {
				delete models[midx];
				models[midx] = pModel;
				score = output.second;
			} else {
				delete pModel;
			}
		}

		std::vector<int> weights((int)this->models.size()+1, 0);
		for (int i = 0; i < (int)this->score.assignments.size(); ++i) {
			int assign = this->score.assignments[i];
			if (assign >= 0) {
				weights[assign]++;
			}
		}
		printf("Score: %.1f [%d]\n", score.score, weights[0]);
		for (int i = 0; i < (int)models.size(); ++i) {
			printf("%d: %s [%d]\n", i, models[i]->getBase64String().data(), weights[i+1]);
		}
		fflush(stdout);

		iterations++;
	}

	// This section of the code tries to patch additional bipartition if maxPart is not
	// reached
	if ((int)this->models.size() < this->maxParts) {
		printf("Filling Up Empty Slots");
		// Add missing model bipartition if there exists one
		std::vector<int> pop_assign(this->m, 0);
		int popid = 0;
		for (int j = 0; j < (int)this->models.size(); ++j) {
			partition* pPart = this->models[j];
			int p0 = -1;
			bool conflict = 0;
			for (int i = 0; i < this->m; ++i) {
				if (!pPart->get(i)) {
					if (p0 < 0) {
						p0 = pop_assign[i];
					} else if (p0 != pop_assign[i]) {
						conflict = 1;
						break;
					}
				}
			}

			popid++;
			for (int i = 0; i < this->m; ++i) {
				if (pPart->get(j) == conflict) {
					pop_assign[j] = popid;
				}
			}
		}

		std::vector<partition*> cands;
		std::vector<bool> pts(this->m, 0);
		popid = pop_assign[0];
		for (int i = 0; i < this->m; ++i) {
			if (popid != pop_assign[i]) {
				bool duplicate = 0;
				partition* pPart = new partition(pts);
				for (int j = 0; j < (int)this->models.size(); ++j) {
					if (pPart == this->models[j]) {
						duplicate = 1;
						break;
					}
				}
				if (!duplicate) {
					cands.push_back(pPart);
				} else {
					delete pPart;
				}
				pts = std::vector<bool>(this->m, 0);
			}
			pts[i] = 1;
		}
		partition* pPart = new partition(pts);
		cands.push_back(pPart);

		//
		int k = std::min(this->maxParts - (int)this->models.size(), (int)cands.size());
		for (int i = 0; i < k; ++i) {
			std::pair<partition*, Score> output(insertModelPartition(models, cands[i]));
			models.push_back(cands[i]);
			score = output.second;
		}
	}

	// If we are not limiting the number of model bipartition, we will construct a set of
	if (this->maxParts <= 0) {
		constructFullPartitionSet();

		score = Score(this->n, (int)models.size()+1);
		compute(models, score);
	}
}

// Obtain the model bipartitions
masked_model_partition_set* consensus_algorithm::getModelPartitionSet(const std::vector<bool>& mask, bool filterNull, bool addAdditionalPartition) {

	Score* pScore = NULL;
	if (addAdditionalPartition) {
		constructFullPartitionSet();
	}

	Score score(this->n, (int)this->models.size()+1);
	if (mask.size() == 0 && !addAdditionalPartition) {
		pScore = &this->score;
	} else {
		compute(this->models, score, mask, filterNull);
		pScore = &score;
	}

	std::vector<int> weights((int)this->models.size()+1, 0);
	for (int i = 0; i < (int)pScore->assignments.size(); ++i) {
		int assign = pScore->assignments[i];
		if (assign >= 0) {
			weights[assign]++;
		}
	}

	masked_model_partition_set* pSet = new masked_model_partition_set(mask);
	for (int i = 0; i < (int)models.size(); ++i) {
		pSet->add_model_partition(models[i], weights[i+1]);
	}
	pSet->setNullWeight(weights[0]);

	return pSet;
}

// Obtain the model score
const double consensus_algorithm::getModelScore() {
	return this->score.score;
}

//------------------------------------------------------------------------------------
// Helpers


// Insert new model bipartition into the set
std::pair<partition*, Score> consensus_algorithm::insertModelPartition(std::vector<partition*>& parts, partition* pNewPart) {

	// create temp parts
	std::vector<partition*> pts = parts;
	pts.push_back(pNewPart);

	// intialize score
	Score score(this->n, (int)pts.size()+1);

	// compute score
	compute(pts, score);

	// perform EM
	performEM(pNewPart, pts, score, (int)pts.size()-1);

	// perform 4 gametes condition
	performGC(pNewPart, pts, score, (int)pts.size()-1);

	// Return result
	return std::pair<partition*, Score>(pNewPart, score);
}

// Replace an existing model bipartition in the set with a new bipartition
std::pair<partition*, Score> consensus_algorithm::replaceModelPartition(std::vector<partition*>& parts, partition* pNewPart, int idx) {
	// create temp parts
	std::vector<partition*> pts = parts;
	pts[idx] = pNewPart;

	// intialize score
	Score score(this->n, (int)pts.size()+1);

	// compute score
	compute(pts, score);

	// perform EM
	performEM(pNewPart, pts, score, idx);

	// perform 4 gametes condition
	performGC(pNewPart, pts, score, idx);

	// Return result
	return std::pair<partition*, Score>(pNewPart, score);
}

// Perform EM on the newly inserted bipartition to improve score
void consensus_algorithm::performEM(partition* pPartition, std::vector<partition*> models, Score& score, int idx) {

	// Initialize a copy of the original score
	Score curScore(score);
	Score newScore(score);

	bool improved = 1;
	while (improved) {
		improved = 0;
		for (int j = 0; j < (int)pPartition->size(); ++j) {
			pPartition->set(j, !pPartition->get(j));

			// Update score
			update(newScore, models, pPartition, idx);

			double news = newScore.score;
			double curs = curScore.score;
			if (news >= curs) {
				pPartition->set(j, !pPartition->get(j));
			} else {
				improved = 1;
				curScore = newScore;
			}
		}
	}

	score = curScore;

}

// Performs four gametes condition to ensure we do not have admixture tree
void consensus_algorithm::performGC(partition* pPartition, std::vector<partition*> models, Score& score, int idx) {

	// initialize gamates array (count how many 00, 01, 10, 11 there is for every model biparts)
	std::vector<std::vector<int> > gametes((int)models.size());
	for (int i = 0; i < (int)models.size(); ++i) {
		gametes[i] = std::vector<int>(4, 0);
		for (int j = 0; j < this->m; ++j) {
			int idx = (models[i]->get(j) ? 2 : 0) + (pPartition->get(j) ? 1 : 0);
			gametes[i][idx]++;
		}
	}

	// determine if there is any conflicts
	bool hasConflict = true;
	while (hasConflict) {
		hasConflict = false;
		int minRow = -1;
		int minCol = -1;
		int minVal = INT_MAX;
		for (int i = 0; i < (int)gametes.size(); ++i) {
			int minCount = INT_MAX;
			int minIndex = -1;
			for (int j = 0; j < (int)gametes[i].size(); ++j) {
				if (minCount > gametes[i][j]) {
					minCount = gametes[i][j];
					minIndex = j;
				}
			}
			if (minCount > 0) {
				if (minVal > minCount) {
					minVal = minCount;
					minRow = i;
					minCol = minIndex;
				}
			}
		}

		if (minRow >= 0) {
			hasConflict = 1;

			// try to flip bit that will eliminate gametes violation
			bool sourceBit = (minCol / 2) == 1;
			bool newBit = (minCol % 2) == 1;

			for (int i = 0; i < (int)pPartition->size(); ++i) {
				if (models[minRow]->get(i) == sourceBit &&
					pPartition->get(i) == newBit) {
					pPartition->set(i, !newBit);
				}
			}

			for (int i = 0; i < (int)models.size(); ++i) {
				gametes[i] = std::vector<int>(4, 0);
				for (int j = 0; j < this->m; ++j) {
					int idx = (models[i]->get(j) ? 2 : 0) + (pPartition->get(j) ? 1 : 0);
					gametes[i][idx]++;
				}
			}
		}
	}

	return update(score, models, pPartition, idx);
}


// Return a random index of the observed bipartition
int consensus_algorithm::sampleObserved(const Score& score) {
	double total = score.score;
	double r = ((double)rand()/(double)RAND_MAX)*total;
	double value = 0.0;
	int idx = 0;
	for (int i = 0; i < (this->n-1); ++i) {
		value += score.entropies[score.assignments[i]][i];

		if (value >= r) {
			break;
		} else {
			idx++;
		}
	}
	return idx;
}

// Return a random index of a model bipartition based on the weights
int consensus_algorithm::sampleModel() {

	std::vector<int> weights((int)this->models.size()+1, 0);
	for (int i = 0; i < (int)this->score.assignments.size(); ++i) {
		int assign = this->score.assignments[i];
		if (assign >= 0) {
			weights[assign]++;
		}
	}

	double total = 0.0;
	for (int i = 1; i < (int)weights.size(); ++i) {
		total += 1/(double)weights[i];
	}
	double v = ((double)(rand())/(double)RAND_MAX);
	double cum = 0;
	printf("%f, %f: ", total, v);
	for (int i = 1; i < (int)this->models.size()+1; ++i) {
		cum += (1/(double)weights[i])/total;
		printf("%f, ", cum);
		if (v < cum) {
			printf("\n");
			return i-1;
		}
	}
	printf("\n");


	// This should not happen
	printf("############THIS SHOULD NOT HAPPEN\n");
	return (int)this->models.size()-1;
}

//
void consensus_algorithm::compute(const std::vector<partition*>& parts, Score& score, const std::vector<bool>& mask, bool filterNull ) {
	// constants
	int k = (int)parts.size();
	double m = (double)this->m;
	double pen = log((double)k+2)/log(2.0);

	score.score = this->penalty*k;

	// determine a skip map
	std::vector<bool> skip_map((int)parts.size(), 0);

	// first check if the model bipartition is informative at all
	for (int j = 0; j < (int) parts.size(); ++j) {
		if (!is_informative(parts[j], mask)) {
			skip_map[j] = 1;
		}
	}

	for (int j = 0; j < (int) parts.size(); ++j) {
		if (!skip_map[j]) {
			for (int i = j+1; i < (int) parts.size(); ++i) {
				if (!skip_map[i]) {
					if (is_identical_or_mirror(parts[j], parts[i], mask)) {
						skip_map[i] = 1;
					}
				}
			}
		}
	}

	// now computes best score for each observed
	for (int i = 0; i < this->n; ++i) {

		// Variables
		partition* pObserved = this->partitions[i];
		double bestScore = (mask.size() == 0) ? this->entropies[i] : m*compute_entropy(pObserved, mask);
		int bestAssign = 0;

		if (filterNull && bestScore == 0.0) {
			bestAssign = -1;
		}

		// Null partition
		score.entropies[0][i] = bestScore;

		// Compute cost for each model
		for (int j = 0; j < (int)parts.size(); ++j) {
			if (!skip_map[j]) {
				partition* pModel = parts[j];
				double ch = m*compute_conditional_entropy(pObserved, pModel, mask) + pen;
				score.entropies[j+1][i] = ch;
				if (ch < bestScore) {
					bestScore = ch;
					bestAssign = j+1;
				}
			}
		}

		// Assign the best score and assignment without the new partition
		score.assignments[i] = bestAssign;
		score.score += bestScore;
	}
}

void consensus_algorithm::update(Score& score, std::vector<partition*> parts, partition* pNewPart, int idx) {
	// constants
	int k = (int)parts.size();
	double pen = log(k+2)/log(2);
	double m = (double)this->m;

	score.score = this->penalty*k;
	for (int i = 0; i < this->n; ++i) {
		partition* pObserved = this->partitions[i];
		double ch = m*compute_conditional_entropy(pObserved, pNewPart) + pen;
		score.entropies[idx+1][i] = ch;

		double bestScore = score.entropies[0][i];
		int bestAssign = 0;
		for (int l = 1; l < k+1; ++l) {
			if (score.entropies[l][i] < bestScore) {
				bestScore = score.entropies[l][i];
				bestAssign = l;
			}
		}
		score.assignments[i] = bestAssign;
		score.score += bestScore;
	}
}

// Add any additional partition into model
void consensus_algorithm::constructFullPartitionSet() {
	// Once sorted, perform split operation from the complete set
	std::vector<std::vector<bool> > membership;
	membership.push_back(std::vector<bool>(this->m, 1));
	for (int i = 0; i < (int)models.size(); ++i) {
		partition* pPart = models[i];
		bool doubleSplit = 0; // Indicator that tells us whether this split will split more than one set
		int splitPop = -1; // Indicator that tells us whether we are going to split or not
		for (int j = 0; j < (int)membership.size(); ++j) {
			int states = -1;
			for (int k = 0; k < m; ++k) {
				if (membership[j][k]) {
					if (states >= 0) {
						if (states != pPart->get(k)) {
							if (splitPop >= 0) {
								doubleSplit = 1;
							} else {
								splitPop = j;
							}
							break;
						}
					} else {
						states = pPart->get(k) ? 1 : 0;
					}
				}
			}
		}

		// If there are no double split, then we can proceed to subdivide the individuals
		// into subclusters
		if (!doubleSplit && splitPop >= 0) {
			// Identify individuals that belongs to the new population
			int states = -1;
			std::vector<bool> membervec(m, 0);
			for (int j = 0; j < m; ++j) {
				if (membership[splitPop][j]) {
					if (states >= 0) {
						if (pPart->get(j) != (bool)states) {
							membership[splitPop][j] = 0;
							membervec[j] = 1;
						}
					} else {
						states = pPart->get(j) ? 1 : 0;
					}
				}
			}
			membership.push_back(membervec);
		}
	}

	std::vector<int> pop_assign(m, 0);
	for (int i = 1; i < (int)membership.size(); ++i) {
		for (int j = 0; j < m; ++j) {
			if (membership[i][j]) {
				pop_assign[j] = i;
			}
		}
	}

	// Given the population assignment, generate missing model bipartition that were missing
	std::vector<partition*> candidates;
	std::vector<partition*> sources;
	std::vector<partition*> targets;
	for (int i = 0; i < (int)membership.size(); ++i) {
		candidates.push_back(new partition(membership[i]));
		sources.push_back(candidates[i]);
	}

	fflush(stdout);
	// Merge partitions
	while (sources.size() > 1) {
		for (int i = 0; i < (int)sources.size(); ++i) {
			for (int j = i+1 ; j < (int)sources.size(); ++j) {
				partition* pPart = merge_partition(sources[i], sources[j]);

				bool exists = !pPart->isEmpty();
				if (exists) {
					for (int k = 0; k < (int)targets.size(); ++k) {
						if (*targets[k] == *pPart) {
							exists = 0;
							break;
						}
					}
				}
				if (exists) {
					for (int k = 0; k < (int)candidates.size(); ++k) {
						if (*candidates[k] == *pPart) {
							exists = 0;
							break;
						}
					}
				}


				if (exists) {
					targets.push_back(pPart);
					candidates.push_back(pPart);
				} else {
					delete pPart;
				}
			}
		}
		sources.clear();
		for (int i = 0; i < (int)targets.size(); ++i) {
			sources.push_back(targets[i]);
		}
		targets.clear();

		fflush(stdout);

	}

	for (int i = 0; i < (int)models.size(); ++i) {
		delete models[i];
	}
	this->models.clear();
	for (int i = 0 ; i < (int)candidates.size(); ++i) {
		models.push_back(candidates[i]);
	}
}
