import os
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt

def load_snaps():
    images = [cv2.cvtColor(cv2.imread('./snaps/'+file), cv2.COLOR_RGB2GRAY) for file in os.listdir('snaps')]
    return images

def load_cubes():
    images = [cv2.imread('./cube_png/'+file) for file in os.listdir('cube_png')]
    return images

snaps = load_snaps()
cubes = load_cubes()

def generate():
    images = []
    labels = []
    nsamples = 5
    for snap in snaps:
        for cube in cubes*nsamples:
            # *** This code only shifts rightward and downward.
            # *** Need to generalize it for shifts into the other quadrants.
            shifted_cube = np.zeros((240,320,3),dtype='uint8')
            shifted_cube[:,:,0] = 255
            shifted_cube[:,:,2] = 255
            shiftx = int(np.random.rand()*80)
            shifty = int(np.random.rand()*120)
            shifted_cube[shiftx:,shifty:] = cube[:240-shiftx, :320-shifty]
            image = snap.copy()
            filter = ~( (shifted_cube[:,:,0]>=254) & (shifted_cube[:,:,1]==0) & (shifted_cube[:,:,2]>=254) )
            gcube = shifted_cube[:,:,0]
            image[filter] = gcube[filter]
            images.append(image)
            labels.append(1)
            print(len(images))
        for i in range(5*nsamples):  # generate a bunch of these to balance the training set
            images.append(snap)
            labels.append(0)
    d = {'train_data': np.array(images),
         'train_labels': np.array(labels)}
    return d

d = generate()
torch.save(d, 'trainset.pt')
print('saved trainset.pt')
