import pickle
import sys
import numpy as np
import matplotlib.pyplot as plt
import numbers
import matplotlib.pylab as pylab
tsize = 14.5
params = {'legend.fontsize': tsize,
         'axes.labelsize': tsize,
         'axes.titlesize':tsize,}
pylab.rcParams.update(params)

plt.rc('axes', titlesize=tsize)     # fontsize of the axes title
plt.rc('axes', labelsize=tsize)    # fontsize of the x and y labels
plt.rc('legend', fontsize=tsize)    # legend fontsize
plt.rc('text', usetex=True)
plt.rc('legend', frameon=False)
plt.rcParams["font.family"] = "Times New Roman"


#doubleCounts-domArchs.pkl
with open(sys.argv[3],'rb') as f:
    doublecount = pickle.load(f)


#alphabet.pkl
with open(sys.argv[4],'rb') as f:
    alphabet = pickle.load(f)

#doubleEndCount.pkl
with open(sys.argv[5],'rb') as f:
    endc = pickle.load(f)



with open(sys.argv[1],'rb') as f:
    genuine = pickle.load(f)

with open(sys.argv[2],'rb') as f:
    randa = pickle.load(f)

# glen = np.sort([len(da) for da in genuine])
# slen = np.sort([len(da) for da in randa])
# nda = len(genuine)
# thresh = int(0.99*nda)
# gthresh = glen[thresh]
# sthresh = slen[thresh]


# size of vocabulary + 0000000
vsize = len(alphabet)+1

# add pseudocount to all pairs
phi = 0.0009

endc2 = {w[1]:endc[w]+vsize*phi for w in endc}
endc2['0000000'] = endc2['0000000']-phi


with open('startDomain.pkl','rb') as f:
    start = pickle.load(f)

# add phi to all domains as start domains
totalstart = len(genuine)+(vsize-1)*phi


# |alphabet| + 00000
# and subtract 00000,00000 pair
totalpair = (vsize)**2-1


allcount = np.sum(list(doublecount.values())) + phi*totalpair

allprobs = []

def start_prob(w):
    return pair_prob(('0000000',w))

def pair_prob(p):
    if p in doublecount:
        pcount = doublecount[p] + phi
    else:
        pcount = phi
    # total pairs with 1st domain the same
    dx = p[0]
    if dx in endc2:
        dxcount = endc2[dx]
    else:
        dxcount = vsize*phi
    return pcount/dxcount



glens = []
for da in genuine:
    prob = start_prob(da[0])
    glens.append(len(da))
    if len(da) > 1:
        for i in range(1,len(da),1):
            pair = (da[i-1],da[i])
            prob *= pair_prob(pair)
        # and last domain + null pair
    pair = (da[-1],'0000000')
    prob *= pair_prob(pair)
    allprobs.append(prob)
allprobs = np.log(allprobs)
allprobs2 = []
slens = []
for da in randa:#
    prob = start_prob(da[0])
    slens.append(len(da))
    if len(da) > 1:
        for i in range(1,len(da),1):
            pair = (da[i-1],da[i])
            prob *= pair_prob(pair)
        # and last domain + null pair
    pair = (da[-1],'0000000')
    prob *= pair_prob(pair)
    allprobs2.append(prob)
#allprobs2 = np.sort(np.log(allprobs2))[11:]
allprobs2 = np.log(allprobs2)

fig,ax = plt.subplots()
plt.scatter(slens,allprobs2,color='blue',alpha=0.2,label='simulated')
plt.scatter(glens,allprobs,color='red',alpha=0.2,label='genuine')

plt.xlabel('DA length')
plt.ylabel('log likelihood')
#plt.legend()
outspecies = sys.argv[6]
outname = outspecies+'_scatter.ps'
ax.set_rasterized(True)
plt.savefig(outname,dpi=1200, bbox_inches = "tight")
