import tree
import random
from util import *
import numpy as np

"""
Different ways to evaluate the output of an algorithm. 
"""

def eval_quartets(D1, D2, numSamples = 0, tolerance=1):
    """
    Check the number of quartet tests that disagree. 
    This is expensive. 
    """
    assert D1.data.shape[0] == D2.data.shape[0]
    n = D1.data.shape[0]
    correct = 0
    total = 0
    if numSamples == 0:
        ## do exact evaluations, look at all quartets
        for w in range(n):
            for x in range(w, n):
                for y in range(x, n):
                    for z in range(y, n):
                        total += 1
                        if is_star(D1, w, x, y, z, tolerance=float(tolerance)/2) and is_star(D2, w, x, y, z, tolerance=float(tolerance)/2):
                            correct += 1
                        elif pairs_with(D1, w, x, y, z, tolerance=float(tolerance)/2) and pairs_with(D2, w, x, y, z, tolerance=float(tolerance)/2):
                            correct += 1
                        elif pairs_with(D1, w, y, x, z, tolerance=float(tolerance)/2) and pairs_with(D2, w, y, x, z, tolerance=float(tolerance)/2):
                            correct += 1
                        elif pairs_with(D1, w, z, x, y, tolerance=float(tolerance)/2) and pairs_with(D2, w, z, x, y, tolerance=float(tolerance)/2):
                            correct += 1
        return float(correct)/float(total)
    for i in range(numSamples):
        [w, x, y, z] = random.sample(range(n), 4)
        if is_star(D1, w, x, y, z) and is_star(D2, w, x, y, z):
            correct += 1
        elif pairs_with(D1, w, x, y, z) and pairs_with(D2, w, x, y, z):
            correct += 1
        elif pairs_with(D1, w, y, x, z) and pairs_with(D2, w, y, x, z):
            correct += 1
        elif pairs_with(D1, w, z, x, y) and pairs_with(D2, w, z, x, y):
            correct += 1
        total += 1
    return float(correct)/float(total)


def eval_relative_error(D1, D2):
    """
    Compute the relative error between the two distance matrices. 
    returns a list of these values. 
    """
    assert D1.shape[0] == D2.shape[0]
    n = D1.shape[0]
    errs = []
    Diff = np.abs(D1 - D2)
    for i in range(n):
        for j in range(i, n):
            if D1[i,j] > 0:
                errs.append(float(Diff[i,j])/D1[i,j])
    return errs


def eval_absolute_error(D1, D2):
    """
    Compute the absolute error between the two matrices. 
    Returns a list of these values.
    """
    assert D1.shape[0] == D2.shape[0]
    n = D1.shape[0]
    Diff = np.abs(D1 - D2)
    errs = []
    for i in range(n):
        for j in range(i,n):
            errs.append(float(Diff[i,j]))
    return errs

