open Ast
open Bytecode

module StringMap = Map.Make(String)

(* Symbol table: Information about all the names in scope *)
type env = {
    function_index : int StringMap.t; (* Index for each function *)
    local_index    : int StringMap.t; (* FP offset for args, locals *)
  }

(* val enum : int -> 'a list -> (int * 'a) list *)
(* takes a normal list and rebuilds it with stack indices as fst of pair *)
let rec enum stride n = function
    [] -> []
  | hd::tl -> (n, hd) :: enum stride (n+stride) tl
  
(* get only the names of vars from var types *)
let rec get_names vars = List.map (fun v -> v.varName) vars

(* val string_map_pairs StringMap 'a -> (int * 'a) list -> StringMap 'a *)
let string_map_pairs map pairs = 
  List.fold_left (fun m (i, n) -> StringMap.add n i m) map pairs

(* Translate a program in AST form into a bytecode program.  Throw an
    exception if something is wrong, e.g., a reference to an unknown
    variable or function *)
let translate (others, functions) =

  (* Assign indexes to function names *)
  let function_indexes = string_map_pairs StringMap.empty
      (enum 1 1 (List.map (fun f -> f.fname) functions)) in

  (* Translate a function in AST form into a list of bytecode statements *)
  let translate env fdecl =
    (* Bookkeeping: FP offsets for locals and arguments *)
    let num_params = List.length fdecl.params
    and num_locals = List.length fdecl.locals
    and local_offsets = enum 1 1 (get_names fdecl.locals)
    and param_offsets = enum (-1) (-2) (get_names fdecl.params) in
    let env = { env with local_index = string_map_pairs
		  StringMap.empty (local_offsets @ param_offsets) } in
(*

Define ast members

*)
let typ = function
	Integer(i)-> Ast.Int
	| String(s) -> Ast.Str
	| Array(a) -> Ast.Arr
	| Id(i) -> Ast.Any
	| Null -> Ast.Any

in let rec term = function
	| Id id -> (try [Lod (StringMap.find id env.local_index)]
		with Not_found -> raise (Failure ("cannot reference undeclared variable " ^ id)))
	| Integer i -> [Int i]
	| String s -> [Str s]
	| Array a -> let l = List.length a in 
				if (l>0) then (
					let f = (if typ (List.hd a) = Ast.Str then 1 else 0) in
					(List.concat (List.map term a)) @ [Arr ((l), f)])
				else [Arr (0, 0)]
	| Null -> []
	
in let rec expr = function
	Term (a) -> term a
	| Call (fname, args) -> (try
	  (List.concat (List.map expr (List.rev args))) @
	  [Jsr (StringMap.find fname env.function_index) ]
        with Not_found -> raise (Failure ("undefined function " ^ fname)))
	| Range (id, a1, a2) -> term a1 @ term a2 @ term id @ [Opr(Ast.Rng, typ a2, 3)]  (* Opr op typ args *)
	| Binop (e, op, a) -> expr e @ term a @ [Opr (op, (typ a), 2)]
	| Unop (e, op) -> expr e @ [Opr (op, Ast.Int,1)]

in let rec stmt = function
	Block sl     ->  List.concat (List.map stmt sl)
	| Return e     -> expr e @ [Ret num_params]
	| If (p, t) -> let t' = stmt t in expr p @ [Beq(1 + List.length t')] @ t'	
    | While (e, b) ->
	  let b' = stmt b and e' = expr e in
	  [Bra (1+ List.length b')] @ b' @ e' @
	  [Bne (-(List.length b' + List.length e'))]
  	 | Assign (s, e) -> expr e @ [Sto (StringMap.find s.varName env.local_index)]
	 | SetAIndex(id, i, e) -> expr e @ term i @ term id @ [Opr(Ast.Sai, Ast.Arr , 3)] 
	 | Output(t, e) -> expr e @ [Opr (Ast.Out, (typ t), 1)]

    in [Ent num_locals] @      (* Entry: allocate space for locals *)
    stmt (Block fdecl.body) @  (* Body *)
    [Int 0; Ret num_params]   (* Default = return 0 *)

  in let env = { function_index = function_indexes;
		 local_index = StringMap.empty } in

  (* Code executed to start the program: Jsr main; halt *)
  let entry_function = try
    [Jsr (StringMap.find "MAIN" function_indexes); Hlt]
  with Not_found -> raise (Failure ("no \"MAIN\" function"))
  in
    
  (* Compile the functions *)
  let func_bodies = entry_function :: List.map (translate env) functions in

  (* Calculate function entry points by adding their lengths *)
  let (fun_offset_list, _) = List.fold_left
      (fun (l,i) f -> (i :: l, (i + List.length f))) ([],0) func_bodies in
  let func_offset = Array.of_list (List.rev fun_offset_list) in

  { 
    (* Concatenate the compiled functions and replace the function
       indexes in Jsr statements with PC values *)
    text = Array.of_list (List.map (function
	Jsr i when i > 0 -> Jsr func_offset.(i)
      | _ as s -> s) (List.concat func_bodies))
  }

