package banyan

import scala.actors.Actor
import scala.actors.AbstractActor
import scala.actors.Actor._
import scala.actors.TIMEOUT

import scala.actors.remote.RemoteActor
import scala.actors.remote.RemoteActor._
import scala.actors.remote._

import java.net.InetAddress


@serializable
class Ref[T](init: T) {
  private var contents: T = init
  def set(value: T) { contents = value}
  def get: T = contents
}

@serializable
class Lock() extends Object

abstract class Status
case class Working() extends Status
case class Returned() extends Status
case class Irrelevant() extends Status


protected [banyan] abstract class NodeSomewhere
protected [banyan] case class NodeHere(nd: TreeNode) extends NodeSomewhere
protected [banyan] case class NodeThere(loc: Host, id: NodeID) 
     extends NodeSomewhere



protected [banyan] case class Host(ip: String, prt: Int) extends Node(ip,prt) {
  override def toString(): String = {
    ip + "_" + prt
  }
}


private case class HostData(tickets: BanyanPublic.Tickets, 
                      target: BanyanPublic.Tickets ){
    def surplus: BanyanPublic.Tickets = tickets - target
}

protected [banyan] case class LogData(wrkr: String, cnt: Int, tm: Long){
  var worker = wrkr
  var count = cnt
  var time = tm

  override def toString(): String = {
    "("+worker+","+count.toString+","+time.toString+")"
  }
}

abstract class Relative
case class Parent() extends Relative
case class Child(id: Int) extends Relative



// A node is identified by the host where it was created and an 
// integer generated at that host.
case class NodeID(hst: String, id: Int ) {
  override def toString(): String = {
    hst.toString + "_" + id.toString
  }
}


private object BanyanPrivate {

  import BanyanPublic.Tickets


//  type A = Any // the return type


  val err = System.err

  // conversion factor
  val rsrc2millis = 1

  var waitForReply: Long = 500 // milliseconds
  
  val timeBetweenAssessments = 2000 // milliseconds

  var thisHost = Host("localhost", 0)
  var coorHost = Host("localhost", 0)

  var shortName = "client"


//  type SessionID = Int


  def passNode(dest: Host, nodeID: NodeID): Unit = 
     nodeMap.get(nodeID) match {
      case Some(NodeHere(nd)) =>
        nodeMap.put(nodeID,NodeThere(dest, nodeID))
        val hst = getHostActor(dest)
        val nodeIDs = nd.parent :: nd.children
        val nodesToSend = 
          nodeIDs.map(x => nodeMap.get(x) match {
            case Some(NodeHere(nd)) => 
              NodeThere(thisHost, nd.nodeID)
            case Some(t) => t
            case None => throw new Error("bad ID: " + x)
          })
        for (nts <- nodesToSend) {             
          hst ! ('nodeUpdate, nts)
        }
       
        addTotalTickets(-nd.checkTickets())
        err.println("about to pass node to " + dest);
        hst ! ('passNode, nodeID, nd);
        err.println("success!" );
      case _ =>
        throw new Error("cannot pass " + nodeID + " to " + dest)
    }



  // return the number of messages we passed to other workers
  def passSubtree(dest: Host, nodeID: NodeID, forward: Boolean): Int = 
    nodeMap.get(nodeID) match {
      case Some(NodeHere(nd)) =>
        passNode(dest, nodeID)
        var numPasses = 0
        for(c <- nd.children){
          numPasses += passSubtree(dest, c, forward)
        }
        numPasses
      case Some(NodeThere(hst,id)) =>
        if(forward){
          val actr = getHostActor(hst)
          println("telling " + hst + " to collect subtree of " + id)
          actr ! ('collectSubtree, dest, id)
          1
        }else{
          0
        }
      case None =>
        throw new Error("node does not exist: " + nodeID)
        
    }




  var nodeNum: Int = 0

  class SyncMap[A,B] 
    extends scala.collection.mutable.HashMap[A, B]
    with scala.collection.mutable.SynchronizedMap[A,B]
    

  val nodeMap = new SyncMap[NodeID, NodeSomewhere]()  

  val nodeLock = new Lock()

  
  val hostTable = 
    new SyncMap[Host, AbstractActor]()
  
  def getHostActor(hst: Host): AbstractActor 
  = hostTable.get(hst) match {
    case Some(actr) => actr
    case None => // open a connection
      val newActor = select(hst, 'worker)
      hostTable.put(hst, newActor)
      newActor
  }


  // the sum of the tickets held by the nodes here.
  var totalTickets: Tickets = 0


  // how many tickets want want here 
  var ticketsTarget: Tickets = 0



  val totalTicketsLock = new Lock()

  def addTotalTickets(t: Tickets): Unit = {
    totalTicketsLock.synchronized{
      totalTickets += t
    }
  }

  def getTotalTickets(): Tickets = {
    totalTicketsLock.synchronized{
      totalTickets
    }
  }
  

  def newNodeID(): NodeID = {
    val r = nodeNum
    nodeNum = nodeNum + 1
//    NodeID(thisHost,r)
    NodeID(shortName, r)
  }

  val unregisteredNode: NodeID = NodeID(shortName, 0)


  def hostData: HostData = HostData(getTotalTickets(), ticketsTarget)

  val tasks = new scala.collection.mutable.SynchronizedQueue[NodeID]()



//  val newChildren = new scala.collection.mutable.HashSet[TreeNode]()



  // send a message to an node
  def sendMessage(nodeID: NodeID, msg: Any): Unit = 
    nodeMap.get(nodeID) match {
      case Some(NodeHere(nd)) => 
        nd.handleMessage(msg)
      case Some(NodeThere(hst, id)) =>
        val actr = getHostActor(hst)
        actr ! ('message, nodeID, msg)
      case None =>
        throw new Error("could not find node " + nodeID)
    }

  def sendTickets(nodeID2: NodeID, rsrc: Tickets)
     : Unit = 
    nodeMap.get(nodeID2) match {
      case Some(NodeHere(nd)) => 
        nd.receiveTickets(rsrc)
      case Some(NodeThere(hst, id)) =>
        addTotalTickets(-rsrc)
        val actr = getHostActor(hst)
        actr ! ('tickets, nodeID2, rsrc)
      case None =>
        throw new Error("could not find node " + nodeID2)
    }                      


  def makeIrrelevant(nodeID: NodeID): Unit = 
    nodeMap.get(nodeID) match {
      case Some(NodeHere(nd)) => 
        nd.becomeIrrelevant()
      case Some(NodeThere(hst, id)) =>
        val actr = getHostActor(hst)
        actr ! ('makeIrrelevant, nodeID)
      case None =>
        throw new Error("could not find node " + nodeID)      
    }

  def returnFromChild(child: NodeID, parent: NodeID, v: Any): Unit = 
    nodeMap.get(parent) match {
      case Some(NodeHere(nd)) => 
        nd.childReturnedAux(child,v)
      case Some(NodeThere(hst, id)) =>
        val actr = getHostActor(hst)
	println("about to send return msg: " 
                + child.toString, id.toString, v.toString)
        actr ! ('childReturned, child, id,  v)
        println("message sent")
      case None =>
        throw new Error("could not find node " + parent)      
    }

}

object BanyanPublic {

  import BanyanPrivate.shortName
  import BanyanPrivate.nodeNum
  import Printing.TreeData


  type Tickets = Long

  def getShortName : String = 
    shortName


  def data_TreeNode(nd: TreeNode): String = {
    val td = new TreeData()
    td.addTreeNode(nd, true)
    td.toString()
  }

}

object BanyanClient {

  import BanyanPrivate._
  import BanyanPublic._

  class Ticker(tm: Long, actr: Actor) extends Actor {
    def act() = {
      loop {
        reactWithin(tm) {
          case TIMEOUT => 
            actr! 'tick
        }
      }
    }
  }


  def setCoordinator(addr: String, prt: Int): Unit = {
      coorHost = new Host(addr, prt)
      err.println("Coordinator: "+coorHost.toString)
  }

  def setLocalPort(prt: Int): Unit = {
      thisHost = new Host(InetAddress.getLocalHost().getHostAddress(), prt)
      err.println("this host: "+thisHost.toString)
  }

  def getRootParent() : NodeID = {
    val nnm = newNodeID()
    val stub = NodeThere(thisHost, nnm)
    nodeMap.put(nnm, stub)

    nnm
  }

  def startRoot(rootNode: TreeNode) : Unit = {

    rootNode.parent = getRootParent()
    rootNode.nodeID = newNodeID()

    
    nodeMap.put(rootNode.nodeID, NodeHere(rootNode))

    actor {
      alive(thisHost.port)
      //register as worker to trick banyan into notifying me
      register('worker, self)

      //wait 5 seconds
      val coor = select(coorHost, 'coordinator)
      val rep = coor !? (5000, ('registerClient))
      rep match {
        case None =>
	  throw new Error("could not connect to coordinator")
	case Some(('registered, tks: Tickets, hst: Host)) =>
	  rootNode.receiveTickets(tks)
	  err.println("rootNode has tickets: "+ rootNode.checkTickets())
          err.println("passing  the root node")
          passNode(hst, rootNode.nodeID)
      }

      val rootNodeID = rootNode.nodeID
      var idleTicks = 0
      var exitOnTicks: Option[Int] = None
      val ticker = new Ticker(1000, self)
      ticker.start

      val startTime: Long = System.currentTimeMillis()
      var totalTime: Long = 0

      while(true) {
       receive {
         case 'quit =>
           err.println("quitting")
           nodeMap.get(rootNodeID) match {
             case Some(NodeHere(nd)) =>
                 println(data_TreeNode(nd))
                 println("\n totalTime = " + totalTime)
                 exit
               case Some(NodeThere(hst,id)) =>
                 err.println("rootNode is not here")
               case None =>
                 err.println("rootNode no longer exists")
           }
           exit

         case 'tick =>
//          println("tick received")
           idleTicks += 1
//           println("idleTicks = " + idleTicks)
           exitOnTicks match {
             case None => 
             case Some(n) =>
               err.println("will quit after " + (n - idleTicks) + " more ticks")
               if (idleTicks >= n) {
                 coor ! 'quit
                 self ! 'quit
               }
           }

         case ('childReturned, child: NodeID, id: NodeID, v: Any) =>
           err.println("ROOT ACTOR. DONE.")
           err.println("child value = " + v)
           err.println()
           totalTime = System.currentTimeMillis() - startTime
     
           err.println("sending out collectSubtree message")
           nodeMap.get(child) match {
             case Some(NodeThere(hst, nodeID)) =>
               val actr= getHostActor(hst)
               actr ! ('collectSubtree, thisHost, child)
               // wait long enough for everyone to get back

             case Some(NodeHere(nd)) =>
               println("the client should not have this node: " + nd)
               println("meh. no biggie. ")
             case None =>
               throw new Error("don't know where to find " + child)
           }
          exitOnTicks = Some(3)
          idleTicks = 0
          
           
         case ('passNode, nodeID: NodeID, nd: TreeNode) =>
           err.println("got node " + nodeID)
           nodeMap.put(nodeID, NodeHere(nd))
           idleTicks = 0 

         case msg =>
           idleTicks = 0
           err.println("ROOT ACTOR. GOT MESSAGE: " + msg)
       }
     }

   }

  }


}





@serializable
abstract class TreeNode() {


  import BanyanPublic._
  import BanyanPrivate._



/* The implementor defines these five methods */
  def workHere(timeslice: Long  ): Unit
  def timeout(): Unit
  def abort(): Unit
  def childReturned(child: Int, childValue: Any): Unit
  def handleMessage(msg: Any): Unit


  var nodeID: NodeID = unregisteredNode
  private [banyan] var parent: NodeID = unregisteredNode


  private val ticketLock = new Lock()

  val statusLock = new Lock()

  private [banyan] var children: List[NodeID] = Nil
  protected [banyan] var childrenStatus: List[Ref[Status]] = Nil

  // tickets held by all nodes here
  private var tickets: Tickets = 0


  private [banyan] var status: Status = Working()


  private [banyan] var returnValue: Option[Any] = None
  
  protected [banyan] var timeSlicesToUse = 1
  private [banyan] var timeSlicesSaved = 0
  private [banyan] var ticketsSaved: Tickets = 0

  private [banyan] var timeSpentHere: Long = 0 // milliseconds
  private [banyan] var workLog: List[LogData] = List(LogData("", 0, 0))

  def mainBranch() : Boolean = true

  val sentTickets = new scala.collection.mutable.HashSet[(NodeID,Tickets)]

  private [banyan] var inTaskQueue: Boolean = false

  protected [banyan] def updateWorkLog(tm: Long) : Unit = {
    var here = BanyanPublic.getShortName
 /*   var entry = workLog.last
    var replace : Boolean = false
    entry = if(entry.worker == here) {
    	      replace = true
	      entry
	    }
    	    else LogData(here, 0, 0)
   // println("entrycountbefore: "+entry.count)
    entry.count += 1
   // println("entrycountafter: "+entry.count)
    entry.time += tm
 
    if(replace) workLog = workLog.dropRight(1)
    workLog = workLog ++ List(entry)

 */
    workLog = workLog ++ List(LogData(here, 1, tm))
  }

  // N.B. changed private to protected so that MemoTreeNode could see it
  protected def childIDtoInt (childID: NodeID): Int = {
    for( i <- children.indices){
      if(children(i) == childID) return i
    }
    println(children)
    throw new Error("child " + childID + " not found")
  }

  private [banyan] final def childReturnedAux
  	    	     (childID: NodeID, childValue: Any): Unit = {
   statusLock.synchronized{
    val i = childIDtoInt(childID)
    childrenStatus(i).set( Returned())
    childReturned(i, childValue)
   }
  }


  final def makeChildIrrelevant(child: Int): Unit = {
    makeIrrelevant(children(child))
  }

  final def makeOpenChildrenIrrelevant(): Unit = {
    for(c <- children.indices) childrenStatus(c).get match {
      case Working() =>
        makeChildIrrelevant(c)
      case _  => ()
    }

  }

  private [banyan] def becomeIrrelevant(): Unit = {
   abort()
   statusLock.synchronized{
    if(status == Working()) {
      status = Irrelevant()
    }
    for(childID <- children){
      makeIrrelevant(childID)
    }
   }    
  }


  // final 
  def returnNode(v: Any): Unit = {
   statusLock.synchronized{
    this.status = Returned()
    this.returnValue = Some(v)
    makeOpenChildrenIrrelevant()
    transferTickets(Parent(), tickets)
    returnFromChild(this.nodeID, parent, v)
   }
  }


  final def checkReturnValue(): Any = {
    returnValue
  }

  final def checkTickets(): Tickets = {
    ticketLock.synchronized{
      tickets
    }
  }


  // can only transfer as many tickets as are here
  final def transferTickets(rel: Relative, rsrc: Tickets): Unit = {
   ticketLock.synchronized {
    val ttt = if(rsrc > tickets) tickets else rsrc
    rel match {
      case Parent() => 
        tickets -= ttt
        sentTickets += ((parent, ttt))
        sendTickets(parent, ttt)
      case Child(n) =>
        if(n < children.length){
          tickets -= ttt
          sentTickets += ((children(n), ttt))
          sendTickets(children(n), ttt)
        } else {
          println("bad ticket transfer")
          throw new 
            Error("tried to transfer tickets to nonexistent child + " + n)
        }
    }
   }
  }

  private [banyan] def receiveTickets(rsrc: Tickets): Unit = {
    if( !inTaskQueue){
      inTaskQueue = true
      tasks += nodeID
    }
    ticketLock.synchronized {
      tickets += rsrc
    }
  }
  
  final def checkStatus(): Status = {
    status
  }

  def sendMessageTo(rel: Relative, msg: Any): Unit = rel match {
    case Parent() =>
      sendMessage(parent, msg)
    case Child(n) =>
      sendMessage(children(n),msg)
  }
  


  def newChild(child: TreeNode): Unit = {
    val childID = newNodeID()
    child.nodeID = childID
    child.parent = this.nodeID
//    println("I am " + this.nodeID)
//    println("my child is " + child.nodeID)
//   println("my child's parent is " + child.pt)

    children = childID:: children
    childrenStatus = (new Ref[Status](Working())) :: childrenStatus

    nodeMap.put(childID,NodeHere(child))
    tasks += childID


  }

  final def numChildren() : Int = {
    children.length
  }


  // for visualization. meant to be overridden.
  def colorMain() : String = status match {
    case Working() =>
      "white"
    case Returned() =>
      "blue"
    case Irrelevant() =>
      "gray"
  }

  // the color of non-main-branch nodes
  def color() : String = {
      "gray"
  }



  def shape() : String = "circle"


}
