#!/usr/bin/env python3

import argparse
import copy
import filecmp
import glob
import json
import os
import os.path
import random
import shutil
import time
import sys
import numpy as np
import math
import struct
import itertools

from functools import wraps
import errno
import os
import signal

import pickle

from contextlib import contextmanager
import sys, os

autoDir = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(1, autoDir)

maxRunTime = 60 # Max number of seconds per sub-test.
eps = 1e-5

def parse_images(filename):
    f = open(filename,"rb");
    magic,size = struct.unpack('>ii', f.read(8))
    sx,sy = struct.unpack('>ii', f.read(8))
    X = []
    for i in range(size):
        im =  struct.unpack('B'*(sx*sy), f.read(sx*sy))
        X.append([float(x)/255.0 for x in im]);
    return np.array(X);

def parse_labels(filename):
    one_hot = lambda x, K: np.array(x[:,None] == np.arange(K)[None, :],
                                    dtype=np.float64)
    f = open(filename,"rb");
    magic,size = struct.unpack('>ii', f.read(8))
    return one_hot(np.array(struct.unpack('B'*size, f.read(size))), 10)

# Just use the first 10 examples to quickly run checks.
X = parse_images(os.path.join(autoDir, "train-images-idx3-ubyte"))
y = parse_labels(os.path.join(autoDir, "train-labels-idx1-ubyte"))
X = X[:10]
y = y[:10]

with open(os.path.join(autoDir, 'expected.pkl'), 'rb') as f:
    expected = pickle.load(f)

initW = expected['initW']
initb = expected['initb']

layer_sizes = [784, 200, 100, 10]
f_relu = lambda x : (np.maximum(0,x), (x>=0).astype(np.float64))
f_lin = lambda x : (x, np.ones(x.shape))
f = [f_relu]*(len(layer_sizes)-2) + [f_lin]

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--solutionDir', default='.')
    args = parser.parse_args()

    print("=== Homework 2 ===")
    sys.path.insert(1, args.solutionDir)
    try:
        global student
        import cls as student
    except Exception as e:
        print("""
The following error occurred when trying to load your 'cls' module:

------
{}
------

Check to see if your tgz was created properly without
a directory structure. Listing the contents like this
should result in the following output.
If you see a directory structure here, re-create a
tgz without the directory structure and submit again.

    $ tar tf handin.tgz
    writeup.pdf
    cls.py""".format(e))
        sys.exit(-1)

    scores = {}
    scores['softmax_gd_theta'] = test_softmax_gd_theta()
    scores['softmax_sgd_theta'] = test_softmax_sgd_theta()
    scores['nn_z'] = test_nn_z()
    scores['nn_loss'] = test_nn_loss()
    scores['nn_params'] = test_nn_params()
    out = {'scores': scores}
    print(json.dumps(out))

def printTimeoutMsg():
    print('''
-----
  + Your code timed out after {} seconds.
    Try to find a bottleneck in your code and improve it.

    You can run a line-profiler on your code like:
    https://github.com/rkern/line_profiler
-----'''.format(maxRunTime))

class TimeoutError(Exception): pass

class timeout:
    def __init__(self, seconds=1, error_message='Timeout'):
        self.seconds = seconds
        self.error_message = error_message
    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)
    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)
    def __exit__(self, type, value, traceback):
        signal.alarm(0)

def test(tag, maxScore):
    def test_decorator(func):
        def func_wrapper():
            print("\n=== Start {} Test ===".format(tag))
            score = func()
            print("=== End {} Test. Score: {}/{} ===\n".format(tag, score, maxScore))
            return score

        return func_wrapper
    return test_decorator


@test('softmax_gd_theta', 15)
def test_softmax_gd_theta():
    Theta = student.softmax_gd(X, y, X, y, epochs=1, alpha=0.5)
    if np.linalg.norm(expected['softmax_gd_theta'] - Theta) <= eps:
        return 15
    else:
        return 0

@test('softmax_sgd_theta', 15)
def test_softmax_sgd_theta():
    Theta = student.softmax_sgd(X, y, X, y, epochs=1, alpha=0.01)
    if np.linalg.norm(expected['softmax_sgd_theta'] - Theta) <= eps:
        return 15
    else:
        return 0

def all_equal(expected, vals):
    vals_flat = list(itertools.chain(*vals))
    exp_flat = list(itertools.chain(*expected))
    return np.all([np.linalg.norm(exp_i-i) <= eps
                   for exp_i, i in zip(exp_flat, vals_flat)])

@test('nn_z', 5)
def test_nn_z():
    z = student.nn(X[0], initW, initb, f)
    if all_equal(expected['nn_z'], z):
        return 5
    else:
        return 0

@test('nn_loss', 5)
def test_nn_loss():
    L, dW, db = student.nn_loss(X[0], y[0], initW, initb, f)
    e_loss = expected['nn_loss']

    score = 0
    if e_loss['L'] - L <= eps: score += 1
    if all_equal(e_loss['dW'], dW): score += 2
    if all_equal(e_loss['db'], db): score += 2
    return score

@test('nn_params', 10)
def test_nn_params():
    W = [Wi.copy() for Wi in initW]
    b = [bi.copy() for bi in initb]
    student.nn_sgd(X, y, X, y, W, b, f, epochs=1, alpha=0.01)
    # # nn_sgd should update W and b in-place.

    score = 0
    if all_equal(expected['learnedW'], W): score += 5
    if all_equal(expected['learnedb'], b): score += 5
    return score

if __name__=='__main__':
    main()
