#!/usr/local/bin/perl

## MBR decoding of a NBest list

use strict;
use warnings;

my $bleu = '/afs/cs/project/mteval-1/code/bleu/bleu-1.04.pl';
my $tmp = '/tmp/mbr-ur';
my $mode = 'expb';

# my @wts = (1,.1,1,1,1.5,1);
#my @wts = (1.413050,1.000000,0.085163,0.535632,1.998422,9.091130);
#my @wts = (.993243,.030911,1,0.002366,3.405969,9.925679);
#my @wts = (.998881,1.664317,.357937,0.009921,4.234457,8.555681); # chinese
my @wts =  (0.497464,0.069569,0.318774,0.231283,4.586331,9.706123);

my @hyps = ();
my @probs = ();
my $fileName = shift;
#my $refFileName = shift;
my @refs = ();
my $refNum = 0;
#open(IFILE, $fileName) or die("Couldn't open the file $fileName\n");
#while (my $hyp  = <IFILE>){
#	chomp $hyp;
#	$refNum++;
#	push @refs,$_;
#}
#close IFILE;

my $senNum = 0;
open(IFILE, $fileName) or die("Couldn't open the file $fileName\n");
while (my $hyp  = <IFILE>){
	chomp $hyp;
	if($hyp eq ''){
		$senNum++;
		#computeExpectedBleu();
		decode();
		#<STDIN>;
		@hyps = ();
		@probs = ();
		last unless($hyp = <IFILE>);
		chomp $hyp;
	}
	push @hyps, $hyp;
	my $line = <IFILE>;
	chomp $line;
	my @scores = split /\s+/,$line;
	my $cost = 0;
	for(my $i = 0; $i <= 5; $i++){
		$cost += $scores[$i]*$wts[$i];
	}
	push @probs,$cost;
}
close IFILE;

sub computeExpectedBleu{
	computePost();
	my $eb = 0;
	for(my $i = 0; $i <= $#hyps; $i++){
		my $bleu = computeBleu($i);
	}
}

sub computeRBleu{
	my $hyp1 = shift;
# 	print $hyp1,"\n";
	unless(-e "$tmp/hyp$senNum-$hyp1.txt"){
		open(RFILE,">$tmp/hyp$senNum-$hyp1.txt") or die("Couldn't open a temp file for ref !\n");
		print RFILE "<DOC docid=\"a\" sysid=\"t\">\n<seg id=1> ".$hyps[$hyp1]." </seg>\n</DOC>\n";
		close RFILE;
	}
	
	unless(-e "$tmp/ref$senNum.txt"){
		open(TFILE,">$tmp/ref$senNum.txt") or die("Couldn't open a temp file for ref !\n");
		print TFILE "<DOC docid=\"a\" sysid=\"t\">\n<seg id=1> ".$refs[$hyp1]." </seg>\n</DOC>\n";
		close TFILE;
	}
	
	`$bleu -ci -t $tmp/hyp$senNum-$hyp1.txt -r $tmp/ref$senNum.txt > $tmp/out 2>/dev/null`;
	open(OFILE,"$tmp/out") or die("Couldn't open out file\n");
	while(<OFILE>){
		chomp;
		next unless(/^BLEUr1n4,(.+)$/);
		close OFILE;
		return $1;
	}
}


sub decode{
	#return if($senNum == 1);
	computePost();
	my $minRisk = 100000;
	my $bestHyp = '';
	my $ncount = $#hyps;
	$ncount = 200 if($#hyps > 200);
	for(my $i = 0; $i <= $ncount; $i++){
		my $currRisk = 0;
		for(my $j = 0; $j <= $#hyps; $j++){
			my $bleu = computeBleu($i,$j);
# 			print "$i $j $bleu\n";
			$currRisk += $probs[$j]*(1-$bleu);
			last if($currRisk > $minRisk);
		}
		print STDERR "$i : $currRisk\n";
		if($currRisk < $minRisk){
			$minRisk = $currRisk;
			$bestHyp = $i;
		}
	}
	print STDERR "Prev Best: $hyps[0]\n";
	print STDERR "New Best $minRisk : $hyps[$bestHyp] \n";
	print "$hyps[$bestHyp]\n";
}

sub computeBleu{
	my $hyp1 = shift;
	my $hyp2 = shift;
# 	print $hyp1,"\n";
	unless(-e "$tmp/hyp$senNum-$hyp1.txt"){
		open(RFILE,">$tmp/hyp$senNum-$hyp1.txt") or die("Couldn't open a temp file for ref !\n");
		print RFILE "<DOC docid=\"a\" sysid=\"t\">\n<seg id=1> ".$hyps[$hyp1]." </seg>\n</DOC>\n";
		close RFILE;
	}
	
	unless(-e "$tmp/hyp$senNum-$hyp2.txt"){
		open(TFILE,">$tmp/hyp$senNum-$hyp2.txt") or die("Couldn't open a temp file for ref !\n");
		print TFILE "<DOC docid=\"a\" sysid=\"t\">\n<seg id=1> ".$hyps[$hyp2]." </seg>\n</DOC>\n";
		close TFILE;
	}
	
	`$bleu -ci -r $tmp/hyp$senNum-$hyp1.txt -t $tmp/hyp$senNum-$hyp2.txt > $tmp/out 2>/dev/null`;
	open(OFILE,"$tmp/out") or die("Couldn't open out file\n");
	while(<OFILE>){
		chomp;
		next unless(/^BLEUr1n4,(.+)$/);
		close OFILE;
		return $1;
	}
}

sub computePost{
	my $max = -1000000;
	for(my $i = 0; $i <= $#probs; $i++){
		$max = $probs[$i] if($probs[$i] > $max);
	}
	my $sum = 0;
	for(my $i = 0; $i <= $#probs; $i++){
		$probs[$i]  = exp($probs[$i] - $max);
		$sum += $probs[$i];
	}
	
	for(my $i = 0; $i <= $#probs; $i++){
		$probs[$i] = $probs[$i]/$sum;
	}
	
# 	for(my $i = 0; $i <= $#probs; $i++){
# 		print $probs[$i],"\n";
# 	}
}
