package banyan

import scala.actors.Actor
import scala.actors.AbstractActor
import scala.actors.Actor._
import scala.actors.TIMEOUT

import scala.actors.remote.RemoteActor
import scala.actors.remote.RemoteActor._
import scala.actors.remote._

import org.apache.commons.cli.Options
import org.apache.commons.cli.CommandLine
import java.net.InetAddress



private object Worker {

  import BanyanPublic._
  import BanyanPrivate._

  def parse(args: Array[String]) : CommandLine = {

    //use CLI to parse command line options
    var opts = new org.apache.commons.cli.Options();
    opts.addOption("c", true, "coordinator address");
    opts.addOption("cp", true, "coordinator port (default = 9500)");
    opts.addOption("p", true, "port this worker should run on");
    opts.addOption("r", true, "pathname to root problem to solve."+
                   " without this arg, will be waiter.");

    //do parsing
    var parser = new org.apache.commons.cli.GnuParser();
    parser.parse( opts, args);

  }

  def main(args: Array[String]) : Unit = {
    println("worker says: hello world.")

    val cmd = parse(args)
    
    //process options
    //coordinator address and port
    if (!cmd.hasOption("c")) {
      println("Forgot to supply coordinator address. (use -c). Quitting.")
      System.exit(0);
    }

    val coorHost = 
      new Host(cmd.getOptionValue("c"), 
               java.lang.Integer.parseInt(cmd.getOptionValue("cp", "9500")))

    println("coordinator at " + coorHost.toString)

    //get thisHost
    //TBD: put in cheesy option bind workaround ... default to 9001 for now
    val thisAddr = InetAddress.getLocalHost().getHostAddress()
    val thisHost = 
      Host(thisAddr, 
           java.lang.Integer.parseInt(cmd.getOptionValue("p", "9001")))


    BanyanPrivate.thisHost = thisHost

    val listenerActor = new ListenerActor(coorHost, thisHost)
    listenerActor.start

  }

}


class TimeoutGen(msec: Long, nd: TreeNode) extends Actor {
  def act() {
    reactWithin(msec) { //or should I do receiveWithin (thread-based)? 
      case 'done =>
        exit
      case TIMEOUT => 
        println("TIMEOUT!")
        System.out.flush()
        nd.timeout()
        exit
    }
  }
}


private class ListenerActor(coorHost: Host, thisHost: Host) extends Actor {
  import BanyanPublic._
  import BanyanPrivate._


  val coor = select(coorHost, 'coordinator)


  val proverActor = new SchedulerActor(coor)


  def doRegistration(coorHost: Host, thisHost: Host) = {

    alive(thisHost.port)
    register('worker, self)


    println("this host: " + thisHost.toString)
    coor ! ('registerWorker, thisHost)
    
    println("waiting for ack from coordinator")
    receive {
      case ('registered, shortName1: String, target: Tickets) =>
         ticketsTarget = target
         shortName = shortName1
      case msg => 
        throw new Error("not an ack: " + msg)
    }



  }


  def act() : Unit = {

    doRegistration(coorHost: Host, thisHost: Host)
    println("registration done.")
    proverActor.start

    loop {
     react { reactFunction } 
    }
  }


    // do I ever need to synchronize here?

   val reactFunction : PartialFunction[Any,Unit] = {
     case 'quit => 
       quit()
     case ('message, id: NodeID, msg: Any) => 
       sendMessage(id, msg)
     case ('makeIrrelevant, id: NodeID) => 
       makeIrrelevant(id)
     case ('childReturned,  child: NodeID, id: NodeID, v: Any) =>
       returnFromChild(child,id,v) 
     case ('tickets, id: NodeID, rsrc: Tickets) =>
       addTotalTickets(rsrc)
       nodeLock.synchronized{
         sendTickets(id, rsrc)
       }
       proverActor ! 'gotTickets
     case ('passNode, nodeID: NodeID, nd: TreeNode) =>
       println("got node " + nodeID)
       nodeLock.synchronized{
         nodeMap.put(nodeID, NodeHere(nd))
       }
       addTotalTickets(nd.checkTickets())
       tasks += nodeID
       proverActor ! 'gotNode
     case ('collectSubtree, dest: Host, nodeID: NodeID) =>
       println("collectSubtree of " +  nodeID)
       val numPasses = passSubtree(dest, nodeID, true)
       ()
     case ('nodeUpdate, nt@NodeThere(hst, id )) => 
       println("got nodeupdate")
       if(hst != thisHost){
         nodeMap.put(id, nt)
       } else {
         println("I own that node")
       }
     case ('setTicketTarget, target: Tickets) =>
       println("setting ticket target to " + target)
       ticketsTarget = target
     case msg =>
       println("got message: " + msg)
   }


  def quit() : Unit = {
    exit
  }


}


private class SchedulerActor(coor: AbstractActor) extends Actor {

  import BanyanPublic._
  import BanyanPrivate._

  import scala.collection.mutable.HashMap

//  self.trapExit = true


  var workThreshold: Long  = 50// milliseconds


  type TicketStatus = (Tickets, Long)

  var assessTime: Long = 0
  var assessCount: Long = 0
//  var taskTime: Long = 0
  var taskCount: Long = 0


  // compute and memoize
  def assessTicketsAux(ndID: NodeID, 
                       tt: HashMap[NodeID,TicketStatus])
  : TicketStatus = tt.get(ndID) match {
    case Some(ts) => 
      ts
    case None =>
      nodeMap.get(ndID) match {
        case Some(NodeHere(nd)) =>
          var subtreeTickets: Tickets = nd.checkTickets
          var subtreeWork: Long = nd.timeSpentHere
          for( c <- nd.children ) {  
            val (t,w) = assessTicketsAux(c, tt)
            subtreeTickets += t
            subtreeWork += w
          }
          val res = (subtreeTickets,subtreeWork)  
          tt.put(ndID, res)
          res
        case Some(NodeThere(hst, id)) =>
          (0,0)
        case None =>
          throw new Error("node not found: " + ndID)
      }
  }

                      

  def assessTickets(): Unit = {

   // for each node, how many tickets its local subtree has
   // and the amount of time spent on its local subtree.
    val ticketTable = 
      new HashMap[NodeID, TicketStatus]

    val ticketSurplus =  getTotalTickets() - ticketsTarget

    

    if(ticketSurplus <= 0 ){
      // actually have a deficit. Just let the coordinator know.
      coor ! ('update, thisHost, hostData )
    } else {

      var bestSubtree : Option[NodeID] = None
      var bestTickets : Tickets = 0


      for( x <- nodeMap.values) x match {
        case NodeHere(nd) =>
          val (t,w) = assessTicketsAux(nd.nodeID, ticketTable)
           if(w > workThreshold){
             if(Math.abs(bestTickets - ticketSurplus) >
                 Math.abs(t - ticketSurplus)  
                 && t < getTotalTickets()    )
                 {
                   bestTickets = t
                   bestSubtree = Some(nd.nodeID)
                 }
           }
          //bestSubtree = Some(nd.nodeID)
        case NodeThere(hst, id) =>
      }

      bestSubtree match {
        case Some(ndID) =>
          // do I really want to block here?
          coor.!?(waitForReply, 
                         ('haveSurplus, 
                          thisHost, 
                          hostData,
                          bestTickets))
            match {
              case Some(('assignment, hst: Host)) =>
                passSubtree(hst, ndID, false)
                coor ! ('update, thisHost, hostData)
              case Some('noAssignment) =>
              case Some(msg) =>
                throw new Error("bad reply from coordinator: " + msg)
              case None => 
                // timed out
            }
        case None =>
          coor ! ('update, thisHost, hostData)

      }

    }
  }

  
  def workOnNode(nd: TreeNode, tmt: Long, onlyDoSimple: Boolean): Long = {
    if(nd.timeSlicesToUse > 1 && onlyDoSimple) {
      0L // don't interfere with timeslice doubling
    } else {
      val watch = new TimeoutGen(tmt, nd)
      watch.start()
      val startTime = System.currentTimeMillis()
      //    taskTime += startTime - taskStart
      taskCount += 1
      nd.workHere(tmt)
      val timeSpent = System.currentTimeMillis() - startTime
      nd.timeSpentHere += timeSpent
      nd.updateWorkLog(timeSpent)
      watch ! 'done //why does this work even if timeout occured?
      val timeUsed = System.currentTimeMillis() - startTime
      val timeLeft = tmt - timeUsed
      timeLeft
    }
  }

  def allotTime(gotTix: List[(NodeID,Long)],
                timeLeft: Long): List[(TreeNode,Long)] = {
     val auxfn : ((NodeID,Tickets)) => List[(TreeNode,Tickets )]  
       = x => 
       nodeMap.get(x._1) match {
         case Some(NodeHere(nd)) => List((nd,x._2))
         case _ => Nil
     }
    val nds = gotTix.flatMap(auxfn)
    val totalNewTickets = nds.foldLeft(0L)((x,y) => x +y._2)
    //println("gotTix= " + gotTix)
    if(totalNewTickets <= 0 ) {
      Nil
    } else {
      val timeAllotments = 
        nds.map( x =>
          (x._1, x._2 * timeLeft  / totalNewTickets))
      timeAllotments
    }
  }

  
  def act(): Unit = {


    var timeOfLastAssessment: Long = System.currentTimeMillis()

    // the main scheduling loop
    while( true ) { 
     
      if(tasks.size == 0){
//        if(totalTickets > 0) {
//          throw new Error("should not be any tickets here")
//        }
	println("assess time: " + 
                assessTime.toString + ", " + assessCount.toString)
//	println("task time: " + taskTime.toString + ", " + taskCount.toString)
        coor ! ('update, thisHost, hostData)
      }

      while( tasks.size == 0  ) {
        println("no tasks in queue. waiting.")
        receive {
          case 'gotNode => 
            println("ok, got a node.")
          coor ! ('update, thisHost, hostData)
          case 'gotTickets => 
            println("ok, got tickets.")
            coor ! ('update, thisHost, hostData)
          case 'quit =>
            exit
        }
      }

      val curTime = System.currentTimeMillis()
      if(curTime - timeOfLastAssessment > timeBetweenAssessments){
        println("assessing tickets.")
	
        nodeLock.synchronized{
	  assessTickets()
        }

	assessTime += System.currentTimeMillis() - curTime
	assessCount += 1	

        println("ok done.")
        timeOfLastAssessment = System.currentTimeMillis()
      }

      println("in main loop")
      println("size = " + tasks.size)


      val taskStart = System.currentTimeMillis()
      val tsk = tasks.dequeue

      nodeMap.get(tsk) match {
        case Some(NodeHere(nd)) =>
//          println(nd)
          nd.timeSlicesSaved += 1
          nd.ticketsSaved += nd.checkTickets()
          if(nd.timeSlicesSaved >= nd.timeSlicesToUse){
//            println("number of tickets = " + nd.checkTickets)
//            println("time spent here = " + nd.timeSpentHere)
            val tmt = nd.ticketsSaved  * rsrc2millis
            if(tmt > 0) {

              nd.sentTickets.clear
              val timeLeft = workOnNode(nd,tmt, false)

              tasks += tsk // new nodes also get added on in the background

              if( timeLeft > 0 && ! nd.sentTickets.isEmpty ){
                // give the new children time if there's any left over
                val newTasks =
                  new scala.collection.mutable.Queue[(TreeNode, Long)]()


                newTasks ++= allotTime(nd.sentTickets.toList,
                                       timeLeft).elements

                while(! newTasks.isEmpty){
                  val (nd1, tmt1) = newTasks.dequeue
                  nd1.sentTickets.clear
                  val timeLeft1 =  
                    if(tmt1 > 0){ workOnNode(nd1,tmt1, true) } else {0L}
                  if (timeLeft1 > 0 && ! nd1.sentTickets.isEmpty){
                    newTasks ++= allotTime(nd1.sentTickets.toList,
                                           timeLeft1).elements
                  }
                }

              }

            } else { // no tickets
              nd.inTaskQueue = false
              println("node has no tickets")
            }
            nd.timeSlicesSaved = 0
            nd.ticketsSaved = 0
            
          } else { // have not saved up enough timeslices
            tasks += tsk
          }
        case Some(NodeThere(hst, nodeID)) =>
        case None => throw new Error("could not find node" + tsk)
      }

    }


  }




  

}
