import numpy as np
from math import sqrt, floor, ceil
import matplotlib.pyplot as plt
import os
from labels import imagenet_labels
import torch
import torchvision
from torchvision import transforms
from PIL import Image

model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
model.eval()

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

image_files = sorted(os.listdir('images'))
feature_vecs = dict()

def get_patterns():
  for filename in image_files:
    input_image = Image.open('images/'+filename)
    input_tensor = preprocess(input_image)
    plt.figure(1)
    plt.clf()
    plt.subplot(1,2,1)
    plt.imshow(input_image)
    plt.axis('off')
    img = input_tensor.detach().numpy().transpose(1,2,0)
    plt.subplot(1,2,2)
    plt.imshow(np.minimum(1.0,np.maximum(0.0,img)))
    plt.xlabel(filename)
    plt.pause(1)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        features = model.features(input_batch)
        classifier_input = torch.nn.functional.adaptive_avg_pool2d(features, 1).reshape(features.shape[0], -1)
        output = model(input_batch)
    probs = torch.nn.functional.softmax(output[0], dim=0)
    prob = probs.max().item()
    index = probs.argmax().item()
    label = imagenet_labels[index]
    name = filename[:filename.rfind('.')]
    feature_vecs[name] = classifier_input.reshape(-1).cpu().numpy()

    print('Top 5 categories:')
    z = sorted(zip(probs.cpu().numpy(),range(1000)), key=lambda x: x[0])
    top5 = list(reversed(z[-5:]))
    for t in top5: print('  %7.5f  %s' % (t[0], imagenet_labels[t[1]]))


def show_patterns():
  nfeats = len(feature_vecs)
  width = ceil(sqrt(nfeats))
  height = ceil(nfeats/width)
  index = 0
  bigmax = max([f.max() for f in feature_vecs.values()])
  print('bigmax=',bigmax)
  plt.figure(2)
  plt.clf()
  for name,feats in sorted(feature_vecs.items()):
    index += 1
    plt.subplot(nfeats,1,index)
    print(index,name)
    plt.imshow(feats.reshape(10,128),vmax=bigmax)
    plt.axis('off')
  plt.pause(0.001)

def show_diff(name1,name2):
  f1 = feature_vecs[name1]
  f2 = feature_vecs[name2]
  diff1 = (f1 - f2).reshape(10,128)
  diff2 = (f2 - f1).reshape(10,128)
  plt.figure(1), plt.clf()
  plt.subplot(2,1,1)
  plt.imshow(diff1)
  plt.subplot(2,1,2)
  plt.imshow(diff2)
  plt.pause(0.001)

def show_matrix():
  f = sorted([(x,y) for x,y in feature_vecs.items()], key=lambda x: x[0])
  fvals = [t[1] for t in f]
  matrix = np.array([ [((x-y)**2).sum()**0.5 for y in fvals] for x in fvals ])
  print('Match errors:')
  print('%10s' % '', end='')
  for label,value in f:
    print('%10s' % label[:10], end='')
  print()
  for i in range(matrix.shape[0]):
    print('%8s  ' % f[i][0], end='')
    for j in range(matrix.shape[1]):
      print('%10.3f' % matrix[i][j], end='')
    print()


get_patterns()
show_patterns()
show_matrix()
