/* 
 * Union-Find
 * 15-122 Principles of Imperative Computation
 * Frank Pfenning
 *
 * This version balances, but does not do path compression, so it
 * it should only be of complexity O(n*log(n)) for n unions.
 */

#include <stdlib.h>
#include <limits.h>
#include "unionfind.h"
#include "lib/contracts.h"
#include "lib/xalloc.h"

struct ufs_header {
  int size;
  int A[];
};

/* ufs_elem(eqs, i) if i is a valid element in eqs */
bool ufs_elem(ufs eqs, int i) {
  if (eqs == NULL) return false;
  if (!(0 <= i && i < eqs->size)) return false;
  return true;
}

/* ufs_rep(eqs, i) if i is the representative of an equivalence class */
bool ufs_rep(ufs eqs, int i) {
  if(!ufs_elem(eqs, i)) return false;
  return (eqs->A)[i] < 0;
}

bool depth_bounded(ufs eqs, int i)
{
  int *A = eqs->A;
  int size = eqs->size;
  int k = i;
  int d = 1;
  while (A[k] >= 0) {
    d++;
    if(!(ufs_elem(eqs, A[k]))) return false;
    if(!(d <= size)) return false; // This indicates a cycle
    k = A[k];
  }

  /* !ufs_elem(eqs, k) means index k out of range */
  if(!(ufs_elem(eqs, k))) return false;

  /* -A[k] < d means path exceeds depth bound */
  int depth_bound = -A[k];
  if(!(depth_bound == INT_MIN || depth_bound >= d)) return false;

  return true;
}

bool is_ufs (ufs eqs)
{
  if (eqs == NULL) return false;
  if (!(eqs->size >= 0)) return false;
  // eqs->size == \length(eqs->A);
  for (int i = 0; i < eqs->size; i++)
    if (!depth_bounded(eqs,i))
      return false;
  return true;
}

/* singletons(n) returns an eqs where each node
 * [0..n) is in its own equivalence class
 */
ufs singletons(unsigned int n) {
  /* We can use malloc here instead of calloc because we're going to 
   * initialize the whole array anyway. */
  ufs eqs = xmalloc(sizeof(struct ufs_header) + sizeof(int)*n);
  int *A = eqs->A;
  for (unsigned int i = 0; i < n; i++)
    A[i] = -1;
  eqs->size = n;
  ENSURES(is_ufs(eqs));
  return eqs;
}

bool is_equiv(ufs eqs, int i, int j) {
  REQUIRES(is_ufs(eqs));
  REQUIRES(ufs_elem(eqs, i));
  REQUIRES(ufs_elem(eqs, j));
  int *A = eqs->A;
  int i2 = i;
  int j2 = j;
  while (A[i2] >= 0) i2 = A[i2];
  while (A[j2] >= 0) j2 = A[j2];
  return i2 == j2;
}

/* ufs_find(eqs, i) finds the representative of the
 * equivalence class of i
 */
int ufs_find(ufs eqs, int i) {
  REQUIRES(is_ufs(eqs));
  REQUIRES(ufs_elem(eqs, i));
  int j = i;
  int *A = eqs->A;
  while (A[j] >= 0) {
    ASSERT(is_equiv(eqs, i, j));
    j = A[j];
  }

  /* A[i] = j; weak path compression */
  /* strong compression would redirect all intermediate notes to j */

  ENSURES(is_ufs(eqs));
  ENSURES(ufs_rep(eqs, j));
  return j;
}

/* ufs_union(eqs, i, k) takes the union of equivalence
 * classes of i and k
 */
void ufs_union(ufs eqs, int i, int k) {
  REQUIRES(is_ufs(eqs));
  REQUIRES(ufs_elem(eqs, i));
  REQUIRES(ufs_elem(eqs, k));

  int irep = ufs_find(eqs, i);
  int krep = ufs_find(eqs, k);
  int *A = eqs->A;
  if (irep == krep) return;
  if (A[irep] < A[krep]) {	/* i has greater depth */
    A[krep] = irep;		/* depth remains the same */
  } else if (A[irep] == A[krep]) {
    A[krep] = irep;		/* direction is arbitrary */
    A[irep]--;			/* depth increases by one */
  } else {			/* k has greater depth */
    A[irep] = krep;		/* depth remains the same */
  }

  ENSURES(is_ufs(eqs));
  ENSURES(is_equiv(eqs, i, k));
  return;
}

void ufs_free(ufs eqs) {
  REQUIRES(is_ufs(eqs));
  free(eqs);
}
