""" Week12-1 Notes """

import csv
# Header:
# Semester,
# #1 Orig,     #2 Orig,     #3 Orig,
# #1 Cleaned,  #2 Cleaned,  #3 Cleaned,
# #1 Category, #2 Category, #3 Category
f = open("all-icecream.csv", "r")
orig = list(csv.reader(f))
allData = []
for line in orig:
    if line[0] != "Semester": # skip header
        # only include coded classes
        categories = line[7:10]
        allData.append([line[0]] +
                       categories)
f.close()

def mapToIndex(flavor):
    allFlavors = [ "chocolate", "coffee/tea",
                   "cookie",    "fruit",
                   "vanilla",   "other" ]
    return allFlavors.index(flavor)

dataInput = []
dataLabels = []
testInput = []
testLabels = []
for point in allData:
    semester = point[0]
    inputValues = [ mapToIndex(point[1]),
                    mapToIndex(point[2])]
    outputValue = mapToIndex(point[3])
    if semester == "S22":
        testInput.append(inputValues)
        testLabels.append(outputValue)
    else:
        dataInput.append(inputValues)
        dataLabels.append(outputValue)

from sklearn.naive_bayes import CategoricalNB

model = CategoricalNB()
model.fit(dataInput, dataLabels)

print(model.predict([ [0, 2] ])) # [0]

accuracy = model.score(testInput, testLabels)
print(accuracy) # 0.26174496644295303


