/* Code to generate and display sparse matrices */

#include <stdio.h>
#include <stdlib.h>
#include "csr.h"

#define USE_RAND

/* Generate random (nonzero) data value */
static ftype_t rdata()
{
  /* Randomly choose +1 or -1 */
#ifdef USE_RAND
  unsigned r = rand();
#else
  unsigned r = random();
#endif
  return  (ftype_t) (r & 0x1 ? 1 : -1);
}

/* Row major encoding/decoding */ 
/* Encode r,c as single number */
static int rm_encode(int r, int c, int nrow)
{
  return r*nrow + c;
}

ftype_t get_ele(csr_ptr m, int r, int c)
{
  if (r < 0 || r > m->nrow) {
    fprintf(stderr, "Invalid matrix row %d\n", r);
    return FZERO;
  }
  if (c < 0 || c > m->nrow) {
    fprintf(stderr, "Invalid matrix column %d\n", r);
    return FZERO;
  }
  if (m->type == DENSE)
    return m->val[r*m->nrow + c];
  else {
    int ci;
    int scale = (m->type == SCALED) ? sizeof(ftype_t) : 1;
    for (ci = m->rstart[r]; ci < m->rstart[r+1]; ci++)
      if (m->cindex[ci] == c * scale)
	return m->val[ci];
    return FZERO;
  }
}

static int matrix_equal(csr_ptr m1, csr_ptr m2) {
  int r, c;
  int nrow = m1->nrow;
  if (nrow != m2->nrow || m1->nentries != m2->nentries)
    return 0;
  for (r = 0; r < nrow; r++)
    for (c = 0; c < nrow; c++)
      if (get_ele(m1, r, c) != get_ele(m2, r, c))
	return 0;
  return 1;
}

/* Convert from one form of matrix to another */ 
csr_ptr retype_matrix(csr_ptr m, matrix_t new_type)
{
  csr_ptr result = malloc(sizeof(csr_rec));
  int r, c, i;
  int nrow = m->nrow;
  int scale = (new_type == SCALED ? sizeof(ftype_t) : 1);

  result->type = new_type;
  result->nrow = nrow;
  result->nentries = m->nentries;
  result->val = calloc(new_type == DENSE ? nrow * nrow : m->nentries,
		       sizeof(ftype_t));
  if (new_type != DENSE) {
    /* Sparse */
    result->cindex = calloc(result->nentries, sizeof(int));
    result->rstart = calloc(nrow+1, sizeof(int));
  }
  i = 0;
  for (r = 0; r < nrow; r++) {
    if (new_type != DENSE)
      result->rstart[r] = i; 
    for (c = 0; c < nrow; c++) {
      ftype_t v = get_ele(m, r, c);
      if (new_type == DENSE)
	result->val[i++] = v;
      else if (v != FZERO) {
	result->val[i] = v;
	result->cindex[i++] = c * scale;
      }
    }
  }
  if (new_type != DENSE) {
    if (result->nentries != i)
      fprintf(stderr, "Found entries %d != nominal %d\n",
	      i, result->nentries);
    result->rstart[nrow] = result->nentries;
  }
  return result;
}

/* Generate nrow X nrow dense matrix with nentries nonzero entries */
csr_ptr gen_dense_matrix(int nrow, int nentries)
{
  int r, c, i;
  int nr2 = nrow*nrow;
  csr_ptr result = malloc(sizeof(csr_rec));
  /* Create matrix with the set of possible matrix entries */
  int *indices = calloc(nr2, sizeof(int));

  result->type = DENSE;
  result->nrow = nrow;
  result->nentries = nentries;
  result->val = calloc(nr2, sizeof(ftype_t));


  /* Initialize arrays */
  for (i = 0; i < nr2; i++) {
    indices[i] = i;
    result->val[i] = FZERO;
  }
  /* Select nentries random entries.
     On each step, candidate indices in indices[i .. nr2-1] */
  for (i = 0; i < nentries; i++) {
#ifdef USE_RAND
    int r = rand();
#else
    int r = random();
#endif
    int s = i + r % (nr2 - i); /* between i and nr2-1 */
    int index = indices[s];  /* This is my next index */
    indices[s] = indices[i]; /* Swap unused index into this place */
    indices[i] = index; /* For tracking permutation */
  }
  /* Now generate nonzero data */
  for (i = 0; i < nentries; i++) {
    ftype_t v = rdata();
    int index = indices[i];
    result->val[index] = v;
    /* Debugging */
  /*  printf("Set (%d,%d) to %f\n", index/nrow, index%nrow, (float) v); */
  }
  free(indices);
  return result;
}


csr_ptr gen_matrix(int nrow, int nentries, matrix_t type)
{
  csr_ptr dresult = gen_dense_matrix(nrow, nentries);
  csr_ptr result;
  if (type == DENSE)
    return dresult;
  result = retype_matrix(dresult, type);
  if (!matrix_equal(result, dresult))
    fprintf(stderr, "Oops, invalid transcription\n");
  free_matrix(dresult);
  return result;
}
     
void free_matrix(csr_ptr m)
{
  free(m->val);
  if (m->type != DENSE) {
    free(m->rstart);
    free(m->cindex);
  }
  free(m);
}

/* Convert value (-1, 0, +1) to string */
static char *sval(ftype_t v)
{
  if (v < FZERO)
    return "-1";
  else if (v > FZERO)
    return "+1";
  else return " 0";
}

/* Print matrix.  Assumes -1/+1 entries */
void mprint(csr_ptr m)
{
  int r, c;
  int nrow = m->nrow;
  for (r = 0; r < nrow; r++) {
    for (c = 0; c < nrow; c++) {
      ftype_t v = get_ele(m, r, c);
      printf(" %s ", sval(v));
    } 
    printf("\n");
  }
}

/* Generate random vector of nrow entries */
ftype_t *rvec(int nrow)
{
  ftype_t *result = calloc(nrow, sizeof(ftype_t));
  int i;
  for (i = 0; i < nrow; i++)
    result[i] = rdata();
  return result;
}

void vprint (ftype_t *v, int nrow)
{
  int r;
  printf("[");
  for (r = 0; r < nrow-1; r++)
    printf("%s, ", sval(v[r]));
  printf("%s]\n", sval(v[nrow-1]));
}


void show_prod(csr_ptr m, ftype_t *v, ftype_t *result)
{
  int r, c;
  int nrow = m->nrow;
  for (r = 0; r < nrow; r++) {
    /* Matrix */
    printf("|");
    for (c = 0; c < nrow-1; c++)
      printf("%s, ", sval(get_ele(m, r, c)));
    printf(r == (nrow >> 1) ?
	   "%s| * | %s | = | %f |\n" :
	   "%s|   | %s |   | %f |\n",
	   sval(get_ele(m, r, nrow-1)), sval(v[r]), (float) result[r]);
  }
}

static void show_mult(csr_ptr m, ftype_t *v, int fast)
{
  int nrow = m->nrow;
  ftype_t *result = calloc(nrow, sizeof(ftype_t));
  csr_mult(m, v, result, fast);
  show_prod(m, v, result);
}

static int vector_equal(ftype_t *v1, ftype_t *v2, int nrow)
{
  int r;
  for (r = 0; r < nrow; r++)
    if (v1[r] != v2[r])
      return 0;
  return 1;
}

/* Test all versions of multiplication.  Print message if discrepancy */ 
int test_mult(csr_ptr m, ftype_t *v)
{
  csr_ptr dm = retype_matrix(m, DENSE);
  csr_ptr um = retype_matrix(m, UNSCALED);
  csr_ptr sm = retype_matrix(m, SCALED);
  int nrow = m->nrow;
  int found_error = 0;
  ftype_t *vref = calloc(m->nrow, sizeof(ftype_t));
  ftype_t *vtest = calloc(m->nrow, sizeof(ftype_t));
  /* Create reference with most straightforward version */
  csr_mult(dm, v, vref, 0);
  csr_mult(dm, v, vtest, 1);
  if (!vector_equal(vref, vtest, nrow)) {
    if (!found_error) {
      printf("Reference product:\n");
      show_prod(dm, v, vref);
      found_error = 1;
    }
    printf("!= Dense, fast product\n");
    show_prod(dm, v, vtest);
  }

  csr_mult(um, v, vtest, 0);
  if (!vector_equal(vref, vtest, nrow)) {
    if (!found_error) {
      printf("Reference product:\n");
      show_prod(dm, v, vref);
      found_error = 1;
    }
    printf("!= Unscaled, slow product\n");
    show_prod(um, v, vtest);
  }
  csr_mult(um, v, vtest, 1);
  if (!vector_equal(vref, vtest, nrow)) {
    if (!found_error) {
      printf("Reference product:\n");
      show_prod(dm, v, vref);
      found_error = 1;
    }
    printf("!= Unscaled, fast product\n");
    show_prod(um, v, vtest);
  }

  csr_mult(sm, v, vtest, 0);
  if (!vector_equal(vref, vtest, nrow)) {
    if (!found_error) {
      printf("Reference product:\n");
      show_prod(sm, v, vref);
      found_error = 1;
    }
    printf("!= Scaled, slow product\n");
    show_prod(sm, v, vtest);
  }
  csr_mult(sm, v, vtest, 1);
  if (!vector_equal(vref, vtest, nrow)) {
    if (!found_error) {
      printf("Reference product:\n");
      show_prod(dm, v, vref);
      found_error = 1;
    }
    printf("!= Scaled, fast product\n");
    show_prod(sm, v, vtest);
  }
  free_matrix(dm);
  free_matrix(sm);
  free_matrix(um);
  return !found_error;
}






