(* Written by Erica Sponsler and Nate Weiss *)

open Sast
open Bytecode

module StringMap = Map.Make(String)

(* Symbol table: Information about all the names in scope *)
type env = {
    function_index : (int * Ast.obj_type) StringMap.t; (* Index for each function *)
    global_index   : (int * Ast.obj_type) StringMap.t; (* "Address" for global variables *)
    node_heap_index : (int * Ast.obj_type) StringMap.t; (* "Address" for node storage *)
    local_index    : (int * Ast.obj_type) StringMap.t; (* FP offset for args, locals *)
    number_of_locals : int; (* number of locals *)
  }

type glh = Glob | Loc | Heap 

let size_of = function
    Ast.IntType ->  4
  | Ast.CharType -> 1
  | Ast.BooleanType -> 1
  | Ast.VoidType -> 1
  | Ast.NodeType(_) -> 4
  | Ast.NullType -> 4
  
(* val enum : int -> 'a list -> (int * 'a) list *)
let rec enum stride n = function
    [] -> []
  | hd::tl -> (n, hd) :: enum stride (n+stride) tl

let snd = function
    (x, y) -> y

let fst = function
    (x, y) -> x

let inFunc = false

let rec remove_print = function
    [] -> [];
  | hd::tl -> if((compare "print" hd.ffname) == 0) then (remove_print tl) else (hd :: (remove_print tl))

let rec find_base_type env = function
    Id s -> 
      (try let id_pair  = (StringMap.find s.vvname env.local_index) in
	  (snd id_pair)
          with Not_found -> try let id_glob_pair = (StringMap.find s.vvname env.global_index) in
	  (snd id_glob_pair)
          with Not_found -> raise (Failure ("undeclared variable " ^ s.vvname)))
  | Unop(l,op) -> (find_base_type env (fst l))

let rec find_this_type base_type = function
    Id s -> (match base_type with
      Ast.NodeType(t) -> t
    | _ -> raise (Failure "base_type is not a node in find_this_type Id"))
  | Unop(l,op) -> (match base_type with
      Ast.NodeType(t) -> (match op with
	ValueOf -> (find_this_type t (fst l))
      | Child(exp) -> (find_this_type base_type (fst l)))
    | _ -> raise (Failure "base_type is not a node in find_this_type Id"))

let find_type env l =
    (find_this_type (find_base_type env l) l)

let heap_glob_or_loc env = function
    Id s -> 
       (try (ignore (StringMap.find s.vvname env.local_index));
	  Loc
          with Not_found -> try (ignore (StringMap.find s.vvname env.global_index));
	  Glob
          with Not_found -> raise (Failure ("undeclared variable " ^ s.vvname)))
  | Unop(l, op) -> Heap

(* val string_map_pairs StringMap 'a -> (int * 'a) list -> StringMap 'a *)
let string_map_pairs map pairs =
  List.fold_left (fun m (i, (n, t)) -> StringMap.add n (i, t) m) map pairs

(** Translate a program in SAST form into a bytecode program.  Throw an
    exception if something is wrong, e.g., a reference to an unknown
    variable or function *)
let translate (globals, functions) =

  (* Allocate "addresses" for each global variable *)
  let global_indexes = string_map_pairs StringMap.empty (enum 1 0 (List.map (fun v -> (v.vvname, v.vvtype)) globals)) in

  (* Assign indexes to function names; built-in "print" is special *)
  let built_in_functions = StringMap.add "print" (-1, Ast.IntType) StringMap.empty in
  let function_indexes = string_map_pairs built_in_functions
      (enum 1 1 (List.map (fun f -> (f.ffname, f.fftype)) (remove_print functions))) in

  (* Translate a function in SAST form into a list of bytecode statements *)
  let translate env fdecl =
    (* Bookkeeping: FP offsets for locals and arguments *)
    let num_formals = List.length fdecl.fformals
    and num_locals = List.length fdecl.flocals
    and local_offsets = enum 1 1 (List.map (fun v -> (v.vvname, v.vvtype)) fdecl.flocals)
    and formal_offsets = enum (-1) (-2) (List.map (fun v -> (v.vvname, v.vvtype))
    fdecl.fformals) in
    let env = { env with local_index = string_map_pairs
		  StringMap.empty (local_offsets @ formal_offsets); number_of_locals = num_locals } in
    let loc_env = env in
    let rec l_value_helper lenv = function
        Id s -> 
	  (try let id_pair  = (StringMap.find s.vvname lenv.local_index) in
	  [LitI (fst id_pair)] @ [LitI 37] @ [Lfp]
          with Not_found -> try let id_glob_pair = (StringMap.find s.vvname lenv.global_index) in
	  [LitI (fst id_glob_pair)] @ [LitI 56] @ [Lod]
          with Not_found -> raise (Failure ("undeclared variable " ^ s.vvname)))
      | Unop(l, op) -> (match op with
	  ValueOf -> (l_value_helper lenv (fst l)) @ [LitI (-1)] @ [Ldh] 
	| Child(exp) -> (l_value_helper lenv (fst l)) @ (simple_expr lenv (fst exp)) @ [Ldh])

    and

	l_value lenv lval from = (match lval with
	  Id s -> 
	    (try let id_pair  = (StringMap.find s.vvname lenv.local_index) in
	     [LitI (fst id_pair)]
	    with Not_found -> try let id_glob_pair = (StringMap.find s.vvname lenv.global_index) in
	     [LitI (fst id_glob_pair)]
            with Not_found -> raise (Failure ("undeclared variable " ^ s.vvname)))
        | Unop(l, op) -> (match op with
	    ValueOf -> (l_value_helper lenv (fst l)) @ [LitI (-1)] 
	  | Child(exp) -> (l_value_helper lenv (fst l)) @ (simple_expr lenv (fst exp))))

    and

    simple_expr lenv = function
	Literal c -> (match c with
           Ast.Integer(i) -> [LitI i]
         | Ast.Character(ch) -> [ LitC ch ]
         | Ast.Boolean(b) -> [ LitB b ]
         | Ast.Null -> [ LitNull ]
           )
      | Binop (e1, op, e2) -> (simple_expr lenv (fst e1)) @ (simple_expr lenv (fst e2)) @ [Bin op]
      | Assign (l, e) -> (simple_expr lenv (fst e)) @ (l_value lenv (fst l) (Assign(l,e))) @ (match (heap_glob_or_loc lenv (fst l)) with
	  Glob -> [Str]
	| Loc -> [Sfp]
	| Heap -> [Sth])
	  
      | Call (fname, actuals) -> (try
	  (List.concat (List.map (fun a -> (simple_expr lenv (fst a))) (List.rev actuals))) @ (match fname.ffname with
	    "print" -> (match (List.hd fname.fformals).vvtype with
	      Ast.IntType -> [Jsr (-1)]
	    | Ast.BooleanType -> [Jsr (-2)]
	    | Ast.CharType -> [Jsr (-3)]
	    | _ -> raise (Failure ("Cannot print this type.")))
	  | _ -> [Jsr (fst (StringMap.find fname.ffname env.function_index)) ])   
        with Not_found -> raise (Failure ("undefined function " ^ fname.ffname)))
      | Neg(e) -> [LitI 0] @ (simple_expr lenv (fst e)) @ [Bin Ast.Sub]
      | Bang(e) -> [LitI 1] @ (simple_expr lenv (fst e)) @ [Bin Ast.Sub]
      | Node(e) -> (simple_expr lenv e) @ [Cnd]
      | LValue(l) -> (l_value lenv (fst l) (LValue(l))) @ (match (heap_glob_or_loc lenv (fst l)) with
	  Glob -> [LitI 8] @ [Lod]
	| Loc -> [LitI 7] @ [Lfp]
	| Heap -> (*[LitI 4] @*) [Ldh])
      | Noexpr -> []

    in let rec stmt lenv = function
	Block (vars, sl) -> if (inFunc)
	then ((ignore (inFunc = false)); (List.concat (List.map (fun a -> (stmt lenv a)) sl)))
	else let new_env = {lenv with local_index = (string_map_pairs lenv.local_index (enum 1 (lenv.number_of_locals + 1) (List.map (fun v -> (v.vvname, v.vvtype)) vars))); number_of_locals = lenv.number_of_locals + (List.length vars) } in [LitI (List.length vars)] @ [Ssp] @ List.concat (List.map (fun a -> (stmt new_env a)) sl) @ [Rsp]
      | Expr e       -> (simple_expr lenv (fst e)) @ [Drp (size_of (snd e))]
      | Return e     -> (simple_expr lenv (fst e)) @ [Rts num_formals]
      | If (p, t, f) -> let t' = (stmt lenv t) and f' = (stmt lenv f) in
	(simple_expr lenv (fst p)) @ [Beq(2 + List.length t')] @
	t' @ [Bra(1 + List.length f')] @ f'
      | While (e, b) ->
	  let b' = (stmt lenv b) and e' = (simple_expr lenv (fst e)) in
	  [Bra (1+ List.length b')] @ b' @ e' @
	  [Bne (-(List.length b' + List.length e'))]

    in ((ignore (inFunc = true));[Ent num_locals] @      (* Entry: allocate space for locals *)
    (stmt loc_env (Block([](*fdecl.fformals*), fdecl.fbody))) @  (* Body *)
    [LitI 0; Rts num_formals])   (* Default = return 0 *)

  in let env = { function_index = function_indexes;
		 global_index = global_indexes;
		 node_heap_index = StringMap.empty;
		 local_index = StringMap.empty; 
	         number_of_locals = 0} in

  (* Code executed to start the program: Jsr main; halt *)
  let entry_function = try
    [Jsr (fst (StringMap.find "root" function_indexes)); Hlt]
  with Not_found -> raise (Failure ("no \"main\" function - sincerely, compile.ml"))
  in
    
  (* Compile the functions *)
  let func_bodies = entry_function :: List.map (translate env) functions in

  (* Calculate function entry points by adding their lengths *)
  let (fun_offset_list, _) = List.fold_left
      (fun (l,i) f -> (i :: l, (i + List.length f))) ([],0) func_bodies in
  let func_offset = Array.of_list (List.rev fun_offset_list) in

  { num_globals = List.length globals;
    (* Concatenate the compiled functions and replace the function
       indexes in Jsr statements with PC values *)
    text = Array.of_list (List.map (function
	Jsr i when i > 0 -> Jsr func_offset.(i)
      | _ as s -> s) (List.concat func_bodies))
  }
