(* ColorNest *)
(* --- Perry --- *)

(* Color trick from handout and more! *)

structure ColorNest :> COLORNEST =
  struct
      open Il
      structure M = Meaning

      exception Error of string

      datatype color = Color of int | NoColor

      fun printColor (Color i) = (print "Col"; print (Int.toString i))
	| printColor NoColor = print "None"

      fun isColorToken (Token.Clr _) = true
	| isColorToken _ = false
      fun isNotColorToken (Token.Clr _) = false
	| isNotColorToken _ = true

      (* Compute outermost color and return expression with redundant color tags removed *)
      fun assignColor (e as (Text _)) = (NoColor, e)
	| assignColor (e as (Seq _)) = (NoColor, e)
	| assignColor (Tag (tags, e)) = 
	  let fun search ([],acc) = let val (col,e) = assignColor e
				    in  (col, Tag(tags, e))
				    end
		| search ((ctag as Token.Clr c)::rest,acc) = 
				    let val tags = (List.filter isNotColorToken (rev rest)) @ (ctag :: acc)
				    in  (Color c, Tag (tags, e))
				    end
		| search (ctag::rest,acc) = search (rest, ctag::acc)
	  in  search (rev tags, [])
	  end

      fun removeColor (Tag (tags, e)) = 
	  let val colors = List.filter isColorToken tags
	      val others = List.filter isNotColorToken tags
	      val _ = (case colors of
			   [_] => ()
			 | _ => raise (Error "internal error in ColorNest"))
	  in  (case others of 
		   [] => e
		 | _ => Tag (others, e))
	  end
	| removeColor _ = raise (Error "internal error in ColorNest")

      fun show_col_e_list col_e_list = (app (fn (c,e) => (printColor c;
							  print "  ";
							  print (Print.treestructure e);
							  print "\n")) col_e_list)

      fun opt (e as (Text _)) = e
	| opt (e as (Tag _)) = e
	| opt (Seq elist) = 
	  let val col_e_list = map assignColor elist
	      val len = length col_e_list
	      fun rewrite (commonColor,leftPos,rightPos) = 
		  let val prefix = List.take(col_e_list, leftPos)
		      val suffix = List.drop(col_e_list, rightPos + 1)
		      val left = List.nth(col_e_list, leftPos)
		      val right = List.nth(col_e_list, rightPos)
		      val middle = List.take(List.drop(col_e_list, leftPos + 1), rightPos - leftPos - 1)
		      (*
		      val _ = (print "Rewrite: "; 
			       print (Int.toString leftPos);
			       print "  ";
			       print (Int.toString rightPos);
			       print "\n";
			       print (Print.treestructure (Seq elist)); 
			       print "\n\n";
			       print "  prefix:\n"; show_col_e_list prefix;
			       print "  left:\n"; show_col_e_list [left];
			       print "  middle:\n"; show_col_e_list middle;
			       print "  right:\n"; show_col_e_list [right];
			       print "  suffix:\n"; show_col_e_list suffix;
			       print "\n\n\n")
		      *)
		      val newMiddle = Tag([commonColor],
					  Seq([removeColor(#2 left)] @ 
					      [opt (Seq (map #2 middle))] @ 
					      [removeColor(#2 right)]))
		      val result = Seq ((map #2 prefix) @ [newMiddle] @ (map #2 suffix))
		  in  result
		  end
	      fun scan leftPos =           (* Try leftPos element as a left match *)
		  if (leftPos >= len - 1)
		      then Seq elist       (* Failed to find any match *)
		  else let val (leftCol, leftE) = List.nth (col_e_list, leftPos)
			   fun seek rightPos =             (* Look at elements backwards for right match *)
			       if (rightPos <= leftPos + 1)      (* there must be at least one element between *)
				   then scan (leftPos + 1) (* try next left match *)
			       else 
				   let val (rightCol, rightE) = List.nth(col_e_list, rightPos)
				   in  if (leftCol = rightCol)
					   then let val Color commonCol = leftCol
						    val middle = List.take(List.drop(col_e_list, leftPos + 1), 
									   rightPos - leftPos - 1)
						    fun hasColor (NoColor,_) = false
						      | hasColor _ = true
						in  if (List.all hasColor middle)
							then rewrite(Token.Clr commonCol, leftPos, rightPos)
						    else seek (rightPos - 1)
						end
				       else seek (rightPos - 1)
				   end
		       in  (case leftCol of
				NoColor => scan (leftPos + 1)
			      | _ => seek (len - 1))
		       end
	      (*
	      val _ = (print "opt: "; 
		       print (Print.treestructure (Seq elist));
		       print "\n")
    	      *)
	      val result = scan 0
	      (* val _ = (print "result: "; print (Print.treestructure result); print "\n\n") *)
	  in  result
	  end

      val optimize = opt

      val _ = Driver.register_opt {optfun = Driver.Iltoil optimize,
				   optname = "Color Nesting",
                                   scorefn=(fn x=>x),
				   trusted=false,
                                   initscore=NONE,
				   blowup=false,
				   disabled=false}
  end
