(* signed: Yanlin Duan, Zhuo Kong, Emily Meng, Shiyu Qiu *)

(* Code generation: translate takes a semantically checked AST and
produces LLVM IR

LLVM tutorial: Make sure to read the OCaml version of the tutorial

http://llvm.org/docs/tutorial/index.html

Detailed documentation on the OCaml LLVM library:

http://llvm.moe/
http://llvm.moe/ocaml/

*)

module L = Llvm
module A = Ast
module S = Sast
module E = Semant
open Ast
open Llvm
open Sast

module StringMap = Map.Make(String)

module StringHash = Hashtbl.Make(struct
    type t = string (* type of keys *)
    let equal x y = x = y (* use structural comparison *)
    let hash = Hashtbl.hash (* generic hash function *)
end)

type symbol_table = {
  parent: symbol_table option;
  varMap: (llvalue*typ) StringHash.t; 
}

type env = {
  scope : symbol_table;
  return_type : typ;
  in_for : bool;
  in_while : bool;
  
}

let translate (functions) =
  let context = L.global_context () in
  let the_module = L.create_module context "rusty"
  and i32_t  = L.i32_type context
  and i8_t   = L.i8_type context
  and i1_t   = L.i1_type context
  and float_t = L.double_type context
  and string_t = (L.pointer_type (L.i8_type context))
  and void_t = L.void_type context in

let structMap = StringHash.create 20 in
  let ltype_of_primitive = (function
      A.IntT -> i32_t
    | A.BoolT -> i1_t
    | A.FloatT -> float_t
    | A.CharT -> i8_t
    | A.VoidT -> void_t
  ) in

let rec ltype_of_typ = (function
      A.DataT(t) -> ltype_of_primitive t
    | A.StringT -> string_t
    | A.ArrayT(t,_) -> L.pointer_type (ltype_of_typ t)
    | A.RefT(_,t) ->  L.pointer_type (ltype_of_typ t)
    | A.StructT(t) -> (try let t = snd (StringHash.find structMap t) in L.pointer_type t with Not_found -> raise(Failure(t)))
    | _ -> void_t
  )  in

let cast lhs rhs lhsType rhsType = 

  (match (lhsType, rhsType) with
    (DataT(IntT), DataT(IntT)) -> (lhs, rhs), DataT(IntT)
  | (DataT(IntT), DataT(CharT))-> (lhs, rhs), DataT(CharT)
  | (DataT(IntT), DataT(FloatT))-> (lhs, rhs), DataT(FloatT)
  | (DataT(CharT), DataT(IntT))-> (lhs, rhs), DataT(CharT)
  | (DataT(CharT), DataT(CharT))-> (lhs, rhs), DataT(CharT)
  | (DataT(BoolT), DataT(BoolT))-> (lhs, rhs), DataT(BoolT)
  | (DataT(FloatT), DataT(IntT)) -> (lhs, rhs), DataT(FloatT)
  | (DataT(FloatT), DataT(FloatT)) -> (lhs, rhs), DataT(FloatT)
  |    _ -> raise(Failure("cannot support other types"))) in 

  (* Declare printf(), which the print built-in function will call *)
  let printf_t = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let printf_func = L.declare_function "printf" printf_t the_module in
 
   let rec prestmt = function
        S.SBlock(sl) -> List.iter prestmt sl
      | S.SStructDef(struct_name, l) -> 
            let struct_fields = List.mapi (fun i (field, t) -> 
                let offset = i * 4 in
                    (field, (t, offset))) l in
            let type_list = List.map snd l in 
            let l_type_list = List.map ltype_of_typ type_list in
            let ltype = struct_type context (Array.of_list l_type_list) in
          StringHash.replace structMap struct_name (struct_fields, ltype) 
      | _ -> () in

  (* Define each function (arguments and return type) so we can call it *)
  let function_decls =
    let function_decl m fdecl =
    let name = fdecl.S.sfname
      and formal_types =
        Array.of_list (List.map (fun (_,t) -> ltype_of_typ t) fdecl.S.sformals) in 
        prestmt (S.SBlock(fdecl.S.sbody));
        let ftype = L.function_type (ltype_of_typ fdecl.S.soutputType) formal_types in
          StringMap.add name (L.define_function name ftype the_module, fdecl) m in
          List.fold_left function_decl StringMap.empty functions in
  
  (* Fill in the body of the given function *)
  let build_function_body fdecl =
    let varMap = StringHash.create 20 in
      
    let (the_function, _) = StringMap.find fdecl.S.sfname function_decls in
    let builder = L.builder_at_end context (L.entry_block the_function) in

      let add_variable m (n, t) v =
      let var = (match t with
          A.StructT(s) -> let (_,lt) = StringHash.find structMap s in 
            let ltyp = (pointer_type lt) in
            let var = L.build_alloca ltyp n builder in ignore (L.build_store v var builder); var
        | A.StringT -> v
        | _ -> let ltyp = ltype_of_typ t in
            let var = L.build_alloca ltyp n builder in ignore (L.build_store v var builder); var) in
      StringHash.replace m n (var,t)
  in

  let _ = List.iter2 (fun a b-> add_variable varMap a b) fdecl.S.sformals(Array.to_list (L.params the_function)) in

  let rec lookup scope name = 
    try
      StringHash.find scope.varMap name;
    with Not_found ->
        match scope.parent with
        Some(parent) -> lookup parent name
      | _ -> raise (Failure ("undeclared identifier " ^ name))
  in
    
  let env = {
    scope = {parent = None; varMap = varMap};
    return_type = fdecl.S.soutputType;
    in_for = false;
    in_while = false;
  } in

  let rec search_index l field =
    match l with
      [] -> 0
    | hd :: tl -> if field = hd then 0 else 1 + search_index tl field in

  let int_format_str = L.build_global_stringptr "%d\n" "imt" builder in
  let float_format_str = L.build_global_stringptr "%0.2f\n" "fmt" builder in
  let char_format_str = L.build_global_stringptr "%e\n" "cmt" builder in
  let string_format_str = L.build_global_stringptr "%s\n" "sfmt" builder in
     
    (* Construct code for an expression; return its value *)
    let rec expr builder env = function
        S.SIntLit(i,_) -> L.const_int i32_t i
      | S.SBoolLit(b,_) -> L.const_int i1_t (if b then 1 else 0)
      | S.SFloatLit(f,_) -> L.const_float float_t f 
      | S.SCharLit(c,_) -> L.const_int i8_t (Char.code c)
      | S.SStringLit(s,_) -> L.build_global_stringptr (s) "tmp" builder
      | S.SCall ("println", [e],_) -> 

        (match (S.get_type_from_sexpr e) with

        StringT ->  
              (match e with
                        S.SStringLit(s,_) -> let strptr = L.build_global_stringptr (s^"\n") "printstr" builder in 
                        L.build_call printf_func [| strptr |] "printf" builder
               |        S.SId(var,_) -> let (strptr, _) = (StringHash.find env.scope.varMap var) in 
                        L.build_call printf_func [|string_format_str; strptr |] "printf" builder
               |        S.SArrayAccess(_,_,_) -> let strptr = expr builder env e in 
                        L.build_call printf_func [|string_format_str; strptr |] "printf" builder
               | _ -> raise(Failure("not implemented" )) 
              )
        | DataT(IntT) -> L.build_call printf_func [| int_format_str ; (expr builder env e)|] "printf" builder
        | DataT(FloatT) -> L.build_call printf_func [| float_format_str ; (expr builder env e)|] "printf" builder
        | DataT(CharT) -> L.build_call printf_func [| char_format_str ;  L.build_call printf_func [| char_format_str ; (expr builder env e)|] 
          "printf" builder|] "printf" builder
        | _ -> raise(Failure("println does not support types other than int and string yet!"))  )     

      | S.SCall (fname, args, _) ->
         let (fdef, fdecl) = StringMap.find fname function_decls in

         let actuals = List.rev (List.map (expr builder env) (List.rev args)) in

         let result = match fdecl.S.soutputType with DataT(VoidT) -> ""
                                                  | _ -> fname ^ "_result" in
               let x = L.build_call fdef (Array.of_list actuals) result builder in
               x

       | S.SBinop(e1,op,e2,_) -> 
          let type1 =get_type_from_sexpr e1 in
          let type2 = get_type_from_sexpr e2 in     
          let e1' = expr builder env e1 and e2' = expr builder env e2 in

          let float_ops op e1 _ =        
          (match op with
              A.Add     -> L.build_fadd e1' e2' "tmp" builder
              | A.Sub     -> L.build_fsub e1' e2' "tmp" builder
              | A.Mult    -> L.build_fmul e1' e2' "tmp" builder
              | A.Div     -> L.build_fdiv e1' e2' "tmp" builder
              | A.Mod     -> L.build_frem e1' e2' "tmp" builder
              | A.Equal   -> L.build_fcmp L.Fcmp.Oeq e1' e2' "tmp" builder
              | A.Neq     -> L.build_fcmp L.Fcmp.One e1' e2' "tmp" builder
              | A.Less    -> L.build_fcmp L.Fcmp.Ult e1' e2' "tmp" builder
              | A.Leq     -> L.build_fcmp L.Fcmp.Ole e1' e2' "tmp" builder
              | A.Greater -> L.build_fcmp L.Fcmp.Ogt e1' e2' "tmp" builder
              | A.Geq     -> L.build_fcmp L.Fcmp.Oge e1' e2' "tmp" builder
              | A.Assign -> (match e1 with
                    S.SId(s,_) -> ignore (L.build_store e2' (fst (lookup env.scope s)) builder); e2'
                  | S.SUnop(Deref,e,_) -> let e1'' = expr builder env e in ignore (L.build_store e2' e1'' builder); e2'
                  | SStructAccess(e, _, _) -> let e1'' = expr builder env e in 
                    let e1'' = build_pointercast e1'' (pointer_type float_t) "tmp" builder
                    in ignore (L.build_store e2' e1'' builder); e2'
                  | _ -> raise(Failure("Float doesn't support this operator!")))
              |_ -> raise(Failure("Float doesn't support this operator!"))
          ) in
          let int_ops op e1 _ =
              (match op with
                A.Add     -> L.build_add e1' e2' "addtmp" builder
              | A.Sub     -> L.build_sub e1' e2' "tmp" builder
              | A.Mult    -> L.build_mul e1' e2' "tmp" builder
              | A.Div     -> L.build_sdiv e1' e2' "tmp" builder
              | A.Mod     -> L.build_srem e1' e2' "tmp" builder
              | A.And     -> L.build_and e1' e2' "tmp" builder
              | A.Or      -> L.build_or e1' e2' "tmp" builder
              | A.Equal   -> L.build_icmp L.Icmp.Eq e1' e2' "tmp" builder
              | A.Neq     -> L.build_icmp L.Icmp.Ne e1' e2' "tmp" builder
              | A.Less    -> L.build_icmp L.Icmp.Slt e1' e2' "tmp" builder
              | A.Leq     -> L.build_icmp L.Icmp.Sle e1' e2' "tmp" builder
              | A.Greater -> L.build_icmp L.Icmp.Sgt e1' e2' "tmp" builder
              | A.Geq     -> L.build_icmp L.Icmp.Sge e1' e2' "tmp" builder
              | A.Assign -> (match e1 with
                    S.SId(s,_) -> ignore (L.build_store e2' (fst (lookup env.scope s)) builder); e2'
                  | S.SUnop(Deref,e,_) -> let e1'' = expr builder env e in ignore (L.build_store e2' e1'' builder); e2'
                  | _ -> raise(Failure("int doesn't support this operator!")))

              | _ -> raise(Failure("int doesn't support this operator!"))) 
           in
  
          let (e1,e2),d = cast e1 e2 type1 type2 in 
          let type_handler d =(match d with
              A.DataT(FloatT) ->float_ops op e1 e2
            | A.DataT(IntT) 
            | A.DataT(BoolT)
            | A.DataT(CharT)  ->int_ops op e1 e2
            | _ -> raise(Failure("no type matched!"))
          ) in type_handler d 

        | S.SUnop(op, e, _) -> let e' = expr builder env e in
            (match op with
              A.Neg    -> L.build_neg e' "tmp" builder
            | A.Not    -> L.build_not e' "tmp" builder
            | A.Borrow(_) -> (match e with 
                          SId(s,_) -> (fst (lookup env.scope s))
                          | _ -> raise(Failure("no type matched!")))
            | A.Deref  -> L.build_load e' "tmp" builder)
      | S.SId(s,_) -> L.build_load (fst (lookup env.scope s)) s builder
      | S.SArrayLit(l, length, tp) -> array_create l length tp env builder 
      | S.SArrayAccess(name, index, _) -> array_access name index env builder
      | S.SStructCreate(fields) -> struct_create fields env builder
      | S.SStructAccess(var, field, _) -> struct_access var field env builder
      | _ -> L.const_int i32_t 0

    and array_create l length tp env builder =
    match tp with
      ArrayT(tp,_) -> let t = ltype_of_typ tp in
      let length_used = (const_int i32_t (length  + 1)) in
        let a = build_array_malloc t length_used "tmp" builder in
        let a = build_pointercast a (pointer_type t) "tmp" builder in
        let llvalues = List.map (expr builder env) l in
          List.iteri (fun i llval -> 
                let ptr = build_gep a [| (const_int i32_t (i+1)) |] "tmp" builder in
                ignore(build_store llval ptr builder);) llvalues;
          a
    | _ -> raise(Failure("array error"))

    and array_access name idx env builder =
      let idx = expr builder env idx in
        let idx = build_add idx (const_int i32_t 1) "tmp" builder
        in
          let arr = expr builder env name in
          let _val = build_gep arr [| idx |] "tmp" builder in
          build_load _val "tmp" builder 

    and struct_create fields env builder =

        let arg_list = List.map snd fields in 
        let l_type_list = List.map ltype_of_typ (List.map S.get_type_from_sexpr arg_list) in
          let ltype = struct_type context (Array.of_list l_type_list) in
            let v = build_malloc ltype "tmp" builder in
          List.iteri(fun i arg-> 
          let field_ptr = build_struct_gep v i "tmp" builder in
              ignore(build_store ((expr builder env) arg) field_ptr builder)) arg_list; v 

    and struct_access var field env builder = 
      let string_var = S.string_of_sexpr var in
        let (struct_ptr,t) = lookup env.scope string_var in        
        let (struct_l,_) = try StringHash.find structMap (A.string_of_typ t)
          with Not_found -> raise(Failure("codegen undeclared struct " ^ (A.string_of_typ t))) in
        let fields_list = List.map fst struct_l in
        let index = search_index fields_list field in
        let struct_ptr = L.build_load struct_ptr "tmp" builder in
          let field_ptr = build_struct_gep struct_ptr index "tmp" builder in
          
          L.build_load field_ptr "tmp" builder
    in
    (* Invoke "f builder" if the current block doesn't already
       have a terminal (e.g., a branch). *)
    let add_terminal builder f =
      match L.block_terminator (L.insertion_block builder) with
        Some _ -> ()
      | None -> ignore (f builder) in
  
    (* Build the code for the given statement; return the builder for
       the statement's successor *)
    let rec stmt builder env = function
        S.SBlock(sl) -> let new_env = {env with scope = {parent=Some(env.scope); varMap = StringHash.create 20}} in 
        let builder = List.fold_left (fun a b -> stmt a new_env b) builder sl in
        builder
      | S.SExpr(e,_) -> ignore (expr builder env e); builder
      | S.SReturn(e,_) -> ignore (match fdecl.S.soutputType with
            A.DataT(A.VoidT) -> L.build_ret_void builder
          | _ -> L.build_ret (expr builder env e) builder); builder
      | S.SIf (pred_expr, then_stmt, else_stmt) ->
        let bool_val = expr builder env pred_expr in
        let merge_bb = L.append_block context "merge" the_function in
        let then_bb = L.append_block context "then" the_function in
        let new_env = {env with scope = {parent=Some(env.scope); varMap = StringHash.create 20}} in
          add_terminal (stmt (L.builder_at_end context then_bb) new_env then_stmt)
          (L.build_br merge_bb);
  
        let else_bb = L.append_block context "else" the_function in
        let new_env = {env with scope = {parent=Some(env.scope); varMap = StringHash.create 20}} in
          add_terminal (stmt (L.builder_at_end context else_bb) new_env else_stmt)
          (L.build_br merge_bb);
  
        ignore (L.build_cond_br bool_val then_bb else_bb builder); 
        L.builder_at_end context merge_bb      
      | S.SWhile (pred_expr, body_stmt) ->
        let pred_bb = L.append_block context "while" the_function in
        ignore (L.build_br pred_bb builder);
   
        let body_bb = L.append_block context "while_body" the_function in
        let new_env = {env with scope = {parent=Some(env.scope); varMap = StringHash.create 20} ; in_while = true;} in
          add_terminal (stmt (L.builder_at_end context body_bb) new_env body_stmt)
          (L.build_br pred_bb);

        let pred_builder =  L.builder_at_end context pred_bb in 
        let bool_val = expr pred_builder env pred_expr in
        let merge_bb = L.append_block context "merge" the_function in

        ignore (L.build_cond_br bool_val body_bb merge_bb pred_builder);
        L.builder_at_end context merge_bb
      | S.SFor (e1, e2, e3, body) -> 
        stmt builder env
        ( S.SBlock [S.SExpr (e1, S.get_type_from_sexpr e1) ; S.SWhile (e2, S.SBlock [body ; S.SExpr (e3, S.get_type_from_sexpr e3)]) ] )
      | S.SLoop (body) -> let new_env = {env with scope = {parent=Some(env.scope); varMap = StringHash.create 20}; in_while= true;} in 
        stmt builder new_env (S.SBlock [S.SWhile (S.SBoolLit(true,A.DataT(A.BoolT)), S.SBlock [body;])])
      | S.SStructDef(struct_name, l) -> 
            let struct_fields = List.mapi (fun i (field, t) -> 
                let offset = i * 4 in
                    (field, (t, offset))) l in

            let type_list = List.map snd l in 
            let l_type_list = List.map ltype_of_typ type_list in
            let ltype = struct_type context (Array.of_list l_type_list) in

          StringHash.replace structMap struct_name (struct_fields, ltype); 
          builder     
      | S.SDeclaration (_,(s,t),e) -> let e' = expr builder env e in add_variable env.scope.varMap (s,t) e'; builder
      | _ -> builder

    in

    (* Build the code for each statement in the function *)
    let builder = stmt builder env (S.SBlock fdecl.S.sbody) in

    match fdecl.S.soutputType with
        A.DataT(A.VoidT)  -> add_terminal builder (L.build_ret_void)
      | A.DataT(A.IntT) -> add_terminal builder (L.build_ret (L.const_int i32_t 0))
      | _ -> ()

  in

  List.iter build_function_body functions;
  the_module
