/*
 * cfr_player.c
 * author: Kevin Waugh (waugh@cs.cmu.edu)
 */

#include "abstraction.h"
#include "game.h"
#include "sequence_form.h"
#include "strategy.h"
#include "util.h"
#include "verify.h"

#include "cfr_player.h"

void * create_cfr(sequence_form_t * restrict sequence_form, int who, int T, 
		  const char * restrict params) {
  cfr_player_t * restrict player;
  
  assert(sequence_form);
  assert(valid_player(who));
  assert(T >= 0);

  player                = xmalloc(sizeof(cfr_player_t));
  player->sequence_form = sequence_form;
  player->who           = who;
  player->sigma         = new_uniform_strategy(sequence_form_sequences(sequence_form, who));
  player->regret        = new_strategy(sequence_form_sequences(sequence_form, who));

  return player;
}

void free_cfr(cfr_player_t * restrict player) {
  assert(player);
  xfree(player->regret);
  xfree(player->sigma);
  xfree(player);
}

const double * get_strategy_cfr(cfr_player_t * restrict player) {
  assert(player);
  return player->sigma;
}

void update_strategy_full_cfr(cfr_player_t * restrict player, double * restrict payoffs) {
  double sum, ev;
  int i, info_set, j;
  sequences_t * restrict sequences;

  assert(player);
  assert(payoffs);

  sequences = sequence_form_sequences(player->sequence_form, player->who);

  for(i=abstract_info_sets(sequences_abstraction(sequences))-1; i>=0; --i) {
    info_set = info_set_order(sequences, i);
    sum = 0.;
    for(j=first_sequence(sequences, info_set);
	j<=last_sequence(sequences, info_set); ++j) {
      if (player->regret[j] > 0) {
	sum += player->regret[j];
      }
    }
    ev = 0.;
    if (sum > 0.) {
      sum = 1/sum;
      for(j=first_sequence(sequences, info_set);
	  j<=last_sequence(sequences, info_set); ++j) {
	if (player->regret[j] > 0) {
	  ev += payoffs[j]*player->regret[j]*sum;
	}
	player->regret[j] += payoffs[j];
      }
    } else {
      sum = 1./num_sequences(sequences, info_set);
      for(j=first_sequence(sequences, info_set);
	  j<=last_sequence(sequences, info_set); ++j) {
	ev += payoffs[j]*sum;
	player->regret[j] += payoffs[j];
      }
    }
    payoffs[parent_sequence(sequences, info_set)] += ev;
    sum = 0.;
    for(j=first_sequence(sequences, info_set);
	j<=last_sequence(sequences, info_set); ++j) {
      player->regret[j] -= ev;
      if (player->regret[j] > 0) {
	sum += player->regret[j];
      }
    }
    if (sum > 0.) {
      sum = 1/sum;
      for(j=first_sequence(sequences, info_set);
	j<=last_sequence(sequences, info_set); ++j) {
	if (player->regret[j] > 0) {
	  player->sigma[j] = player->regret[j]*sum;
	} else {
	  player->sigma[j] = 0;
	}
      }
    } else {
      sum = 1./num_sequences(sequences, info_set);
      for(j=first_sequence(sequences, info_set);
	j<=last_sequence(sequences, info_set); ++j) {
	player->sigma[j] = sum;
      } 
    }
  }
  
  normalize_strategy(sequences, player->sigma, 1.);
}
