"""
Learning a color class using a Support Vector Machine.

Mouse click to select training points: left click for positive
examples, right click for negative.  You can click in either
window.  Yellow and cyan lines in the source window mark
positive and negative examples, respectively.
"""

import cv2
import numpy as np
from matplotlib import pyplot as plt
from sklearn import svm

classifier = svm.NuSVC()

image = cv2.cvtColor(cv2.imread('sample_image.jpg'), cv2.COLOR_BGR2RGB)
image_scaled = np.array(image, dtype=np.float64)/256
pixels = np.reshape(np.array(image), (-1,3))
pixels_scaled = pixels.astype(np.float64)/256

fig1 = plt.figure(1)
plt.imshow(image)

fig2 = plt.figure(2)

data = [[], [], [], []]

def onclick(event):

    col = int(round(event.xdata))
    row = int(round(event.ydata))
    pixel = image_scaled[row,col]
    data[event.button].append(pixel.tolist())
    if event.button == 1:
        line_color = (255,255,0)
    else:
        line_color = (0,255,255)
    for i in range(0,3):
        image[row,col+i] = line_color
    plt.figure(1)
    plt.imshow(image)
    plt.show(block=False)
    retrain()

def retrain():
    global X, Y, result
    if (len(data[1]) == 0) or (len(data[3]) == 0):
        return
    X = np.array(data[1] + data[3], dtype=np.float64)
    Y = np.array(([1]*len(data[1])) + ([0]*len(data[3])), dtype=np.float64)
    classifier.fit(X, Y)
    prediction = classifier.predict(pixels_scaled)
    result = np.zeros((len(prediction),3),dtype=np.uint8) + 255
    for i in range(len(prediction)):
        if prediction[i] > 0:
            result[i] = pixels[i]
    result = np.reshape(result, (240,320,3))
    plt.figure(2)
    plt.imshow(result)
    plt.show(block=False)
    print('%d training points; %d support vectors' %
          (len(data[1])+len(data[3]), len(classifier.support_)))

fig1.canvas.mpl_connect('button_press_event', onclick)
fig2.canvas.mpl_connect('button_press_event', onclick)

plt.show(block=False)
