#!/usr/local/bin/perl

my(@probs);

$set = "";
$transfile = "";
$scorefile = "";
$mode = "force";
#$mode = "step";
$base = "/tmp";

for ($i = 0; $i < @ARGV; $i++) {
    if ($ARGV[$i] eq "-s") {
	$set = $ARGV[$i+1];
	$i++;
    } elsif ($ARGV[$i] eq "-n") {
	$transfile = $ARGV[$i+1];
	$i++;
    } elsif ($ARGV[$i] eq "-m") {
	$mode = $ARGV[$i+1];
	$i++;
    } elsif ($ARGV[$i] eq "-c") {
	$scorefile = $ARGV[$i+1];
	$i++;
    } else {
	print "Usage: learnweights.pl -s scoreset -n nbest-file
   -s\tScoring set (mt03|h1|june02)
   -n\tN-best output file
";
	exit;
    }
}

if ($set eq "") {
    print STDERR "Must give set type\n";
    exit;
}
if ($transfile eq "") {
    print STDERR "Must give n-best file to use\n";
    exit;
}

print STDERR "Running scoring set $set on $transfile\n";

if ($set eq "h1") {
    $srcref = "/usr0/eepeter/Hebrew/e1.ref.sgm.new";
    $refdir = "h1refs"; 
    $srclang = "Hebrew";
} elsif ($set eq "mt03") {
    $srcref = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/Chinese/mt03_chinese_evlset_v0-ref.sgm";
    $refdir = "mt03refs";
    $srclang = "Chinese";
} elsif ($set eq "june02") {
    $srcref = "/usr0/rreynold/eval/TidesEval-Chinese-June2002-Ref.sgm";
    $refdir = "june02refs";
    $srclang = "Chinese";
} elsif ($set eq "dr07bot") {
    $srcref = "/afs/cs.cmu.edu/project/gale-1/GALE/MEMT/Translations/dryrun-chinese07-text/refs/ref-dryrun-chinese07-text-bot.sgm";
    $refdir = "dr07botrefs";
    $srclang = "Chinese";
} elsif ($set eq "dr07top") {
    $srcref = "/afs/cs.cmu.edu/project/gale-1/GALE/MEMT/Translations/dryrun-chinese07-text/refs/ref-dryrun-chinese07-text-top.sgm";
    $refdir = "dr07toprefs";
    $srclang = "Chinese";
} elsif ($set eq "french") {
    $srcref = "/usr0/ghannema/results/eval/ref-001002.sgm"; # On Chicago
    $refdir = "french";
    $srclang = "French";
} elsif ($set eq "george") {
    $srcref = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/Korean/k.txt"; 
    $refdir = "george";
    $srclang = "Korean";
} else {
    die "Unknown set $set.  Exiting...";
}

if (!-e "$base/$refdir") {
    mkdir("$base/$refdir");
    print STDERR "Creating reference files for each sentence\n";
    &createSentRefs;
}

if ($scorefile eq "") {
    $transfile =~ m/([^\/]+)$/;
    $transbase = $1;
    $scorefile = "$base/$refdir/$transbase.scores";
}


#$transfile = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/Hebrew/h1-erik-nbest.txt";
#$transfile = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/Hebrew/h1-erik-nbest-may04.txt.gz";
#$scorefile = "$refdir/h1-scores.txt";

$meteordir = "/afs/cs.cmu.edu/project/avenue-1/Avenue/Transfer/weights/meteor/meteor_version_0.4.3";

$totalsents = 0;

srand(time() ^ ($$ + ($$ << 15)));


$| = 1;

if (-s $scorefile) {
    # Read in translations, scores and traces
    print STDERR "Reading in scores from $scorefile\n";
    print "Total sents: ", &readScores($scorefile), "\n";
} else {
    # Need to score sentences
    print STDERR "Scoring each translation and writing scores to $scorefile\n";
    &scoreAll($transfile);
}

if ($mode eq "step") {
    &learnWeightsStep;
} elsif ($mode eq "force") {
    &learnWeightBruteForce;
}


# Try brute force
sub learnWeightBruteForce {
    my($probweight) = 1;
    my($ruleweight) = 0;
    my($bestfrag, $bestlen);
    my($fragweight, $lenweight);
    my($maxmeteor) = -1;
    my($thismeteor, $thisscore, $maxscore, $maxindex);
    my(@maxindexes);
    my($i, $j);
    my($step) = 0.1;


    for ($fragweight = 0; $fragweight >= -10; $fragweight -= $step) {
	for ($lenweight = 0; $lenweight >= -10; $lenweight -= $step) {
	    @maxindexes = ();
	    for ($i = 0; $i < scalar(@probs); $i++) {
		$maxscore = -10000;
		#$maxindex = -1;
		#print "Prob val size: " . $probs[$i][0] . "\n";

		for ($j = 0; $j < @{$probs[$i]}; $j++) {
		    next if $i == 0 and $j == 0;
		    $thisscore = ($probweight * ($probs[$i][$j]/$tgtlens[$i][$j])) +
			($ruleweight * $rulescores[$i][$j]) +
			($fragweight * $fragpens[$i][$j]) +
			($lenweight * $lenpens[$i][$j]);
		    #print "This score $thisscore, Max $maxscore\n";
		    
		    if ($thisscore > $maxscore) {
			$maxscore = $thisscore;
			$maxindex = $j;
		    }
		}
		push @maxindexes, $maxindex;
	    }
	    # Calculate cumulative meteor score
	    $thismeteor = 0;
	    for ($i = 0; $i < @meteorscores; $i++) {
		#print "METEOR: " . $meteorscores[$i][$maxindexes[$i]] . " ";
		$thismeteor += $meteorscores[$i][$maxindexes[$i]];
	    }
	    #print "Frag $fragweight, Len $lenweight, $thismeteor\n";

	    if ($thismeteor > $maxmeteor) {
		$maxmeteor = $thismeteor;
		$bestfrag = $fragweight;
		$bestlen = $lenweight;
		print "Best so far $maxmeteor: frag $bestfrag, len $bestlen\n";
	    }
	}
    }
    
    print "Best frag weight $bestfrag, Best len weight $bestlen (M $maxmeteor)\n";
}


sub learnWeightsStep {
    # Now have necessary information
    # For each sentence, find best scoring version, note weights
    # Adjust global weights by small increment, recalculate total score
    # Repeat till score delta below some threshold
    my($threshold) = 0.001;
    my($stepsize)  = 0.1;

    my($fragweight) = -1;
    my($lenweight) = -1;
    my($probweight) = 1;
    my($ruleweight) = 0;

    my($fragdirection, $lendirection);
    
    my($i, $delta, $totalsents);
    my($maxscore, $maxindex, $newtotalscore, $thisscore);
    my(@maxindexes);

    # Start fragweight, lenweight at random values within a certain range
    $fragweight = -rand(6);
    $lenweight = -rand(6);

    #$fragweight = -2.9; $lenweight = -4.8;

    $totalsents = scalar(@probs);

    # Get initial score
    my($totalscore) = 0;
    #for ($i = 0; $i < $totalsents; $i++) {
    #}

    # Find highest scoring version for each sentence
    @maxindexes = ();
    for ($i = 0; $i < $totalsents; $i++) {
	$maxscore = -1000;
	$maxindex = -1;
	for ($j = 0; $j < @{$meteorscores[$i]}; $j++) {
	    next if $i == 0 and $j == 0; # For some reason, $meteorscores[0][0] not valid
	    if ($meteorscores[$i][$j] > $maxscore) {
		$maxscore = $meteorscores[$i][$j];
		$maxindex = $j;
	    }
	}
	push @maxindexes, $maxindex;
    }

    #$thismeteor = 0;
    #for ($i = 0; $i < @meteorscores; $i++) {
	#$thismeteor += $meteorscores[$i][$maxindexes[$i]];
    #}
    #print "Max $thismeteor\n";

    $delta = 100;
    # Adjust weights, recalculate and see how it affects    
    while ($delta > $threshold) {
	# Calculate based on current weights
	$fragdirection = 0; $lendirection = 0;
	$totalscore = 0;
	#@maxindexes = ();

	for ($i = 0; $i < $totalsents; $i++) {
	    $maxscore = -1000;
	    #$maxindex = -1;
	    for ($j = 0; $j < @{$probs[$i]}; $j++) {
		#print $prob, $rule, $fragpen, $lenpen, $srclen, $tgtlen;
		next if $i == 0 and $j == 0; # For some reason, $meteorscores[0][0] not valid
		$thisscore = ($probweight * ($probs[$i][$j]/$tgtlens[$i][$j])) + 
		    ($ruleweight * $rulescores[$i][$j]) + 
		    ($fragweight * $fragpens[$i][$j]) +
		    ($lenweight * $lenpens[$i][$j]);
		
		if ($thisscore > $maxscore) {
		    $maxscore = $thisscore;
		    $maxindex = $j;
		}
	    }

	    #print "Maxindex $maxindex\n";
	    $totalscore += $meteorscores[$i][$maxindex];

	    if ($meteorscores[$i][$maxindex] < $meteorscores[$i][$maxindexes[$i]]) {
		# Method 3
 		if ($fragpens[$i][$maxindex] > $fragpens[$i][$maxindexes[$i]]) {
 		    $fragdirection-- ;
 		} elsif ($fragpens[$i][$maxindex] < $fragpens[$i][$maxindexes[$i]]) {
 		    $fragdirection++;
 		}
 		if ($lenpens[$i][$maxindex] > $lenpens[$i][$maxindexes[$i]]) {
 		    $lendirection--;
 		} elsif ($lenpens[$i][$maxindex] < $lenpens[$i][$maxindexes[$i]]) {
 		    $lendirection++;
 		}

		
		# Method 2
		#$fragdirection += ($fragpens[$i][$maxindexes[$i]] - $fragpens[$i][$maxindex]);
		#$lendirection += ($lenpens[$i][$maxindexes[$i]] - $lenpens[$i][$maxindex]);
		# Method 1
# 		if ($fragpens[$i][$maxindex] > $fragpens[$i][$maxindexes[$i]]) {
# 		    $fragdirection--;
# 		} elsif ($fragpens[$i][$maxindex] < $fragpens[$i][$maxindexes[$i]]) {
# 		    $fragdirection++;
# 		}
# 		if ($lenpens[$i][$maxindex] > $lenpens[$i][$maxindexes[$i]]) {
# 		    $lendirection--;
# 		} elsif ($lenpens[$i][$maxindex] < $lenpens[$i][$maxindexes[$i]]) {
# 		    $lendirection++;
# 		}
	    } 

	}

	#print "Fragdiff $fragdirection  Lendiff $lendirection\n";

	$fragweight += ($fragdirection/$totalsents);
	$lenweight += ($lendirection/$totalsents);

	# Adjust weights based on this to maximize, recalculate total score
#  	if ($fragdirection > 0 && ($fragweight + $stepsize < 0)) {
#  	    $fragweight += $stepsize;
#  	} elsif ($fragdirection < 0) {
#  	    $fragweight -= $stepsize;
#  	}
#  	if ($lendirection > 0 && ($lenweight + $stepsize < 0)) {
#  	    $lenweight += $stepsize;
#  	} elsif ($lendirection < 0) {
#  	    $lenweight -= $stepsize;
#  	}


	$newtotalscore = 0;
	for ($i = 0; $i < $totalsents; $i++) {
	    $maxscore = -1000;
	    #$maxindex = -1;
	    for ($j = 0; $j < @{$probs[$i]}; $j++) {
		next if $i == 0 and $j == 0; # For some reason, $meteorscores[0][0] not valid
		$thisscore = ($probweight * ($probs[$i][$j]/$tgtlens[$i][$j])) + 
		    ($ruleweight * $rulescores[$i][$j]) + 
		    ($fragweight * $fragpens[$i][$j]) +
		    ($lenweight * $lenpens[$i][$j]);
		
		if ($thisscore > $maxscore) {
		    $maxscore = $thisscore;
		    $maxindex = $j;
		}
	    }
	    $newtotalscore += $meteorscores[$i][$maxindex];

	}


	# Get delta with previous total, stop when delta is smaller than threshold
	$delta = abs($newtotalscore - $totalscore);

	print "Old score: $totalscore, New $newtotalscore, Frag $fragweight, Len $lenweight\n";

    }


    print "Best weights: Frag $fragweight, Length $lenweight\n";
}


sub scoreOverall {
    my($prob, $rule, $fragpen, $lenpen, $srclen, $tgtlen) = @_;

    my($thisscore) = ($probweight * ($probs[$i][$j]/$tgtlen)) +
	($ruleweight * $rulescores[$i][$j]) +
	($fragweight * $fragpens[$i][$j]) +
	($lenweight * $lenpens[$i][$j]);


}

sub readScores {
    my($scorefile) = shift;
    my($sentcount, $ncount, $prob, $rulescore, $fragpen, $lenpen, $srclen, $tgtlen, $meteor, $trans);
    my($line);

    open(SCORE, $scorefile) or die $!;
    while ($line = <SCORE>) {
	next if $line =~ m/^\#/;
	next if $line =~ m/^\s*$/;
	chomp($line);
	($sentcount, $ncount, $prob, $rulescore, $fragpen, $lenpen, $srclen, $tgtlen, $meteor, $trans) = split(/\t+/, $line);

	#if ($sentcount > $totalsents) {
	#    $totalsents = $sentcount;
	#}
	
	$probs[$sentcount][$ncount] = $prob;
	$rulescores[$sentcount][$ncount] = $rulescore;
	$fragpens[$sentcount][$ncount] = $fragpen;
	$lenpens[$sentcount][$ncount] = $lenpen;
	$srclens[$sentcount][$ncount] = $srclen;
	$tgtlens[$sentcount][$ncount] = $tgtlen;
    	$meteorscores[$sentcount][$ncount] = $meteor;
	#print "$sentcount $ncount METEOR $meteor $meteorscores[0][0] $rulescore $srclen $tgtlen\n";
    }

    #print "0 0 " . $meteorscores[0][0] . "\n";
    #print STDERR "Read scores\n";
    #$totalsents++;
    return scalar(@probs);
}



sub scoreAll {
    my($transfile) = shift;
    my($line, $transline, $sentcount, $ncount, $trans, $scores, $meteor, $totalsents);
    my($table) = "";

    #print "#sentcount ncount prob rule fragpen lenpen meteor\n";

    if ($transfile =~ m/\.gz$/) {
	open(NBEST, "gunzip -c $transfile |") or die $!;
    } else {
	open(NBEST, "$transfile") or die $!;
    }
    $maxmeteor = -1;


    while (!eof(NBEST)) {
	while ($line = <NBEST>) {
	    last if $line !~ m/^\s*$/;
	}
	#print "First line: $line\n";
	
	if ($line =~ m/^SrcSent/i) {
	    print STDERR ".";
	    $transline = <NBEST>;
	} else {
	    $transline = $line;
	}
	
	#print "Trans line $transline\n";
	chomp($transline);
	($sentcount, $ncount, $trans) = ($transline =~ m/^(\d+) (\d+)\t(.*)$/);
	
	$scores = <NBEST>; chomp($scores);
# Overall: -8.57367, Prob: -168.426, Rules: 2.69743, Frag: 0.291667, Length: 0.605446, Words: 12,24
	#($overall, $prob, $rulescore, $fragpen, $lenpen, $srclen, $tgtlen) = ($scores =~ m/Overall: ([^,]+), Prob: ([^,]+), Rules: ([^,]+), Frag: ([^,]+), Length: ([^,]+), Words: (\d+),(\d+)/i);  

	$transscore = 0;
	
	$prob = $1 if $scores =~ m/Prob: ([^,]+)/;
	$rulescore = $1 if $scores =~ m/Rules: ([^,]+)/;
	$transscore = $1 if $scores =~ m/Trans: ([^,]+)/;
	$fragpen = $1 if $scores =~ m/Frag: ([^,]+)/;
	$lenpen = $1 if $scores =~ m/Length: ([^,]+)/;
	$scores =~ m/Words: (\d+),(\d+)/;
	$srclen = $1; $tgtlen = $2;


	
	# Calculate a METEOR score for the sentence
	$meteor = &scoreMETEOR($sentcount, $trans);
	#last if $sentcount == 4;

	#print "Some text $scores\n";
	#print SCORE "Some text\n";
	$table .= "$sentcount\t$ncount\t$prob\t$rulescore\t$fragpen\t$lenpen\t$srclen\t$tgtlen\t$meteor\t$trans\n";
	#print STDOUT "$sentcount\t$ncount\t$prob\t$rulescore\t$fragpen\t$lenpen\t$meteor\t$trans\n";

	$probs[$sentcount][$ncount] = $prob;
	$rulescores[$sentcount][$ncount] = $rulescore;
	$fragpens[$sentcount][$ncount] = $fragpen;
	$lenpens[$sentcount][$ncount] = $lenpen;
	$srclens[$sentcount][$ncount] = $srclen;
	$tgtlens[$sentcount][$ncount] = $tgtlen;
    	$meteorscores[$sentcount][$ncount] = $meteor;


	if ($ncount == 0) {
	    #print STDERR "$sentcount\n";
	    if ($sentcount != 0) {
		#print STDERR " Max METEOR: $maxinfo\n";
		#print "$maxinfo\n";
	    }
	    #print "$sentcount $ncount $meteor\t$trans\n";
	    $maxmeteor = -1;
	}

	if ($meteor > $maxmeteor) {
	    $maxmeteor = $meteor;
	    #$maxinfo = "$sentcount $ncount $prob $rulescore $fragpen $lenpen $meteor\n";
	    #$maxinfo = "$sentcount $ncount $meteor\t$trans\n";
	    #print $maxinfo;
	}
	
	if ($sentcount > $totalsents) {
	    $totalsents = $sentcount;
	}


	# Overall: -4.27181, Prob: -4.27181, Frag: 0.485714, Length: 0.590909
	
	# Read in trace, arc by arc
	$trace = <NBEST>; chomp($trace);
	$src = "";
	while ($trace !~ m/^\s*$/) {
	    #print "Trace $trace\n";
	    $trace =~ m/: ([^\(]+) \(/;
	    $src .= " " . $1;
	    $trace = <NBEST>;
	}
	$line = <NBEST>;
	
	$src =~ s/^\s//;
	$src =~ s/\s$//;
	$src =~ s/\s\s/ /g;

    }
    #print "$maxinfo\n";

    print STDERR "\n";

    open(SCORE, "> $scorefile") or die $!;
    print SCORE "# $transfile\n";
    print SCORE $table;
    print  "End text\n";
    #print SCORE "End text\n";

    close(NBEST);
    close(SCORE);
    $totalsents++;

}



# Score a particular sentence from the n-best list
sub scoreMETEOR {
    my($sentindex, $trans) = @_;
    my($reference, $hypfile, $sysid, $meteor);

    # Write out to a mini-document
    $sysid = "xfer";
    $hypfile = "/tmp/hypsent.sgm";

    open(HYP, "> $hypfile") or die $!;
    print HYP "<tstset setid=\"xfer_sent$sentnum\" srclang=\"$srclang\" trglang=\"English\">\n";
    print HYP "<DOC docid=\"SENT$sentindex\" sysid=\"$sysid\">\n";
    print HYP "<seg id=0> $trans </seg>\n";
    print HYP "</DOC>\n</tstset>\n";
    close(HYP);

    $reference = "$base/$refdir/refs$sentindex.txt";

    # Score against reference file for just that sentence index
    $meteorcommand = "perl -I$meteordir $meteordir/meteor.pl -s $sysid -r $reference -t $hypfile 2> /dev/null |";
#print "METEOR: $meteorcommand\n";
    $meteor = 0;
    open(SCORE, $meteorcommand) or die $!;
    while ($line = <SCORE>) {
	#print "$line";
	if ($line =~ m/^Score: (.*)$/) {
	    $meteor = $1;
	}
    }
    close(SCORE);

    return $meteor;
}


# Create appropriate ref file for each sent
sub createSentRefs {
    my($i);
    my($sentcount); # = 27;
    
    open(REF, $srcref) or die $!;
    while ($line = <REF>) {
	$line =~ s/[\r\n]*$//;
	if ($line =~ m/sysid=\"(\w+)\"/) {
	    $sysid = $1;
	}
	if ($line =~ m/<seg/) {
	    $line =~ s/<seg id=\d+>\s*//;
	    $line =~ s/\s*<\/seg>//;
	    push @{$refsets{$sysid}}, $line;
	}
    }
    close(REF);

    foreach $sysid (sort keys %refsets) { 
	$sentcount = scalar(@{$refsets{$sysid}});
    }


    for ($sentnum = 0; $sentnum < $sentcount; $sentnum++) {
	open(NEWREF, "> $base/$refdir/refs$sentnum.txt") or die $!;
	print NEWREF "<refset setid=\"XFER.ref\" srclang=\"$srclang\" trglang=\"English\">\n";
	foreach $sysid (sort keys %refsets) { 
	    print NEWREF "<DOC docid=\"SENT$sentnum\" sysid=\"$sysid\">\n";
	    print NEWREF "<seg id=0> " . $refsets{$sysid}[$sentnum] . " </seg>\n";
	    print NEWREF "</DOC>\n";
	}
	print NEWREF "</refset>\n";
	close(NEWREF);
    }
}


# Given a set of weights, extract the best sentence from each n-best set
sub extractBest {
    my($probweight, $ruleweight, $fragweight, $lenweight) = @_;
    my($maxscore) = -1;
    my($maxindex) = 0;
    my($i, $j, $thisscore);
    my($totalsents) = scalar(@fragpens);


    my(@maxindexes) = ();
    for ($i = 0; $i < $totalsents; $i++) {
	$maxscore = -1;
	$maxindex = -1;
	for ($j = 0; $j < @{$probs[$i]}; $j++) {
	    $thisscore = ($probweight * $probs[$i][$j]) + 
		($ruleweight * $rulescores[$i][$j]) + 
		($fragweight * $fragpens[$i][$j]) +
		($lenweight * $lenpens[$i][$j]);
	    
	    if ($thisscore > $maxscore) {
		$maxscore = $thisscore;
		$maxindex = $j;
	    }
	}
	push @maxindexes, $maxindex;
    }
    
    open(NBEST, $transfile) or die $!;
    while ($line = <NBEST>) {
	chomp($line);
	if ($line =~ m/^(\d+)\s+(\d+)\t(.*)$/) {
	    ($sentcount, $ncount, $trans) = ($transline =~ m/^(\d+) (\d+)\t(.*)$/);
	    if ($maxindexes[$sentcount] == $ncount) {
		print $trans;
	    }
	}
    }
    close(NBEST);


}


# Given a set of weights, extract the best sentence from each n-best set
sub extractBestMETEOR {
    my($maxscore) = -1;
    my($maxindex) = 0;
    my($i, $j, $thisscore);
    my($totalsents) = scalar(@fragpens);


    my(@maxindexes) = ();
    for ($i = 0; $i < $totalsents; $i++) {
	$maxscore = -1;
	$maxindex = -1;
	for ($j = 0; $j < @{$probs[$i]}; $j++) {
	    if ($meteorscores[$i][$j] > $maxscore) {
		$maxscore = $thisscore;
		$maxindex = $j;
	    }
	}
	push @maxindexes, $maxindex;
    }
    
    open(NBEST, $transfile) or die $!;
    while ($line = <NBEST>) {
	chomp($line);
	if ($line =~ m/^(\d+)\s+(\d+)\t(.*)$/) {
	    ($sentcount, $ncount, $trans) = ($transline =~ m/^(\d+) (\d+)\t(.*)$/);
	    if ($maxindexes[$sentcount] == $ncount) {
		print $trans;
	    }
	}
    }
    close(NBEST);


}
