#!/usr/local/bin/perl

my(@probs);

$set = "";
$transfile = "";
$scorefile = "";
$mode = "force";
#$mode = "step";
$base = "/tmp";

for ($i = 0; $i < @ARGV; $i++) {
    if ($ARGV[$i] eq "-s") {
	$set = $ARGV[$i+1];
	$i++;
    } elsif ($ARGV[$i] eq "-n") {
	$transfile = $ARGV[$i+1];
	$i++;
    } elsif ($ARGV[$i] eq "-m") {
	$mode = $ARGV[$i+1];
	$i++;
    } elsif ($ARGV[$i] eq "-c") {
	$scorefile = $ARGV[$i+1];
	$i++;
    } elsif ($ARGV[$i] eq "-o") {
	$oraclefile = $ARGV[$i+1];
	$i++;
    } else {
	print "Usage: memtweights.pl -s scoreset -n nbest-file
   -s\tScoring set (mt03|h1|june02)
   -n\tN-best output file
";
	exit;
    }
}

if ($set eq "") {
    print STDERR "Must give set type\n";
    exit;
}
if ($transfile eq "") {
    print STDERR "Must give n-best file to use\n";
    exit;
}

print STDERR "Running scoring set $set on $transfile\n";

if ($set eq "h1") {
    $srcref = "/usr0/eepeter/Hebrew/e1.ref.sgm.new";
    $refdir = "h1refs"; 
    $srclang = "Hebrew";
} elsif ($set eq "mt03") {
    $srcref = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/Chinese/mt03_chinese_evlset_v0-ref.sgm";
    $refdir = "mt03refs";
    $srclang = "Chinese";
} elsif ($set eq "june02") {
    $srcref = "/usr0/rreynold/eval/TidesEval-Chinese-June2002-Ref.sgm";
    $refdir = "june02refs";
    $srclang = "Chinese";
} elsif ($set eq "arabic05") {
    $srcref = "/afs/cs/project/gale-1/GALE/MEMT/Translations/arabic05/refs/mt05_arabic_evlset_v0-ref.sgm"; 
    $refdir = "arabic05";
    $srclang = "Arabic";
} elsif ($set eq "french") {
    $srcref = "/usr0/ghannema/results/eval/ref-001002.sgm"; # On Chicago
    $refdir = "french";
    $srclang = "French";
} elsif ($set eq "george") {
    $srcref = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/Korean/k.txt"; 
    $refdir = "george";
    $srclang = "Korean";
}

if (!-e "$base/$refdir") {
    mkdir("$base/$refdir");
    print STDERR "Creating reference files for each sentence\n";
    &createSentRefs;
}

if ($scorefile eq "") {
    $transfile =~ m/([^\/]+)$/;
    $transbase = $1;
    $scorefile = "$base/$refdir/$transbase.scores";
}


#$transfile = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/Hebrew/h1-erik-nbest.txt";
#$transfile = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/Hebrew/h1-erik-nbest-may04.txt.gz";
#$scorefile = "$refdir/h1-scores.txt";

$meteordir = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/weights/meteor/meteor_version_0.4.3";

$totalsents = 0;

srand(time() ^ ($$ + ($$ << 15)));


$| = 1;

if (-s $scorefile) {
    # Read in translations, scores and traces
    print STDERR "Reading in scores from $scorefile\n";
    print "Total sents: ", &readScores($scorefile), "\n";
} else {
    # Need to score sentences
    print STDERR "Scoring each translation and writing scores to $scorefile\n";
    &scoreAll($transfile);
}

if ($oraclefile ne "") {
    &extractBestMETEOR($transfile);
    exit;
}


if ($mode eq "step") {
    &learnWeightsStep;
} elsif ($mode eq "force") {
    &learnWeightBruteForce;
}


# Try brute force
sub learnWeightBruteForce {
    my($maxmeteor) = -1;
    my($thismeteor, $thisscore, $maxscore, $maxindex);
    my(@maxindexes);
    my($i, $j, $k, $l);
    my($step) = 0.1;

    my $minweight = 0;
    my $maxweight = 5;
    my @weights = (0, 0, 0);
    my @bestweights = (0, 0, 0);

    my $allweights = 0;

    while (!$allweights) {
	@maxindexes = ();
	for ($i = 0; $i < scalar(@scores); $i++) { # Cycle over every sentence
	    $maxscore = -10000;
	    #$maxindex = -1;
	    #print "Prob val size: " . $probs[$i][0] . "\n";
	    
	    for ($j = 0; $j < @{$scores[$i]}; $j++) { # Cycle over n-best for each sentence
		#next if $i == 0 and $j == 0;
		$thisscore = 0;
		for ($k = 0; $k < @weights; $k++) {
		    $thisscore += $weights[$k] * $scores[$i][$j][$k];
		}
		
		#print "This score $thisscore, Max $maxscore\n";
		
		if ($thisscore > $maxscore) {
		    $maxscore = $thisscore;
		    $maxindex = $j;
		}
	    }
	    push @maxindexes, $maxindex;
	}
	# Calculate cumulative meteor score
	$thismeteor = 0;
	for ($i = 0; $i < @meteorscores; $i++) {
	    #print "METEOR: " . $meteorscores[$i][$maxindexes[$i]] . " ";
	    $thismeteor += $meteorscores[$i][$maxindexes[$i]];
	}
	#print "Frag $fragweight, Len $lenweight, $thismeteor\n";
	
	if ($thismeteor > $maxmeteor) {
	    $maxmeteor = $thismeteor;
	    $maxweights = join(" ", @weights);
	    
	    print STDOUT "Best so far $maxmeteor: $maxweights\n";
	    print STDERR "Best so far $maxmeteor: $maxweights\n";
	}
	
	# Try next set of weights
	$allweights = 1;
	for ($l = 0; $l < @weights; $l++) {
	    if ($weights[$l] < $maxweight) {
		$weights[$l] += $step;
		$allweights = 0;
		last;
	    } else {
		$weights[$l] = $minweight;
	    }
	}


    }

    print "Best weights $maxweights (M $maxmeteor)\n";
    print STDERR "Best weights $maxweights (M $maxmeteor)\n";
}


sub learnWeightsStep {
    # Now have necessary information
    # For each sentence, find best scoring version, note weights
    # Adjust global weights by small increment, recalculate total score
    # Repeat till score delta below some threshold
    my($threshold) = 0.001;
    my($stepsize)  = 0.1;

    my($fragweight) = -1;
    my($lenweight) = -1;
    my($probweight) = 1;
    my($ruleweight) = 0;

    my($fragdirection, $lendirection);
    
    my($i, $delta, $totalsents);
    my($maxscore, $maxindex, $newtotalscore, $thisscore);
    my(@maxindexes);

    # Start fragweight, lenweight at random values within a certain range
    $fragweight = -rand(6);
    $lenweight = -rand(6);

    #$fragweight = -2.9; $lenweight = -4.8;

    $totalsents = scalar(@probs);

    # Get initial score
    my($totalscore) = 0;
    #for ($i = 0; $i < $totalsents; $i++) {
    #}

    # Find highest scoring version for each sentence
    @maxindexes = ();
    for ($i = 0; $i < $totalsents; $i++) {
	$maxscore = -1000;
	$maxindex = -1;
	for ($j = 0; $j < @{$meteorscores[$i]}; $j++) {
	    next if $i == 0 and $j == 0; # For some reason, $meteorscores[0][0] not valid
	    if ($meteorscores[$i][$j] > $maxscore) {
		$maxscore = $meteorscores[$i][$j];
		$maxindex = $j;
	    }
	}
	push @maxindexes, $maxindex;
    }

    #$thismeteor = 0;
    #for ($i = 0; $i < @meteorscores; $i++) {
	#$thismeteor += $meteorscores[$i][$maxindexes[$i]];
    #}
    #print "Max $thismeteor\n";

    $delta = 100;
    # Adjust weights, recalculate and see how it affects    
    while ($delta > $threshold) {
	# Calculate based on current weights
	$fragdirection = 0; $lendirection = 0;
	$totalscore = 0;
	#@maxindexes = ();

	for ($i = 0; $i < $totalsents; $i++) {
	    $maxscore = -1000;
	    #$maxindex = -1;
	    for ($j = 0; $j < @{$probs[$i]}; $j++) {
		#print $prob, $rule, $fragpen, $lenpen, $srclen, $tgtlen;
		next if $i == 0 and $j == 0; # For some reason, $meteorscores[0][0] not valid
		$thisscore = ($probweight * ($probs[$i][$j]/$tgtlens[$i][$j])) + 
		    ($ruleweight * $rulescores[$i][$j]) + 
		    ($fragweight * $fragpens[$i][$j]) +
		    ($lenweight * $lenpens[$i][$j]);
		
		if ($thisscore > $maxscore) {
		    $maxscore = $thisscore;
		    $maxindex = $j;
		}
	    }

	    #print "Maxindex $maxindex\n";
	    $totalscore += $meteorscores[$i][$maxindex];

	    if ($meteorscores[$i][$maxindex] < $meteorscores[$i][$maxindexes[$i]]) {
		# Method 3
 		if ($fragpens[$i][$maxindex] > $fragpens[$i][$maxindexes[$i]]) {
 		    $fragdirection-- ;
 		} elsif ($fragpens[$i][$maxindex] < $fragpens[$i][$maxindexes[$i]]) {
 		    $fragdirection++;
 		}
 		if ($lenpens[$i][$maxindex] > $lenpens[$i][$maxindexes[$i]]) {
 		    $lendirection--;
 		} elsif ($lenpens[$i][$maxindex] < $lenpens[$i][$maxindexes[$i]]) {
 		    $lendirection++;
 		}

		
		# Method 2
		#$fragdirection += ($fragpens[$i][$maxindexes[$i]] - $fragpens[$i][$maxindex]);
		#$lendirection += ($lenpens[$i][$maxindexes[$i]] - $lenpens[$i][$maxindex]);
		# Method 1
# 		if ($fragpens[$i][$maxindex] > $fragpens[$i][$maxindexes[$i]]) {
# 		    $fragdirection--;
# 		} elsif ($fragpens[$i][$maxindex] < $fragpens[$i][$maxindexes[$i]]) {
# 		    $fragdirection++;
# 		}
# 		if ($lenpens[$i][$maxindex] > $lenpens[$i][$maxindexes[$i]]) {
# 		    $lendirection--;
# 		} elsif ($lenpens[$i][$maxindex] < $lenpens[$i][$maxindexes[$i]]) {
# 		    $lendirection++;
# 		}
	    } 

	}

	#print "Fragdiff $fragdirection  Lendiff $lendirection\n";

	$fragweight += ($fragdirection/$totalsents);
	$lenweight += ($lendirection/$totalsents);

	# Adjust weights based on this to maximize, recalculate total score
#  	if ($fragdirection > 0 && ($fragweight + $stepsize < 0)) {
#  	    $fragweight += $stepsize;
#  	} elsif ($fragdirection < 0) {
#  	    $fragweight -= $stepsize;
#  	}
#  	if ($lendirection > 0 && ($lenweight + $stepsize < 0)) {
#  	    $lenweight += $stepsize;
#  	} elsif ($lendirection < 0) {
#  	    $lenweight -= $stepsize;
#  	}


	$newtotalscore = 0;
	for ($i = 0; $i < $totalsents; $i++) {
	    $maxscore = -1000;
	    #$maxindex = -1;
	    for ($j = 0; $j < @{$probs[$i]}; $j++) {
		next if $i == 0 and $j == 0; # For some reason, $meteorscores[0][0] not valid
		$thisscore = ($probweight * ($probs[$i][$j]/$tgtlens[$i][$j])) + 
		    ($ruleweight * $rulescores[$i][$j]) + 
		    ($fragweight * $fragpens[$i][$j]) +
		    ($lenweight * $lenpens[$i][$j]);
		
		if ($thisscore > $maxscore) {
		    $maxscore = $thisscore;
		    $maxindex = $j;
		}
	    }
	    $newtotalscore += $meteorscores[$i][$maxindex];

	}


	# Get delta with previous total, stop when delta is smaller than threshold
	$delta = abs($newtotalscore - $totalscore);

	print "Old score: $totalscore, New $newtotalscore, Frag $fragweight, Len $lenweight\n";

    }


    print "Best weights: Frag $fragweight, Length $lenweight\n";
}


sub scoreOverall {
    my($prob, $rule, $fragpen, $lenpen, $srclen, $tgtlen) = @_;

    my($thisscore) = ($probweight * ($probs[$i][$j]/$tgtlen)) +
	($ruleweight * $rulescores[$i][$j]) +
	($fragweight * $fragpens[$i][$j]) +
	($lenweight * $lenpens[$i][$j]);

}

sub readScores {
    my($scorefile) = shift;
    my($sentcount, $ncount);
    my($line, $i);
    my(@fields);

    open(SCORE, $scorefile) or die $!;
    while ($line = <SCORE>) {
	next if $line =~ m/^\#/;
	next if $line =~ m/^\s*$/;
	chomp($line);
	(@fields) = split(/\t+/, $line);
	next if @fields < 4;
	$sentcount = $fields[0];
	$ncount = $fields[1];

	for ($i = 2; $i+1 < @fields; $i++) {
	    $scores[$sentcount][$ncount][$i-2] = $fields[$i];
	}

    	$meteorscores[$sentcount][$ncount] = $fields[$#fields];
	#print "$sentcount $ncount METEOR $meteor $meteorscores[0][0] $rulescore $srclen $tgtlen\n";
    }

    return scalar(@scores);
}

sub byindex {
    my($aindex, $bindex);
    $a =~ m/(\d+)\.hyp$/; $aindex = $1;
    $b =~ m/(\d+)\.hyp$/; $bindex = $1;
    $aindex <=> $bindex;
}


sub scoreAll {
    my($transdir) = shift;
    my($line, $transline, $sentcount, $ncount, $trans, $scores, $meteor, $totalsents, $i);
    my($table) = "";
    my $tableline;
    my @trscores;

    #print "#sentcount ncount prob rule fragpen lenpen meteor\n";
    my @transfiles = <$transdir/oracle*.hyp>;
    
    @transfiles = sort byindex @transfiles;

    $maxmeteor = -1;

    foreach $transfile (@transfiles) {
	print "Sentence file $transfile\n";
	$transfile =~ m/(\d+)\.hyp$/;
	$sentcount = $1;
	$ncount = 0;

	open(NBEST, $transfile) or die $!;
	while ($transline = <NBEST>) {
	    #print "Sentcount $sentcount n $ncount Trans line $transline\n";
	    chomp($transline);
	    ($scores, $trans) = split(/\s\s/, $transline);
	    #print "Scores $scores trans $trans\n";
	    @trscores = split(/\s/, $scores);
	
	    $transscore = 0;
	    @tgtwords = split(/\s+/, $trans);
	    $tgtlen = scalar(@tgtwords);

	
	    # Calculate a METEOR score for the sentence
	    $meteor = &scoreMETEOR($sentcount, $trans);
	    #last if $sentcount == 4;

	    $tableline = "$sentcount\t$ncount\t";
	    for ($i = 0; $i < @trscores; $i++) {
		$tableline .= ($trscores[$i]/$tgtlen) . "\t";
		$scores[$sentcount][$ncount][$i] = ($trscores[$i]/$tgtlen);
	    }
	    $tableline .= "$meteor\n";

	    #print "$tableline";

	    $table .= $tableline;

	    $meteorscores[$sentcount][$ncount] = $meteor;
	    
	    if ($sentcount > $totalsents) {
		$totalsents = $sentcount;
	    }

	    $ncount++;
	}
	close(NBEST);
    }

    
    #print "$maxinfo\n";

    print STDERR "\n";

    open(SCORE, "> $scorefile") or die $!;
    #print SCORE "# $transfile\n";
    print SCORE $table;
    print  "End text\n";
    close(SCORE);

    $totalsents++;

}



# Score a particular sentence from the n-best list
sub scoreMETEOR {
    my($sentindex, $trans) = @_;
    my($reference, $hypfile, $sysid, $meteor);

    # Write out to a mini-document
    $sysid = "xfer";
    $hypfile = "/tmp/hypsent.sgm";

    open(HYP, "> $hypfile") or die $!;
    print HYP "<tstset setid=\"xfer_sent$sentnum\" srclang=\"$srclang\" trglang=\"English\">\n";
    print HYP "<DOC docid=\"SENT$sentindex\" sysid=\"$sysid\">\n";
    print HYP "<seg id=0> $trans </seg>\n";
    print HYP "</DOC>\n</tstset>\n";
    close(HYP);

    $reference = "$base/$refdir/refs$sentindex.txt";

    # Score against reference file for just that sentence index
    $meteorcommand = "perl -I$meteordir $meteordir/meteor.pl -s $sysid -r $reference -t $hypfile 2> /dev/null |";
#print "METEOR: $meteorcommand\n";
    $meteor = 0;
    open(SCORE, $meteorcommand) or die $!;
    while ($line = <SCORE>) {
	#print "$line";
	if ($line =~ m/^Score: (.*)$/) {
	    $meteor = $1;
	}
    }
    close(SCORE);

    return $meteor;
}


# Create appropriate ref file for each sent
sub createSentRefs {
    my($i);
    my($sentcount); # = 27;
    
    open(REF, $srcref) or die $!;
    while ($line = <REF>) {
	$line =~ s/[\r\n]*$//;
	if ($line =~ m/sysid=\"(\w+)\"/) {
	    $sysid = $1;
	}
	if ($line =~ m/<seg/) {
	    $line =~ s/<seg id=\d+>\s*//;
	    $line =~ s/\s*<\/seg>//;
	    push @{$refsets{$sysid}}, $line;
	}
    }
    close(REF);

    foreach $sysid (sort keys %refsets) { 
	$sentcount = scalar(@{$refsets{$sysid}});
    }


    for ($sentnum = 0; $sentnum < $sentcount; $sentnum++) {
	open(NEWREF, "> $base/$refdir/refs$sentnum.txt") or die $!;
	print NEWREF "<refset setid=\"XFER.ref\" srclang=\"$srclang\" trglang=\"English\">\n";
	foreach $sysid (sort keys %refsets) { 
	    print NEWREF "<DOC docid=\"SENT$sentnum\" sysid=\"$sysid\">\n";
	    print NEWREF "<seg id=0> " . $refsets{$sysid}[$sentnum] . " </seg>\n";
	    print NEWREF "</DOC>\n";
	}
	print NEWREF "</refset>\n";
	close(NEWREF);
    }
}


# Given a set of weights, extract the best sentence from each n-best set
sub extractBest {
    my($probweight, $ruleweight, $fragweight, $lenweight) = @_;
    my($maxscore) = -1;
    my($maxindex) = 0;
    my($i, $j, $thisscore);
    my($totalsents) = scalar(@fragpens);


    my(@maxindexes) = ();
    for ($i = 0; $i < $totalsents; $i++) {
	$maxscore = -1;
	$maxindex = -1;
	for ($j = 0; $j < @{$probs[$i]}; $j++) {
	    $thisscore = ($probweight * $probs[$i][$j]) + 
		($ruleweight * $rulescores[$i][$j]) + 
		($fragweight * $fragpens[$i][$j]) +
		($lenweight * $lenpens[$i][$j]);
	    
	    if ($thisscore > $maxscore) {
		$maxscore = $thisscore;
		$maxindex = $j;
	    }
	}
	push @maxindexes, $maxindex;
    }
    
    open(NBEST, $transfile) or die $!;
    while ($line = <NBEST>) {
	chomp($line);
	if ($line =~ m/^(\d+)\s+(\d+)\t(.*)$/) {
	    ($sentcount, $ncount, $trans) = ($transline =~ m/^(\d+) (\d+)\t(.*)$/);
	    if ($maxindexes[$sentcount] == $ncount) {
		print $trans;
	    }
	}
    }
    close(NBEST);


}


# Given a set of weights, extract the best sentence from each n-best set
sub extractBestMETEOR {
    my($transdir) = shift;
    my($transfile);
    my($maxscore) = -1;
    my($maxindex) = 0;
    my($i, $j, $thisscore);
    my($totalsents) = scalar(@scores);
    print "Total sents $totalsents\n";

    my(@maxindexes) = ();
    for ($i = 0; $i < $totalsents; $i++) {
	$maxscore = -1;
	$maxindex = -1;
	for ($j = 0; $j < @{$scores[$i]}; $j++) {
	    if ($meteorscores[$i][$j] > $maxscore) {
		$maxscore = $thisscore;
		$maxindex = $j;
	    }
	}
	push @maxindexes, $maxindex;
    }
    

    #print "#sentcount ncount prob rule fragpen lenpen meteor\n";
    my @transfiles = <$transdir/oracle*.hyp>;
    
    @transfiles = sort byindex @transfiles;

    $maxmeteor = -1;

    open(ORACLE, "> $oraclefile") or die $!;
    foreach $transfile (@transfiles) {
	$transfile =~ m/(\d+)\.hyp$/;
	$sentcount = $1;
	$ncount = 0;

	open(NBEST, $transfile) or die $!;
	while ($transline = <NBEST>) {
	    #print "Sentcount $sentcount n $ncount Trans line $transline\n";
	    chomp($transline);
	    ($scores, $trans) = split(/\s\s/, $transline);
	    if ($maxindexes[$sentcount] == $ncount) {
		print ORACLE $trans;
	    }
	    $ncount++;

	}
	close(NBEST);
    }
    close(ORACLE);


}
