(* Code generation: translate takes a semantically checked AST and
produces LLVM IR

LLVM tutorial: Make sure to read the OCaml version of the tutorial

http://llvm.org/docs/tutorial/index.html

Detailed documentation on the OCaml LLVM library:

http://llvm.moe/
http://llvm.moe/ocaml/

*)

module L = Llvm
module A = Ast
open Sast 

module StringMap = Map.Make(String)

type l_environment = {
  vars: L.llvalue StringMap.t;
  parent: l_environment option;
}

type l_types = {
  mutable nodes: L.lltype StringMap.t;
}

let get_opt o = match o with 
  | Some(x) -> x
  | None -> raise (Failure ("failed to extract value (should have been caught by semant)"))

(* translate : Sast.program -> Llvm.module *)
let translate (globals, functions) =
  let context    = L.global_context () in
  
  (* Create the LLVM compilation module into which
     we will generate code *)
  let the_module = L.create_module context "Arbol" in
  
  let mytypes = { nodes = StringMap.empty } in 
  
  (* Get types from the context *)
  let i32_t      = L.i32_type    context
  and i8_t       = L.i8_type     context
  and i1_t       = L.i1_type     context
  and float_t    = L.double_type context
  and str_t      = L.pointer_type (L.i8_type context)
  and void_t     = L.void_type   context in

  (* Return the LLVM type for a MicroC type *)
  let rec ltype_of_typ = function
      A.Int   -> i32_t
    | A.Bool  -> i1_t
    | A.Float -> float_t
    | A.Char -> i8_t
    | A.String -> str_t
    | A.Void  -> void_t
    | A.Node(t) -> 
      try StringMap.find (A.string_of_vtype t) mytypes.nodes
      with Not_found -> 
        let ntype = L.named_struct_type context ("node_" ^ A.string_of_vtype t) in
        let ptype = L.pointer_type (ltype_of_typ t) in
        let ctype = L.pointer_type ntype in 
        let _ = L.struct_set_body ntype [|ptype; ctype; ctype|] true in
        mytypes.nodes <- StringMap.add (A.string_of_vtype t) ntype mytypes.nodes; ntype
  in
  
  let default_vals t = match t with
    | A.Float -> L.const_float (ltype_of_typ t) 0.0
    | A.Int -> L.const_int (ltype_of_typ t) 0
    | A.Bool -> L.const_int (ltype_of_typ t) 0
    | A.Char -> L.const_int (ltype_of_typ t) 0
    | A.String -> L.const_pointer_null (ltype_of_typ t)
    | A.Node(t) -> let pnull = L.const_pointer_null (ltype_of_typ t) in
        let cnull = L.const_pointer_null (ltype_of_typ t) in 
        L.const_struct context [|pnull; cnull; cnull|]
    | A.Void -> raise (Failure "illegal void type (should have been checked by semant)") in

  let fill_null_node node node_t t builder =  
    let pnull = L.build_malloc (ltype_of_typ t) "default_val" builder in 
    let cnull = L.const_pointer_null (L.pointer_type (ltype_of_typ node_t)) in 
    let data_ptr = L.build_struct_gep node 0 "result" builder in 
    let lchild_ptr = L.build_struct_gep node 1 "lchild_ptr" builder in
    let rchild_ptr = L.build_struct_gep node 2 "rchild_ptr" builder in
    let _ = L.build_store (default_vals t) pnull builder in
    let _ = L.build_store pnull data_ptr builder in 
    let _ = L.build_store cnull lchild_ptr builder in
    let _ = L.build_store cnull rchild_ptr builder in builder in
  
  (* Create a map of global variables after creating each *)
  (* Takes a binding (type and name) and creates global constants initialized to 0 *)
  let global_vars : L.llvalue StringMap.t =
    let global_var m sv = 
      let init = default_vals sv.sv_type
      in StringMap.add sv.sv_name (L.define_global sv.sv_name init the_module) m in
    List.fold_left global_var StringMap.empty globals in
  
  let global_env = { vars = global_vars; parent = None } in 

  let printf_t : L.lltype = 
      L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let printf_func : L.llvalue = 
      L.declare_function "print" printf_t the_module in

  (* Adding print_int implementation from semantic *)
  let printf_int : L.lltype = 
        L.var_arg_function_type i32_t [| i32_t |] in
  let printf_int_func : L.llvalue = 
        L.declare_function "print_int" printf_int the_module in

  (* print_float implementation *)
  let printf_float : L.lltype = 
    L.var_arg_function_type float_t [| float_t |] in
  let printf_float_func : L.llvalue = 
    L.declare_function "print_float" printf_float the_module in

  (* built-in function 'pre-order' *)
  (* let preorder_node : L.lltype = 
    L.var_arg_function_type node [| node |] in
  let preorder_node_func : L.llvalue = 
    L.declare_function "preorder" preorder_node the_module in *)

  let unwrap_sargs = function (t,_) -> ltype_of_typ t in

  (* Define each function (arguments and return type) so we can 
     call it even before we've created its body *)
  let function_decls : (L.llvalue * sfdecl) StringMap.t =
  let function_decl m fdecl =
    let name = fdecl.sfname
    and formal_types = Array.of_list (List.map unwrap_sargs fdecl.sargs_list)
    in let ftype = L.function_type (ltype_of_typ fdecl.srtype) formal_types in
    StringMap.add name (L.define_function name ftype the_module, fdecl) m in
  List.fold_left function_decl StringMap.empty functions in

  (* Fill in the body of the given function *)
  let build_function_body global_env fdecl =
    let (the_function, _) = StringMap.find fdecl.sfname function_decls in
    let builder = L.builder_at_end context (L.entry_block the_function) in

  (* Return the value for a variable or formal argument, starting from lowest scope *)
  let rec lookup n env = try StringMap.find n env.vars
    with Not_found -> match env.parent with 
                    | Some(p) -> lookup n p
                    | None -> raise (Failure ("var " ^ n ^ " not found (semantic analysis failed!)")) in

  (* 
    Add variable to env
  *)
  let add_formal m (t, n) p =
    L.set_value_name n p;
    let local = L.build_alloca (ltype_of_typ t) n builder in
    ignore (L.build_store p local builder);
    StringMap.add n local m in

  let add_local m sv builder = match sv.sv_type with 
    | A.Node(t) as node_t -> let node_var = L.build_alloca (ltype_of_typ sv.sv_type) sv.sv_name builder in
      let _ = fill_null_node node_var node_t t builder in
      StringMap.add sv.sv_name node_var m
    | _ -> let local_var = L.build_alloca (ltype_of_typ sv.sv_type) sv.sv_name builder in
      StringMap.add sv.sv_name local_var m in

  (* Add formals *)
  let formals = List.fold_left2 add_formal StringMap.empty fdecl.sargs_list
      (Array.to_list (L.params the_function)) in
  let local_env = { vars = formals; parent = Some(global_env) } in 

  (* Construct code for an expression; return its value *)
  let rec expr env builder (e : sexpr) = match e with
      SInt_lit i  -> L.const_int i32_t i, None
    | SChar_lit c  -> L.const_int i8_t (Char.code c), None
    | SBool_lit b  -> L.const_int i1_t (if b then 1 else 0), None
    | SString_lit s -> L.build_global_stringptr s "string" builder, None
    | SFloat_lit l -> L.const_float_of_string float_t l, None
    | SUnop(op, e, t) ->
      let (e', _) = expr env builder e in
        (match op with
          A.Neg when t = A.Float -> L.build_fneg 
        | A.Neg                  -> L.build_neg
        | A.Not                  -> L.build_not) e' "tmp" builder, None
    | SBinop (_, op, e2, A.Void, _, _) ->
      let (_, ptr) = expr env builder e2 in
      (match op with
        | A.Equals -> L.build_is_null
        | A.Not_Equals -> L.build_is_not_null
        | _ -> raise (Failure "internal error: semant should have rejected invalid op on void")
      ) (get_opt ptr) "tmp" builder, None
    | SBinop (e1, op, _, _, A.Void, _) ->
      let (_, ptr) = expr env builder e1 in
      (match op with
        | A.Equals -> L.build_is_null
        | A.Not_Equals -> L.build_is_not_null
        | _ -> raise (Failure "internal error: semant should have rejected invalid op on void")
      ) (get_opt ptr) "tmp" builder, None
    | SBinop (e1, op, e2, A.Float, _, _) -> 
    let (e1', _) = expr env builder e1
    and (e2', _) = expr env builder e2 in
      (match op with 
        A.Plus     -> L.build_fadd
      | A.Minus     -> L.build_fsub
      | A.Times    -> L.build_fmul
      | A.Divide     -> L.build_fdiv
      | A.Mod     -> L.build_frem 
      | A.Equals   -> L.build_fcmp L.Fcmp.Oeq
      | A.Not_Equals     -> L.build_fcmp L.Fcmp.One
      | A.Less    -> L.build_fcmp L.Fcmp.Olt
      | A.Less_Eq     -> L.build_fcmp L.Fcmp.Ole
      | A.Greater -> L.build_fcmp L.Fcmp.Ogt
      | A.Greater_Eq    -> L.build_fcmp L.Fcmp.Oge
      | A.And | A.Or ->
          raise (Failure "internal error: semant should have rejected and/or on float")
      ) e1' e2' "tmp" builder, None
    | SBinop (e1, op, e2, _, _, _) ->
    let (e1', _) = expr env builder e1
    and (e2', _) = expr env builder e2 in
      (match op with
        A.Plus     -> L.build_add
      | A.Minus     -> L.build_sub
      | A.Times    -> L.build_mul
      | A.Divide     -> L.build_sdiv
      | A.Mod     -> L.build_srem
      | A.And     -> L.build_and
      | A.Or      -> L.build_or
      | A.Equals   -> L.build_icmp L.Icmp.Eq
      | A.Not_Equals     -> L.build_icmp L.Icmp.Ne
      | A.Less    -> L.build_icmp L.Icmp.Slt
      | A.Less_Eq     -> L.build_icmp L.Icmp.Sle
      | A.Greater -> L.build_icmp L.Icmp.Sgt
      | A.Greater_Eq     -> L.build_icmp L.Icmp.Sge
      ) e1' e2' "tmp" builder, None
    | SCall ("print", [e], _) -> 
        let (v, _) = expr env builder e in L.build_call printf_func [| v |]
          "print" builder, None
    | SCall ("print_int", [e], _) -> 
        let (v, _) = expr env builder e in L.build_call printf_int_func [| v |]
          "print_int" builder, None
    | SCall ("print_float", [e], _) -> 
        let (v,_) = expr env builder e in L.build_call printf_float_func [| v |]
        "print_float" builder, None
    (* | SCall ("preorder", [e], _) -> 
        let (v,_) = expr env builder e in L.build_call preorder_node_func [| v |]
        "preorder" builder, None *)
    | SCall (f, args, _) ->
        let (fdef, fdecl) = StringMap.find f function_decls in
        let llargs = List.rev (List.map (fun arg -> fst (expr env builder arg)) (List.rev args)) in
        let result = (match fdecl.srtype with 
                        A.Void -> ""
                      | _ -> f ^ "_result") in
        L.build_call fdef (Array.of_list llargs) result builder, None
    | SAssign (s, e, _) -> let (e', _) = expr env builder e in
        ignore(L.build_store e' (lookup s env) builder); e', None
    | SId (s, _)       -> let ptr = lookup s env in L.build_load ptr s builder, Some(ptr) 
    | SNode_assign(s, e, t) -> let (e', _) = expr env builder e in
      let ptr = L.build_malloc (ltype_of_typ t) (s ^ "_val") builder in
      let _ = L.build_store e' ptr builder in 
      let node = lookup s env in
      let ptr_ptr = L.build_struct_gep node 0 "result" builder in 
      let _ = L.build_store ptr ptr_ptr builder in e', Some(ptr)
    | SNodeop(nodeop, s, _) -> (match nodeop with
      | A.Get_left_child ->
        let node = lookup s env in
        let ptr_ptr = L.build_struct_gep node 1 "result" builder in 
        let ptr = L.build_load ptr_ptr (s ^ "_ptr") builder in
        L.build_load ptr (s ^ "_lchild") builder, Some(ptr)
      | A.Get_right_child ->
        let node = lookup s env in
        let ptr_ptr = L.build_struct_gep node 2 "result" builder in 
        let ptr = L.build_load ptr_ptr (s ^ "_ptr") builder in
        L.build_load ptr (s ^ "_rchild") builder, Some(ptr)
      | A.Dref ->
        let node = lookup s env in
        let ptr_ptr = L.build_struct_gep node 0 "result" builder in
        let ptr = L.build_load ptr_ptr (s ^ "_ptr") builder in 
        L.build_load ptr (s ^ "_val") builder, Some(ptr))
    | SNoexpr     -> L.const_int i32_t 0, None
  in
  
  (* LLVM insists each basic block end with exactly one "terminator" 
      instruction that transfers control.  This function runs "instr builder"
      if the current block does not already have a terminator.  Used,
      e.g., to handle the "fall off the end of the function" case. *)
  let add_terminal builder instr =
    match L.block_terminator (L.insertion_block builder) with
    | Some _ -> ()
    | None -> ignore (instr builder) in

  (* First allocate all local variables *)
  let create_env builder m s = 
    match s with
    | SVariable(sv) -> add_local m sv builder
    | _ -> m in

  (* Remove SVariables *)
  let parse_vars s = match s with 
    | SVariable(sv) -> (match sv.sv_node_val with
      | true -> (
        match sv.sv_type with
        | A.Node(t) -> SExpr (SNode_assign(sv.sv_name, sv.sv_val, t))
        | _ -> raise (Failure ("internal error: node assignment check failed"))
        )
      | false -> (
        match sv.sv_val with
        | SNoexpr -> SExpr(SNoexpr)
        | expr -> SExpr(SAssign(sv.sv_name, expr, sv.sv_type)) 
      )  
    )
    | stmt -> stmt in
 
  (* Build the code for the given statement; return the builder for
      the statement's successor (i.e., the next instruction will be built
      after the one generated by this call) *)
  
  let rec stmt merge_block next_block env builder = function
    | SBlock sl -> let parsed_sl = List.map parse_vars sl in 
      let new_env = {
        vars = List.fold_left (create_env builder) StringMap.empty sl;
        parent = Some(env) 
      } in List.fold_left (stmt merge_block next_block new_env) builder parsed_sl
    | SExpr (e) -> ignore(expr env builder e); builder 
    | SReturn (e, _) -> ignore(match fdecl.srtype with
        (* Special "return nothing" instr *)
        A.Void -> L.build_ret_void builder 
        (* Build return statement *)
        | _ -> let (e', _) = expr env builder e in L.build_ret e' builder);
      builder
    | SIf (predicate, then_stmt, else_stmt) ->
      let (bool_val, _) = expr env builder predicate in
      let merge_bb = L.append_block context "merge" the_function in
      let build_br_merge = L.build_br merge_bb in (* partial function *)

      let then_bb = L.append_block context "then" the_function in
      add_terminal (stmt merge_block next_block env (L.builder_at_end context then_bb) then_stmt)
      build_br_merge;

      let else_bb = L.append_block context "else" the_function in
      add_terminal (stmt merge_block next_block env (L.builder_at_end context else_bb) else_stmt)
      build_br_merge;

      ignore(L.build_cond_br bool_val then_bb else_bb builder);
      L.builder_at_end context merge_bb

    | SWhile (predicate, body) ->
      let pred_bb = L.append_block context "while" the_function in
      ignore(L.build_br pred_bb builder);

      let body_bb = L.append_block context "while_body" the_function in

      let pred_builder = L.builder_at_end context pred_bb in
      let (bool_val, _) = expr env pred_builder predicate in
      let merge_bb = L.append_block context "merge" the_function in

      add_terminal (stmt (Some merge_bb) (Some pred_bb) env (L.builder_at_end context body_bb) body)
        (L.build_br pred_bb);

      ignore(L.build_cond_br bool_val body_bb merge_bb pred_builder);
      L.builder_at_end context merge_bb

    | SFor (e1, predicate, e3, body) -> 
      ignore(expr env builder e1);
      
      let pred_bb = L.append_block context "for" the_function in
      ignore(L.build_br pred_bb builder);

      let body_bb = L.append_block context "for_body" the_function in

      let next_bb = L.append_block context "for_next" the_function in 
      
      let pred_builder = L.builder_at_end context pred_bb in
      let (bool_val, _) = expr env pred_builder predicate in
      let merge_bb = L.append_block context "merge" the_function in

      let next_builder = L.builder_at_end context next_bb in
      ignore(expr env next_builder e3);

      add_terminal (stmt (Some merge_bb) (Some next_bb) env (L.builder_at_end context body_bb) body)
        (L.build_br next_bb);

      add_terminal next_builder (L.build_br pred_bb);
      ignore(L.build_cond_br bool_val body_bb merge_bb pred_builder);
      L.builder_at_end context merge_bb
    (* | SFor (e1, e2, e3, body) -> stmt merge_block next_block env builder
      ( SBlock [SExpr e1 ; SWhile (e2, SBlock [body ; SExpr e3]) ] ) *)
    | SNode_child(s, nodeop, e, _) -> let child_ptr = (match nodeop with
      | A.Set_left_child -> L.build_struct_gep (lookup s env) 1 (s ^ "_lchild") builder
      | A.Set_right_child -> L.build_struct_gep (lookup s env) 2 (s ^ "_rchild") builder) in 
      let (_, target) = expr env builder e in (match target with
        | Some(ptr) -> ignore(L.build_store ptr child_ptr builder); builder
        | None -> raise (Failure "no node detected")
        )
    | SBreak -> (match merge_block with
      | Some(bb) -> ignore(L.build_br bb builder); builder
      | None -> raise (Failure "not in loop--internal error, should have been checked in semant"))
    | SContinue -> (match next_block with
      | Some(bb) -> ignore(L.build_br bb builder); builder
      | None -> raise (Failure "not in loop--internal error, should have been checked in semant"))
    | SVariable(_) -> raise (Failure "internal codegen error--variable not parsed out")
  in

  (* Build the code for each statement in the function *)

  let builder = stmt None None local_env builder (SBlock fdecl.sbody) in
  (* Add a return if the last block falls off the end *)
  add_terminal builder (match fdecl.srtype with
      A.Void -> L.build_ret_void
    | t -> L.build_ret (default_vals t)
  )
  in

  List.iter (build_function_body global_env) functions;
  the_module
