import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.models import MobileNetV2
from PIL import Image

model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
model.eval()

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

layer1 = list(model.modules())[3]
layer1_weight = layer1.weight
print('min=',layer1_weight.min().item(),'  max=',layer1_weight.max().item())
l1 = layer1_weight.detach().numpy()
 

def get_image_maps(filename):
  input_image = Image.open(filename)
  input_tensor = preprocess(input_image)
  input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
  with torch.no_grad():
    image_maps = layer1(input_batch)
  return image_maps

def get_3kernel(i):
  w = l1[i].transpose([1,2,0])
  wrange = np.max(np.abs(w))
  w = (w + wrange) / (2*wrange)

  d = np.zeros((3,11,3)) + 0.5

  d[0:3,0:3,0] = w[:,:,0]
  d[0:3,0:3,1] = 1 - w[:,:,0]
  d[0:3,0:3,2] = 1 - w[:,:,0]

  d[0:3,3,:] = [1, 1, 0.5]

  d[0:3,4:7,0] = 1 - w[:,:,1]
  d[0:3,4:7,1] = w[:,:,1]
  d[0:3,4:7,2] = 1 - w[:,:,1]

  d[0:3,7,:] = [1, 1, 0.5]

  d[0:3,8:11,0] = 1 - w[:,:,2]
  d[0:3,8:11,1] = 1 - w[:,:,2]
  d[0:3,8:11,2] = w[:,:,2]

  return d

def show1():
  plt.figure(1)
  plt.clf()
  for i in range(32):
    plt.subplot(6,6,1+i)
    d = get_3kernel(i)
    plt.imshow(d)
    plt.axis('off')
  plt.subplot(6,6,33)
  plt.text(0,0,'MobileNetV2: Layer 1 Kernels (R,G,B channels)')
  plt.axis('off')
  plt.pause(0.001)

def show2():
  plt.figure(2)
  plt.clf()
  for i in range(32):
    plt.subplot(6,6,1+i)
    w = l1[i].transpose([1,2,0])
    wnorm = (w - np.mean(w)) / np.std(w) + 0.5
    wnorm = np.maximum(0, np.minimum(1, wnorm))
    vnorm = w - np.min(w)
    vnorm = vnorm / np.max(vnorm)
    umax = np.max(np.abs(w))
    unorm = (w+umax) / (2*umax)
    plt.imshow(unorm)
    plt.axis('off')
  plt.subplot(6,6,34)
  plt.text(0,0,'MobileNetV2: Layer 1 Kernels (Color)')
  plt.axis('off')
  plt.pause(0.001)

  
show1()
show2()

auto1_maps = get_image_maps('./images/auto1.jpg')
ignatz1_maps = get_image_maps('./images/ignatz1s.jpg')
juices1_maps = get_image_maps('./extra/juices1.jpg')

filename = './images/auto1.jpg'
filename = './images/dog2.jpg'
filename = './extra/juices1.jpg'

raw_image = Image.open(filename)
resized_image = transforms.Resize(256)(raw_image)
cropped_image = transforms.CenterCrop(224)(resized_image)
tensor_image = transforms.ToTensor()(cropped_image)
normalized_image = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensor_image).numpy().transpose([1,2,0])
normalized_display = (normalized_image - normalized_image.min()) / (normalized_image.max() - normalized_image.min())

image_maps = get_image_maps(filename)

def show_image_map(index,maps=image_maps):
  image_map = maps[0,index,:,:]
  plt.figure(3)
  plt.clf()
  plt.subplot(2,2,1,label=1)
  plt.imshow(cropped_image)
  plt.axis('off')
  plt.title('Cropped')
  plt.subplot(2,2,2,label=2)
  plt.imshow(normalized_display)
  plt.axis('off')
  plt.title('Normalized')
  plt.subplot(2,2,3)
  plt.imshow(get_3kernel(index))
  plt.axis('off')
  plt.title('Kernel %d' % index)
  plt.subplot(2,2,4)
  plt.imshow(image_map,cmap='gray')
  plt.axis('off')
  plt.title('Image Map %d' % index)
  plt.pause(1)

def show3():
  for i in range(32):
    show_image_map(i,image_maps)
    input('Press Enter to continue...')

show3()
