/*
 * Union-Find (Here's the code we wrote in class, 6/26/2013)
 * 15-122 Principles of Imperative Computation, Fall 2012
 * Frank Pfenning
 */

#include <stdlib.h>
#include "unionfind.h"
#include "lib/contracts.h"
#include "lib/xalloc.h"

struct ufs_header {
  int size;
  int A[];
};

/* ufs_elem(eqs, i) if i is a valid element in eqs */
bool ufs_elem(ufs eqs, int i) {
  if (eqs == NULL) return false;
  if (!(0 <= i && i < eqs->size)) return false;
  return true;
}

/* ufs_rep(eqs, i) if i is the canonical representative of
   an equivalence class */
bool ufs_rep(ufs eqs, int i) {
  REQUIRES( ufs_elem(eqs, i) );
  return   i == eqs->A[i];
}

bool is_ufs (ufs eqs) {
  if (eqs == NULL) return false;
  if (!(eqs->size >= 0)) return false; 
  // eqs->size == \length(eqs->A);
  int *A = eqs->A;
  for (int i = 0; i < eqs->size; i++)
    if (!(ufs_elem(eqs, A[i]))) return false;
  return true;
}

bool is_equiv(ufs eqs, int i, int j) {
  REQUIRES(is_ufs(eqs));
  REQUIRES(ufs_elem(eqs, i));
  REQUIRES(ufs_elem(eqs, j));
  int *A = eqs->A;
  int i2 = i;
  int j2 = j;
  while (A[i2] != i2) i2 = A[i2];
  while (A[j2] != j2) j2 = A[j2];
  return i2 == j2;
}

/* singletons(n) returns an eqs where each node
 * [0..n) is in its own equivalence class
 */
ufs singletons(unsigned int n) {
  ufs result = xmalloc(sizeof(struct ufs_header) + n*sizeof(int));
  result->size = n;
  for ( size_t i=0 ; i < n ; i++ ) {
    result->A[i] = i;
  }
  ENSURES(is_ufs(result));
  return result;
}

/* ufs_find(eqs, i) finds the representative of the
 * equivalence class of i
 */
int ufs_find(ufs eqs, int i) {
  REQUIRES(is_ufs(eqs));
  REQUIRES(ufs_elem(eqs, i));
  int j = i; 
  int *A = eqs->A;
  while( A[j] != j ) {
    ASSERT( is_equiv(eqs, i, j) );
    j = A[j];
  }
  ENSURES(is_ufs(eqs));
  ENSURES(ufs_rep(eqs, j));
  return j;
}

/* ufs_union(eqs, i, k) takes the union of equivalence
 * classes of i and k
 */
void ufs_union(ufs eqs, int i, int k) {
  REQUIRES(is_ufs(eqs));
  // make i's canonical rep point to k's canonical rep
  int *A = eqs->A;
  int i_rep = ufs_find(eqs, i);
  int k_rep = ufs_find(eqs, k);
  A[i_rep] = k_rep;

  ENSURES(is_ufs(eqs));
  return;
}

void ufs_free(ufs eqs) {
  REQUIRES(is_ufs(eqs));
  free(eqs);
}
