/* 

Hillclimbing Code
with starts initialized with the Tree Stuff.

*/

#include "walksat.c"
#include <float.h>

#define MAX(a,b) ((a) > (b) ? (a) : (b)) 
 

#include <stdlib.h> 
#include <stdio.h> 
#include <assert.h> 
#include <math.h> 
#include <sys/time.h> 
#include <string.h> 


int print = 0;

#if 0
#define VARIABLES 700

#define LENGTH (VARIABLES)
#define NODES (VARIABLES)
#endif

int VARIABLES;
int LENGTH;
int NODES;

#define RUNS 25

/* #define NUMBER_OF_HC_BEFORE_TREE 100000 */
#define NUMBER_TO_UPDATE_FROM 1000
#define NUMBER_TO_TAKE_FROM_EACH_RUN 100

int USE_TREE = 1;
int NUMBER_OF_START_POINTS_TO_GENERATE=100;

/* -------------------------------------------------------- */
int *history [NUMBER_TO_UPDATE_FROM];
double history_val [NUMBER_TO_UPDATE_FROM];
int *local_history[NUMBER_TO_TAKE_FROM_EACH_RUN];
double local_history_val[NUMBER_TO_TAKE_FROM_EACH_RUN];

long int divide_by; 
/*int *children[LENGTH];*/
/*int *count_ar[LENGTH];*/
/*int num_child[LENGTH];*/
short **children;
short **count_ar;
int *num_child;

int total_examples = NUMBER_TO_UPDATE_FROM;
int head; 
double BEST_EVER = -1.0; 

/*int bestMatchInTree[LENGTH];*/
/*float bestMatchMutInf[LENGTH];*/
int *bestMatchInTree;
float *bestMatchMutInf;

/* -------------------------------------------------------- */

/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/


int randBit(double p) 
{ 
  double t; 

  t = random () % divide_by; 
  t /= (double) divide_by; 
  if (t < p) return (1); 
  else return (0); 
} 




float counts (int a,int b,int c,int d) {

  if (c == 0) {
    if (d == 0) {
      return ( ((float) (total_examples -
                         ((count_ar[b][b] - count_ar[a][b]) +
                          (count_ar[a][a] - count_ar[a][b]) +
                          (count_ar[a][b])
                          )))  + 0.0001);
    }
    else {
      return ( ((float)(count_ar[b][b] - count_ar[a][b])) + 0.0001);
    }
  }
  else {
    if (d == 0) {
      return ( ((float)count_ar[a][a] - count_ar[a][b]) + 0.0001);
    }
    else {
      return ( ((float)count_ar[a][b]) + 0.0001);
    }
  }
}


void PBIL_init () { 
  int i; 
  time_t t;  
  static int first_time = 1;

  LENGTH = VARIABLES;
  NODES = VARIABLES;

  children = malloc (sizeof (short*) * LENGTH);
  count_ar = malloc(sizeof (short*) * LENGTH);
  num_child = malloc(sizeof (int) * LENGTH);
  bestMatchInTree = malloc(sizeof (int) * LENGTH);
  bestMatchMutInf = malloc(sizeof (float) * LENGTH);

  for (i = 0; i <LENGTH; i ++) { 
    int j;
    if (first_time)count_ar[i] = malloc (sizeof (short) * LENGTH);
    if (first_time)children[i] = malloc (sizeof (short) * LENGTH);
    for (j = 0; j  < LENGTH; j ++) { 
      count_ar[i][j] = 0.0;
    } 
  } 

  for (i= 0; i < NUMBER_TO_UPDATE_FROM; i ++) {
    int j;
    if (first_time)history[i] = malloc (sizeof (int) * LENGTH);
    history_val[i] = -DBL_MAX;

    for (j = 0; j  < LENGTH; j ++) { 
      history[i][j] =  randBit (0.5);
    }
  }

  for (i= 0; i < NUMBER_TO_TAKE_FROM_EACH_RUN; i ++) {
    if (first_time)local_history[i] = malloc (sizeof (int) * LENGTH);
    local_history_val[i] = -DBL_MAX;
  }
    BEST_EVER =-DBL_MAX;
    first_time = 0;
} 



float condProb(int a, int b, int bval) 
     /* P(xa = 1 | xb = bval) */ 
{ 
  float v1 = counts(a,b,1,bval); 
  float v0 = counts(a,b,0,bval); 
  float toReturn = v1/(v0 + v1); 
  return toReturn; 
} 

float uncondProb(int a) 
{ 
  float v1 = counts(a,0,1,0) + counts(a,0,1,1); 
  float v0 = counts(a,0,0,0) + counts(a,0,0,1); 
  float toReturn = v1/(v0+v1); 

  return toReturn; 
} 

float mutual (int in, int out) { 

  float pout = uncondProb(out); 
  float outH, outHIn0, outHIn1; 
  float poutIn0 = condProb(out, in, 0); 
  float poutIn1 = condProb(out, in, 1); 
  float pin = uncondProb(in); 
  float answer; 

  outH = ((pout == 0.0 || pout == 1.0) ? 0.0 : 
          -pout * log(pout) - (1.0-pout)*log(1.0-pout)); 
  outHIn0 = ((poutIn0 == 0.0 || poutIn0 == 1.0) ? 0.0 : 
             -poutIn0 * log(poutIn0) - (1.0-poutIn0) * log(1.0-poutIn0)); 

  outHIn1 = ((poutIn1 == 0.0 || poutIn1 == 1.0) ? 0.0 : 
             -poutIn1 * log(poutIn1) - (1.0-poutIn1) * log(1.0-poutIn1)); 
  answer = outH - pin * outHIn1 - (1.0-pin) * outHIn0; 
  return answer; 
} 

void PBIL_use_new_history () {
  int i,j,k;

  for (i = 0; i < LENGTH; i ++) {
    for (j = 0; j < LENGTH; j ++) {
      count_ar[i][j] = 0;
    }
  }

  for (k =0; k< NUMBER_TO_UPDATE_FROM; k ++) {
    for (i = 0; i < LENGTH; i ++) {
      if (history[k][i]) {
        for (j = 0; j < LENGTH; j ++) {
          if (history[k][j]) count_ar[i][j] += 1;
        }
      }
    }
  }
}




int select_head () { 
  int i; 
  int which = 0; 
  float goodP = .5; 
  for (i = 1; i < LENGTH; i ++) { 
    float p = uncondProb(i); 

    if ((fabs (p - 0.5)) > (fabs (goodP - 0.5))) { 
      which = i; 
      goodP = p; 
    } 
  } 

  return (which); 
} 

int PBIL_make_tree () { 

  int in_tree [LENGTH]; 
  int i,j, n_in_tree, last_added; 
  int best_in, best_out; 
  float best_mut = -1.0; 


  last_added = head = select_head (); 

  for (i = 0; i < LENGTH; i ++) { 
    num_child[i] = 0; 
    in_tree[i] = 0; 
    bestMatchInTree[i] = head;
    bestMatchMutInf[i] = mutual(head, i);
  } 
  n_in_tree = 1; 
  in_tree[head] = 1; 


  while (n_in_tree != LENGTH) { 
    int j,k; 

    /* find next node to add */
    for (j =0; j < LENGTH; j ++) { 
      if (!in_tree [j]) { 
        if (best_mut < bestMatchMutInf[j]) {
          best_mut = bestMatchMutInf[j];
          best_in = bestMatchInTree[j];
          best_out = j;
        }
      } 
    } 
    
    best_mut = -1.0; 
    children [best_in][num_child[best_in]] = best_out; 
    in_tree[best_out] = 1; 
    num_child[best_in] ++; 
    n_in_tree ++; 
    last_added = best_out; 



    /* update the best matching mutual information for
       all nodes which are not in the tree */

    for (j = 0; j < LENGTH; j++) {
      if (! in_tree[j]) {
        float mut = mutual(last_added, j);
        if (mut > bestMatchMutInf[j]) {
          bestMatchMutInf[j] = mut;
          bestMatchInTree[j] = last_added;
        }
      }
    }
  }
}






void PBIL_generate_new_helper( int *a, int parent, int child) 
{ 
  double childP = condProb((int) child, (int) parent, a[parent]); 
  int i; 

  a[child] = randBit(childP); 
  for (i = 0; i < num_child[child]; i++) 
    { 
      PBIL_generate_new_helper(a, (int)child, (int)children[child][i]); 
    } 
} 



void PBIL_sample_from_tree ( int *a) { 
  double headP = uncondProb(head); 
  int i; 

  a[head] = randBit(headP); 

  for (i = 0; i < num_child[head]; i++) 
    { 
      PBIL_generate_new_helper(a, (int)head, (int)children[head][i]); 
    } 
} 




/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/
/* ---------- HILLCLIMBING CODE -------------------------------------------*/
/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/
/* ------------------------------------------------------------------------*/




void keep_top (int *newOne, double val) {
  int i;
  int min_pos;
  double min;

  /*printf("in keep_top, val = %f\n", val);*/

  min = local_history_val[0];
  min_pos = 0;


  /* find the minimum */
  for (i= 1; i < NUMBER_TO_TAKE_FROM_EACH_RUN; i ++) {
    if (local_history_val[i] < min) {
      min = local_history_val[i];
      min_pos = i;
    }
  }
  
  if (val < min) {
    return;
  }
  else {
    /* check for unique */
    for (i= 0; i < NUMBER_TO_TAKE_FROM_EACH_RUN; i ++) {    
      if (memcmp (local_history[i], newOne, sizeof (int) * LENGTH) == 0) return;
    }
    /*printf("Adding to top: ");
    for (i = 0; i < LENGTH; i++)
      {
	printf("%d", newOne[i]);
      }
    printf("  %f\n", val);*/
    memcpy (local_history[min_pos], newOne, sizeof (int) * LENGTH);
    local_history_val[min_pos] = val;
  }

}

  
int exact_copy_in_history (int *x, double val) {
  int j;

  for (j =0; j < NUMBER_TO_UPDATE_FROM; j ++) {
    if (fabs (val - history_val[j]) < 0.000001) {
      if (memcmp (history [j], x, sizeof (int) * LENGTH) == 0) { 
        return (1);
      }
    }
  }
  return (0);
}




void add_top_to_data () {
  int i,j;
  int add;
  int added = 0;

  for (i= 0; i < NUMBER_TO_TAKE_FROM_EACH_RUN; i ++) {

    /* check if there is a copy in the history */
    add = ! (exact_copy_in_history (local_history [i], local_history_val[i]));

    if (add) {
      double min = history_val[0];
      int min_pos = 0;


      /* find the minimum */
      for (j =1; j < NUMBER_TO_UPDATE_FROM; j ++) {
        if (history_val[j] < min) {
          min = history_val[j];
          min_pos = j;
        }
      }
      
/*      printf ("in add -> %f %f \n", local_history_val[i], min);*/
      if (local_history_val[i] > min) {
/*        printf ("better: %f %f %d %d \n", local_history_val[i],
                min,
                i,min_pos); */
        memcpy (history [min_pos], local_history [i], sizeof (int) * LENGTH);
        history_val [min_pos] = local_history_val[i];
        added ++;
      }
    }
  }
  printf ("%d added from the hc run \n", added);
}






#if 0
int hc (int *vector) {
  double i ;
  int k,j;
  int best[LENGTH];
  double best_value = -1;
  int changed = 0, changed_pos, last_change = 0;
  int touched[LENGTH];
  int evaluated =0;



  for (j = 0; j  < LENGTH; j ++) { 
    touched[j] = j;
    best[j] =vector[j];
  } 

  for (j =0; j < NUMBER_TO_TAKE_FROM_EACH_RUN; j ++) {
    local_history_val[j] = -1;
  }
  
         
         
  /************************************************
    /* PERTURB **************************************
    *************************************************/


  while (last_change < HC_MOVES_BEFORE_RESTART) {
    evaluated ++;    
    changed_pos = random () % (LENGTH);
    changed = touched[changed_pos];
    
    best[changed] = 1 - best[changed];
    i = walksatEval (best);
    keep_top (best, i);         

    if (i  >= best_value) {
      if (i > best_value) {
        last_change = 0;
      }
      best_value = i;
    }
    else {
      best[changed] = 1 - best[changed];
      last_change ++;
    }
    
    if (best_value > BEST_EVER) {
/*
      if (((int) (1.0/best_value)) < ((int) (1.0/BEST_EVER)))
        printf ("---------- best EVER:%f %f\n", best_value, 1.0/best_value);
*/
      BEST_EVER = best_value;
    }
  }
  fprintf (stderr, "                      ******** best in run %f %f \n", best_value, 1.0/best_value);
  printf ("                      ******** best in run %f %f \n", best_value, 1.0/best_value);

  return (evaluated);
}

#endif 

float prop_through_tree (int example[VARIABLES],  
                          int n,
                          int parent, 
                          int parent_val) {
  int i;
  float total = 0;


  for (i =0; i < num_child[n]; i ++) {
    total += (prop_through_tree (example,
                             children[n][i],
                             n,
                             example[n]));
  }
  if (parent != -1) {
    if (example[n]) {
      total += log (condProb (n, parent, parent_val));
    }
    else {
      total += log (1.0 - condProb (n, parent, parent_val));
    }
  }
  else {
    if (example[n]) {
      total += log (uncondProb (n));
    }
    else {
      total += log (1.0 - uncondProb (n));
    }
  }
  return (total);
}







void PBIL_sample_from_tree_many (int *a) {
  double temp, best_likelihood;
  int vector[LENGTH],i;


  for (i = 0; i < NUMBER_OF_START_POINTS_TO_GENERATE ; i ++) {
    PBIL_sample_from_tree (vector);
/*    temp = (double) prop_through_tree (vector, head, -1, -1);*/
    temp = walksatEval (vector);
    if (i == 0 || temp > best_likelihood) {
      memcpy (a, vector, sizeof(int) * LENGTH);
      best_likelihood = temp;
    }
  }

}


void get_parameters (int argc, char *argv[]) { 

  int cur_arg = 1;

  while (cur_arg < argc) {
    if (strcmp (argv[cur_arg], "-tree") == 0) {
      cur_arg ++;
      USE_TREE = atoi (argv[cur_arg]);
      cur_arg ++;
    }
    else if (strcmp (argv[cur_arg], "-generate") == 0) {
      cur_arg ++;
      NUMBER_OF_START_POINTS_TO_GENERATE = atoi (argv[cur_arg]);
      cur_arg ++;
    }
  }


}



void main (int argc, char *argv[]) { 
  int i, j, r;
  int temp;
  int* example;
  int done = 0;
  int* start_pos;
  int count_evals = 0;
  int seed;

  get_parameters (argc, argv);

  divide_by = (long int) pow (2.0,30.0);


  printf ("%s\n", (USE_TREE) ? "---------------------TREE-----------" :
                               "************* NO TREE ******************" );

  printf ("\n\n NUMBER_TO_UPDATE_FROM %d \n NUMBER_TO_TAKE_FROM_EACH_RUN %d\n NUMBER_OF_START_POINTS_TO_GENERATE %d\n SEED %d\n\n\n", NUMBER_TO_UPDATE_FROM, NUMBER_TO_TAKE_FROM_EACH_RUN, NUMBER_OF_START_POINTS_TO_GENERATE, seed);

    for (r =0; r < RUNS; r ++) {

      /*setUpEvaluation ();*/
      walksat_init(argc, argv);
      VARIABLES = numatom;
      PBIL_init ();
      example = malloc(sizeof(int)*VARIABLES);
      start_pos = malloc(sizeof(int)*VARIABLES);

      count_evals = 0;
      while (count_evals < 50000000) {

        if (USE_TREE) {
          printf ("using new history %d \n", count_evals);
          fprintf (stderr, "using new history %d \n", count_evals);
          PBIL_use_new_history ();
          printf ("updating tree \n");
          PBIL_make_tree ();
          
          printf ("sampling tree \n");
          PBIL_sample_from_tree_many (start_pos);
          count_evals += NUMBER_OF_START_POINTS_TO_GENERATE; 
        }
        else {
          int i;

          for (i= 0;i < LENGTH ; i ++) {
            start_pos[i] = randBit (0.5);
          }
        }

        printf ("new walksat \n");
        count_evals += walksatIter (start_pos);
        
        if (USE_TREE) {
          add_top_to_data ();
        }

      }

      printf ("HC_MOVES >>> %%%%%%%%%%%%%%%%% FINAL : %f %f\n", BEST_EVER, 1.0/BEST_EVER);
      fprintf (stderr,"HC_MOVES >>> %%%%%%%%%%%%%%%%% FINAL : %f %f\n", BEST_EVER, 1.0/BEST_EVER);

  }
}



  

