import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.models import MobileNetV2
from PIL import Image

model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
model.to('cpu')
model.eval()

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def get_features(filename):
  input_image = Image.open(filename)
  input_tensor = preprocess(input_image)
  input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
  with torch.no_grad():
    x = model.features(input_batch)
  # The next line comes from MobileNetV2.forward()
  classifier_input = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
  return classifier_input

f1 = get_features('images/ignatz1s.jpg')
f3 = get_features('images/ignatz3s.jpg')
dg = get_features('images/dog1.jpg')
a1 = get_features('images/auto1.jpg')
a2 = get_features('images/auto2.jpg')

print('Look at variables f1, f3, dg, a1, a2 for feature vector representations.')

