(* Code generation: translate takes a semantically checked AST and
produces LLVM IR

LLVM tutorial: Make sure to read the OCaml version of the tutorial

http://llvm.org/docs/tutorial/index.html

Detailed documentation on the OCaml LLVM library:

http://llvm.moe/
http://llvm.moe/ocaml/

*)

module L = Llvm
module A = Ast

module StringMap = Map.Make(String)

let translate (globals, functions) =
  let context = L.global_context () in
  let the_module = L.create_module context "GoBackwards"
  and i32_t  = L.i32_type  context
  and i8_t   = L.i8_type   context
  and i1_t   = L.i1_type   context
  and str_t  = L.pointer_type (L.i8_type context)
  and array_t = L.array_type
  and void_t = L.void_type context in

  let ltype_of_typ = function
      A.Int -> i32_t
    | A.Bool -> i1_t
    | A.String -> str_t
    | A.ArrayType(t, s) -> (match t with
        A.Int -> array_t i32_t s
        | A.String -> array_t str_t s
        | A.Bool -> array_t i1_t s
        | _ -> raise(Failure("Undeclarable Array Type"))
      )
    | A.Void -> void_t in

  (* Declare each global variable; remember its value in a map *)
  let global_vars =
    let global_var m (t, s) =
      let init = L.const_int (ltype_of_typ t) 0
      in StringMap.add s (L.define_global s init the_module) m  in
      List.fold_left global_var StringMap.empty globals in
(*    List.map (global_var StringMap.empty) globals in *)

  (* Declare printf(), which the print built-in function will call *)
  let printf_t = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let printf_func = L.declare_function "printf" printf_t the_module in

  (* Declare the built-in ascii() function *)
  (*L.VAR_ARG_FUNCTION_TYPE--> WE'RE TAKING A FUNCTION:
    declare i32 @ascii(i8*, ...)
    WHICH WE DON'T WANT TO DO

    I_32T--> OUR RETURN TYPE IS I32 WHICH ISN'T THE RETURN TYPE OF THE ASCII FUNCTION.
    WE WANT A VOID TYPE.
  *)
  let ascii_t = L.function_type void_t [| L.pointer_type i8_t |] in
  let ascii_func = L.declare_function "ascii" ascii_t the_module in

  (* Define each function (arguments and return type) so we can call it *)
  let function_decls =
    let function_decl m fdecl =
      let name = fdecl.A.fname in
      let formal = let open A in fdecl.A.signature.formals in
      let ret = let open A in fdecl.signature.A.ret_typ in
      let formal_types = Array.of_list (List.map (fun (t,_) -> ltype_of_typ t) formal) in
      let ftype = L.function_type (ltype_of_typ ret) formal_types in
      StringMap.add name (L.define_function name ftype the_module, fdecl) m in
    List.fold_left function_decl StringMap.empty functions in

  (* Fill in the body of the given function *)
  let build_function_body fdecl =
    let (the_function, _) = StringMap.find fdecl.A.fname function_decls in
    let builder = L.builder_at_end context (L.entry_block the_function) in

    let int_format_str = L.build_global_stringptr "%d\n" "fmt" builder in
    let str_format_str = L.build_global_stringptr "%s\n" "fmt" builder in

    (* Construct the function's "locals": formal arguments and locally
       declared variables.  Allocate each on the stack, initialize their
       value, if appropriate, and remember their values in the "locals" map *)

    let local_vars =
      let add_formal m (t, n) p = L.set_value_name n p ;
	     let local =  L.build_alloca (ltype_of_typ t) n builder in
        ignore (L.build_store p local builder);

        StringMap.add n local m in


    let add_local m(t,n) =
      let local_var = L.build_alloca (ltype_of_typ t) n builder in
      StringMap.add n local_var m in



      let open A in let formals = List.fold_left2 add_formal StringMap.empty fdecl.signature.A.formals
          (Array.to_list (L.params the_function)) in List.fold_left add_local formals fdecl.A.body.locals in

      (* Return the value for a variable or formal argument *)
      let lookup n = try StringMap.find n local_vars
          with Not_found -> try StringMap.find n global_vars
          with Not_found -> raise (Failure ("undeclared variable " ^ n))
      in

      (*helper function to access arrays*)
      let build_array_alloc s p1 p2 builder b1 =
      if b1
        then L.build_gep (lookup s) [|p1; p2|] s builder
      else
        L.build_load( L.build_gep (lookup s) [|p1; p2|] s builder) s builder
      in


      (* helper function to check arrays *)
      let rec array_expression_check e = match e with
              | A. Literal i -> i
              | A.Binop (e1, op, e2) -> (match op with
              | A.Add -> (array_expression_check e1) + ( array_expression_check e2)
              | A.Sub -> (array_expression_check e1) - ( array_expression_check e2)
              | A.Mult -> (array_expression_check e1) * ( array_expression_check e2)
              | A.Div -> (array_expression_check e1) / ( array_expression_check e2)
              | _ -> 0
          )
          | _-> 0
      in


      let array_lookup_helper = List.fold_left (fun m(t,n) -> StringMap.add n t m)
        StringMap.empty (globals @ let open A in fdecl.A.signature.formals @ fdecl.A.body.locals)
      in

      let array_helper s =
        let s1 = array_lookup_helper in
          StringMap.find s s1
      in

    (* Construct code for an expression; return its value *)
    let rec expr builder = function
	   A.Literal i -> L.const_int i32_t i
      | A.BoolLit b -> L.const_int i1_t (if b then 1 else 0)
      | A.Strlit s -> let ptr = L.build_global_string ((String.sub s 1 ((String.length
s) - 2))) "" builder in L.build_gep ptr [|L.const_int i32_t 0 ; L.const_int i32_t 0|] "" builder
      | A.Noexpr -> L.const_int i8_t 0
      | A.Id s -> L.build_load (lookup s) s builder
      | A.AccessArray(t,e1)-> let e' = expr builder e1 in (
          match( array_helper t) with
              A.ArrayType(_,1)-> (
                    if (array_expression_check e1) >= 1 then raise(Failure("Array out of Bounds"))
                    else build_array_alloc t (L.const_int i32_t 0) e'  builder false
              )
              | _ ->build_array_alloc t (L.const_int i32_t 0) e'  builder false
          )
      | A.Binop (e1, op, e2) ->
        let e1' = expr builder e1
    	  and e2' = expr builder e2 in
        (match op with
        	    A.Add     -> L.build_add
        	  | A.Sub     -> L.build_sub
        	  | A.Mult    -> L.build_mul
        | A.Div     -> L.build_sdiv
        	  | A.And     -> L.build_and
        	  | A.Or      -> L.build_or
        	  | A.Equal   -> L.build_icmp L.Icmp.Eq
        	  | A.Neq     -> L.build_icmp L.Icmp.Ne
        	  | A.Less    -> L.build_icmp L.Icmp.Slt
        	  | A.Leq     -> L.build_icmp L.Icmp.Sle
        	  | A.Greater -> L.build_icmp L.Icmp.Sgt
        	  | A.Geq     -> L.build_icmp L.Icmp.Sge
	      )  e1' e2' "tmp" builder
      | A.Unop(op, e) ->
	  let e' = expr builder e in
	  (match op with
	    A.Neg     -> L.build_neg
          | A.Not     -> L.build_not) e' "tmp" builder
      | A.Assign (e1,e2) ->
          let e1' = ( match e1 with
                A.Id s -> lookup s
                | A.AccessArray(t, e1) -> let e' = expr builder e1 in (match (array_helper t) with
                      A.ArrayType(_,1)->(
                          if(array_expression_check e1)>= 1 then raise(Failure("Array out of bounds"))
                          else build_array_alloc t (L.const_int i32_t 0) e' builder true)
                | _ -> build_array_alloc t (L.const_int i32_t 0) e' builder true)
             |_ -> raise(Failure( "illegal left assignment"))
          )
          and  e2' = expr builder e2
          in ignore (L.build_store e2' e1' builder); e2'

      | A.Call ("print", [e])

      | A.Call ("printb", [e]) ->
	  L.build_call printf_func [| int_format_str ; (expr builder e) |]
	    "printf" builder

      | A.Call ("println", [e]) ->
	  L.build_call printf_func [| str_format_str ; (expr builder e) |]
	    "printf" builder

      (*BECAUSE WE HAVE CHANGED THE TYPE, WE NEED TO MAKE SURE WE ARE CALLING THE RIGHT YPE.
      [| str_format_str ; (expr builder e) |] -->THIS ARRAY IS PASSED TO THE CALL.

      WE ONLY WANT ONE PARAMETER.

      JUST THE EXPRESSION BUILDER -->THIS IS THE ACTUAL STRING WE'RE PASSING IN.

      "ascii"-->IN THE LLVM WE CAN'T GIVE FUNCTIONS WITH A VOID RETURN TYPE A NAME. WE CAN STILL CALL IT.

      *)
      | A.Call ("ascii", [e]) ->
    L.build_call ascii_func [| (expr builder e) |]
      "" builder

      | A.Call (f, act) ->
         let (fdef, fdecl) = StringMap.find f function_decls in
	 let actuals = List.rev (List.map (expr builder) (List.rev act)) in
	 let result = ( let open A in match fdecl.A.signature.ret_typ with A.Void -> ""
                                            | _ -> f ^ "_result") in
         L.build_call fdef (Array.of_list actuals) result builder
    in

    (* Invoke "f builder" if the current block doesn't already
       have a terminal (e.g., a branch). *)
    let add_terminal builder f =
      match L.block_terminator (L.insertion_block builder) with
	Some _ -> ()
      | None -> ignore (f builder) in

    (* Build the code for the given statement; return the builder for
       the statement's successor *)
       let open A in
    let rec stmt builder = function
	A.Block sl -> List.fold_left stmt builder sl
      | A.Expr e -> ignore (expr builder e); builder
      | A.Return e -> ignore (match fdecl.A.signature.ret_typ with
	  A.Void -> L.build_ret_void builder
	| _ -> L.build_ret (expr builder e) builder); builder
      | A.If (predicate, then_stmt, else_stmt) ->
         let bool_val = expr builder predicate in

	 let merge_bb = L.append_block context "merge" the_function in

	 let then_bb = L.append_block context "then" the_function in
	 add_terminal (stmt (L.builder_at_end context then_bb) then_stmt)
	   (L.build_br merge_bb);

	 let else_bb = L.append_block context "else" the_function in
	 add_terminal (stmt (L.builder_at_end context else_bb) else_stmt)
	   (L.build_br merge_bb);

	 ignore (L.build_cond_br bool_val then_bb else_bb builder);
	 L.builder_at_end context merge_bb

      | A.While (predicate, body) ->
	  let pred_bb = L.append_block context "while" the_function in
	  ignore (L.build_br pred_bb builder);

	  let body_bb = L.append_block context "while_body" the_function in
	  add_terminal (stmt (L.builder_at_end context body_bb) body)
	    (L.build_br pred_bb);

	  let pred_builder = L.builder_at_end context pred_bb in
	  let bool_val = expr pred_builder predicate in

	  let merge_bb = L.append_block context "merge" the_function in
	  ignore (L.build_cond_br bool_val body_bb merge_bb pred_builder);
	  L.builder_at_end context merge_bb

      | A.For (e1, e2, e3, body) -> stmt builder
	    ( A.Block [A.Expr e1 ; A.While (e2, A.Block [body ; A.Expr e3]) ] ) in

    (* Build the code for each statement in the function *)
    let builder = stmt builder (A.Block fdecl.A.body.stmts) in

    (* Add a return if the last block falls off the end *)
    add_terminal builder (match fdecl.A.signature.ret_typ with
        A.Void -> L.build_ret_void
      | t -> L.build_ret (L.const_null (ltype_of_typ t)))
  in

  List.iter build_function_body functions;
  the_module
