(* Code generation: translate takes a semantically checked AST and
produces LLVM IR

LLVM tutorial: Make sure to read the OCaml version of the tutorial

http://llvm.org/docs/tutorial/index.html

Detailed documentation on the OCaml LLVM library:

http://llvm.moe/
http://llvm.moe/ocaml/

*)

module L = Llvm
module A = Ast
module S = Sast
module E = Exceptions
module Semant = Semant 

module StringMap = Map.Make(String)


let translate (globals, functions) =
    
    let context = L.global_context () in
    let the_module = L.create_module context "SetC" 

        and i32_t  = L.i32_type  context 
        and i8_t   = L.i8_type   context 
        and i1_t   = L.i1_type   context 
        and str_t  = L.pointer_type (L.i8_type context) 
        and float_t = L.double_type context 
        and void_t = L.void_type context in

    let br_block    = ref (L.block_of_value (L.const_int i32_t 0)) in 

    let global_vars = ref (StringMap.empty) in
    let local_vars = ref (StringMap.empty) in 
    let current_f = ref (List.hd functions) in
    let set_lookup = ref (StringMap.empty) in  

    (* Pointer wrapper-- map of the named struct types represeting pointers. *)
    let pointer_wrapper = 
        List.fold_left (fun m name -> StringMap.add name (L.named_struct_type context name) m) 
        StringMap.empty ["string"; "int"; "float"; "void"; "bool"]
    in 
    (* Set the struct body (fields) for each of the pointer struct types *)
    List.iter2 (fun n l -> let t = StringMap.find n pointer_wrapper in 
    ignore(L.struct_set_body t (Array.of_list(l)) true))
    ["float"; "int"; "string"; "void"; "bool"]
    [[L.pointer_type float_t; i32_t; i32_t]; [L.pointer_type i32_t; i32_t; i32_t]; 
    [L.pointer_type str_t; i32_t; i32_t]; 
    [L.pointer_type void_t; i32_t; i32_t]; [L.pointer_type i1_t; i32_t; i32_t]];
    
    (* Format strings for printing *) 
    let int_format_str builder = L.build_global_stringptr "%d\n" "fmt" builder 
    and str_format_str builder = L.build_global_stringptr "%s\n" "fmt" builder 
    and float_format_str builder = L.build_global_stringptr "%f\n" "fmt" builder in 

    (* Declare built in c functions (or c function wrappers) *)
    let printf_t = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
    let printf_func = L.declare_function "printf" printf_t the_module in

    let strcmp_t = L.var_arg_function_type i32_t [| str_t; str_t |] in
    let strcmp_func = L.declare_function "strcmp" strcmp_t the_module in

    let strint_t = L.var_arg_function_type str_t [| str_t; L.pointer_type i32_t ; i32_t|] in 
    let strint_func = L.declare_function "strint" strint_t the_module in  
   
    let read_t = L.var_arg_function_type str_t [| i32_t |] in 
    let read_func = L.declare_function "read_sc" read_t the_module in 

    let writei_t =  L.var_arg_function_type i32_t [| i32_t; i32_t |] in 
    let writei_func = L.declare_function "writei_sc" writei_t the_module in 
   
    let writes_t =  L.var_arg_function_type i32_t [| i32_t; str_t |] in 
    let writes_func = L.declare_function "writes_sc" writes_t the_module in 

    let writef_t =  L.var_arg_function_type i32_t [| i32_t;  float_t |] in 
    let writef_func = L.declare_function "writef_sc" writef_t the_module in 

    let close_t =  L.var_arg_function_type i32_t [| i32_t |] in 
    let close_func = L.declare_function "close_sc" close_t the_module in  

    let open_t =  L.var_arg_function_type i32_t [| str_t; str_t |] in 
    let open_func = L.declare_function "open_sc" open_t the_module in  

    let str_to_int_t =  L.var_arg_function_type i32_t [| str_t|] in 
    let str_to_int_func = L.declare_function "str_to_int" str_to_int_t the_module in 

    let split_t =  L.var_arg_function_type (L.pointer_type str_t) [| str_t; str_t |] in 
    let split_func = L.declare_function "split" split_t the_module in 

    let split_len_t =  L.var_arg_function_type i32_t [| str_t; str_t |] in 
    let split_len_func = L.declare_function "split_len" split_len_t the_module in 


    let rec string_of_typ datatype = match datatype with 
        A.Datatype(A.Int) -> "int"
        | A.Datatype(A.String) -> "string"
        | A.Datatype(A.Void) -> "void"
        | A.Datatype(A.Bool) -> "bool"
        | A.Datatype(A.Float) -> "float" 
        | A.Settype(t) -> string_of_typ (A.Datatype(t))
    in

    (* Gets the struct pointer *) 
    let lookup_struct typ = 
        let s = string_of_typ typ in 
        StringMap.find s pointer_wrapper in


    (* Gets the llvm type of a datatype *)
    let ltype_of_typ datatype = match datatype with 
        A.Datatype(A.Int) -> i32_t
        | A.Datatype(A.String) -> str_t 
        | A.Datatype(A.Void) -> void_t
        | A.Datatype(A.Bool) -> i1_t 
        | A.Datatype(A.Float) -> float_t 
        | A.Settype(t) -> L.pointer_type (lookup_struct (A.Datatype(t)))
    in
 

    (* StringMap of each function in the file (including lib functions) *)
    let function_decls  =
        let function_decl m fdecl =
            let name = fdecl.S.sfname
            and formal_types = Array.of_list (List.map 
            (fun (_, t) -> ltype_of_typ t) fdecl.S.sformals)
            in 
            let ftype = L.function_type (ltype_of_typ fdecl.S.styp) formal_types in
            StringMap.add name (L.define_function name ftype the_module, fdecl) m 
        in
        List.fold_left function_decl StringMap.empty functions 
    in


    let rec add_terminal builder f =
        match L.block_terminator (L.insertion_block builder) with
	    Some _ -> ()
        | None -> ignore (f builder) 
	
        
    and expr builder = function 
        S.SIntLit (i, _) -> L.const_int i32_t i
        | S.SStrLit (s, _) -> L.build_global_stringptr s "string" builder
        | S.SBoolLit(b, _) -> L.const_int i1_t (if b then 1 else 0)
        | S.SFloatLit(f, _) -> L.const_float float_t f 
        | S.SNoexpr -> L.const_int i32_t 0
        | S.SId (s, _) ->  L.build_load (lookup s) s builder 
        | S.SSet (el, t) ->
                let it = match t with A.Settype(it)-> it | _ -> raise E.Invalid in
                let struct_ptr = L.build_malloc (lookup_struct t) "set1" builder in 
                let size = L.const_int i32_t ((List.length el) + 1) in
                let typ = L.pointer_type (ltype_of_typ (A.Datatype(it))) in 
                let arr = L.build_array_malloc typ size "set2" builder in 
                let arr = L.build_pointercast arr typ "set3" builder in
                let values = List.map (expr builder) el in
                let buildf i v = (let arr_ptr = L.build_gep arr
                    [| (L.const_int i32_t (i + 1)) |] "set4" builder in
                ignore(L.build_store v arr_ptr builder);) in List.iteri buildf values;
                ignore(L.build_store arr (L.build_struct_gep struct_ptr 0 "set5" builder) builder); 
                ignore(L.build_store (L.const_int i32_t (List.length el)) 
                (L.build_struct_gep struct_ptr 1 "set6" builder) builder); 
                
                ignore(L.build_store (L.const_int i32_t 0) (L.build_struct_gep struct_ptr 2 "set7" builder) builder);
                struct_ptr  
                
        | S.SSetAccess (s, e, t) ->
                let idx = expr builder e in
                let idx = L.build_add idx (L.const_int i32_t 1) "access1" builder in
                let struct_ptr = expr builder (S.SId(s, t)) in
                let arr = L.build_load (L.build_struct_gep struct_ptr 0 "access2" builder) "idl" builder in 
                let res = L.build_gep arr [| idx |] "access3" builder in
                L.build_load res "access4" builder
                
        | S.SBinop (e1, op, e2, _) ->
	        let e1' = expr builder e1
	        and e2' = expr builder e2 in
            let typ = Semant.sexpr_to_type e1 in 
            (match typ with 
                A.Datatype(A.Int) |  A.Datatype(A.Bool) ->  (match op with
	            A.Add     ->   L.build_add
	            | A.Sub     -> L.build_sub
	            | A.Mult    -> L.build_mul
                | A.Div     -> L.build_sdiv
	            | A.And     -> L.build_and
	            | A.Or      -> L.build_or
	            | A.Equal   -> L.build_icmp L.Icmp.Eq
	            | A.Neq     -> L.build_icmp L.Icmp.Ne
                | A.Less    -> L.build_icmp L.Icmp.Slt
	            | A.Leq     -> L.build_icmp L.Icmp.Sle
	            | A.Greater -> L.build_icmp L.Icmp.Sgt
	            | A.Geq     -> L.build_icmp L.Icmp.Sge
                | A.Mod     -> L.build_srem 
                | _         -> raise E.InvalidBinaryOperation 
	            ) e1' e2' "tmp" builder
                | A.Datatype(A.Float) -> (match op with
	            A.Add     ->   L.build_fadd
	            | A.Sub     -> L.build_fsub
	            | A.Mult    -> L.build_fmul
                | A.Div     -> L.build_fdiv
	            | A.Equal   -> L.build_fcmp L.Fcmp.Oeq
	            | A.Neq     -> L.build_fcmp L.Fcmp.One
                | A.Less    -> L.build_fcmp L.Fcmp.Ult
	            | A.Leq     -> L.build_fcmp L.Fcmp.Ole
	            | A.Greater -> L.build_fcmp L.Fcmp.Ogt
	            | A.Geq     -> L.build_fcmp L.Fcmp.Oge
                | A.Mod     -> L.build_frem
                | _ -> raise E.InvalidBinaryOperation 
	            ) e1' e2' "tmp" builder
                | _ -> raise E.InvalidBinaryOperation) 
        | S.SUnop(op, e, _) ->
	        let e' = expr builder e in
	        (match op with
	            A.Neg       -> L.build_neg e' "tmp" builder
                | A.Not     -> L.build_not e' "tmp" builder
                | A.Card    -> 
                        let struct_ptr = expr builder e in
                        L.build_load (L.build_struct_gep struct_ptr 1 "struct1" builder) "idl" builder) 
        
        | S.SCall("str_to_int", [e], _) ->
                L.build_call str_to_int_func [| expr builder e|] "str_to_int" builder 
        | S.SCall("split", [e1; e2], _) ->
                let struct_ptr = L.build_malloc (lookup_struct (A.Datatype(A.String))) "append3" builder in
                let len = L.build_call split_len_func [| expr builder e1; expr builder e2 |] "split_len" builder in
                let arr =  L.build_call split_func [| expr builder e1; expr builder e2 |] "split" builder in  
                
                ignore(L.build_store arr (L.build_struct_gep struct_ptr 0 "append8" builder) builder); 
                ignore(L.build_store len (L.build_struct_gep struct_ptr 1 "append9" builder) builder);
                struct_ptr

        | S.SCall("open", [e1; e2], _) ->
                L.build_call open_func [| expr builder e1; expr builder e2 |] "open" builder

        | S.SCall("read", [e1], _) ->
                L.build_call read_func [| expr builder e1|] "read" builder
        | S.SCall("write", [e1; e2], _) ->
                let typ = Semant.sexpr_to_type e2 in
                (match typ with 
                A.Datatype(A.String) -> L.build_call writes_func [|expr builder e1; expr builder e2 |] "write" builder
                |A.Datatype(A.Int) -> L.build_call writei_func [|expr builder e1; expr builder e2 |] "write" builder
                | A.Datatype(A.Float) -> L.build_call writef_func [|expr builder e1; expr builder e2 |] "write" builder
                | _ -> raise E.Invalid)
        | S.SCall("close", [e], _) ->
                L.build_call close_func [|expr builder e|] "close" builder

        | S.SCall ("print", [e], _)  ->
	        L.build_call printf_func [| int_format_str builder ; (expr builder e) |]
	        "print" builder
        | S.SCall ("prints", [e], _) ->
            L.build_call printf_func [| str_format_str builder; (expr builder e) |] 
            "prints" builder
        | S.SCall ("printf", [e], _) ->
            L.build_call printf_func [| float_format_str builder; (expr builder e) |] 
            "printf" builder
        | S.SCall ("strcmp", [e1; e2], _) ->
            L.build_call strcmp_func [| expr builder e1; expr builder e2 |] "strcmp" builder
        | S.SCall ("append", [e1; e2], t) ->
                let struct_ptr1 = expr builder e1 in
                let len1 = L.build_load(L.build_struct_gep struct_ptr1 1 "append1" builder) "tmp" builder  in 
                let struct_ptr2 = expr builder e2 in
                let len2 = L.build_load(L.build_struct_gep struct_ptr2 1 "append2" builder) "tmp" builder  in
                
                let it = match t with A.Settype(it)-> it | _ -> raise E.Invalid in
                let struct_ptr = L.build_malloc (lookup_struct t) "append3" builder in
                let s = L.build_add len1 len2 "append4" builder in
                let size = L.build_add s (L.const_int i32_t 1) "append5" builder in 
                let typ = L.pointer_type (ltype_of_typ (A.Datatype(it))) in 
                let arr = L.build_array_malloc typ size "append6" builder in 
                let arr = L.build_pointercast arr typ "append7" builder in
                
                ignore(L.build_store arr (L.build_struct_gep struct_ptr 0 "append8" builder) builder); 
                ignore(L.build_store (s) (L.build_struct_gep struct_ptr 1 "append9" builder) builder);

                ignore(L.build_store (L.const_int i32_t 0) (L.build_struct_gep struct_ptr 2 "append10" builder) builder); 
                
                let str = (match t with 
                A.Settype(A.Int) -> "appendi"
                | A.Settype(A.Float) -> "appendf"
                | A.Settype(A.String) -> "appends"
                | A.Settype(A.Bool) -> "appendb"
                | A.Settype(A.Void) -> raise E.Invalid 
                | _ -> raise E.Invalid) in
                
                let (fdef, fdecl) = StringMap.find str function_decls in
                let result = (match fdecl.S.styp with 
                                 A.Datatype(A.Void) -> ""
                                 | _ -> str ^ "_result") in
                L.build_call fdef [| struct_ptr1; struct_ptr2; struct_ptr; len1; len2 |] result builder
          
        | S.SCall("set", act, t) -> 
                let str = (match t with 
                A.Settype(A.Int) -> "seti"
                | A.Settype(A.Float) -> "setf"
                | A.Settype(A.String) -> "sets"
                | A.Settype(A.Bool) -> "setb"
                | _ -> raise E.Invalid) in
                
                let (fdef, fdecl) = StringMap.find str function_decls in
                let actuals = List.rev (List.map (expr builder) (List.rev act)) in
                let result = (match fdecl.S.styp with 
                                 A.Datatype(A.Void) -> ""
                                 | _ -> str ^ "_result") in
                L.build_call fdef (Array.of_list actuals) result builder
        | S.SCall("str", [e], _) ->
                let struct_ptr = expr builder e in
                let int_ptr = L.build_load(L.build_struct_gep struct_ptr 0 "append1" builder) "tmp" builder  in 
                let size = L.build_load(L.build_struct_gep struct_ptr 1 "append1" builder) "tmp" builder in 
                let arr = L.build_array_malloc str_t size "set2" builder in 
                let arr = L.build_pointercast arr str_t  "set3" builder in
                
                L.build_call strint_func [| arr; int_ptr ; size|] "strint" builder
        | S.SCall("pop", [e], _) ->
                let struct_ptr = expr builder e in
                let arr_ptr = L.build_load(L.build_struct_gep struct_ptr 0 "append1" builder) "tmp" builder  in 
                let size = L.build_load(L.build_struct_gep struct_ptr 1 "append1" builder) "tmp" builder in 
                let new_size = L.build_sub (L.build_load(L.build_struct_gep struct_ptr 1 "append1" builder) "tmp" builder) 
                (L.const_int i32_t 1) "sub" builder in 
                 ignore(L.build_store (new_size) (L.build_struct_gep struct_ptr 1 "append9" builder) builder);
                 
                let res = L.build_gep arr_ptr [| size |] "pop3" builder in
                L.build_load res "pop4" builder


        | S.SCall (f, act, _ ) ->
            let (fdef, fdecl) = StringMap.find f function_decls in
	        let actuals = List.rev (List.map (expr builder) (List.rev act)) in
	        let result = (match fdecl.S.styp with 
                                 A.Datatype(A.Void) -> ""
                                 | _ -> f ^ "_result") in
            L.build_call fdef (Array.of_list actuals) result builder
                                 

  and stmt builder = 
        let (the_function, _) = StringMap.find !current_f.S.sfname function_decls 
        in function

        S.SBlock sl -> 
            List.fold_left stmt builder sl; 
        | S.SExpr (e, _) -> ignore (expr builder e); builder
        | S.SAssign (s, e, _) ->
                let expr_t = Semant.sexpr_to_type e in
                (match expr_t with
                A.Settype(A.Void) ->
                    if StringMap.find s !set_lookup = A.Void then (builder) 
                    else (
                        let typ = StringMap.find s !set_lookup in
                        let struct_ptr = L.build_malloc (lookup_struct (A.Datatype(typ))) "voidassign1" builder in
                        let typ =  L.pointer_type (ltype_of_typ (A.Datatype(typ))) in
                        let arr = L.const_pointer_null typ in
                        ignore(L.build_store arr (L.build_struct_gep struct_ptr 0 "voidassign2" builder) builder); 
                        let size = L.const_int i32_t 0 in 
                        ignore(L.build_store size (L.build_struct_gep struct_ptr 1 "voidassign3" builder) builder); 
                        ignore(L.build_store struct_ptr (lookup s) builder); builder)
                
                | _ -> 
                        (ignore(let e' = expr builder e in 
                        (L.build_store e' (lookup s) builder)); builder))

        | S.SSetElementAssign(s, e1, e2, t) ->
                let e2' = expr builder e2 in 
                let idx = expr builder e1 in 
                let idx = L.build_add idx (L.const_int i32_t 1) "setassign1" builder in 
                let struct_ptr = expr builder (S.SId(s, t)) in 
                let arr = L.build_load(L.build_struct_gep struct_ptr 0 "setassign2" builder) "arr" builder  in
                let res = L.build_gep arr [| idx |] "setassign3" builder in
                ignore(L.build_store e2' res builder); builder
               
        | S.SReturn (e, _) -> ignore (match !current_f.S.styp with
	        A.Datatype(A.Void) -> L.build_ret_void builder
	        | _ -> L.build_ret (expr builder e) builder); builder 

        | S.SIf (predicate, then_stmt, else_stmt) ->
            let bool_val = expr builder predicate in
	        let merge_bb = L.append_block context "merge" the_function in

	        let then_bb = L.append_block context "then" the_function in
	        add_terminal (stmt (L.builder_at_end context then_bb) then_stmt)
	        (L.build_br merge_bb);

	        let else_bb = L.append_block context "else" the_function in
	         add_terminal (stmt (L.builder_at_end context else_bb) else_stmt)
	        (L.build_br merge_bb);

	        ignore (L.build_cond_br bool_val then_bb else_bb builder);
	        L.builder_at_end context merge_bb 

      | S.SWhile (predicate, body) ->

	        let pred_bb = L.append_block context "while" the_function in

	        let body_bb = L.append_block context "while_body" the_function in

	        let pred_builder = L.builder_at_end context pred_bb in
	        let bool_val = expr pred_builder predicate in

	        let merge_bb = L.append_block context "merge" the_function in

            br_block  := merge_bb; 

            ignore(L.build_br pred_bb builder);

            add_terminal (stmt (L.builder_at_end context body_bb) body) 
                (L.build_br pred_bb);

	        ignore (L.build_cond_br bool_val body_bb merge_bb pred_builder);
	        L.builder_at_end context merge_bb

      | S.SIter(constraints, body) ->
            let (assign, predicate, increment) = constraints in 
            stmt builder (S.SBlock[assign; 
            S.SWhile(predicate, S.SBlock[body; increment])]) 
      
      | S.SBreak ->  ignore (L.build_br !br_block builder);  builder

  
    (* Lookup gives llvm for variable *)
    and lookup n  = try StringMap.find n !local_vars
        with Not_found ->   StringMap.find n !global_vars  
    in
 

    (* Declare each global variable; remember its value in a map *)
    let _global_vars =
        let (f, _) = StringMap.find "main" function_decls in
        let builder = L.builder_at_end context (L.entry_block f) in
        let global_var m (n, e, _) = 
            let init = expr builder e 
            in StringMap.add n (L.define_global n init the_module) m in
        List.fold_left global_var StringMap.empty globals 
    in global_vars := _global_vars;
    
    
    let build_function_body fdecl =
        let (the_function, _)  = StringMap.find fdecl.S.sfname function_decls
        in let builder = L.builder_at_end context (L.entry_block the_function) in  
        current_f := fdecl; 

        (*Construct the function's "locals": formal arguments and locally
        declared variables.  Allocate each on the stack, initialize their
         value, if appropriate, and remember their values in the "locals" map *)
        
        let _local_vars = 
            let add_formal m (n, t) p = 
                L.set_value_name n p; 

                let local = L.build_alloca (ltype_of_typ t) n builder 
                in ignore (L.build_store p local builder); StringMap.add n local m 
            in
            let add_local m (n, t) =
                (match t with 
                | A.Settype(it) -> (ignore(set_lookup:= StringMap.add n it !set_lookup);
                    let local_var = L.build_alloca (ltype_of_typ t) n builder
                    in  StringMap.add n local_var m)
                | _ ->
                        (let local_var = L.build_alloca (ltype_of_typ t) n builder
                        in StringMap.add n local_var m)
                )
            in
            let formals = List.fold_left2 add_formal StringMap.empty 
                fdecl.S.sformals (Array.to_list (L.params the_function)) 
            in 
            List.fold_left add_local formals fdecl.S.slocals 
        in  
        local_vars := _local_vars;
        
        (* Build the code for each statement in the function *)
        let builder = stmt builder (S.SBlock fdecl.S.sbody) in

        (* Add a return if the last block falls off the end *)
        add_terminal builder (match fdecl.S.styp with
            A.Datatype(A.Void) -> L.build_ret_void
        | t -> L.build_ret (L.const_int (ltype_of_typ t) 0))
  
    in
    List.iter build_function_body functions;
    the_module
  
