(* Code generation: translate takes a semantically checked AST and
produces LLVM IR

LLVM tutorial: Make sure to read the OCaml version of the tutorial

http://llvm.org/docs/tutorial/index.html

Detailed documentation on the OCaml LLVM library:

http://llvm.moe/
http://llvm.moe/ocaml/

*)

module L = Llvm
module A = Ast
open Sast 

module StringMap = Map.Make(String)

(* translate : Sast.program -> Llvm.module *)
let translate decls =
  let context    = L.global_context () in
  let llmem = L.MemoryBuffer.of_file "util.o" in
  let llm = Llvm_bitreader.parse_bitcode context llmem in
  
  (* Create the LLVM compilation module into which
     we will generate code *)
  let the_module = L.create_module context "Pixel" in

  (* Get types from the context *)
  let i32_t      = L.i32_type    context
  and i8_t       = L.i8_type     context
  and float_t     = L.double_type context
  and void_t     = L.void_type   context
  and string_t   = L.pointer_type (L.i8_type context)
  and image_t   = L.pointer_type (
    match L.type_by_name llm "struct.image" with
        None -> raise (Failure "missing struct Image")
      | Some t -> t
  )
  and matrix_t   = L.pointer_type (
    match L.type_by_name llm "struct.matrix" with
        None -> raise (Failure "missing struct Matrix")
      | Some t -> t
  )
  in

  (* Return the LLVM type for a Pixel type *)
  let ltype_of_typ = function
      A.Int    -> i32_t
    | A.Float  -> float_t
    | A.Void   -> void_t
    | A.String -> string_t
    | A.Image  -> image_t
    | A.Matrix -> matrix_t
  in

  (* Define built-in function types *)
  let printf_t : L.lltype = 
      L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let printf_func : L.llvalue = 
      L.declare_function "printf" printf_t the_module in

  (* let initMatrix_t = L.function_type matrix_t [| float_t; i32_t; i32_t |] in *)
  let initMatrix_t = L.function_type matrix_t [| L.pointer_type (L.pointer_type float_t); i32_t; i32_t; |] in
  let initMatrix_f = L.declare_function "initMatrix" initMatrix_t the_module in
  let access_t = L.function_type float_t [| matrix_t; i32_t; i32_t |] in
  let access_f = L.declare_function "access" access_t the_module in
  let accessAssign_t = L.function_type void_t [| matrix_t; i32_t; i32_t; float_t |] in
  let accessAssign_f = L.declare_function "accessAssign" accessAssign_t the_module in
  let image_in_t = L.function_type image_t [| string_t; string_t |] in
  let image_in_f = L.declare_function "image_in" image_in_t the_module in
  let image_out_t = L.function_type void_t [| string_t; image_t; string_t |] in
  let image_out_f = L.declare_function "image_out" image_out_t the_module in
  let convolute_t = L.function_type matrix_t [| matrix_t; matrix_t |] in
  let convolute_f = L.declare_function "convolute" convolute_t the_module in
  let join_color_t = L.function_type image_t [| matrix_t; matrix_t; matrix_t |] in
  let join_color_f = L.declare_function "join" join_color_t the_module in
  let join_grayscale_t = L.function_type image_t [| matrix_t; |] in
  let join_grayscale_f = L.declare_function "join" join_grayscale_t the_module in
  let multiply_matrix_t = L.function_type matrix_t [| matrix_t; matrix_t |] in
  let multiply_matrix_f = L.declare_function "multiply_matrix" multiply_matrix_t the_module in
  let add_matrix_t = L.function_type matrix_t [| matrix_t; matrix_t |] in
  let add_matrix_f = L.declare_function "add_matrix" add_matrix_t the_module in
  let scale_matrix_t = L.function_type matrix_t [| matrix_t; float_t |] in
  let scale_matrix_f = L.declare_function "scale_matrix" scale_matrix_t the_module in
  let exp_matrix_t = L.function_type matrix_t [| matrix_t; float_t |] in
  let exp_matrix_f = L.declare_function "exp_matrix" exp_matrix_t the_module in

  (* Define each function (arguments and return type) so we can 
     call it even before we've created its body *)
  let function_decls : (L.llvalue * sfunc_decl) StringMap.t =
    let function_decl m fdecl =
      let name = fdecl.sfname
      and formal_types = 
	      Array.of_list (List.map (fun (t,_, _) -> ltype_of_typ t) fdecl.sformals)
      in let ftype = L.function_type (ltype_of_typ fdecl.styp) formal_types in
      StringMap.add name (L.define_function name ftype the_module, fdecl) m in
    List.fold_left function_decl StringMap.empty decls
  in
  
  (* Fill in the body of the given function *)
  let build_function_body fdecl =
    let (the_function, _) = StringMap.find fdecl.sfname function_decls in
    let builder = L.builder_at_end context (L.entry_block the_function) in

    let int_format_str = L.build_global_stringptr "%d\n" "fmt" builder in
    let float_format_str = L.build_global_stringptr "%g\n" "fmt" builder in
    let string_format_str = L.build_global_stringptr "%s\n" "fmt" builder in
    (* Construct the function's "locals": formal arguments and locally
       declared variables.  Allocate each on the stack, initialize their
       value, if appropriate, and remember their values in the "locals" map *)
    (* let local_vars =
      let add_formal m (t, n) p = 
        L.set_value_name n p;
	    let local = L.build_alloca (ltype_of_typ t) n builder in
        ignore (L.build_store p local builder);
	    StringMap.add n local m 

      (* Allocate space for any locally declared variables and add the
       * resulting registers to our map *)
      and add_local m (t, n) =
        let local_var = L.build_alloca (ltype_of_typ t) n builder
        in StringMap.add n local_var m
      in

      let formals = List.fold_left2 add_formal StringMap.empty fdecl.sformals
          (Array.to_list (L.params the_function)) in
      List.fold_left add_local formals fdecl.slocals
    in *)
    let local_vars =
      let add_formal m (t, n, _) p =
        L.set_value_name n p;
      let local = L.build_alloca (ltype_of_typ t) n builder in
        ignore (L.build_store p local builder);
    StringMap.add n local m

      and add_local m (t, n) =
      (* add_local m (t, n) = *)
        let local_var = L.build_alloca (ltype_of_typ t) n builder
        in StringMap.add n local_var m
      in

    let formals = List.fold_left2 add_formal StringMap.empty fdecl.sformals
      (Array.to_list (L.params the_function))
    (* in *)
    in List.fold_left add_local formals []
  in

    (* Return the value for a variable or formal argument *)
    let lookup n =
      try StringMap.find n local_vars
      with Not_found -> raise (Failure "Undefined variable ")
    in

    (* Construct code for an expression; return its value *)
    (*
      | SBinop of sexpr * op * sexpr
      | SUnop of uop * sexpr
      | SMatrixAccess of string * sexpr * expr
      | SImageRedAccess of string
      | SImageGreenAccess of string
      | SImageBlueAccess of string
      | SImageGrayscaleAccess of string
      | SMatrixRows of string
      | SMatrixCols of string
    *)
    (* let rec expr builder ((_, e) : sexpr) local_vars = match e with *)
    let rec expr builder ((_, e) : sexpr) lvars = match e with
	      SLiteral i  -> L.const_int i32_t i
      | SFliteral l -> L.const_float_of_string float_t l
      | SStrLiteral l -> L.build_global_stringptr l "tmp" builder
      | SNoexpr     -> L.const_int i32_t 0
      | SId s       -> L.build_load (lookup s) s builder
      | SAssign (s, e) -> let value = StringMap.find s local_vars
          in let e' = expr builder e local_vars
          in ignore(L.build_store e' value builder);
          e'
      | SMLiteral (contents, rows, cols) ->
          let rec expr_list = function
            [] -> []
            | hd::tl -> (expr builder hd local_vars)::expr_list tl
          in
          let contents' = expr_list contents
          in
          (* let m = L.build_call matrix_init_f [| L.const_int i32_t cols; L.const_int i32_t rows |] "matrix_init" builder *)
          let m = L.build_call initMatrix_f [| L.pointer_type (L.pointer_type float_t) contents'; L.const_int i32_t rows; L.const_int i32_t cols |] "initMatrix" builder
          in
          ignore(List.map (fun v -> L.build_call store_matrix_f [| m ; v |] "store_val" builder) contents'); m
      | SBinop ((A.Float, _) as e1, op, e2) ->
        let e1' = expr builder e1 local_vars
        and e2' = expr builder e2 local_vars in
        (match op with 
          A.Add     -> L.build_fadd
        | A.Sub     -> L.build_fsub
        | A.Mult    -> L.build_fmul
        | A.Div     -> L.build_fdiv
        | A.Equal   -> L.build_fcmp L.Fcmp.Oeq
        | A.Neq     -> L.build_fcmp L.Fcmp.One
        | A.Less    -> L.build_fcmp L.Fcmp.Olt
        | A.Leq     -> L.build_fcmp L.Fcmp.Ole
        | A.Greater -> L.build_fcmp L.Fcmp.Ogt
        | A.Geq     -> L.build_fcmp L.Fcmp.Oge
        (* | A.Exp     -> L. *)
        | A.And | A.Or ->
            raise (Failure "internal error: semant should have rejected and/or on float")
        ) e1' e2' "tmp" builder
      (* | SBinop ((A.Matrix, A.Matrix) as e1, op, e2) ->
        let e1' = expr builder e1 local_vars
        and e2' = expr builder e2 local_vars in
        (match op with
          A.Add     -> L.build_call add_matrix_f [| (expr builder e1 local_vars); (expr builder e2 local_vars) |] "add_matrix" builder
        | A.Mult    -> L.build_call multiply_matrix_f [| (expr builder e1 local_vars); (expr builder e2 local_vars) |] "multiply_matrix" builder
        ) e1' e2' "tmp" builder
      (* | SBinop ((A.Matrix, A.Float) as e1, op, e2) | SBinop ((A.Matrix, A.Float) as e1, op, e2) -> *)
        | SBinop ((A.Matrix, A.Float) as e1, op, e2) ->
        let e1' = expr builder e1 local_vars
        and e2' = expr builder e2 local_vars in
        (match op with
          A.Mult  -> L.build_call scale_matrix_f [| (expr builder e1 local_vars); (expr builder e2 local_vars) |] "scale_matrix" builder
        | A.Exp   -> L.build_call exp_matrix_f [| (expr builder e1 local_vars); (expr builder e2 local_vars) |] "exp_matrix" builder
        ) e1' e2' "tmp" builder *)
      | SBinop (e1, op, e2) ->
        let e1' = expr builder e1 local_vars
        and e2' = expr builder e2 local_vars in
        (match op with
          A.Add     -> L.build_add
        | A.Sub     -> L.build_sub
        | A.Mult    -> L.build_mul
        | A.Div     -> L.build_sdiv
        | A.And     -> L.build_and
        | A.Or      -> L.build_or
        | A.Equal   -> L.build_icmp L.Icmp.Eq
        | A.Neq     -> L.build_icmp L.Icmp.Ne
        | A.Less    -> L.build_icmp L.Icmp.Slt
        | A.Leq     -> L.build_icmp L.Icmp.Sle
        | A.Greater -> L.build_icmp L.Icmp.Sgt
        | A.Geq     -> L.build_icmp L.Icmp.Sge
        ) e1' e2' "tmp" builder
      | SUnop(op, ((t, _) as e)) ->
        let e' = expr builder e local_vars in
        (match op with
          A.Neg when t = A.Float -> L.build_fneg 
        | A.Neg                  -> L.build_neg
        | A.Not                  -> L.build_not
        ) e' "tmp" builder
      | SMatrixAccess (s, e1, e2) -> L.build_call access_f [| (expr builder s local_vars); (expr builder e1 local_vars); (expr builder e2 local_vars) |] "access" builder
      | SImageRedAccess (s) -> let img_val = StringMap.find s local_vars
                              in let pointer_to_red = L.build_struct_gep img_val 0 "red" builder
                              in (L.build_load pointer_to_red "red" builder)
      | SImageGreenAccess (s) -> let img_val = StringMap.find s local_vars
                              in let pointer_to_green = L.build_struct_gep img_val 1 "green" builder
                              in (L.build_load pointer_to_green "green" builder)
      | SImageBlueAccess (s) -> let img_val = StringMap.find s local_vars
                              in let pointer_to_blue = L.build_struct_gep img_val 2 "blue" builder
                              in (L.build_load pointer_to_blue "blue" builder)
      | SImageGrayscaleAccess (s) -> let img_val = StringMap.find s local_vars
                              in let pointer_to_grayscale = L.build_struct_gep img_val 3 "grayscale" builder
                              in (L.build_load pointer_to_grayscale "grayscale" builder)
      | SMatrixRows (s) -> let mat_val = StringMap.find s local_vars
                              in let pointer_to_rows = L.build_struct_gep mat_val 0 "rows" builder
                              in (L.build_load pointer_to_rows "rows" builder)
      | SMatrixCols (s) -> let mat_val = StringMap.find s local_vars
                              in let pointer_to_cols = L.build_struct_gep mat_val 1 "cols" builder
                              in (L.build_load pointer_to_cols "cols" builder)
      (* Match built-in function names *)
      | SCall ("print", [e]) ->
	      L.build_call printf_func [| int_format_str ; (expr builder e local_vars) |]
	      "printf" builder
      | SCall ("printf", [e]) -> 
	      L.build_call printf_func [| float_format_str ; (expr builder e local_vars) |]
	      "printf" builder
      | SCall ("image_in", [e1; e2]) ->
        L.build_call image_in_f [| (expr builder e1 local_vars) ; (expr builder e2 local_vars) |]
        "image_in" builder
      | SCall ("image_out", [e1; e2; e3]) ->
        L.build_call image_out_f [| (expr builder e1 local_vars) ; (expr builder e2 local_vars) ; (expr builder e3 local_vars) |]
        "image_out" builder
      | SCall ("convolute", [e1; e2]) ->
        L.build_call image_out_f [| (expr builder e1 local_vars) ; (expr builder e2 local_vars) |]
        "convolute" builder
      | SCall ("join", [e1; e2; e3]) ->
        L.build_call join_grayscale_f [| (expr builder e1 local_vars) |]
        "join" builder
      | SCall ("join", [e1]) ->
        L.build_call join_color_f [| (expr builder e1 local_vars) ; (expr builder e2 local_vars) ; (expr builder e3 local_vars) |]
        "join" builder
      | SCall (f, args) ->
        let (fdef, fdecl) = StringMap.find f function_decls in
          (* let llargs = List.rev (List.map (expr builder) (List.rev args)) in *)
          let llargs = List.rev (List.map ( fun x -> expr builder x local_vars ) (List.rev args)) in
          let result = (match fdecl.styp with 
                        A.Void -> ""
                      | _ -> f ^ "_result") in
          L.build_call fdef (Array.of_list llargs) result builder
        in
    
      (* LLVM insists each basic block end with exactly one "terminator" 
        instruction that transfers control.  This function runs "instr builder"
        if the current block does not already have a terminator.  Used,
        e.g., to handle the "fall off the end of the function" case. *)
      let add_terminal builder instr =
        match L.block_terminator (L.insertion_block builder) with
          Some _ -> ()
        | None -> ignore (instr builder)
      in
	
    (* Build the code for the given statement; return the builder for
       the statement's successor (i.e., the next instruction will be built
       after the one generated by this call) *)
    let rec stmt builder lvals = function
	      SBlock sl -> List.fold_left stmt builder local_vars sl
      | SExpr e -> ignore(expr builder e local_vars); builder
      | SVariable (t, n, e) ->
        (* local_vars := List.fold_left add_local local_vars [(t, n)]
        if e != Noexpr then (
          let e' = expr builder e local_vars in ignore(L.build_store e' (lookup s) builder); e'
        ) *)
        let local_var = L.build_alloca (ltype_of_typ typ) name builder
        in let new_local_vars = StringMap.add name local_var local_vars
        in ignore (L.build_store (fst (expr builder e new_local_vars)) local_var builder);
        (builder, new_local_vars)
      | SMatrixAssign (t, s, e, rows, cols) ->
          let local_var = L.build_alloca (ltype_of_typ t) s builder
          in let new_local_vars = StringMap.add s local_var local_vars
          in ignore (L.build_store (fst (expr builder (L.build_call initMatrix_f [| (expr builder e local_vars); (expr builder e1 local_vars); (expr builder e2 local_vars) |] "initMatrix" builder) new_local_vars)) local_var builder);
          (builder, new_local_vars)
      | SMatrixAccessAssign (s, e1, e2, e3) ->
        (* LOOK HERE *)
        L.build_call accessAssign_f [| (expr builder s local_vars); (expr builder e1 local_vars); (expr builder e2 local_vars); (expr builder e3 local_vars) |] "accessAssign" builder
      | SReturn e -> ignore(match fdecl.styp with
                              (* Special "return nothing" instr *)
                              A.Void -> L.build_ret_void builder 
                              (* Build return statement *)
                            | _ -> L.build_ret (expr builder e local_vars) builder );
                     builder
      | SIf (predicate, then_stmt, else_stmt) ->
        let bool_val = expr builder predicate local_vars in
        let merge_bb = L.append_block context "merge" the_function in
        let build_br_merge = L.build_br merge_bb in (* partial function *)

          let then_bb = L.append_block context "then" the_function in
          add_terminal (stmt (L.builder_at_end context then_bb) local_vars then_stmt)
            build_br_merge;

          let else_bb = L.append_block context "else" the_function in
          add_terminal (stmt (L.builder_at_end context else_bb) local_vars else_stmt)
            build_br_merge;

        ignore(L.build_cond_br bool_val then_bb else_bb builder);
        L.builder_at_end context merge_bb

      | SWhile (predicate, body) ->
        let pred_bb = L.append_block context "while" the_function in
        ignore(L.build_br pred_bb builder);

        let body_bb = L.append_block context "while_body" the_function in
        add_terminal (stmt (L.builder_at_end context body_bb) local_vars body)
          (L.build_br pred_bb);

        let pred_builder = L.builder_at_end context pred_bb in
          let bool_val = expr pred_builder predicate local_vars in

            let merge_bb = L.append_block context "merge" the_function in
            ignore(L.build_cond_br bool_val body_bb merge_bb pred_builder);
          L.builder_at_end context merge_bb

          (* Implement for loops as while loops *)
      | SFor (e1, e2, e3, body) -> stmt builder local_vars
          ( SBlock [SExpr e1 ; SWhile (e2, SBlock [body ; SExpr e3]) ] )
        in

        (* Build the code for each statement in the function *)
        let builder = stmt builder local_vars (SBlock fdecl.sbody) in

        (* Add a return if the last block falls off the end *)
        add_terminal builder (match fdecl.styp with
            A.Void -> L.build_ret_void
          | A.Float -> L.build_ret (L.const_float float_t 0.0)
          | t -> L.build_ret (L.const_int (ltype_of_typ t) 0))
  in

  List.iter build_function_body decls;
  the_module