(* Code generation: translate takes a semantically checked AST and
produces LLVM IR

LLVM tutorial: Make sure to read the OCaml version of the tutorial

http://llvm.org/docs/tutorial/index.html

Detailed documentation on the OCaml LLVM library:

http://llvm.moe/
http://llvm.moe/ocaml/

author: William Hom

*)

module L = Llvm
module A = Ast
module S = Sast

open Llvm.MemoryBuffer
open Llvm_bitreader
open Llvm_linker

module StringMap = Map.Make(String)

(* compilator received a semantically accurate AST from evaluator.
 * It mechanically emits LLVM code from here.
 *)
let translate (blargh) =
  let context = L.global_context () in
  let the_module = L.create_module context "Macaw"
  and i32_t  = L.i32_type   context
  and d64_t  = L.double_type context
  and i8_t   = L.i8_type    context
  and void_t = L.void_type  context in

  let matrix_t = L.struct_type context 
    [| L.pointer_type (L.pointer_type d64_t); i32_t; i32_t |] in
  let rec ltype_of_typ = function
      A.NumberTyp -> d64_t
    | A.StringTyp -> L.pointer_type i8_t
    | A.VoidTyp -> void_t
    | A.MatrixTyp -> L.pointer_type matrix_t (* PLACEHOLDER *) in
  
  let filename = "_includes/matrix.bc" in
  let llctx = Llvm.global_context() in
  let llmem = Llvm.MemoryBuffer.of_file filename in
  let llm = Llvm_bitreader.parse_bitcode llctx llmem in
  ignore(Llvm_linker.link_modules' the_module llm);

  let globals = blargh.S.global_vars and functions = blargh.S.func_decls in

  (* Declare each global variable; remember its value in a map *)
  let global_vars =
    let global_var m (t, n) = 
      let init = (match t with 
        A.NumberTyp -> L.const_float (ltype_of_typ t) 0.0
      | A.StringTyp -> L.const_pointer_null (ltype_of_typ t)
      | A.MatrixTyp -> L.const_null (L.pointer_type matrix_t)
      | A.VoidTyp -> L.const_null (ltype_of_typ t))
      in StringMap.add n (L.define_global n init the_module) m in
    List.fold_left global_var StringMap.empty globals

  (* print functions *)
  in let printf_t = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
    let printf_func = L.declare_function "printf" printf_t the_module in
  let printf_s = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
    let printf_func_s = L.declare_function "printf" printf_s the_module in
  let printf_f = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
    let printf_func_f = L.declare_function "printf" printf_f the_module in
  
  (* string concatenation function *)
  let strcat_t = L.function_type (L.pointer_type i8_t) 
    [| L.pointer_type i8_t; L.pointer_type i8_t |] in
    let strcat_func = L.declare_function "strcat" strcat_t the_module in
  
  (* Initialize a m x n matrix of zeroes *)
  let zero_init_f = L.function_type (L.pointer_type matrix_t) 
    [| d64_t; d64_t|] in
    let zero_init_func = L.declare_function "zero_matrix_init" 
    zero_init_f the_module in
  
  (* Access a specific element of a provided matrix *)
  let matrix_access_f = L.function_type d64_t 
    [| L.pointer_type (L.pointer_type matrix_t); d64_t; d64_t |] in
    let matrix_access_func = L.declare_function 
    "access_element" matrix_access_f the_module in
  
  (* Replace a specific element of a provided matrix *)
  let matrix_replace_f = L.function_type d64_t 
    [| L.pointer_type (L.pointer_type matrix_t); d64_t; d64_t; d64_t |] in
    let matrix_replace_func = L.declare_function "replace_element" 
    matrix_replace_f the_module in
  
  (* Matrix shape *)
  let matrix_shape_f = L.function_type (L.pointer_type matrix_t) 
    [| L.pointer_type matrix_t |] in
    let matrix_shape_func = L.declare_function "shape" 
    matrix_shape_f the_module in
  
  (* Matrix length *)
  let matrix_length_f = L.function_type d64_t 
    [| L.pointer_type matrix_t |] in
    let matrix_length_func = L.declare_function "matrix_len"
    matrix_length_f the_module in

  let function_decls =
    let function_decl m fdecl =
      let name = fdecl.S.cname
      and formal_types = Array.of_list (List.map (fun (t,_) -> 
          ltype_of_typ t) fdecl.S.cparams)
      in let ftype = L.function_type 
      (ltype_of_typ fdecl.S.cdtype) formal_types in
      StringMap.add name (L.define_function name ftype the_module, fdecl) m in
    List.fold_left function_decl StringMap.empty functions in
  
  let build_function_body fdecl =
    let (the_function, _) = StringMap.find fdecl.S.cname function_decls in
    let builder = L.builder_at_end context (L.entry_block the_function) in

    let str_format_str = L.build_global_stringptr "%s\n" "str" builder in
    let float_format_str = L.build_global_stringptr "%.3f\n" "float" builder in

    let str_format_str_nl = L.build_global_stringptr "%s" "str" builder in
    let float_format_str_nl = L.build_global_stringptr "%.3f" "float" builder in
    (* Construct the function's local variables: param arguments and locally
       declared variables.  Allocate each on the stack, initialize their
       value, if appropriate, and remember their values in the map *)
    let local_vars =
      let add_param m (t, n) p = L.set_value_name n p;
      let local = L.build_alloca (ltype_of_typ t) n builder in
        ignore (L.build_store p local builder);
        StringMap.add n local m in
      let add_local m (t, n) =
      let local_var = L.build_alloca (ltype_of_typ t) n builder
        in StringMap.add n local_var m in
      let params = List.fold_left2 add_param StringMap.empty fdecl.S.cparams
          (Array.to_list (L.params the_function)) in
          List.fold_left add_local params fdecl.S.clocal_vars in

    (* Return the value for a variable or param argument *)
    let lookup n = try StringMap.find n local_vars 
      with Not_found -> StringMap.find n global_vars
    in

    (* Construct code for an expression; return its value *)
    let rec expr builder = function
        A.Number(t, n1, n2) -> (match t with
          A.IntTyp -> L.const_sitofp (L.const_int i32_t n1) d64_t
        | A.FloatTyp -> L.const_float d64_t n2)
      | A.Noexpr -> L.const_int i32_t 0
      | A.Id s -> L.build_load (lookup s) s builder
      (*
      | A.Id s -> let x = lookup s in (match (fst x) with 
          A.NumberTyp -> L.build_load (L.const_fptosi (snd x) i32_t) s builder 
        | _ -> L.build_load (snd x) s builder)
       *)
      | A.Str s -> L.build_global_stringptr
        (String.sub s 1 ((String.length s) - 2)) "" builder
      | A.Binop (e1, op, e2) ->
        let e1' = expr builder e1
        and e2' = expr builder e2 in
        (match op with
          A.Add     -> L.build_fadd
        | A.Sub     -> L.build_fsub
        | A.Mul     -> L.build_fmul
        | A.Div     -> L.build_fdiv
        | A.Mod     -> L.build_frem
        | A.And     -> L.build_and
        | A.Or      -> L.build_or
        | A.Equal   -> L.build_fcmp L.Fcmp.Oeq
        | A.Neq     -> L.build_fcmp L.Fcmp.One
        | A.Less    -> L.build_fcmp L.Fcmp.Olt
        | A.Leq     -> L.build_fcmp L.Fcmp.Ole
        | A.Greater -> L.build_fcmp L.Fcmp.Ogt
        | A.Geq     -> L.build_fcmp L.Fcmp.Oge) e1' e2' "tmp" builder
      | A.Unop(op, e) ->
        let e' = expr builder e in (match op with
          A.Neg     -> L.build_fneg
        | A.Not     -> L.build_not) e' "tmp" builder
      | A.Assign (l, e) -> (match l with
        | A.MatCellAsn (m, r, c) -> L.build_call matrix_replace_func 
            [| (lookup m); (expr builder r); 
            (expr builder c); (expr builder e) |] 
            "access_element" builder
        | A.VarDecl (t, n) -> let e' = 
            (expr builder e) in ignore (L.build_store e' (lookup n) builder); e'
        | A.IdAsn k -> let e' = 
            expr builder e in ignore (L.build_store e' (lookup k) builder); e')
      | A.MatEmptyInit (r, c) -> L.build_call zero_init_func 
        [| (expr builder r); (expr builder c)|] "zero_init" builder
      | A.MatAcc (e, r, c) -> L.build_call matrix_access_func 
        [| (lookup e); (expr builder r); (expr builder c) |] "mat_acc" builder
      | A.Func ("print_string", [e]) -> L.build_call printf_func_s 
        [| str_format_str; (expr builder e) |] "printf" builder
      | A.Func ("print_number", [e]) -> L.build_call printf_func_f 
        [| float_format_str; (expr builder e) |] "printf" builder
      | A.Func ("printnl_string", [e]) -> L.build_call printf_func_s 
        [| str_format_str_nl; (expr builder e) |] "printf" builder
      | A.Func ("printnl_number", [e]) -> L.build_call printf_func_f 
        [| float_format_str_nl; (expr builder e) |] "printf" builder
      | A.Func ("strcat_string_string", [e1; e2]) -> L.build_call strcat_func 
        [| (expr builder e1); (expr builder e2) |] "strcat" builder
      | A.Func ("shape_matrix", [e]) -> L.build_call matrix_shape_func 
        [| (expr builder e) |] "shape" builder
      | A.Func ("len_matrix", [e]) -> L.build_call matrix_length_func 
        [| (expr builder e) |] "matrix_len" builder
      | A.Func (f, act) ->
         let (fdef, fdecl) = StringMap.find f function_decls in
     let actuals = List.rev (List.map (expr builder) (List.rev act)) in
     let result = (match fdecl.S.cdtype with A.VoidTyp -> ""
                                            | _ -> f ^ "_result") in
         L.build_call fdef (Array.of_list actuals) result builder
    in

    let add_terminal builder f =
      match L.block_terminator (L.insertion_block builder) with
        Some _ -> ()
      | None -> ignore (f builder) in
  
    let rec stmt builder = function
        A.Block sl -> List.fold_left stmt builder sl
      | A.VDecl (t, n) -> builder
      | A.Expr e -> ignore (expr builder e); builder
      | A.Return e -> ignore (match fdecl.S.cdtype with
          A.VoidTyp -> L.build_ret_void builder
        | _ -> L.build_ret (expr builder e) builder); builder
      | A.If (predicate, then_stmt, else_stmt) ->
         let bool_val = expr builder predicate in
   let merge_bb = L.append_block context "merge" the_function in

   let then_bb = L.append_block context "then" the_function in
   add_terminal (stmt (L.builder_at_end context then_bb) then_stmt)
     (L.build_br merge_bb);

   let else_bb = L.append_block context "else" the_function in
   add_terminal (stmt (L.builder_at_end context else_bb) else_stmt)
     (L.build_br merge_bb);

   ignore (L.build_cond_br bool_val then_bb else_bb builder);
   L.builder_at_end context merge_bb

      | A.While (predicate, body) ->
    let pred_bb = L.append_block context "while" the_function in
    ignore (L.build_br pred_bb builder);

    let body_bb = L.append_block context "while_body" the_function in
    add_terminal (stmt (L.builder_at_end context body_bb) body)
      (L.build_br pred_bb);

    let pred_builder = L.builder_at_end context pred_bb in
    let bool_val = expr pred_builder predicate in

    let merge_bb = L.append_block context "merge" the_function in
    ignore (L.build_cond_br bool_val body_bb merge_bb pred_builder);
    L.builder_at_end context merge_bb
    in

    (* Build the code for each statement in the function *)
    let builder = stmt builder fdecl.S.cbody in

    (* Add a return if the last block falls off the end *)
    add_terminal builder (match fdecl.S.cdtype with
        A.VoidTyp -> L.build_ret_void
      | t -> L.build_ret (L.const_float (ltype_of_typ t) 0.0))
in List.iter build_function_body functions;
the_module
