open Ast
open Bytecode

module StringMap = Map.Make(String)

(* Symbol table: Information about all the names in scope *)
type env = {
   function_index : int StringMap.t; (* Index for each function *)
   global_index   : int StringMap.t; (* "Address" for global variables *)
   local_index    : int StringMap.t; (* FP offset for args, locals *)
}

(* val enum : int -> 'a list -> (int * 'a) list *)
let rec enum stride n = function
   | [] -> []
   | hd::tl -> (n, hd) :: enum stride (n+stride) tl

let sizeoftype = function
   | Literalt -> 1
   | Pointt -> 2
   | Curvet -> 8
   | Layert -> 81

let sizereq acc l = 
      List.fold_left (fun ac el -> ac + (sizeoftype el.t))
         acc l

let rec multiLfp index i n = 
   if i = n then 
      []
   else
      [Lfp (index + i)] @ (multiLfp index (i+1) n)
   
let rec multiLod index i n = 
   if i = n then 
      []
   else
      [Lod (index + i)] @ (multiLod index (i+1) n)

let rec multiSfp index i n = 
   if i = n then 
      []
   else
      [Drp] @ [Sfp (index + i)] @ (multiSfp index (i+1) n)

let rec multiStr index i n = 
   if i = n then 
      []
   else
      [Drp] @ [Str (index + i)] @ (multiStr index (i+1) n)

(* enum for variables, dir is direction *)
let rec enum_var dir n = function
   | [] -> []
   | hd:: tl -> let sz = 
      sizeoftype hd.t in
      if dir < 0 then
         (n-sz+1, hd.name) :: enum_var dir (n + dir*sz) tl
      else
         (n, hd.name) :: enum_var dir (n + dir*sz) tl
         

(* val string_map_pairs StringMap 'a -> (int * 'a) list -> StringMap 'a *)
let string_map_pairs map pairs =
   List.fold_left (fun m (i, n) -> StringMap.add n i m) map pairs

let vdeclmap map vdlist = 
   List.fold_left (fun m v -> StringMap.add v.name v m) map vdlist

let fdeclmap map fdlist = 
   List.fold_left (fun m f -> StringMap.add f.fname f.return m) map fdlist

(** Translate a program in AST form into a bytecode program.  Throw an
    exception if something is wrong, e.g., a reference to an unknown
    variable or function *)
let translate (globals, functions) =

   (* Allocate "addresses" for each global variable *)
   let global_indexes = string_map_pairs StringMap.empty 
      (enum_var 1 0 globals) in
   
   let funcdmap = fdeclmap StringMap.empty functions in
   let funcdmap = StringMap.add "print" (Literalt) funcdmap in
   let funcdmap = StringMap.add "draw" (Layert) funcdmap in
   let funcdmap = StringMap.add "pause" (Literalt) funcdmap in
   let funcdmap = StringMap.add "clear" (Literalt) funcdmap in
   let funcdmap = StringMap.add "random" (Literalt) funcdmap in
   let funcdmap = StringMap.add "getX" (Literalt) funcdmap in
   let funcdmap = StringMap.add "getY" (Literalt) funcdmap in
   let funcdmap = StringMap.add "setX" (Literalt) funcdmap in
   let funcdmap = StringMap.add "setY" (Literalt) funcdmap in
   let funcdmap = StringMap.add "getPoint" (Pointt) funcdmap in
   let funcdmap = StringMap.add "setPoint" (Literalt) funcdmap in
   let funcdmap = StringMap.add "getCurve" (Curvet) funcdmap in
   let funcdmap = StringMap.add "getSize" (Literalt) funcdmap in
   let funcdmap = StringMap.add "setCurveh" (Curvet) funcdmap in
   let funcdmap = StringMap.add "setCurve" (Literalt) funcdmap in
   let globvdmap = vdeclmap StringMap.empty globals in

   (* Assign indexes to function names; built-in "print" is special *)
   let graphics_functions = StringMap.add "print" (-1) StringMap.empty in
   let graphics_functions = StringMap.add "draw" (-2) graphics_functions in
   let graphics_functions = StringMap.add "pause" (-3) graphics_functions in
   let graphics_functions = StringMap.add "clear" (-4) graphics_functions in
   let graphics_functions = StringMap.add "random" (-5) graphics_functions in
   let built_in_functions = StringMap.add "getX" (1) graphics_functions in
   let built_in_functions = StringMap.add "getY" (2) built_in_functions in
   let built_in_functions = StringMap.add "getPoint" (3) built_in_functions in
   let built_in_functions = StringMap.add "setPoint" (4) built_in_functions in
   let built_in_functions = StringMap.add "getCurve" (5) built_in_functions in
   let built_in_functions = StringMap.add "getSize" (6) built_in_functions in
   let built_in_functions = StringMap.add "setCurveh" (7) built_in_functions in
   let function_indexes = string_map_pairs built_in_functions
      (enum 1 ((StringMap.cardinal built_in_functions) -
         (StringMap.cardinal graphics_functions) + 1) 
         (List.map (fun f -> f.fname) functions)) in

   (* Translate a function in AST form into a list of bytecode statements *)
   let translate env fdecl =
   (* Bookkeeping: FP offsets for locals and arguments *)
   let sz_formals = sizereq 0 fdecl.formals
   and sz_locals = sizereq 0 fdecl.locals
   and sz_ret = sizeoftype fdecl.return
   and local_offsets = enum_var 1 1 fdecl.locals
   and locvdmap = vdeclmap StringMap.empty (fdecl.locals@fdecl.formals)
   and formal_offsets = enum_var (-1) (-2) fdecl.formals in
   let rec sizeofexpr = function
      | Literal i -> 1
      | Dotop (id, fn, e) -> sizeofexpr (Call(fn,[Id(id)]@e))
      | Curve (a,b,c,d,e,f,g,h) ->  8
      | Point (a,b) -> 2
      | Layer lst -> 81
      | Id s->
         (try 
               sizeoftype (StringMap.find s locvdmap).t
         with Not_found -> try 
               sizeoftype (StringMap.find s globvdmap).t
         with Not_found -> raise (Failure ("Undeclared variable " ^ s)))
      | Binop (e1,op,e2) -> 1
      | Assign(s,e) -> sizeofexpr(Id(s))
      | Call(fname, arg) -> 
         (let rett = 
            (try
               (StringMap.find fname funcdmap)
            with Not_found -> 
               raise (Failure ("undefined function " ^ fname)))
         in sizeoftype rett)
      | Noexpr -> 1 in
   let env = { env with local_index = string_map_pairs
      StringMap.empty (local_offsets @ formal_offsets) } in

   let rec expr = function
      | Literal i -> [Lit i]
      | Point (a,b) -> expr a @ expr b
      | Curve (a,b,c,d,e,f,g,h) -> 
         expr a @
         expr b @
         expr c @
         expr d @
         expr e @
         expr f @
         expr g @
         expr h
      | Dotop (id, fn, e) -> expr (Call(fn,[Id(id)]@e))
      | Layer lst -> 
         (let rec blanks i n = 
            if i = n then
               []
            else
               [Lit (0)] @ (blanks (i+1) n) in
            blanks 0 (8*(10-(List.length lst)))) 
         @ List.concat (List.map (fun el -> expr (Id(el))) lst)
         @ [Lit (List.length lst)]
      | Id s ->
         (try 
            let start = (StringMap.find s env.local_index)
            and sz = sizeoftype (StringMap.find s locvdmap).t in
            multiLfp start 0 sz (* @ [Lit sz] *)
         with Not_found -> try 
            let startg = (StringMap.find s env.global_index)
            and sz = sizeoftype (StringMap.find s globvdmap).t in
            multiLod startg 0 sz (* @ [Lit sz] *)
         with Not_found -> raise (Failure ("Undeclared variable " ^ s)))
      | Binop (e1, op, e2) -> expr e1 @ expr e2 @ [Bin op]
      | Assign (s, e) -> 
            expr e @
            (try 
               let start = (StringMap.find s env.local_index) in
                  let sz = sizeoftype (StringMap.find s locvdmap).t in
               List.rev (List.tl (multiSfp start 0 sz))
            with Not_found -> try 
               let startg = (StringMap.find s env.global_index) in
                  let sz = sizeoftype (StringMap.find s globvdmap).t in
               List.rev (List.tl (multiStr startg 0 sz))
            with Not_found -> raise (Failure ("undeclared variable " ^ s)))
      | Call (fname, actuals) -> (match fname with
         | "setX" -> expr (Assign(
            (match (List.hd actuals) with
               | Id (s)-> s | _ -> ""), 
            Point(
               List.hd (List.tl actuals), Call("getY", [(List.hd actuals)])
               )))
         | "setY" -> expr (Assign(
            (match (List.hd actuals) with
               | Id (s)-> s | _ -> ""), 
            Point( 
               Call("getX",[(List.hd actuals)]),List.hd (List.tl actuals)
               )))
         | "getPoint" ->
            (* a b c d e f g h 8 8 a b *)
            let indindx = 
               (Binop(Binop(Literal(4),Sub,(List.hd (List.tl actuals))),
                  Mult,Literal(2))) in
            (try
               expr (List.hd actuals) @ 
               expr indindx @ expr indindx @
               [Jsr (StringMap.find fname env.function_index) ]   
               with Not_found -> 
                  raise (Failure ("undefined function " ^ fname))
            )
         | "setPoint" ->
            (* a b c d e f g h 8 8 n1 n2 *)
            let s = match (List.hd actuals) with
               | Id (s)-> s | _ -> "" in
            let indindx = 
               (Binop(Binop(Literal(4),Sub,(List.hd (List.tl actuals))),
                  Mult,Literal(2))) in
            (try
               expr (List.hd actuals) @ 
               expr indindx @ expr indindx @
               expr
               (List.hd (List.rev actuals))
               @
               [Jsr (StringMap.find fname env.function_index) ]   
               with Not_found -> 
                  raise (Failure ("undefined function " ^ fname)))
               @
               (try 
                  let start = (StringMap.find s env.local_index) in
                     let sz = sizeoftype (StringMap.find s locvdmap).t in
                  List.rev (List.tl (multiSfp start 0 sz))
               with Not_found -> try 
                  let startg = (StringMap.find s env.global_index) in
                     let sz = sizeoftype (StringMap.find s globvdmap).t in
                  List.rev (List.tl (multiStr startg 0 sz))
               with Not_found -> raise (Failure ("undeclared variable " ^ s))
               )
         | "getCurve" ->
            (try
               expr (List.hd actuals) @ 
               expr (List.hd (List.tl actuals)) @
               [Jsr (StringMap.find fname env.function_index) ]   
               with Not_found -> 
                  raise (Failure ("undefined function " ^ fname))
            )
         | "getSize" ->
            (try
               expr (List.hd actuals) @ 
               [Jsr (StringMap.find fname env.function_index) ]   
               with Not_found -> 
                  raise (Failure ("undefined function " ^ fname))
            )
         | "setCurveh" ->
            (try
               expr (List.hd actuals) @ 
               expr (List.hd (List.tl actuals)) @
               expr (List.hd (List.rev actuals)) @
               [Jsr (StringMap.find fname env.function_index) ]   
               with Not_found -> 
                  raise (Failure ("undefined function " ^ fname))
            )
            
         | "setCurve" ->
            expr (Assign(
            (match (List.hd actuals) with
               | Id (s)-> s | _ -> ""), 
            Call("setCurveh", actuals)))
         | _ ->
            (try
               (List.concat (List.map expr (List.rev actuals))) @
               [Jsr (StringMap.find fname env.function_index) ]   
               with Not_found -> 
                  raise (Failure ("undefined function " ^ fname))))
      | Noexpr -> []

   in let rec stmt = function
      | Block sl     -> List.concat (List.map stmt sl)
      | Expr e       -> (
         let rec drpsize a = match a with
            | Assign(s,e) -> 1
            | _-> sizeofexpr a in
         let rec drpe i ex = 
            if i = drpsize ex then
               []
            else
               [Drp] @ drpe (i+1) ex in
         expr e @ drpe 0 e)
      | Return e     -> 
         (let rec loadretval i exp sargs = 
            let sz = sizeofexpr exp in
            if i = sz then
               []
            else
               [Sfp (-sargs + sz -2 -i )] @ [Drp] @ 
               loadretval (i+1) exp sargs
         in
         expr e @ [Rta] @ (loadretval 0 e sz_formals) @ 
         [Rts (sz_formals - (sizeofexpr e) + 1)])
      | If (p, t, f) -> let t' = stmt t and f' = stmt f in
         expr p @ [Beq(2 + List.length t')] @
         t' @ [Bra(1 + List.length f')] @ f'
      | For (e1, e2, e3, b) ->
         stmt (Block([Expr(e1); While(e2, Block([b; Expr(e3)]))]))
      | While (e, b) ->
         let b' = stmt b and e' = expr e in
            [Bra (1+ List.length b')] @ b' @ e' @
            [Bne (-(List.length b' + List.length e'))]

   in 
      [Ent (max sz_locals sz_ret)] @(* Entry: allocate space for locals *)
      stmt (Block fdecl.body) @  (* Body *)
      [Rta; Lit 0; Rts sz_formals]   (* Default = return 0 *)

   in let env = { function_index = function_indexes;
      global_index = global_indexes;
      local_index = StringMap.empty } in

   (* Code executed to start the program: Jsr main; halt *)
   let entry_function = try
      [Ogr; Jsr (StringMap.find "main" function_indexes); Hlt]
      with Not_found -> raise (Failure ("no \"main\" function"))
   in
    
   (* Compile the functions *)
   let func_bodies = entry_function :: 
      [[Ent 0;Lfp (-3);Rta;Rts 2]] @
      [[Ent 0;Lfp (-2);Rta;Sfp (-3);Rts 2]] @
      [[Ent 0;Ind 1;Ind 0;Rta; Sfp (-10);Drp;Sfp (-11);Rts 9]] @
      [[Ent 0;Lfp(-3);Lfp(-2);Ins 2;Drp;Ins 3;Rta;Rts 5]] @
      [[Ent 0;Lfp(-3);Lfp(-2);Bin(Sub);Lit 8;Bin(Mult);Lit 4;Bin(Add);
         Lfp (1);Lfp (1);Lfp (1);Lfp (1);Lfp (1);Lfp (1);Lfp (1);
         Ind (-3);Ind (-4);Ind (-5);Ind (-6);Ind (-7);Ind (-8);
         Ind (-9);Ind (-10);Rta;Sfp(-76);Drp;Sfp(-77);Drp;
         Sfp(-78);Drp;Sfp(-79);Drp;Sfp(-80);Drp;Sfp(-81);Drp;Sfp(-82);
         Drp;Sfp(-83);Rts (75);]] @
      [[Ent 0;Lfp(-2);Rta;Sfp(-82);Rts(81)]] @
      [[Ent 0;Lfp(-11);Lfp(-10);Bin(Sub);Lit 8;Bin(Mult);Lit 12;Bin(Add);
         Lfp (1);Lfp (1);Lfp (1);Lfp (1);Lfp (1);Lfp (1);Lfp (1);
         Lfp (-9);Lfp(-8);Lfp(-7);Lfp(-6);Lfp(-5);Lfp(-4);Lfp(-3);Lfp(-2);
         Ins(-10);Drp;Ins(-9);Drp;Ins(-8);Drp;
         Ins(-7);Drp;Ins(-6);Drp;Ins(-5);Drp;Ins(-4);Drp;
         Ins(-3);Rta;Rts (10);]] @
      (List.map (translate env) functions) in

   (* Calculate function entry points by adding their lengths *)
   let (fun_offset_list, _) = List.fold_left
      (fun (l,i) f -> (i :: l, (i + List.length f))) ([],0) func_bodies in
   let func_offset = Array.of_list (List.rev fun_offset_list) in

   { num_globals = sizereq 0 globals;
         (* Concatenate the compiled functions and replace the function
          indexes in Jsr statements with PC values *)
     text = Array.of_list (List.map (function
        | Jsr i when i > 0 -> Jsr func_offset.(i)
        | _ as s -> s) (List.concat func_bodies))
   }
