/*

Copyright (C) 1998,1999,2000,2001  Franz Josef Och (RWTH Aachen - Lehrstuhl fuer Informatik VI)

This file is part of GIZA++ ( extension of GIZA ).

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, 
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, 
USA.

*/
#include "HMMTables.h"
#include <fstream>
#include <sstream>
#include "Globals.h"
#include "Parameter.h"

template<class CLS,class MAPPERCLASSTOSTRING>
void HMMTables<CLS,MAPPERCLASSTOSTRING>::writeJumps(ostream&out) const
{
    double ssum=0.0;
    for(typename map<AlDeps<CLS>,FlexArray<double> >::const_iterator i=alProb.begin();i!=alProb.end();++i){
        double sum=0.0;
        out << "\n\nDistribution for: ";
        printAlDeps(out,i->first,*mapper1,*mapper2);
        out << ' ';
        for(int a=i->second.low();a<=i->second.high();++a)
            if( i->second[a] ) {
                out << a << ':' << i->second[a] << ';' << ' ';
                sum+=i->second[a];
            }
        out << '\n' << '\n';
        out << "SUM: " << sum << '\n';
        ssum+=sum;
    }
    out << "FULL-SUM: " << ssum << '\n';
}
template<class CLS,class MAPPERCLASSTOSTRING>
void HMMTables<CLS,MAPPERCLASSTOSTRING>::readJumps(istream&)
{
}
template<class CLS,class MAPPERCLASSTOSTRING>
double HMMTables<CLS,MAPPERCLASSTOSTRING>::getAlProb(int istrich,int k,int sentLength,int J,CLS w1,CLS w2,int j,int iter) const{
    massert(k<sentLength&&k>=0);
    massert(istrich<sentLength&&istrich>=-1);
    int pos=istrich-k;
    switch(PredictionInAlignments){
        case 0: pos=istrich-k; break;
        case 1: pos=k; break;
        case 2: 
            pos=(k*J-j*sentLength);
        if( pos>0 ) pos+=J/2; else pos-=J/2;
        pos/=J;
        break;
        default:abort();
    }
    typename map<AlDeps<CLS>,FlexArray<double> >::const_iterator p=alProb.find(AlDeps<CLS>(sentLength,istrich,j,w1,w2));
    if( p!=alProb.end() ){
        return (p->second)[pos];
    }else{
        if( iter>0&&iter<5000 )
            cout << "WARNING: Not found: " << ' ' << J << ' ' << sentLength << '\n';;
        return 1.0/(2*sentLength-1); 
    }
}

template<class CLS,class MAPPERCLASSTOSTRING>
void HMMTables<CLS,MAPPERCLASSTOSTRING>::addAlCount(int istrich,int k,int sentLength,int J,CLS w1,CLS w2,int j,double value,double valuePredicted){
    int pos=istrich-k;
    switch(PredictionInAlignments){
        case 0: pos=istrich-k; break;
        case 1: pos=k; break;
        case 2:
            pos=(k*J-j*sentLength);
        if( pos>0 ) pos+=J/2; else pos-=J/2;
        pos/=J;
        break;
        default:abort();
    }
    
    AlDeps<CLS> deps(AlDeps<CLS>(sentLength,istrich,j,w1,w2));
    
    {
        typename map<AlDeps<CLS>,FlexArray<double> >::iterator p=alProb.find(deps);
        if( p==alProb.end() ) {
            if( (CompareAlDeps&1)==0 )
                p=alProb.insert(make_pair(deps,FlexArray<double> (-MAX_SENTENCE_LENGTH,MAX_SENTENCE_LENGTH,0.0))).first;
            else
                p=alProb.insert(make_pair(deps,FlexArray<double> (-sentLength,sentLength,0.0))).first;
        }
        p->second[pos]+=value;
    }
    
    if( valuePredicted ){
        typename map<AlDeps<CLS>,FlexArray<double> >::iterator p=alProbPredicted.find(deps);
        if( p==alProbPredicted.end() ) {
            if( (CompareAlDeps&1)==0 )
                p=alProbPredicted.insert(make_pair(deps,FlexArray<double> (-MAX_SENTENCE_LENGTH,MAX_SENTENCE_LENGTH,0.0))).first;
            else
                p=alProbPredicted.insert(make_pair(deps,FlexArray<double> (-sentLength,sentLength,0.0))).first;
        }
        p->second[pos]+=valuePredicted;
    }
}

template<class CLS,class MAPPERCLASSTOSTRING>
Array<double>&HMMTables<CLS,MAPPERCLASSTOSTRING>::doGetAlphaInit(int I)
{
  if( !init_alpha.count(I) )
    init_alpha[I]=Array<double>(I,0);
  return init_alpha[I];
}
template<class CLS,class MAPPERCLASSTOSTRING>
Array<double>&HMMTables<CLS,MAPPERCLASSTOSTRING>::doGetBetaInit(int I)
{
  if( !init_beta.count(I) )
    init_beta[I]=Array<double>(I,0);
  return init_beta[I];
}

template<class CLS,class MAPPERCLASSTOSTRING>
bool HMMTables<CLS,MAPPERCLASSTOSTRING>::getAlphaInit(int I,Array<double>&x)const
{
  hash_map<int,Array<double> >::const_iterator i=init_alpha.find(I);
  if( i==init_alpha.end() )
    return 0;
  else
    {
      x=i->second;
      for(unsigned int j=x.size()/2+1;j<x.size();++j) // only first empty word can be chosen
	x[j]=0;
      return 1;
    }
}
template<class CLS,class MAPPERCLASSTOSTRING>
bool HMMTables<CLS,MAPPERCLASSTOSTRING>::getBetaInit(int I,Array<double>&x)const{
    hash_map<int,Array<double> >::const_iterator i=init_beta.find(I);
    if( i==init_beta.end() )
        return 0;
    else{
        x=i->second;
        return 1;
    }
}


/***********************************
By Edward Gao
************************************/

template<class CLS,class MAPPERCLASSTOSTRING>
bool HMMTables<CLS,MAPPERCLASSTOSTRING>::writeJumps(const char* alprob, const char* alpredict, const char* alpha, const char* beta ) const
{
    if(alprob){
        ofstream ofs(alprob);
        if(!ofs.is_open()){
            cerr << "Cannot open file for HMM output "  << alprob << endl;
            return false;
        }
        cerr << "Dumping HMM table to " << alprob << endl;

        for(typename map<AlDeps<CLS>,FlexArray<double> >::const_iterator i=alProb.begin();i!=alProb.end();++i){
            double sum=0.0;
            ofs <<i->first.englishSentenceLength << " " 
                << i->first.classPrevious << " " 
                << i->first.previous << " "
                << i->first.j << " "
                << i->first.Cj <<" "
                << i->second.low() <<" "
                << i->second.high()<< " ";
            for(int a=i->second.low();a<=i->second.high();++a)
                if( i->second[a] ) {
                    ofs << a << ' ' << i->second[a] << ' ' ;
                    sum+=i->second[a];
                }
            ofs << endl;
        }
        ofs.close();
    }
    if(alpredict){
        ofstream ofs(alpredict);
        if(!ofs.is_open()){
            cerr << "Cannot open file for HMM output "  << alpredict << endl;
            return false;
        }
        cerr << "Dumping HMM table to " << alpredict << endl;
        for(typename map<AlDeps<CLS>,FlexArray<double> >::const_iterator i=alProbPredicted.begin();i!=alProbPredicted.end();++i){
            double sum=0.0;
            ofs << i->first.englishSentenceLength << " " 
                << i->first.classPrevious << " " 
                << i->first.previous << " "
                << i->first.j << " "
                << i->first.Cj <<" "
                << i->second.low() <<" "
                << i->second.high()<< " ";
            for(int a=i->second.low();a<=i->second.high();++a)
                if( i->second[a] ) {
                    ofs << a << ' ' << i->second[a] << ' ';
                    sum+=i->second[a];
                }
            ofs << endl;
        }
        ofs.close();
    }
    if(alpha){
        ofstream ofs(alpha);
        
        if(!ofs.is_open()){
            cerr << "Cannot open file for HMM output "  << alpha << endl;
            return false;
        }
        cerr << "Dumping HMM table to " << alpha << endl;
        for(typename hash_map<int,Array<double> >::const_iterator i=init_alpha.begin();
            i!=init_alpha.end();i++)
        {
            ofs << i->first << " " << i->second.size() <<" ";
            int j;
            for(j=0;j<i->second.size();j++){
                ofs << i->second[j] << " ";
            }
            ofs<<endl;
        }
        ofs.close();
    }
    if(beta){
        ofstream ofs(beta);
        if(!ofs.is_open()){
            cerr << "Cannot open file for HMM output "  << beta << endl;
            return false;
        }
        cerr << "Dumping HMM table to " << beta << endl;
        for(typename hash_map<int,Array<double> >::const_iterator i=init_beta.begin();
            i!=init_beta.end();i++)
        {
            ofs  << i->first << " " << i->second.size() << " ";
            int j;
            for(j=0;j<i->second.size();j++){
                ofs << i->second[j] << " ";
            }
            ofs << endl;
        }
        ofs.close();
    }
    return true;
}

template<class CLS,class MAPPERCLASSTOSTRING>
bool HMMTables<CLS,MAPPERCLASSTOSTRING>::readJumps(const char* alprob, const char* alpredict, const char* alpha, const char* beta){
    if(alprob){
        ifstream ifs(alprob);
        if(!ifs.is_open()){
            cerr << "Cannot open file for HMM input "  << alprob << endl;
            return false;
        }
        cerr << "Reading HMM table from " << alprob << endl;
        string strLine="";
        bool expect_data = false;
        while(!ifs.eof()){
            strLine = "";
            getline(ifs,strLine);
            if(strLine.length()){
                stringstream ss(strLine.c_str());
                AlDeps<CLS> dep;
                int low, high;
                ss >> dep.englishSentenceLength >> dep.classPrevious
                    >> dep.previous
                    >> dep.j >> dep.Cj >> low >> high;
                typename map<AlDeps<CLS>,FlexArray<double> >::iterator p=alProb.find(dep);
                if( p==alProb.end() ) {
                    p=alProb.insert(make_pair(dep,FlexArray<double> (low,high,0.0))).first;
                }
                int pos;
                double val;
                while(!ss.eof()){
                    pos = low-1;
                    val = 0;
                    ss >> pos >> val;
                    if(pos>low-1){
                        p->second[pos]+=val;
                    }
                }
            }
        }
    }
    if(alpredict){
        ifstream ifs(alpredict);
        if(!ifs.is_open()){
            cerr << "Cannot open file for HMM input "  << alpredict << endl;
            return false;
        }
        cerr << "Reading HMM table from " << alpredict << endl;
        string strLine="";
        bool expect_data = false;
        while(!ifs.eof()){
            strLine = "";
            getline(ifs,strLine);
            if(strLine.length()){
                stringstream ss(strLine.c_str());
                AlDeps<CLS> dep;
                int low, high;
                ss >> dep.englishSentenceLength >> dep.classPrevious
                    >> dep.previous
                    >> dep.j >> dep.Cj >> low >> high;
                typename map<AlDeps<CLS>,FlexArray<double> >::iterator p=alProbPredicted.find(dep);
                if( p==alProbPredicted.end() ) {
                    p=alProbPredicted.insert(make_pair(dep,FlexArray<double> (low,high,0.0))).first;
                }
                int pos;
                double val;

                while(!ss.eof()){
                    pos = low-1;
                    val = 0;
                    ss >> pos >> val;
                    if(pos>low-1){
                        p->second[pos]+=val;
                    }
                }
            }
        }
    }
    
    if(alpha){
        ifstream ifs(alpha);
        
        if(!ifs.is_open()){
            cerr << "Cannot open file for HMM input "  << alpha << endl;
            return false;
        }
        string strLine="";
        bool expect_data = false;
        while(!ifs.eof()){
            strLine = "";
            getline(ifs,strLine);
            if(strLine.length()){
                stringstream ss(strLine.c_str());
                int id = -1,size = -1;
                ss >> id >> size ;
                if(id<0||size<0||id!=size){
                    cerr << "Mismatch in alpha init table!" << endl;
                    return false;
                }
                Array<double>&alp = doGetAlphaInit(id);
                int j;
                double v;
                for(j=0;j<alp.size();j++){
                    ss >> v;
                    alp[j]+=v;
                }
            }
        }
    }
    
    if(beta){
        ifstream ifs(beta);
        
        if(!ifs.is_open()){
            cerr << "Cannot open file for HMM input "  << beta << endl;
            return false;
        }
        string strLine="";
        bool expect_data = false;
        while(!ifs.eof()){
            strLine = "";
            getline(ifs,strLine);
            if(strLine.length()){
                stringstream ss(strLine.c_str());
                int id = -1,size = -1;
                ss >> id >> size ;
                if(id<0||size<0||id!=size){
                    cerr << "Mismatch in alpha init table!" << endl;
                    return false;
                }
                Array<double>&bet = doGetBetaInit(id);
                int j;
                double v;
                for(j=0;j<bet.size();j++){
                    ss >> v;
                    bet[j]+=v;
                }
            }
        }
    }
    
    return true;
}

//////////////////////////////////////
template<class CLS,class MAPPERCLASSTOSTRING>
HMMTables<CLS,MAPPERCLASSTOSTRING>::  HMMTables(double _probForEmpty,const MAPPERCLASSTOSTRING&m1,const MAPPERCLASSTOSTRING&m2): 
  probabilityForEmpty(mfabs(_probForEmpty)),
  updateProbabilityForEmpty(_probForEmpty<0.0),
  mapper1(&m1),
  mapper2(&m2)
{}
template<class CLS,class MAPPERCLASSTOSTRING>
HMMTables<CLS,MAPPERCLASSTOSTRING>::~HMMTables() {}
