/*
 * strategy.c
 * author: Kevin Waugh (waugh@cs.cmu.edu)
 */

#include <math.h>
#include <stdio.h>
#include <string.h>

#include "abstraction.h"
#include "game.h"
#include "sequence_form.h"
#include "util.h"
#include "verify.h"

#include "strategy.h"

#define EPSILON 1e-10

void read_strategy(sequences_t * restrict sequences, FILE * restrict stream, 
		   double * restrict sigma) {
  char tmp[64];
  int i, j;
  double v;

  assert(sequences);
  assert(stream);
  assert(sigma);

  verify(remove_comment(stream));
  verify(fscanf(stream, "'%[^']'", tmp) == 1);
  verify(!strcmp(tmp, game_name(sequences_game(sequences))));

  verify(remove_comment(stream));
  verify(fscanf(stream, "%d", &i) == 1);
  verify(i == abstraction_who(sequences_abstraction(sequences)));

  verify(remove_comment(stream));
  verify(fscanf(stream, "'%[^']'", tmp) == 1);
  verify(!strcmp(tmp, abstraction_name(sequences_abstraction(sequences))));
  
  verify(remove_comment(stream));
  verify(fscanf(stream, "%d", &i) == 1);
  verify(i == sequences_n(sequences));
  
  for(i=0; i<sequences_n(sequences); ++i) {
    sigma[i] = NAN;
  }
  
  for(i=0; i<sequences_n(sequences); ++i) {
    verify(remove_comment(stream));
    verify(fscanf(stream, "%d:", &j) == 1);
    verify(remove_comment(stream));
    verify(fscanf(stream, "%lf", &v) == 1);

    verify(valid_sequence(sequences, j));
    verify(isnan(sigma[j]));
    verify(!isnan(v));
    sigma[j] = v;
  }
}

void write_strategy(sequences_t * restrict sequences, const double * restrict sigma,
		    FILE * restrict stream) {
  int i;
  
  assert(sequences);
  assert(sigma);
  assert(stream);
  
  fprintf(stream, "'%s' %d '%s' %d\n", 
	  game_name(sequences_game(sequences)),
	  abstraction_who(sequences_abstraction(sequences)),
	  abstraction_name(sequences_abstraction(sequences)),
	  sequences_n(sequences));
  for(i=0; i<sequences_n(sequences); ++i) {
    fprintf(stream, "%d: %g\n", i, sigma[i]);
  }
}

void write_strategy_verbose(sequences_t * restrict sequences, const double * restrict sigma,
			    FILE * restrict stream) {
  int i, info_set, j, k;
  abstraction_t * restrict abstraction;

  assert(sequences);
  assert(sigma);
  assert(stream);
  
  abstraction = sequences_abstraction(sequences);

  fprintf(stream, "'%s' %d '%s' %d\n", 
	  game_name(sequences_game(sequences)),
	  abstraction_who(abstraction),
	  abstraction_name(abstraction),
	  sequences_n(sequences));
  fprintf(stream, "%d: %g\n", 0, sigma[0]);
  for(i=0; i<abstract_info_sets(abstraction); ++i) {
    info_set = info_set_order(sequences, i);
    fprintf(stream, "# abstract info set: %d members:", info_set);
    for(j=0; j<num_members(abstraction, info_set); ++j) {
      k = ith_member(abstraction, info_set, j);
      fprintf(stream, " %d", k);
    }
    fprintf(stream, "\n");
    for(j=first_sequence(sequences, info_set);
	j<=last_sequence(sequences, info_set); ++j) {
      fprintf(stream, "%d: %g\n", j, sigma[j]);
    }
  }
}

int is_strategy(sequences_t * restrict sequences, const double * restrict sigma) {
  int i, j;
  double sum;

  assert(sequences);
  assert(sigma);

  if (fabs(1.-sigma[EMPTY_SEQUENCE]) > EPSILON) {
    return 0;
  }
  for(i=0; i<abstract_info_sets(sequences_abstraction(sequences)); ++i) {
    sum = 0.;
    for(j=first_sequence(sequences, i); j<=last_sequence(sequences, i); ++j) {
      if (sigma[j] < -EPSILON) {
	return 0;
      }
      sum += sigma[j];
    }
    if (fabs(sigma[parent_sequence(sequences, i)]-sum) > EPSILON) {
      return 0;
    }
  }
  
  return 1;
}

void normalize_strategy(sequences_t * restrict sequences, double * restrict sigma, 
			double alpha) {
  int i, info_set, parent;

  assert(sequences);
  assert(sigma);

  sigma[EMPTY_SEQUENCE] = alpha;
  for(i=0; i<abstract_info_sets(sequences_abstraction(sequences)); ++i) {
    info_set = info_set_order(sequences, i);
    parent   = parent_sequence(sequences, info_set);
    normalize(sigma+first_sequence(sequences, info_set),
	      num_sequences(sequences, info_set),
	      sigma[parent]);
  }
}

void lift_strategy(sequences_t * restrict from, sequences_t * restrict to,
		   const double * restrict sigma, double * restrict sigma2) {
  int i, j, k;

  assert(is_coarser(sequences_abstraction(from), sequences_abstraction(to)));
  assert(sigma);
  assert(sigma2);
  
  for(i=0; i<abstract_info_sets(sequences_abstraction(to)); ++i) {
    k = first_sequence(from, abstraction_map(sequences_abstraction(from), 
						      ith_member(sequences_abstraction(to),
									     i, 0)));
    for(j=first_sequence(to, i); j<=last_sequence(to, i); ++j, ++k) {
      sigma2[j] = sigma[k];
    }
  }

  normalize_strategy(to, sigma2, 1.);
}

void flatten_payoffs(sequences_t * restrict from, sequences_t * restrict to,
		     const double * restrict payoffs, double * restrict payoffs2) {
  int i, j, k;

  assert(is_coarser(sequences_abstraction(to), sequences_abstraction(from)));
  assert(payoffs);
  assert(payoffs2);

  zero_strategy(to, payoffs2);
  
  payoffs2[EMPTY_SEQUENCE] = payoffs[EMPTY_SEQUENCE];

  for(i=0; i<abstract_info_sets(sequences_abstraction(from)); ++i) {
    k = first_sequence(to, abstraction_map(sequences_abstraction(to),
						    ith_member(sequences_abstraction(from),
									   i, 0)));
    for(j=first_sequence(from, i); j<=last_sequence(from, i); ++j, ++k) {
      payoffs2[k] += payoffs[j];
    }
  }
}

void purify_strategy(sequences_t * restrict sequences, double * restrict sigma) {
  int i, j, m, info_set;

  assert(sequences);
  assert(sigma);

  for(i=0; i<abstract_info_sets(sequences_abstraction(sequences)); ++i) {
    info_set = info_set_order(sequences, i);
    for(m=first_sequence(sequences, info_set), j=m+1; j<=last_sequence(sequences, info_set); ++j) {
      if (sigma[j] > sigma[m]) {
	m = j;
      }
    }
    for(j=first_sequence(sequences, info_set); j<=last_sequence(sequences, info_set); ++j) {
      sigma[j] = j==m?1.:0.;
    }
  }
  normalize_strategy(sequences, sigma, 1.);
}

void threshold_strategy(sequences_t * restrict sequences, double * restrict sigma, double threshold) {
  int i, j, m, info_set;

  assert(sequences);
  assert(sigma);

  for(i=0; i<abstract_info_sets(sequences_abstraction(sequences)); ++i) {
    info_set = info_set_order(sequences, i);
    if (sigma[parent_sequence(sequences, info_set)] > 1e-10) {
      for(m=first_sequence(sequences, info_set), j=m+1; j<=last_sequence(sequences, info_set); ++j) {
	if (sigma[j] > sigma[m]) {
	  m = j;
	}
      }
      if (sigma[m]/sigma[parent_sequence(sequences, info_set)] < threshold) {
	for(j=first_sequence(sequences, info_set); j<=last_sequence(sequences, info_set); ++j) {
	  if (j != m) {
	    sigma[j] = 0.;
	  }
	}
      } else {
	for(j=first_sequence(sequences, info_set); j<=last_sequence(sequences, info_set); ++j) {
	  if (sigma[j]/sigma[parent_sequence(sequences, info_set)] < threshold) {
	    sigma[j] = 0.;
	  }
	}
      }
    }
  }
  normalize_strategy(sequences, sigma, 1.);
}

void average_strategy(sequences_t * restrict sequences, const double * restrict new, double * restrict average, int t) {
  int i;
  double tau;

  tau = (t-1.)/t;

  for(i=0; i<sequences_n(sequences); ++i) {
    average[i] = tau*average[i] + (1-tau)*new[i];
  }
}

double * sampled_payoffs(sequence_form_t * sequence_form, int who, const double * restrict sigma,
			 double * restrict payoffs, double * restrict chance, int * restrict q) {
  int i, j, k, id;
  sequences_t * restrict me       = sequence_form_sequences(sequence_form, who);
  sequences_t * restrict opp      = sequence_form_sequences(sequence_form, opponent(who));
  game_t * restrict game          = sequences_game(me);

  if (!payoffs) {
    payoffs = new_strategy(me);
  }
  zero_strategy(me, payoffs);
  
  for(q[j=1,i=0]=game_root(game); i!=j; ++i) {
    id = q[i];
    if (terminal_history(game, id)) {
      payoffs[sequence_from_history(me, id)] += sigma[sequence_from_history(opp, id)]*value_for(game, id, who);
    } else if (chance_history(game, id)) {
      k = sample(sigma_chance_history(game, id), num_actions(game, id), chance[chance_depth(game, id)]);
      q[j++] = do_action(game, ith_action(game, id, k));
    } else {
      for(k=first_action(game, id); k<=last_action(game, id); ++k) {
	q[j++] = do_action(game, k);
      }
    }
  }

  return payoffs;
}
