#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
prom.py:
promiscuity analysis 
Inputs:
    expPath: path to the relevant experiment folder 
    -r --real: .pkl file of real domain architectures (optional, defaults to  
         <expPath>/formatted/domainArchs.pkl)
    -s  --simulated: .pkl file of simulated domain architectures 
         (optional, defaults to <expPath>/output100super/final_DA_list.pkl)
    -n --specName: shorthand name for the relevant species for file naming purposes
        (optional, defaults to expPath)
Outputs:
     -compare_<specName>_co.txt: text file with domain IDs in descending promiscuity order
         (co occurrence count)for both architectures, correlations, number 
         intersecting in top 100, side-by-side of intersecting domains, their 
         respective promiscuities, and a description of the domain 
    -comapre_<specName>_bi.txt: same but with bigram count 
    -compare_<specName>_basu.txt: same but with Basu et al. promiscuity 
         (see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2259109/ )
    -compare_<specName>_basuco.txt: same but with Basu et al. promiscuity 
         using co-occurrence frequency 
    -compare_<specName>_co.png: plot of correlation and fraction of domains 
         shared in top n for co-occurrence 
    -compare_<specName>_bi.png: ''
    -compare_<specName>_basu.png: ''
    -compare_<specName>_basuco.png: '' 
   
example:
    in DomainEvoSimulator run:
        python3 prom.py ape_test -r ape_genuine.pkl -s ape_simulated.pkl -n ape
        python3 prom.py fly_test -r fly_genuine.pkl -s fly_simulated.pkl -n fly
        python3 prom.py fish_test -r fish_genuine.pkl -s fish_simulated.pkl -n fish
        python3 prom.py cnidaria_test -r cnidaria_genuine.pkl -s cnidaria_simulated.pkl -n cnidaria
    
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import pearsonr 
import argparse 
import sys
#from scipy.stats import mannwhitneyu
#from scipy.stats import wilcoxon
import os

"""
accepts a set of architectures and returns a sorted index of promiscuity defined by the
number of unique domains with which each domain co-occurs within the same architecture 

inputs:
    -archset: 2d numpy ragged array of arrays of domain IDs (strings)
outputs:
    a pandas series of co-occurrence for each domain defined above 
"""
def co_occurrence(archset):
    codict={}
    out={}
    for arch in archset:
        for d in arch:
                try:
                    codict[d]+=list(arch)
                except:
                    codict[d]=list(arch)
    for d in codict.keys():
        codict[d]=set(codict[d])
        out[d]=len(codict[d])
    
    return pd.Series(out)
"""
accepts a set of architectures and returns a sorted index of promiscuity defined by the number of
unique domains to which each domain is adjacent within the same architecture 

inputs:
    -archset: 2d numpy ragged array of arrays of domain IDs (strings)
outputs:
    a pandas series of promiscuity for each domain defined above 
"""
def neighbor_occurrence(archset):
    codict={}
    out={}
    for arch in archset:
        if len(arch)<2:
            continue
        else:
            for i in range(len(arch)):
                nbrs=[]
                #get neighbors for first and last
                if i==0:
                    nbrs.append(arch[i+1])
                elif i==len(arch)-1:
                    nbrs.append(arch[i-1])
                else:
                    nbrs.append(arch[i-1])
                    nbrs.append(arch[i+1])
                try:
                    codict[arch[i]]+=nbrs
                except:
                    codict[arch[i]]=nbrs
    
            
    for d in codict.keys():
        codict[d]=set(codict[d])
        out[d]=len(codict[d])
    return pd.Series(out)
    
"""
accepts a set of architectures and returns a sorted index of promiscuity defined as it is
in Basu et al. 2008:
   \[ prom_i=\beta_i \times \log(\beta_i/f_i)\]
    where \[\beta_i=\frac{T_i}{\frac12 \sum_i^t T_j }\]
    
    "where t is the number of distinct domain types
    ,Ti is the number of unique domain neighbors of domain i,
    and fi is the frequency of domain i in the genome, 
    calculated as ni/N, where ni is the total count of domain i, 
    and N is the total number of domains detected in the given genome:"
    
    \[N=\sum_i^t n_i]
returns a pandas series with the promiscuity for each domain 
"""
def basu_prom(archset):
    bigram=get_bigram(archset)
    freq=get_freq(archset)
    #get only indices in freq that have multidomain proteins
    multifreq= freq[pd.Index(set(bigram.index) & set(freq.index))]
    return (np.log(bigram)-np.log(multifreq))*bigram
"""
modified form of basu that uses co-occurrence frequency instead of bigram 
"""
def basu_prom_co(archset):
    cogram= get_coRate(archset)
    freq=get_freq(archset)
    multifreq= freq[pd.Index(set(cogram.index) & set(freq.index))]
    return (np.log(cogram)-np.log(multifreq))*cogram 
    
"""
accepts a set of architectures and returns a dictionary of domains with overall frequency as
the value
inputs:
    -archset: 2d numpy ragged array of arrays of domain IDs (strings)
outputs:
    -a pandas series of domains and frequencies defined by Basu et al. 
    
"""
def get_freq(archset):
    
    flat= pd.Series([item for sublist in archset for item in sublist])
    flat=flat.value_counts(sort=False)
    fr= flat/np.sum(flat)
    return fr
"""
accepts a set of architectures and returns a dictionary of domains with bigram frequency as
the values
inputs:
    -archset: 2d numpy ragged array of arrays of domain IDs (strings)
outputs:
    -a pandas series of domains and bigram frequencies defined by Basu et al. 
"""
def get_bigram(archset):
    count= neighbor_occurrence(archset)
    bigram= 2*count/np.sum(count)
    return bigram 
def get_coRate(archset):
    count=co_occurrence(archset)
    return count/np.sum(count)
    

"""
accepts a pkl file and returns a 2d numpy ragged array of arrays of domain IDs (strings)
inputs:
    -path2pkl: the file path (str)
outputs:
    2d numpy ragged array of arrays of domain IDs (strings)
"""
def getArchs(path2pkl):
    try:  
        assert(path2pkl.split('.')[-1]=='pkl')
    except: 
        print(f'Invalid Argument: {path2pkl} is not a .pkl file')
        print('exiting...')
        sys.exit()
    out=np.load(path2pkl,allow_pickle=True)
    return out 
"""
given a series of promiscuities, output plot of sorted domains (descending and list of the top n
inputs:
    -prom_ser: pandas series of promiscuity for each domain 
   
outputs: 
    -prom_sorted: sorted series of promiscuities (descending)
    
"""
def rank(prom_ser):
    #prom_sorted=pd.Series(prom_dict)
    prom_sorted=prom_ser.sort_values(ascending=False)
    
    return prom_sorted
def plot_result(domainList):
    plt.plot(domainList)
    plt.show()

def get_promiscuity(archset,tag='basu'):
    if tag=='co':
        return get_coRate(archset)
    elif tag=='bi':
        return get_bigram(archset)
    elif tag=='basuco':
        return basu_prom_co(archset)
    else:
        return basu_prom(archset)
"""
analyze promiscuity for 2 architectures, output comparison statistics 
inputs:
    -arch1: 2d list of architectures 
    -arch2: ''
    -speciesName: string to label output files 
outputs:
     -compare_<speciesName>_co.txt: text file with domain IDs in descending promiscuity (co occurrence count)
                    for both architectures, correlations, number intersecting in top 100, side-by-side
                    of intersecting domains, their respective promiscuities, and a description 
                    of the domain 
    -comapre_<speciesName>_bi.txt: same but with bigram count 
    -compare_<speciesName>_basu.txt: same but with basu promiscuity 
    -compare_<speciesName>_basuco.txt: same but with basu promiscuity using co-occurrence frequency 
    -compare_<speciesName>_co.png: plot of correlation and fraction of domains shared in top n for co-occurrence 
    -compare_<speciesName>_bi.png: ''
    -compare_<speciesName>_basu.png: ''
    -compare_<speciesName>_basuco.png: '' 
                
"""
def compare(arch1,arch2,speciesPath,speciesName):
    
    tags=['co','bi','basu','basuco']
    tagdict={'co': 'co-occurrence count','bi': 'bigram count','basu': 'Basu et al. promiscuity',
             'basuco': 'Basu et al. promiscuity with co-occurrence'}
    print('\n\nstarting output... \n')
    for t in tags:
       #get promiscuities
       a1=get_promiscuity(arch1,t)
       a2=get_promiscuity(arch2,t)
       rawpath1='promisc_{arch1_name}_{t}.csv'
       rawpath2='promisc_{arch2_name}_{t}.csv'
       a1.to_csv(rawpath1)
       a2.to_csv(rawpath2)
       #get rankingfs
       rank1= rank(a1)
       rank2= rank(a2)
       topn=100
       
       #common_idx_topn=pd.Index(set(rank1[:topn].index) & set(rank2[:topn].index))
       common_idx_topn=pd.Index(set(rank1[:topn].index) & set(rank2[:topn].index))
       common_idx_all= list(set(rank1.index) & set(rank2.index))
       
       shared1=rank1[common_idx_topn].sort_values(ascending=False)
       shared2=rank2[common_idx_topn].sort_values(ascending=False)
       
       
       #write to text files
       comparepath=f'compare_{speciesName}_{t}'
       with open(comparepath+'.txt','w') as f:
           f.write(f'#{comparepath +".txt"}: output of prom.py \n')
           f.write('='*70 +'\n')
           f.write(f'top {topn} domains and promiscuity of genuine {speciesName} using {tagdict[t]} \n\n')
           f.write(rank1[:topn].to_string(header=['promiscuity'])+'\n\n')
           f.write(f'top {topn} domains and promiscuity of simulated {speciesName} using {tagdict[t]} \n\n')
           f.write(rank2[:topn].to_string(header=['promiscuity'])+'\n\n')
           #write shared set in top 20
           f.write(f'shared set in top {topn}:\n\n')
           f.write(f'{len(shared1)/topn*100:.1f} % agreement\n\n')
           f.write('correlations:\n')

          
           
           r_value_all, p_value_all = pearsonr(rank1[common_idx_all],rank2[common_idx_all])
           r_value_topn, p_value_topn = pearsonr(shared1[common_idx_topn],shared2[common_idx_topn])
           
           s=f"""Pearson's r,p for {tagdict[t]}:\nall data:\n{r_value_all, p_value_all}\nonly for top {topn}:\n{r_value_topn, p_value_topn}\n"""
           f.write(s)
           f.write('='*70 +'\n')
           #f.write('SSF\tReal Promisc.\tSim Promisc\tDesc\n')
           path2desc=speciesPath
           descs={}
           for d in shared1.index:
               descs[d]=getDesc(d,path2desc)
           descs=pd.Series(descs)
           f.write(pd.concat((shared1,shared2,descs),axis=1).to_string(header=['Real', 'Sim','Description']))
           f.close()
           
           print(s)
           x,y,p=agreement(rank1,rank2)
           fig, ax1 = plt.subplots()
           ax2 = ax1.twinx()
           ax1.plot(x, y, 'g-')
           ax2.plot(x, p, 'b-')
           ax1.set_xlabel('domains compared (ranked by promiscuity)')
           ax1.set_ylabel('fraction shared', color='g')
           ax2.set_ylabel("Pearson Correlation", color='b')
           ax1.set_ylim(0,1)
           ax2.set_ylim(0,1)
           
           
           plt.title(f'agreement using {tagdict[t]}')
           x=np.array(x)
           
           plt.show()
           #plt.plot(x,baseline(x,x[-1]))
           #plt.show()
"""
get vector of % agreement as a function of top n where n is the alphabet length 
for the smaller domain (like qq plot)
inputs:
    -ranked1: sorted pandas series of promiscuity for each domain 
    -ranked2: ''
outputs: 
    -vector of agreement 
"""
def agreement(ranked1,ranked2):    
    n=min(len(ranked1),len(ranked2))
    print(f'n={n}')
    outx=list(range(10,n+5,5))
    outy=[]
    outpear=[]
    for limit in outx:
        #get %agreement
        pct=len(set(ranked1[:limit].index) & set(ranked2[:limit].index))/limit
        outy.append(pct)
        pear=ranked1[:limit].corr(ranked2[:limit])
        outpear.append(pear)
    return outx,outy,outpear
          
""" def baseline(x,n):
    out=np.zeros(len(x))
    
    for k,index in zip(x,range(len(x))):
        termlist=np.zeros(k)
        
        s=0
        for i in range(1,k+1):
            termlist[i-1]=binom(k,i)/binom(n,i)
        for i in range(1,k):
            s+=(termlist[i-1]-np.sum(termlist[i:]))*i/k
        s+=termlist[-1]
        out[index]=s
    return out """
            
"""
function that takes SSF ID and returns the description of the domain
@author: trachman
inputs: 
    -ID: ID name
    -localpath: local path to species folder in DomainEvoSimulator 
outputs:
    -string corresponding to description of domain with SSF ID <ID> (empty if not found)
"""



def getDesc(ID,localpath):
    try:
        directory = os.listdir(localpath)
    except:
        return '' 
    
    for filename in directory:
        with open(f'{localpath}/{filename}','r') as f:
            for line in f.readlines():
                if ID in line:
                    return line.strip().split('\t')[-2]
    return ''



def main():
    parser=argparse.ArgumentParser(description=print(__doc__))
    parser.add_argument('expPath',type=str,help="path to experiment folder")
    parser.add_argument('-r','--real',type=str,help="optional, path to .pkl with genuine architecture (defaults to domainArchs.pkl in <expPath>/formatted)") 
    parser.add_argument('-s','--simulated',type=str,help="optional, path to .pkl with simulated architecture (defaults to <expPath>/output100super/final_DA_list.pkl)")
    parser.add_argument('-n','--specName',type=str,help="""optional, shorthand species name to use for file outputs (defaults to expPath)""")
    
    args=parser.parse_args()
    
    if args.real:
        real=args.real
    else:
        real=f'{args.expPath}/formatted/domainArchs.pkl'
    if args.simulated:
        sim=args.simulated
    else:
        sim=f'{args.expPath}/output100super/final_DA_list.pkl'
    
    
    spec_path= args.expPath+'/species'
    if args.specName:
        spec_name=args.specName
    else:
        spec_name=args.expPath
    real=getArchs(real)
    sim=getArchs(sim)
    compare(real,sim,spec_path,spec_name)
     
if __name__=='__main__':
    main()
   

    
