#include <stdlib.h>
#include "app.h"
#include <stdio.h>

#define VERBOSE 0
#define DEBUG 0
#define LOG_DIR "/afs/cs/usr/hongsuda/pvm3/source/logs/"

/*
int THRESHOLD_LENGTH = 20;
int THRESHOLD_WIDTH = 200;
*/
/* ------------------------ global variables ------------------- */

int gtid;                 /* process's take id */

block_type gblock1;  
block_type gblock2;
block_type gresult_block; /* the result of the multiplication */
int gsender_tid;          /* sender's tid */
int greceiver_tid;        /* receiver's tid */
int gnum_processes;       /* number of processes */
int gchild_number;
int gnum_operations;
int gdimension;
int gparent_id;           /* master process's id */
FILE *gfp;   /* global file pointer */
int gcurrent_num_subblocks;

/* ------------------------------------------------------------ */
int cal_num_subblocks(block_type block1, int len, int width){
  int block_len = get_length(block1);
  int block_width = get_width(block1);
  int num_operations = ceil(block_len/(double)len) * ceil(block_width/(double)width);
  
  return(num_operations);

}
/* ------------------------------------------------------------- */
/* given a block and the coordinate of the subblock, pack_subblock
   packs the matrix elements of that subblock to the send buffer */
void pack_subblocks(block_type block1, int i1, int j1, int i2, int j2){
  int i, j, row, col;
  int len = i2 - i1 + 1;
  int width = j2 - j1 + 1;
  element_type *start_row; 
  int size; 
  
  pvm_pkint(&i1, 1, 1);
  pvm_pkint(&j1, 1, 1);
  pvm_pkint(&i2, 1, 1);
  pvm_pkint(&j2, 1, 1);
  
  row = i1 - block1.upper_left.i;
  col = j1 - block1.upper_left.j;
  
  size = sizeof(element_type) * width;

  for (i=0; i<len; i++){
    start_row = &((block1.data_block)[row][col]);
    pvm_pkbyte((char *)start_row, size, 1);
    row = row+1;
  }
}

/* ------------------------------------------------------------- */
void fill_block(block_type dest_block){
  int i, j;
  int i1, j1, i2, j2;
  int len;
  int width;
  int size;
  int start_row, start_col;
  element_type *start_point;
  char *tmp_row = (char *)malloc(size);
  
  pvm_upkint(&i1, 1, 1);
  pvm_upkint(&j1, 1, 1);
  pvm_upkint(&i2, 1, 1);
  pvm_upkint(&j2, 1, 1);
  len = i2 - i1 + 1;
  width = j2 - j1 + 1;
  size = width * sizeof(element_type);

  start_row = i1 - dest_block.upper_left.i;
  start_col = j1 - dest_block.upper_left.j;
  
  for (i=0; i<len; i++){
    start_point = &((dest_block.data_block)[start_row][start_col]);
    pvm_upkbyte((char *)start_point, size, 1);
    start_row++;
  }
  
}

/* ------------------------------------------------------------- */
void fprint_block(FILE* fp, block_type block){
  int i, j;
  element_type *row;
  
  int len = block.lower_right.i - block.upper_left.i + 1;
  int width = block.lower_right.j - block.upper_left.j + 1; 
  fprintf(fp, "upper_left = (%d, %d), lower_right = (%d, %d) => ", block.upper_left.i, block.upper_left.j, block.lower_right.i, block.lower_right.j);
  fprintf(fp, "len = %d, width = %d \n", len, width);
  
  for (i=0; i<len; i++){
    row = (block.data_block)[i];
    for (j=0; j<width; j++){
      fprintf(fp," %f ", row[j]);
    }
    fprintf(fp,"\n");
  }
  fprintf(fp,"\n");
  
}
/* ------------------------------------------------------------- */
void init_slave_result_block(block_type block1){
  int len = get_length(block1);
  int width = get_width(block1);
  int i, j;
  element_type *tmp_data;
  
  gresult_block.upper_left = block1.upper_left;
  gresult_block.lower_right = block1.lower_right;
  
  gresult_block.data_block = allocate_matrix(len, width);
  
  for (i=0; i<len; i++){
    tmp_data = (gresult_block.data_block)[i];
    for (j=0; j<width; j++)
      tmp_data[j] = 0;
  }
}

/* ------------------------------------------------------------- */
int multiply_block2(block_type result_block, block_type block1, block_type block2){
  element_type sum = 0;
  int len1, len2, width1, width2;
  int i, j, k, x, y;
  element_type *tmp_row;

  len1 = get_length(block1);
  len2 = get_length(block2);
  width1 = get_width(block1);
  width2 = get_width(block2);
  
  if (width1 != len2){
    fprintf(stderr, "Error: Inappropriate matrix size \n");
    return(-1);
  }
  
  for (i=0, y=0; i<len1; y++, i++){
    tmp_row = (result_block.data_block)[y];
    for (j=0, x=(block2.upper_left).j; j<width2; x++, j++){
      sum = 0.0;
      for (k=0; k<width1; k++)
	sum = sum + (block1.data_block)[i][k] * (block2.data_block)[k][j];
      tmp_row[x] = sum;
    }
  }
  
  return(0);
}
/* ------------------------------------------------------------- */
int multiply_block(block_type result_block, block_type block1, block_type block2){
  element_type sum = 0;
  int len1, len2, width1, width2;
  int i, j, k, k2,x, start_col;
  element_type *tmp_row;

  len1 = get_length(block1);
  len2 = get_length(block2);
  width1 = get_width(block1);
  width2 = get_width(block2);
  
  start_col = block2.upper_left.i;
  
  for (i=0; i<len1; i++){
    tmp_row = (result_block.data_block)[i];
    for (j=0, x=(block2.upper_left).j; j<width2; x++, j++){
      sum = 0.0;
      for (k=start_col, k2=0; k2<len2; k++, k2++)
	sum += (block1.data_block)[i][k] * (block2.data_block)[k2][j];
      tmp_row[x] += sum;
    }
  }
  
  return(0);
}

/* ------------------------------------------------------------- */

void send_subblocks(block_type block1, int num_subblocks){
  int i1, j1, i2, j2;
  int lower_right_i = block1.lower_right.i;
  int lower_right_j = block1.lower_right.j;
  block_type a_block;
  int total_bytes;
  int temp_int;

  i1 = block1.upper_left.i;

/* ---- */
  pvm_initsend(ENCODING);
  pvm_pkint(&gcurrent_num_subblocks, 1, 1);
  pvm_pkint(&block1.upper_left.i, 1, 1);
  pvm_pkint(&block1.upper_left.j, 1, 1);
  pvm_pkint(&block1.lower_right.i, 1, 1);
  pvm_pkint(&block1.lower_right.j, 1, 1);
  pvm_send(greceiver_tid,EXCHANGE_DATA_TAG);
  
  /* ----- */
  
#if 1
  while (i1 <= lower_right_i){
    i2 = i1 + THRESHOLD_LENGTH - 1;
    if (i2 > lower_right_i)
      i2 = lower_right_i;
    
    j1 = block1.upper_left.j;
    while (j1 <= lower_right_j){
      j2 = j1 + THRESHOLD_WIDTH - 1;
      if (j2 > lower_right_j)
	j2 = lower_right_j;
      
      pvm_initsend(ENCODING);
      pack_subblocks(block1, i1, j1, i2, j2);
      pvm_send(greceiver_tid,EXCHANGE_DATA_TAG);
      
      j1 = j2+1;
    }
    i1 = i2+1;
  }
#else

  pvm_initsend(ENCODING);
  pack_for_send(block1, &total_bytes);
  pvm_send(greceiver_tid,EXCHANGE_DATA_TAG);

#endif

}
/* ------------------------------------------------------------- */
block_type unpack_subblocks(int tag){
  int len, width;
  int i;
  block_type ablock;

  /* the overhead data of a big block */
  pvm_recv(gsender_tid, tag);
  pvm_upkint(&gcurrent_num_subblocks, 1, 1);
  pvm_upkint(&(ablock.upper_left.i), 1, 1);
  pvm_upkint(&(ablock.upper_left.j), 1, 1);
  pvm_upkint(&(ablock.lower_right.i), 1, 1);
  pvm_upkint(&(ablock.lower_right.j), 1, 1);
  len = get_length(ablock);
  width = get_width(ablock);
  ablock.data_block = allocate_matrix(len, width);
  
  /* get the content of that big block */
  for (i=0; i<gcurrent_num_subblocks; i++){
    pvm_recv(gsender_tid, tag);
    
    fill_block(ablock); 
  }

  return(ablock);
}

/* ------------------------------------------------------------- */

/* ------------------------------------------------------------- */
/* Everybody is doing recv then send ---> not efficient..
 IF one get stuck, the whole thing get stuck*/
void compute_phase1(){
  int i, counter;
  int total_bytes;
  int num_subblocks;
  int len, width;
  block_type temp_subblock;
  int x, y;

  for (counter=0; counter<gnum_processes; counter++){

#if 0
    pvm_recv(gsender_tid, EXCHANGE_DATA_TAG);
    pvm_upkint(&gcurrent_num_subblocks, 1, 1);
    pvm_upkint(&gblock2.upper_left.i, 1, 1);
    pvm_upkint(&gblock2.upper_left.j, 1, 1);
    pvm_upkint(&gblock2.lower_right.i, 1, 1);
    pvm_upkint(&gblock2.lower_right.j, 1, 1);
    len = get_length(gblock2);
    width = get_width(gblock2);
    gblock2.data_block = allocate_matrix(len, width);

    
    for (i=0; i<gcurrent_num_subblocks; i++){
      pvm_recv(gsender_tid, EXCHANGE_DATA_TAG);
    
      fill_block(gblock2); 
    }
#else 
    gblock2 = unpack_subblocks(EXCHANGE_DATA_TAG);

#endif

    send_subblocks(gblock2, gcurrent_num_subblocks);
    
    multiply_block(gresult_block, gblock1, gblock2);
/*    (gresult_block.data_block)[0][0] = (double)gcurrent_num_subblocks; */

    free(gblock2.data_block);
  }
}
/* ------------------------------------------------------------- */
/* Even child number  will do rec then send 
 Odd child number will do send then recv */
void compute_phase2(){
  int i;
  int total_bytes;
  block_type tmp_block;

  pvm_recv(gsender_tid, EXCHANGE_DATA_TAG);
  gblock2 = unpack_for_recv(&total_bytes);
  
  for (i=0; i<gnum_operations-1; i++){

    multiply_block(gresult_block, gblock1, gblock2);
    tmp_block = gblock2;
    
    pvm_recv(gsender_tid, EXCHANGE_DATA_TAG);
    gblock2 = unpack_for_recv(&total_bytes);
    
    pvm_initsend(ENCODING);
    pack_for_send(tmp_block, &total_bytes);
    pvm_send(greceiver_tid, EXCHANGE_DATA_TAG);
    free(tmp_block.data_block);
  }
  
  multiply_block(gresult_block, gblock1, gblock2);
  pvm_initsend(ENCODING);
  pack_for_send(gblock2, &total_bytes);
  pvm_send(greceiver_tid, EXCHANGE_DATA_TAG);
  free(gblock2.data_block);
}
/* ------------------------------------------------------------- */
void compute_phase(){
  compute_phase1();
/*  if ((gchild_number%2) == 0)
    compute_phase2();
  else compute_phase1();
*/
}

/* ------------------------------------------------------------- */
/* Everybody is doing recv then send ---> not efficient..
 IF one get stuck, the whole thing get stuck*/
void compute_phase_for_last_child(){
  int i, counter;
  int total_bytes;
  int num_subblocks;
  int len, width;
  block_type temp_subblock;
  int x, y;

  for (counter=0; counter<gnum_processes; counter++){

#if 0
    pvm_recv(gsender_tid, EXCHANGE_DATA_TAG);
    pvm_upkint(&gcurrent_num_subblocks, 1, 1);
    pvm_upkint(&gblock2.upper_left.i, 1, 1);
    pvm_upkint(&gblock2.upper_left.j, 1, 1);
    pvm_upkint(&gblock2.lower_right.i, 1, 1);
    pvm_upkint(&gblock2.lower_right.j, 1, 1);
    len = get_length(gblock2);
    width = get_width(gblock2);
    gblock2.data_block = allocate_matrix(len, width);


    for (i=0; i<gcurrent_num_subblocks; i++){
      pvm_recv(gsender_tid, EXCHANGE_DATA_TAG);

      fill_block(gblock2);

    }
#else 
    gblock2 = unpack_subblocks(EXCHANGE_DATA_TAG);
#endif
    
    multiply_block(gresult_block, gblock1, gblock2); 
/*    (gresult_block.data_block)[0][0] = num_subblocks; */
    free(gblock2.data_block);  
  }
}
/* ------------------------------------------------------------- */
/* doing only receiver, never have to rotate anything */
void compute_phase_for_last_child_orig(){
  int i;
  int total_bytes;
  
  for (i=0; i<gnum_operations; i++){
    pvm_recv(gsender_tid, EXCHANGE_DATA_TAG);
    gblock2 = unpack_for_recv(&total_bytes);
    
    multiply_block(gresult_block, gblock1, gblock2);
    free(gblock2.data_block);
#if DEBUG
    fprint_block(gfp, gresult_block);
    fflush(gfp);
#endif
  }
}
/* ------------------------------------------------------------- */
void test_init_slave_block(){
  block_type block1;
  
  init_block(&block1, 6, 6);
  init_slave_result_block(block1);
  print_block(gresult_block);
}


/* ------------------------------------------------------------- */
main(argc, argv)
     int argc;
     char **argv;
{
  int mytid;   /* my task id */
  int bufid;
  int i, j;
  int int_array[NUM_ARGC];
  char filename[100];
  char *data_buffer;
  int total_bytes;
  int iterations;
  int dtid;

  /* enroll in pvm */
  mytid = pvm_mytid();
  gtid = mytid;
  gparent_id = pvm_parent();
  
# if 0
  sprintf(filename, "%sslave.%x",LOG_DIR, gtid); 
  gfp = fopen(filename, "w");
#else
  gfp = stdout;
#endif

#if DEBUG
  fprintf(gfp, "task id is %x \n", gtid);
  fflush(gfp);
#endif

  /* tell parent I am ready */
  pvm_setopt(PvmRoute, PvmRouteDirect);
  pvm_initsend(ENCODING);
  pvm_send( gparent_id, SLAVE_TAG );
  
  /* get the arguments from the master processes */
  bufid = pvm_recv(-1, ARGV_TAG); 
  pvm_upkint(int_array, NUM_ARGC, 1);
  gsender_tid = int_array[0];
  greceiver_tid = int_array[1];
  gnum_processes = int_array[2];
  gchild_number = int_array[3];
  gnum_operations = int_array[4];
  
#if DEBUG
  fprintf(gfp, "sender = %x, receiver = %x, num_processes = %d, child_number = %d \n", gsender_tid, greceiver_tid, gnum_processes, gchild_number);
  fflush(gfp);

  fprintf(gfp, "before doing the recv \n");
  fflush(gfp);

#endif
  

#if 1
  /* get the first initial data from the master processes */
  bufid = pvm_recv(-1, INIT_DATA_TAG);
  gblock1 = unpack_for_recv(&total_bytes);
#else
  gblock1 = unpack_subblocks(INIT_DATA_TAG);
#endif

  init_slave_result_block(gblock1);
  
#if DEBUG
  fprintf(gfp, "after doing the recv init_data \n");
  fprint_block(gfp, gblock1);
  fflush(gfp);
#endif

  if (gchild_number == gnum_processes-2){
    fprintf(gfp, "it is the last child \n");
    fflush(gfp);
    compute_phase_for_last_child();
  }  else {
    fprintf(gfp, "it is not the last child \n");
    fflush(gfp);
    compute_phase();
  }

  /* send out the result block to the master process */
  pvm_initsend(ENCODING);
  pack_for_send(gresult_block, &total_bytes);
  pvm_send(gparent_id, SLAVE_TAG);
  
  pvm_exit();
}






