(*
  Allie Costa
  Laura Smerling
  Jacob Penn
  Laura Matos
*)
module L = Llvm
module A = Ast
open Sast 

module StringMap = Map.Make(String)

let translate (functions,statements) =
  let theprogram = (functions) @ (statements) in
  let context = L.global_context () in
    let the_module = L.create_module context "MicroC"
    and i8_t = L.i8_type context
    and str_t = L.pointer_type (L.i8_type context)
    and i1_t   = L.i1_type context
    and i32_t  = L.i32_type context 
    and float_t    = L.double_type context
    and void_t     = L.void_type   context in

  (* types of variables in BURGer*)
  let ltype_of_typ = function
      A.Int   -> i32_t
    | A.String -> str_t
    | A.Bool  -> i1_t
    | A.Float -> float_t
    | A.Void  -> void_t
  in
  (* isolate list of items that match as statements and then form a list of statements *)
  let stmt_list =
    let stmts_as_items =
      List.filter (fun x -> match x with
        SStmt(x) -> true
        | _ -> false) theprogram
    in List.map (fun x -> match x with
        SStmt(x) -> x
        | _ -> failwith "stmt casting didn't work") stmts_as_items
  in

  (*after you figure out which items are statements, you need to go through the statements
    and figure out which ones contain the variable declarations *)
  let globals =
    let global_list = List.filter (fun x -> match x with
        SVarDec(x, _) -> true
      | _ -> false) stmt_list
    in List.map (fun x -> match x with
        SVarDec(x, _) -> x
      | _ -> failwith "not turned into global") global_list
  in
  let decode x = List.map (fun v -> match v with SStmt(y) -> y) x in 
  (* isolate list of statements that are NOT variable declarations *)
  let not_globals_list = List.filter (fun x -> match x with
    SVarDec(_,_) -> false
  | _ -> true) (decode statements) in

  (* from list of items in program, form list of functions from items and
  build the main function *)
  let functions =
    (* generating the hidden main function *)
    let sfunc_decl_main = SFunction{
           styp = Int;
           sfname = "main";
           sformals = [];
           sbody =(* SReturn(Void,SLiteral(0)) ::*) not_globals_list;
         }
    in
    (* filtering out items that match as functions *)
      let functions_as_items = List.filter (fun x -> match x with
          SFunction(x) -> true
        | _ -> false) theprogram
      in
    let all_functions_as_items = sfunc_decl_main :: functions_as_items
    in List.map (fun x -> match x with
        SFunction(x) -> x
      | _ -> failwith "function casting didn't work") all_functions_as_items
  in

  (* Store the global variables in a string map *)
  let global_vars =
    let global_var map (t, n) =
      if (ltype_of_typ t = str_t)
      then (
        let init = L.const_null str_t in
        StringMap.add n (L.define_global n init the_module) map
      )
      else (
        let init = L.const_int (ltype_of_typ t) 0
        in StringMap.add n (L.define_global n init the_module) map
      )
    in 
    List.fold_left global_var StringMap.empty globals in

  (* printf() declaration *)
  let print_t = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let print_func = L.declare_function "print" print_t the_module in

  let printf_t = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let printf_func = L.declare_function "printf" printf_t the_module in

  let printbig_t = L.var_arg_function_type i32_t [| i32_t |] in
  let printbig_func = L.declare_function "printbig" printbig_t the_module in

(*   let println_t = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let println_func = L.declare_function "println" println_t the_module in *)

  (* Define each function (arguments and return type) so we can call it *)
  let function_decls =
    let function_decl map func_dec =
      let name = func_dec.sfname
      and formal_types = Array.of_list (List.map (fun (t,_) -> ltype_of_typ t) func_dec.sformals)
      in
      let ftype = L.function_type (ltype_of_typ func_dec.styp) formal_types in
      StringMap.add name (L.define_function name ftype the_module, func_dec) map
    in
    List.fold_left function_decl StringMap.empty functions
  in

    (* Fill in the body of the given function *)
  let build_function_body func_dec =
    let (the_function, _) = StringMap.find func_dec.sfname function_decls in
    let builder = L.builder_at_end context (L.entry_block the_function) in

    let int_format_str = L.build_global_stringptr "%d\n" "fmt" builder
    and float_format_str = L.build_global_stringptr "%g\n" "fmt" builder in 

    let local_vars =
      let add_formal var_map (formal_type, formal_name) param = L.set_value_name formal_name param;
        let local = L.build_alloca (ltype_of_typ formal_type) formal_name builder in
        ignore (L.build_store param local builder);
        StringMap.add formal_name local var_map
      in

      let add_local map (formal_type, formal_name) =
        let local_var = L.build_alloca (ltype_of_typ formal_type) formal_name builder in
        StringMap.add formal_name local_var map
      in

      let formals = List.fold_left2 add_formal StringMap.empty func_dec.sformals
          (Array.to_list (L.params the_function)) in

      let function_locals =
        let get_locals_from_fbody function_body =
          let get_vardec locals_list stmt = match stmt with
              SVarDec((typ, string), _) -> if (func_dec.sfname = "main")
                then
                  locals_list
                else
                  (typ, string) :: locals_list
              | _ -> locals_list
          in
          List.fold_left get_vardec [] function_body
        in get_locals_from_fbody func_dec.sbody
      in List.fold_left add_local formals function_locals
    in

  let lookup n = try StringMap.find n local_vars
                   with Not_found -> StringMap.find n global_vars
  in

  (* generate code for different kinds of expressions *)
  let rec expr builder ((_, e) : sexpr) = match e with
      SLiteral i  -> L.const_int i32_t i
      | SBoolLit b  -> L.const_int i1_t (if b then 1 else 0)
      | SFliteral l -> L.const_float_of_string float_t l
      | SNoexpr     -> L.const_int i32_t 0
      | SId s       -> L.build_load (lookup s) s builder
      | SBinop ((A.Float,_ ) as e1, op, e2) ->
          let e1' = expr builder e1
           and e2' = expr builder e2 in
            (match op with
              A.Add     -> L.build_add
            | A.Sub     -> L.build_sub
            | A.Mult    -> L.build_mul
            | A.Div     -> L.build_sdiv
            | A.Mod     -> L.build_srem
            | A.And     -> L.build_and
            | A.Or      -> L.build_or
            | A.Equal   -> L.build_icmp L.Icmp.Eq
            | A.Neq     -> L.build_icmp L.Icmp.Ne
            | A.Less    -> L.build_icmp L.Icmp.Slt
            | A.Leq     -> L.build_icmp L.Icmp.Sle
            | A.Greater -> L.build_icmp L.Icmp.Sgt
            | A.Geq     -> L.build_icmp L.Icmp.Sge
            ) e1' e2' "tmp" builder
      | SUnop(op, e) ->
        let e' = expr builder e in
          (match op with
            A.Neg     -> L.build_neg
           | A.Not     -> L.build_not)
           e' "tmp" builder
      | SAssign (s, e) -> let e' = expr builder e in ignore(L.build_store e' (lookup s) builder); e'
      | SFunCall ("print", [e]) | SFunCall ("printb", [e]) ->
        L.build_call printf_func [| int_format_str ; (expr builder e) |]
        "printf" builder
      | SFunCall ("printbig", [e]) ->
        L.build_call printbig_func [| (expr builder e) |] "printbig" builder
      | SFunCall ("printf", [e]) -> 
        L.build_call printf_func [| float_format_str ; (expr builder e) |]
        "printf" builder
      | SFunCall (f, args) ->
          let (fdef, func_dec) = StringMap.find f function_decls in
       let llargs = List.rev (List.map (expr builder) (List.rev args)) in
       let result = (match func_dec.styp with 
                            A.Void -> ""
                          | _ -> f ^ "_result") in
             L.build_call fdef (Array.of_list llargs) result builder
      in

      (* Invoke "f builder" if the current block doesn't already
           have a terminal (e.g., a branch). *)
        let add_terminal builder f =
          match L.block_terminator (L.insertion_block builder) with
           Some _ -> ()
          | None -> ignore (f builder)
        in

    (* generate code for different kinds of statements *)
    let rec stmt builder = function
      SSeq sl -> List.fold_left stmt builder sl
    | SExpr e -> ignore(expr builder e); builder
    | SVarDec ((typ, string), e) -> ignore(expr builder (typ, (SAssign(string, e)))); builder
    | SReturn e -> ignore (match func_dec.styp with
                              (* Special "return nothing" instr *)
                              A.Void -> L.build_ret_void builder 
                              (* Build return statement *)
                            | _ -> L.build_ret (expr builder e) builder );
                     builder
    | SIf (predicate, then_stmt, else_stmt) ->
       let bool_val = expr builder predicate in
       let merge_bb = L.append_block context "merge" the_function in
            let build_br_merge = L.build_br merge_bb in (* partial function *)

       let then_bb = L.append_block context "then" the_function in
       add_terminal (stmt (L.builder_at_end context then_bb) then_stmt)
         build_br_merge;

       let else_bb = L.append_block context "else" the_function in
       add_terminal (stmt (L.builder_at_end context else_bb) else_stmt)
         build_br_merge;

       ignore (L.build_cond_br bool_val then_bb else_bb builder);
       L.builder_at_end context merge_bb
    | SWhile (predicate, body) ->
      let pred_bb = L.append_block context "while" the_function in
      ignore (L.build_br pred_bb builder);

      let body_bb = L.append_block context "while_body" the_function in
      add_terminal (stmt (L.builder_at_end context body_bb) body)
        (L.build_br pred_bb);

      let pred_builder = L.builder_at_end context pred_bb in
      let bool_val = expr pred_builder predicate in

      let merge_bb = L.append_block context "merge" the_function in
      ignore (L.build_cond_br bool_val body_bb merge_bb pred_builder);
      L.builder_at_end context merge_bb
  in

    (* Build the code for each statement in the function *)
    let builder = stmt builder (SSeq func_dec.sbody) in
     
    (* Add a return if the last block falls off the end *)
     add_terminal builder (match func_dec.styp with
        A.Void -> L.build_ret (L.const_float float_t 5.5)
      | A.Float -> L.build_ret (L.const_float float_t 0.0)
      | t -> L.build_ret (L.const_int (ltype_of_typ t) 5))
  in

List.iter build_function_body functions;
the_module

