/*
 * Union-Find
 * 15-122 Principles of Imperative Computation
 * Frank Pfenning
 *
 * IMPORTANT
 * This version, written in lecture, does not guarantee to pick the
 * representative of the larger class and the representative of a
 * union, so the find operation can be O(n) where n is the number of
 * vertices in the data structure.
 * 
 * See unionfind-log.c0 for a better version.
 */

#include <stdlib.h>
#include "unionfind.h"
#include "lib/contracts.h"
#include "lib/xalloc.h"

struct ufs_header {
  int size;
  int A[];
};

/* ufs_elem(eqs, i) if i is a valid element in eqs */
bool ufs_elem(ufs eqs, int i) {
  if (eqs == NULL) return false;
  if (!(0 <= i && i < eqs->size)) return false;
  return true;
}

/* ufs_rep(eqs, i) if i is the representative of an equivalence class */
bool ufs_rep(ufs eqs, int i) {
  if(!ufs_elem(eqs, i)) return false;
  return (eqs->A)[i] == i;
}

bool is_ufs (ufs eqs) {
  if (eqs == NULL) return false;
  if (!(eqs->size >= 0)) return false; 
  // eqs->size == \length(eqs->A);
  int *A = eqs->A;
  for (int i = 0; i < eqs->size; i++)
    if (!(ufs_elem(eqs, A[i]))) return false;
  return true;
}

/* singletons(n) returns an eqs where each node
 * [0..n) is in its own equivalence class
 */
ufs singletons(unsigned int n) {
  /* We can use malloc here instead of calloc because we're going to 
   * initialize the whole array anyway. */
  ufs eqs = xmalloc(sizeof(struct ufs_header) + sizeof(int)*n);
  int *A = eqs->A;
  for (unsigned int i = 0; i < n; i++)
    A[i] = (int)i;
  eqs->size = n;
  ENSURES(is_ufs(eqs));
  return eqs;
}

bool is_equiv(ufs eqs, int i, int j) {
  REQUIRES(is_ufs(eqs));
  REQUIRES(ufs_elem(eqs, i));
  REQUIRES(ufs_elem(eqs, j));
  int *A = eqs->A;
  int i2 = i;
  int j2 = j;
  while (A[i2] != i2) i2 = A[i2];
  while (A[j2] != j2) j2 = A[j2];
  return i2 == j2;
}

/* ufs_find(eqs, i) finds the representative of the
 * equivalence class of i
 */
int ufs_find(ufs eqs, int i) {
  REQUIRES(is_ufs(eqs));
  REQUIRES(ufs_elem(eqs, i));
  int j = i;
  int *A = eqs->A;
  while (A[j] != j) {
    ASSERT(is_equiv(eqs, i, j));
    j = A[j];
  }
  ASSERT(is_equiv(eqs, i, j));
  ASSERT(ufs_rep(eqs, j));

  A[i] = j; /* weak path compression */
  /* strong compression would redirect all intermediate notes to j */

  ENSURES(is_ufs(eqs));
  ENSURES(ufs_rep(eqs, j));
  return j;
}

/* ufs_union(eqs, i, k) takes the union of equivalence
 * classes of i and k
 */
void ufs_union(ufs eqs, int i, int k) {
  REQUIRES(is_ufs(eqs));
  int irep = ufs_find(eqs, i);
  int krep = ufs_find(eqs, k);
  (eqs->A)[irep] = krep;	/* should point smaller to larger */
  ENSURES(is_ufs(eqs));
  ENSURES(is_equiv(eqs, i, k));
  return;
}

void ufs_free(ufs eqs) {
  REQUIRES(is_ufs(eqs));
  free(eqs);
}
