import argparse, sys, unicodedata, os, codecs, collections, time
import regex as re
import ycutils.tokenize, ycutils.tfidf, ycutils.bagofwords, ycutils.corpus, parse_sage

parser = argparse.ArgumentParser(description='Creates lag information/model terms from text corpus and SAGE file.')
parser.add_argument('vocab_file', type=file, help='Vocabulary file.')
parser.add_argument('input_dir', type=str, help='Input directory containing tokenized text files..')
parser.add_argument('output_dir', type=str, help='Output directory to store lag information.')
A = parser.parse_args()

def do_document(i, input_f, output_f):
  # print 'Writing {}'.format(output_f.name)

  cur_lag = 0
  output_f.write('__START_OF_SPEECH__\t0\n')
  prev_gram = ''
  speech_length = 0
  for line in input_f:
    tokens = line.strip().split()
    if not tokens: continue

    N = len(tokens)
    speech_length += N
    ngram_list = []
    for i in xrange(N):
      if i + 2 < N:
        ngram_list.append(('_'.join(tokens[i:i+2]), '_'.join(tokens[i:i+3])))
      elif i + 1 < N:
        ngram_list.append(('_'.join(tokens[i:i+2]), None))
    #end for

    for bigram, trigram in ngram_list:
      if trigram and trigram in vocab:
        output_f.write('{}\t{}\n'.format(trigram, cur_lag))
        prev_gram = trigram
        cur_lag = -1
      #end if
      elif bigram not in prev_gram and bigram in vocab:
        output_f.write('{}\t{}\n'.format(bigram, cur_lag))
        prev_gram = bigram
        cur_lag = -1
      #end if

      cur_lag += 1
    #end for
  #end for

  output_f.write('__END_OF_SPEECH__\t{}\n'.format(cur_lag))
  output_f.write('__SPEECH_LENGTH__\t{}\n'.format(speech_length))
  output_f.close()
#end def

input_dir = A.input_dir
output_dir = A.output_dir
if not os.path.exists(output_dir): os.makedirs(output_dir, 0755)

vocab = ycutils.corpus.CorpusVocabulary(from_file=A.vocab_file)

print >>sys.stderr, 'Processing files from {} and saving it to {}'.format(input_dir, output_dir)
i = 1
for (dirpath, dirnames, filenames) in os.walk(input_dir):
  filenames.sort()
  dirnames.sort()
  for fname in filenames:
    if not fname.endswith('.txt'): continue

    src_path = os.path.join(dirpath, fname)
    rel_path = os.path.relpath(src_path, input_dir)
    new_dir = os.path.join(output_dir, os.path.dirname(rel_path))
    dst_path = os.path.join(new_dir, fname)[:-4] + '.lag'

    if not os.path.exists(new_dir): os.makedirs(new_dir, 0755)
    
    print >>sys.stderr, '  {} -> {}'.format(src_path, dst_path)
    do_document(i, input_f=codecs.open(src_path, 'r', 'utf-8'), output_f=codecs.open(dst_path, 'w', 'utf-8'))
    i += 1
  #end for
#end for
