import numpy as np
import os

import util

import gmm

def max_score():
    return 4

def timeout():
    return 60

def test():
    figures_directory = 'figures'
    os.makedirs(figures_directory, exist_ok=True)

    print('Testing GMM on simple 2-D dataset...')

    X = np.array([
            [1, 1], [5, 1], [7, 7], [13,7],
            [5, 3], [1, 3], [13, 13], [7,13]], dtype=np.float32)

    print('X:')
    print(X)

    # Initial variables and graphing code. DO NOT MODIFY.
    K = 2
    pi_list = [0.5, 0.5]
    mu_list = [np.array([1, 1]), np.array([13, 13])]
    Sigma_list = [np.array([[1, 0], [0, 1]]), np.array([[1, 0], [0, 1]])]

    print('Num clusters: {}'.format(K))

    max_iters = 5

    pi_list, mu_list, Sigma_list = gmm.learn_gmm(X, pi_list, mu_list, Sigma_list, max_iters)

    print("The final parameters are: ")
    for k in range(K):
        print("Cluster {}".format(k))
        print("\tpi: {}".format(pi_list[k]))
        print("\tmu: {}".format(mu_list[k]))
        print("\tSigma: [{}, {}; {}, {}]".format(Sigma_list[k][0, 0], Sigma_list[k][0, 1], Sigma_list[k][1, 0], Sigma_list[k][1, 1]))


    filename = '{}/gmm_toy2_K_{}_max_iter_{}'.format(figures_directory, K, max_iters)
    util.plot_gmm(X, mu_list, Sigma_list, show_figure=False, save_filename=filename)

    log_likelihood = gmm.log_likelihood(X, pi_list, mu_list, Sigma_list) 
    expected_log_likelihood = -4.9746618


    assert abs(log_likelihood - expected_log_likelihood) < 1e-3, 'Incorrect log likelihood found. Expected {}, found {}'.format(expected_log_likelihood, log_likelihood)
     
    test_score = max_score()
    test_output = 'PASS\n'

    return test_score, test_output

if __name__ == "__main__":
    test()

