import csv
import statistics

def readData(filename):
    f = open(filename, "r", encoding="utf8")
    # Semester, 3 orig, 3 cleaned, 3 categories
    data = list(csv.reader(f))
    return data

def getFlavorCounts(data, flavor):
    counts = []
    startIndex = data[0].index('#1 category')
    for row in range(1, len(data)):
        entry = data[row]
        cats = entry[startIndex:]
        counts.append(cats.count(flavor))
    return counts

# Given #1 fav, probability of #2 fav being probFlavor
def getCondProb(data, givenFlavor, probFlavor):
    filteredData = []
    startIndex = data[0].index('#1 category')
    for row in range(1, len(data)):
        entry = data[row]
        if entry[startIndex] == givenFlavor:
            filteredData.append(entry[startIndex+1])
    return filteredData.count(probFlavor) / len(filteredData)

data = readData("all-icecream.csv")
counts = getFlavorCounts(data, "fruit")
print(statistics.mean(counts))
prob = getCondProb(data, "chocolate", "cookie")
print(prob)

###

import matplotlib.pyplot as plt
plt.title("Empty")
x = [2, 4, 5, 7, 7, 9]
y = [3, 5, 4, 6, 9, 7]
plt.scatter(x, y)
plt.show()

###

data = readData("all-icecream.csv")
firstCol = data[0].index("#1 category")
numberOneData = []
flavors = []
for i in range(1, len(data)):
    flavor = data[i][firstCol]
    numberOneData.append(flavor)
    if flavor not in flavors: # haven't seen this one yet
        flavors.append(flavor)

counts = []
for flavor in flavors:
    counts.append(numberOneData.count(flavor)) # get each flavor's count

plt.pie(counts, labels=flavors)
plt.show()