import sys
import os
import argparse
import pickle
import time
import multiprocessing as multiproc

sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/lib")
from Parsing import Superfamily
from Parsing import NULLDOM
from Processing import Doubles, Triples, Domains, countDomainOccurrencesThread
import Util


def parseArgs():
    parser = argparse.ArgumentParser(description="Read species files and parse the data into various "+
                                     "formats convenient for analysis and domain architecture simulation.")
    parser.add_argument('--procs', action='store', type=int, dest='procs', default=1,
                        help="Number of processes to use, default=1")
    parser.add_argument('--universal', '-u', action='store_true', dest="universal", default=True,
                        help="Write out the universal file.")
    parser.add_argument('--singles', '-s', action='store_true', dest="singles", default=True,
                        help="Write out the singles count file.")
    parser.add_argument('--doubles', '-d', action='store_true', dest="doubles", default=True,
                        help="Write out the doubles count file.")
    parser.add_argument('--triples', '-t', action='store_true', dest="triples", default=True,
                        help="Write out the triples count file.")
    parser.add_argument('--alphabet', '-a', action='store_true', dest='alphabet', default=True,
                        help="Write out the domain alphabet file.")
    parser.add_argument('--frequency', '-f', action='store_true', dest="frequency", default=True,
                        help="Write out the frequency stats of each domain.")
    parser.add_argument('--rawlength', '-l', action='store_true', dest="rawlength", default=True,
                        help="Write out the full lengths and compact lengths of the domain architectures from the dataset.")
    parser.add_argument('--domArchs', '-c', action='store_true', dest="domArchs", default=True,
                        help="Write out all the domain architectures from the dataset.")
    parser.add_argument('--id', action='store', type=str, dest='ID', default='sf', choices=['sf','fam'],
                        help="Which domain identifier to use (default=sf)")
    parser.add_argument('--unique', action='store', type=str, dest='uniqueMethod', default='all',
                        choices=['all','species','none'], help="Select a method to consolidate domain architectures: across all species, per species, or none (default=all)")
    parser.add_argument('inputPath', action='store', type=str,
                        help="Path to the dataset, with 'species' subfolder")
    return parser.parse_args()


def sanityChecks(arguments):
    if not os.path.exists(arguments.inputPath+'/species'):
        print("'species' folder not present at '{}'!".format(arguments.inputPath))
        print("Exiting...")
        sys.exit(0)

    arguments.output = arguments.inputPath+'/formatted'
    if not os.path.exists(arguments.output):
        os.mkdir(arguments.output)

    return arguments


def getDomainOccurrenceDict(numProcs, uniqueDomains, domainTuples):
    # multiprocessing manager queues
    manager = multiproc.Manager()
    inQue = manager.Queue()
    outQue = manager.Queue()
    progQue = manager.Queue()

    # Fill input queue with domains
    print("Filling input queue with {} domains...".format(len(uniqueDomains)))
    for d in uniqueDomains:
        inQue.put(d)
    # Plant poison pills
    for i in range(numProcs):
        inQue.put(None)
    # Create, start counting processes
    procs = []
    for i in range(numProcs):
        print("Starting counting process {} of {}".format(i+1, numProcs))
        p = multiproc.Process(target=countDomainOccurrencesThread, args=(inQue, outQue, progQue, domainTuples))
        p.start()
        procs.append(p)

    # Track progress
    prog = Util.Progress(len(uniqueDomains))
    prog.start()
    while progQue.qsize() < len(uniqueDomains):
        prog.update(progQue.qsize())
        time.sleep(1)
    # Wait for all processes to finish and deposit their results
    for p in procs:
        p.join()
    prog.finish()

    # Combine results
    print("Combining domain occurrence results...")
    domainOccurrence = {}
    while not outQue.empty():
        domainOccurrence.update(outQue.get())

    return domainOccurrence


def countDomainTupleOccurrences(domainTuples):
    """
    Count occurrences of a domain tuple.
    """
    domTupCount = {}
    for dt in domainTuples:
        c = domTupCount.setdefault(dt, 0)
        domTupCount[dt] = c+1
    return domTupCount

#=============================================
#################### Main ####################
#=============================================

if __name__ == "__main__":
    args = parseArgs()

    print("Sanity checks...")
    args = sanityChecks(args)

    # Open and read all species files
    print("Reading sequence data from {}".format(args.inputPath))
    results = {}
    domArchs = []
    for filename in os.listdir(args.inputPath+'/species'):
        if args.uniqueMethod == 'species':
            # getting sequences from one species at a time, partialRes for temp storage
            partialRes = Superfamily.insertEnds(Superfamily.parseRawTextFile(args.inputPath+'/species/'+filename, argID=args.ID))
            results.update(partialRes)
            # getDomainArchitecture returns two objects, partialSequences is unused
            partial_domArchs, partialSequences = Superfamily.getDomainArchitectures(partialRes)
            domArchs += partial_domArchs  # grouping partials under domArchs
        elif args.uniqueMethod == 'all' or args.uniqueMethod == 'none':
            results.update(Superfamily.insertEnds(Superfamily.parseRawTextFile(args.inputPath+'/species/'+filename, argID=args.ID)))
            
    # Pre-processing (for code layout sanity)
    print("Pre-processing...")
    if args.uniqueMethod == 'species':
        # domArchDeprecated is unused, placeholder for the return statement.
        domArchsDeprecated, sequences = Superfamily.getDomainArchitectures(results)
    elif args.uniqueMethod == 'all':
        domArchs, sequences = Superfamily.getDomainArchitectures(results)
    elif args.uniqueMethod == 'none':
        domArchs, sequences = Superfamily.getDomainArchitecturesAll(results)
        
    domAlphabet = Domains.getDomainAlphabet(domArchs)

    # If specified, write domain architectures file
    if args.domArchs:
        print("There are {} distinct domain architectures".format(len(domArchs)))
        print("Writing raw domain architectures file to {}".format(args.output))
        rawDomainArchs = []
        for rawDA in domArchs:
            rawDomainArchs.append(rawDA[1:-1])
        with open(args.output+'/domainArchs.pkl', 'w+b') as daOutFile:
            pickle.dump(rawDomainArchs, daOutFile)
        outPutFile = open(args.output+'/domainArchs.txt','w')
        for i in range(len(rawDomainArchs)):
            outPutFile.write(str(rawDomainArchs[i])+'\n')
        outPutFile.close()
        outPutFile = open(args.output+'/sequences.txt','w')
        for seq in sequences:
            outPutFile.write(str(seq)+'\t')
            for i in range(1,len(sequences[seq])-1): # Ignore NULLDOM
                outPutFile.write(str(sequences[seq][i][0])+'\t')
            outPutFile.write('\n')
        outPutFile.close()

    # If specified, write raw length and raw compact length files
    if args.rawlength:
        print("Writing raw domain architecture length and compact length files to {}".format(args.output))
        outPutFile = open(args.output+'/rawDomainArchLength.txt','w')
        outPutFile2 = open(args.output+'/rawDomainArchCompactLength.txt','w')
        for i in range(len(domArchs)):
            DAlength = len(domArchs[i]) - 2
            outPutFile.write(str(DAlength)+'\n')
            for j in range(2,len(domArchs[i])-1):
                if domArchs[i][j] == domArchs[i][j-1]:
                    DAlength = DAlength - 1
            outPutFile2.write(str(DAlength)+'\n')
        outPutFile.close()
        outPutFile2.close()

    # If specified, write universal file
    if args.universal:
        print("Writing universal file to {}".format(args.output))
        Superfamily.writeUniversalFile(args.output+'/universal-'+args.ID+'.txt', results)

    # If specified, write the domain alphabet file
    if args.alphabet:
        print("Writing out the domain alphabet to {}".format(args.output))
        with open(args.output+"/alphabet.pkl", 'wb') as alphOutFile:
            pickle.dump(domAlphabet, alphOutFile)
        outPutFile=open(args.output+'/alphabet.txt','w')
        for key in domAlphabet:
            outPutFile.write(str(key)+'\n')
        outPutFile.close()

    # If specified, write file of domain counts (singles)
    if args.singles:
        print("Processing single counts...")
        # Instantiate generator for domains
        domains = Domains.iterate(domArchs)
        # Use subroutine to count occurrences
        singlesResults = countDomainTupleOccurrences(domains)
        singlesResults['sum'] = sum([len(da)-2 for da in domArchs])
        # Save results
        print("There are {} domains including NULLDOM".format(len(singlesResults)-1))
        print("Writing domain counts to {}".format(args.output))
        with open(args.output+'/domainCounts-domArchs.pkl', 'w+b') as domOutFile:
            pickle.dump(singlesResults, domOutFile)
        outPutFile=open(args.output+'/domainCounts-domArchs.txt','w')
        for key in singlesResults:
            if key == 'sum': #Skip the 'sum' entry and empty domain
                continue
            outPutFile.write(str(key)+'\t'+str(singlesResults[key])+'\n')
        outPutFile.close()

    # If specified, write file of doubles and its domain counts
    if args.doubles:
        print("Processing double counts...")
        # Get doubles from domain architectures
        dbls = Doubles.getDoubles(domArchs)
        # Use subroutine to count occurrences
        dblResults = countDomainTupleOccurrences(dbls)
        dblResults['sum'] = sum([v for v in dblResults.values()])
        # Save results
        print("There are {} domain pairs".format(len(dblResults)-1))
        print("Writing pair counts to {}".format(args.output))
        with open(args.output+'/doubleCounts-domArchs.pkl', 'w+b') as dblOutFile:
            pickle.dump(dblResults, dblOutFile)
        outPutFile=open(args.output+'/doubleCounts-domArchs.txt','w')
        for key in dblResults:
            if key=='sum':
                continue
            outPutFile.write(str(key)+'\t'+str(dblResults[key])+'\n')
        outPutFile.close()

    # If specified, write file of triples and its domain counts
    if args.triples:
        print("Processing triple counts...")
        # Get triples from domain architectures
        trpls = Triples.getTriples(domArchs)
        # Use subroutine to count occurrences
        trplResults = countDomainTupleOccurrences(trpls)
        trplResults['sum'] = sum([v for v in trplResults.values()])
        # Save results
        print("There are {} domain triples".format(len(trplResults)-1))
        print("Writing triple counts to {}".format(args.output))
        with open(args.output+'/tripleCounts-domArchs.pkl', 'w+b') as trplOutFile:
            pickle.dump(trplResults, trplOutFile)
        outPutFile=open(args.output+'/tripleCounts-domArchs.txt','w')
        for key in trplResults:
            if key=='sum':
                continue
            outPutFile.write(str(key)+'\t'+str(trplResults[key])+'\n')
        outPutFile.close()

    # If specified, write file of domain frequency numbers
    if args.frequency:
        print("Processing domain frequency...")
        domFreq = Domains.getDomainFrequency(domArchs)
        with open(args.output+'/domainFrequency-domArchs.pkl', 'w+b') as dfOutFile:
            pickle.dump(domFreq, dfOutFile)

    print("Done")
