(* Code generation: translate takes a semantically checked AST and
produces LLVM IR

LLVM tutorial: Make sure to read the OCaml version of the tutorial

http://llvm.org/docs/tutorial/index.html

Detailed documentation on the OCaml LLVM library:

http://llvm.moe/
http://llvm.moe/ocaml/

Author: Boya Song, Yuli Han
Copyright 2018, MathLight

*)

module L = Llvm
module A = Ast
open Sast

module StringMap = Map.Make(String)

(* translate : Sast.program -> Llvm.module *)
let translate (globals, functions) =
  let context    = L.global_context () in

  (* Create the LLVM compilation module into which
     we will generate code *)
  let the_module = L.create_module context "MathLight" in

  (* Get types from the context *)
  let i32_t      = L.i32_type    context
  and i8_t       = L.i8_type     context
  and i1_t       = L.i1_type     context
  and double_t    = L.double_type context
  and void_t     = L.void_type   context
  and array_t m n = (L.array_type (L.array_type (L.double_type context) n) m) in

  (* Return the LLVM type for a MicroC type *)
  let ltype_of_typ = function
    | A.String -> L.pointer_type (L.array_type i8_t 100)
    | A.Void  -> void_t
    | A.Int   -> i32_t
    | A.Bool  -> i1_t
    | A.Double -> double_t
    | A.Matrix -> L.pointer_type (array_t 10 10)
  in

  let global_vars = Hashtbl.create 12 in

  let printf_t : L.lltype =
      L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let printf_func : L.llvalue =
      L.declare_function "printf" printf_t the_module in

  let sqrt_t : L.lltype =
      L.function_type double_t [| double_t |] in
  let sqrt_func : L.llvalue =
      L.declare_function "sqrt" sqrt_t the_module in

  let abs_t : L.lltype =
      L.function_type i32_t [| i32_t |] in
  let abs_func : L.llvalue =
      L.declare_function "abs" abs_t the_module in

  let fabs_t : L.lltype =
      L.function_type double_t [| double_t |] in
  let fabs_func : L.llvalue =
      L.declare_function "fabs" fabs_t the_module in

  let pow_t : L.lltype =
      L.function_type double_t [| double_t; double_t |] in
  let pow_func : L.llvalue =
      L.declare_function "pow" pow_t the_module in

  let log_t : L.lltype =
      L.function_type double_t [| double_t|] in
  let log_func : L.llvalue =
      L.declare_function "log" log_t the_module in
  let matrix_size = Hashtbl.create 12 in

  (* Define each function (arguments and return type) so we can
     call it even before we've created its body *)
  let function_decls : (L.llvalue * sfunc_decl) StringMap.t =
    let function_decl m fdecl =
      let name = fdecl.sfname
      and formal_types =
  Array.of_list (List.map (fun (t,_, (_, _), _) -> ltype_of_typ t) fdecl.sargs)
      in let ftype = L.function_type (ltype_of_typ fdecl.styp) formal_types in
      StringMap.add name (L.define_function name ftype the_module, fdecl) m in
    List.fold_left function_decl StringMap.empty functions in

  (* Fill in the body of the given function *)
  let build_function_body fdecl =
    let (the_function, _) = StringMap.find fdecl.sfname function_decls in
    let builder = L.builder_at_end context (L.entry_block the_function) in

    let int_format_str = L.build_global_stringptr "%d\n" "fmt" builder
    and string_format_str = L.build_global_stringptr "%s\n" "fmt" builder
    and double_format_str = L.build_global_stringptr "%g\n" "fmt" builder
    and matrix_format_str = L.build_global_stringptr "%g " "fmt" builder
    and return_format_str = L.build_global_stringptr "\n" "fmt" builder in
    (* Construct the function's "locals": formal arguments and locally
       declared variables.  Allocate each on the stack, initialize their
       value, if appropriate, and remember their values in the "locals" map *)

    (* Return the value for a variable or formal argument.
       Check local names first, then global names *)
    let local_vars = Hashtbl.create 10 in
    let lookup n = try Hashtbl.find local_vars n
                   with Not_found -> Hashtbl.find global_vars n
    in

    let rec range i j = if i > j then [] else i :: (range (i+1) j)

    in

    let extract_element ptr i j builder = (L.build_load (L.build_gep ptr [|L.const_int i32_t 0; L.const_int i32_t i; L.const_int i32_t j|] "tmp" builder) "tmp" builder) in

    let get_matrix_row ptr builder = L.array_length (L.type_of (L.build_load ptr "tmp" builder)) in
    let get_matrix_col ptr builder = L.array_length (L.type_of (L.build_load (L.build_gep ptr [|L.const_int i32_t 0; L.const_int i32_t 0|] "tmp" builder) "indexing"  builder)) in
    let get_element_address ptr i j builder =  (L.build_gep ptr [|L.const_int i32_t 0; L.const_int i32_t i; L.const_int i32_t j|]  "tmp" builder) in
    (* Construct code for an expression; return its value *)
    let rec expr builder (((_, (g_row, g_col)), e) : sexpr) = match e with
        SStrLit i  -> let m = String.length i in
                         let str = L.build_alloca (L.array_type i8_t (m + 1)) "res" builder in
                         let idxs = range 0 (m - 1) in
                           List.iter (fun idx -> ignore(L.build_store  (L.const_int i8_t (int_of_char (String.get i idx))) (L.build_gep str [|L.const_int i32_t 0; L.const_int i32_t idx|]  "tmp" builder)  builder)) idxs; ignore(L.build_store  (L.const_int i8_t 0) (L.build_gep str [|L.const_int i32_t 0; L.const_int i32_t m|]  "tmp" builder)  builder); str
      | SIntLit i -> L.const_int i32_t i
      | SBoolLit i -> L.const_int i1_t (if i then 1 else 0)
      | SDoubleLit i -> L.const_float double_t i
      | SMatrixLit s -> let m = Array.length s and n = Array.length (Array.get s 0) in
                         let matrix = L.build_malloc (array_t m n) "res" builder in
                            let row_idxs = range 0 (m - 1) in
                           let col_idxs = range 0 (n-1) in
                           List.iter (fun row_idx -> List.iter (fun idx -> ignore(L.build_store  (L.const_float double_t (Array.get (Array.get s row_idx) idx)) (L.build_gep matrix [|L.const_int i32_t 0; L.const_int i32_t row_idx; L.const_int i32_t idx|]  "tmp" builder)  builder)) col_idxs; ) row_idxs; matrix
      | SMatrixOp (s1, op, s2) -> let s1' = L.build_load (lookup s1) s1 builder and s2' = L.build_load (lookup s2) s2 builder in
                                  let m1 = get_matrix_row s1' builder and
                                  n1 = get_matrix_col s1' builder and
                                  m2 = get_matrix_row s2' builder and
                                  n2 = get_matrix_col s2' builder in
                                  let n = n1 + n2 and m = m1 + m2 in
                                  (match op with
                                  A.Comma -> let result = L.build_malloc (array_t m1 n) "res" builder in
                                    for i = 0 to m1-1 do
                                      for j = 0 to n1+n2-1 do
                                        if j < n1 then
                                          ignore(L.build_store (extract_element s1' i j builder) (get_element_address result i j builder)  builder)
                                        else
                                          ignore(L.build_store (extract_element s2' i (j-n1) builder) (get_element_address result i j builder)  builder)
                                      done
                                    done;
                                    result
                                  | _ -> let result = L.build_malloc (array_t m n1) "res" builder in
                                    for i = 0 to m1 + m2 - 1 do
                                      for j = 0 to n1 do
                                        if i < m1 then
                                          ignore(L.build_store  (extract_element s1' i j builder) (get_element_address result i j builder)  builder)
                                        else
                                          ignore(L.build_store  (extract_element s2' (i-m1) j builder) (get_element_address result i j builder)  builder)
                                      done
                                    done;
                                    result)
      | SBinop (e1, op, e2) ->
        let e1' = expr builder e1 and
        e2' = expr builder e2 in
        let type_of_e1 = L.classify_type (L.type_of e1') and
        type_of_e2 = L.classify_type (L.type_of e2') in
        if type_of_e1 = L.TypeKind.Double && type_of_e2 = L.TypeKind.Double then
        (match op with
          A.Add     -> L.build_fadd e1' e2' "tmp" builder
        | A.Sub     -> L.build_fsub e1' e2' "tmp" builder
        | A.Mult    -> L.build_fmul e1' e2' "tmp" builder
        | A.Div     -> L.build_fdiv e1' e2' "tmp" builder
        | A.Equal   -> L.build_fcmp L.Fcmp.Oeq e1' e2' "tmp" builder
        | A.Neq     -> L.build_fcmp L.Fcmp.One e1' e2' "tmp" builder
        | A.Less    -> L.build_fcmp L.Fcmp.Olt e1' e2' "tmp" builder
        | A.Leq     -> L.build_fcmp L.Fcmp.Ole e1' e2' "tmp" builder
        | A.Greater -> L.build_fcmp L.Fcmp.Ogt e1' e2' "tmp" builder
        | A.Geq     -> L.build_fcmp L.Fcmp.Oge e1' e2' "tmp" builder
        | A.Pow     -> L.build_call pow_func [| e1'; e2' |] "pow" builder
        | _ -> raise (Failure "internal error: semant should have rejected on doubles")
        )  else if type_of_e1 = L.TypeKind.Double && type_of_e2 = L.TypeKind.Integer then
        (match op with
          A.Add     -> L.build_fadd e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Sub     -> L.build_fsub e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Mult    -> L.build_fmul e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Div     -> L.build_fdiv e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Equal   -> L.build_fcmp L.Fcmp.Oeq e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Neq     -> L.build_fcmp L.Fcmp.One e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Less    -> L.build_fcmp L.Fcmp.Olt e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Leq     -> L.build_fcmp L.Fcmp.Ole e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Greater -> L.build_fcmp L.Fcmp.Ogt e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Geq     -> L.build_fcmp L.Fcmp.Oge e1' (L.build_uitofp e2' double_t "tmp" builder) "tmp" builder
        | A.Pow     -> L.build_call pow_func [| e1'; (L.build_uitofp e2' double_t "tmp" builder) |] "pow" builder
        | _ -> raise (Failure "internal error: semant should have rejected")
        ) else if type_of_e1 = L.TypeKind.Integer && type_of_e2 = L.TypeKind.Double then
        (match op with
          A.Add     -> L.build_fadd (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Sub     -> L.build_fsub (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Mult    -> L.build_fmul (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Div     -> L.build_fdiv (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Equal   -> L.build_fcmp L.Fcmp.Oeq (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Neq     -> L.build_fcmp L.Fcmp.One (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Less    -> L.build_fcmp L.Fcmp.Olt (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Leq     -> L.build_fcmp L.Fcmp.Ole (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Greater -> L.build_fcmp L.Fcmp.Ogt (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Geq     -> L.build_fcmp L.Fcmp.Oge (L.build_uitofp e1' double_t "tmp" builder) e2' "tmp" builder
        | A.Pow     -> L.build_call pow_func [| (L.build_uitofp e1' double_t "tmp" builder); e2' |] "pow" builder
        | _ -> raise (Failure "internal error: semant should have rejected")
        ) else if type_of_e1 = L.TypeKind.Integer && type_of_e2 = L.TypeKind.Integer then
        (match op with
          A.Add     -> L.build_add e1' e2' "tmp" builder
        | A.Sub     -> L.build_sub e1' e2' "tmp" builder
        | A.Mult    -> L.build_mul e1' e2' "tmp" builder
        | A.Div     -> L.build_sdiv e1' e2' "tmp" builder
        | A.Equal   -> L.build_icmp L.Icmp.Eq e1' e2' "tmp" builder
        | A.Neq     -> L.build_icmp L.Icmp.Ne e1' e2' "tmp" builder
        | A.Less    -> L.build_icmp L.Icmp.Slt e1' e2' "tmp" builder
        | A.Leq     -> L.build_icmp L.Icmp.Sle e1' e2' "tmp" builder
        | A.Greater -> L.build_icmp L.Icmp.Sgt e1' e2' "tmp" builder
        | A.Geq     -> L.build_icmp L.Icmp.Sge e1' e2' "tmp" builder
        | A.Pow     -> L.build_fptosi (L.build_call pow_func [| (L.build_uitofp e1' double_t "tmp" builder); (L.build_uitofp e2' double_t "tmp" builder) |] "pow" builder) i32_t "tmp" builder
        | A.And     -> L.build_and e1' e2' "tmp" builder
        | A.Or      -> L.build_or e1' e2' "tmp" builder
        | _ -> raise (Failure "internal error: semant should have rejected")
        ) else if (type_of_e1 = L.TypeKind.Double || type_of_e1 = L.TypeKind.Integer) && type_of_e2 = L.TypeKind.Pointer then
        (match op with
          A.Mult    ->   let m = (get_matrix_row e2' builder) and n = (get_matrix_col e2' builder) in
                          let result = L.build_malloc (array_t m n) "res" builder  in
                              for i = 0 to (m - 1) do
                                for j = 0 to (n - 1) do
                                  if type_of_e1 = L.TypeKind.Double then
                                  ignore(L.build_store (L.build_fmul e1' (extract_element e2' i j builder) "val" builder)  (get_element_address result i j builder)  builder)
                                  else ignore(L.build_store (L.build_fmul (L.build_uitofp e1' double_t "tmp" builder) (extract_element e2' i j builder) "val" builder)  (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
          | A.Add   ->    let m = (get_matrix_row e2' builder) and n = (get_matrix_col e2' builder) in
                          let result = L.build_alloca (array_t m n) "res" builder  in
                              for i = 0 to (m - 1) do
                                for j = 0 to (n - 1) do
                                  if type_of_e1 = L.TypeKind.Double then
                                  ignore(L.build_store (L.build_fadd e1' (extract_element e2' i j builder) "val" builder)  (get_element_address result i j builder)  builder)
                                  else ignore(L.build_store (L.build_fadd (L.build_uitofp e1' double_t "tmp" builder) (extract_element e2' i j builder) "val" builder)  (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
          | A.Sub   ->    let m = (get_matrix_row e2' builder) and n = (get_matrix_col e2' builder) in
                          let result = L.build_alloca (array_t m n) "res" builder  in
                              for i = 0 to (m - 1) do
                                for j = 0 to (n - 1) do
                                  if type_of_e1 = L.TypeKind.Double then
                                  ignore(L.build_store (L.build_fsub e1' (extract_element e2' i j builder) "val" builder)  (get_element_address result i j builder)  builder)
                                  else ignore(L.build_store (L.build_fsub (L.build_uitofp e1' double_t "tmp" builder) (extract_element e2' i j builder) "val" builder)  (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
          | A.Div   ->    let m = (get_matrix_row e2' builder) and n = (get_matrix_col e2' builder) in
                          let result = L.build_alloca (array_t m n) "res" builder  in
                              for i = 0 to (m - 1) do
                                for j = 0 to (n - 1) do
                                  if type_of_e1 = L.TypeKind.Double then
                                  ignore(L.build_store (L.build_fdiv e1' (extract_element e2' i j builder) "val" builder)  (get_element_address result i j builder)  builder)
                                  else ignore(L.build_store (L.build_fdiv (L.build_uitofp e1' double_t "tmp" builder) (extract_element e2' i j builder) "val" builder)  (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
         | _         ->    raise (Failure "internal error: semant does not support operation on given type.")
        ) else if type_of_e1 = L.TypeKind.Pointer && (type_of_e2 = L.TypeKind.Double || type_of_e2 = L.TypeKind.Integer) then (
          match op with
          A.Mult -> let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
                          let result = L.build_malloc (array_t m n) "res" builder  in
                              for i = 0 to (m - 1) do
                                for j = 0 to (n - 1) do
                                  if type_of_e2 = L.TypeKind.Double then
                                  ignore(L.build_store (L.build_fmul e2' (extract_element e1' i j builder) "val" builder)  (get_element_address result i j builder)  builder)
                                 else ignore(L.build_store (L.build_fmul (L.build_uitofp e2' double_t "tmp" builder) (extract_element e1' i j builder) "val" builder)  (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
          | A.Add -> let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
                          let result = L.build_alloca (array_t m n) "res" builder  in
                              for i = 0 to (m - 1) do
                                for j = 0 to (n - 1) do
                                  if type_of_e2 = L.TypeKind.Double then
                                  ignore(L.build_store (L.build_fadd e2' (extract_element e1' i j builder) "val" builder)  (get_element_address result i j builder)  builder)
                                 else ignore(L.build_store (L.build_fadd (L.build_uitofp e2' double_t "tmp" builder) (extract_element e1' i j builder) "val" builder)  (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
          | A.Sub -> let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
                          let result = L.build_alloca (array_t m n) "res" builder  in
                              for i = 0 to (m - 1) do
                                for j = 0 to (n - 1) do
                                  if type_of_e2 = L.TypeKind.Double then
                                  ignore(L.build_store (L.build_fsub (extract_element e1' i j builder) e2' "val" builder)  (get_element_address result i j builder)  builder)
                                 else ignore(L.build_store (L.build_fsub (extract_element e1' i j builder) (L.build_uitofp e2' double_t "tmp" builder) "val" builder)  (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
          | A.Div -> let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
                          let result = L.build_alloca (array_t m n) "res" builder  in
                              for i = 0 to (m - 1) do
                                for j = 0 to (n - 1) do
                                  if type_of_e2 = L.TypeKind.Double then
                                  ignore(L.build_store (L.build_fdiv (extract_element e1' i j builder) e2' "val" builder)  (get_element_address result i j builder)  builder)
                                 else ignore(L.build_store (L.build_fdiv (extract_element e1' i j builder) (L.build_uitofp e2' double_t "tmp" builder) "val" builder)  (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
         | _         ->    raise (Failure "internal error: semant does not support operation on given type.")
        )else if type_of_e1 = L.TypeKind.Pointer && type_of_e2 = L.TypeKind.Pointer then
        (match op with
          A.Add     ->    let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
                          let result = L.build_malloc (array_t m n) "res" builder  in
                              for i = 0 to m - 1 do
                                for j = 0 to (n - 1) do
                                  ignore(L.build_store (L.build_fadd (extract_element e1' i j builder) (extract_element e2' i j builder) "tmp" builder) (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
        | A.Sub   ->      let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
                          let result = L.build_malloc (array_t m n) "res" builder  in
                              for i = 0 to m - 1 do
                                for j = 0 to (n - 1) do
                                  ignore(L.build_store (L.build_fsub (extract_element e1' i j builder) (extract_element e2' i j builder) "tmp" builder) (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
        | A.Mult  ->      let matrix1 = L.build_load e1' "tmp" builder and matrix2 = L.build_load e2' "tmp" builder in
                          let m = L.array_length (L.type_of matrix1) and k = L.array_length (L.type_of matrix2) and
                          n = L.array_length (L.type_of (L.build_load (L.build_gep e2' [|L.const_int i32_t 0; L.const_int i32_t 0|] "tmp" builder) "indexing"  builder)) in
                            let result = L.build_malloc (array_t m n) "res" builder  in
                            for i = 0 to m - 1 do
                              for j = 0 to (n - 1) do
                                ignore(L.build_store (L.const_float double_t 0.0) (get_element_address result i j builder) builder);
                                for p = 0 to k - 1 do
                                  let prod = L.build_fmul (extract_element e1' i p builder) (extract_element e2' p j builder) "prod" builder in
                                  let sum = L.build_fadd prod (extract_element result i j builder) "sum" builder in
                                  ignore(L.build_store sum (get_element_address result i j builder) builder);
                                done
                              done
                            done;
                        result
        | A.Dotmul  ->    let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
                          let result = L.build_malloc (array_t m n) "res" builder  in
                              for i = 0 to m - 1 do
                                for j = 0 to (n - 1) do
                                  ignore(L.build_store (L.build_fmul (extract_element e1' i j builder) (extract_element e2' i j builder) "tmp" builder) (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
        | A.Dotdiv  ->    let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
                          let result = L.build_malloc (array_t m n) "res" builder  in
                              for i = 0 to m - 1 do
                                for j = 0 to (n - 1) do
                                  ignore(L.build_store (L.build_fdiv (extract_element e1' i j builder) (extract_element e2' i j builder) "tmp" builder) (get_element_address result i j builder)  builder);
                                done
                              done;
                          result
        | _ -> raise (Failure "internal error: semant should have rejected")
        ) else raise (Failure "internal error: semant does not support operation on given type.")
      | SId s       -> L.build_load (lookup s) s builder
      | SAssign (s, (((typ, _), _) as e)) -> let e' = expr builder e in
                           (match typ with
                           A.String -> let m = get_matrix_row e' builder in
                            let dest = L.build_bitcast (lookup s) (L.pointer_type (L.pointer_type (L.array_type i8_t m))) "tmp" builder in
                             ignore(L.build_store e' dest builder); e'
                           | _ -> ignore(L.build_store e' (lookup s) builder); e')
      | SNoexpr     -> L.const_int i32_t 0
      | SMatrix1DElement (name, e) -> let matrix_ptr = (L.build_load (lookup name) name builder) and e' = expr builder e in
                    (L.build_load (L.build_gep matrix_ptr [|L.const_int i32_t 0; L.const_int i32_t 0; e'|] "tmp" builder) "indexing"  builder)
      | SMatrix2DElement (name, e1, e2) -> let matrix_ptr = (L.build_load (lookup name) name builder) and e1' = expr builder e1 and e2' = expr builder e2 in
                    (L.build_load (L.build_gep matrix_ptr [|L.const_int i32_t 0; e1'; e2'|] "tmp" builder) "indexing" builder)
      | SMatrix1DModify (name, e1, e2) -> let matrix_ptr = (L.build_load (lookup name) name builder) and e1' = expr builder e1 and                                  e2' = expr builder e2 in
                                      (L.build_store e2' (L.build_gep matrix_ptr [|L.const_int i32_t 0; e1'|] "tmp" builder) builder)
      | SMatrix2DModify (name, (e_row, e_col), e) -> let matrix_ptr = (L.build_load (lookup name) name builder) and er' = expr builder e_row and ec' = expr builder e_col and e' = expr builder e in
                    (L.build_store e' (L.build_gep matrix_ptr [|L.const_int i32_t 0; er'; ec'|] "tmp" builder) builder)
      | SUnop(op, (((t, _), _) as e)) ->
        let e' = expr builder e in
          (match op with
            A.Neg when t = A.Double -> L.build_fneg e' "tmp" builder
            | A.Neg                  -> L.build_neg e' "tmp" builder
            | A.Not                  -> L.build_not e' "tmp" builder
            | A.Abs when t = A.Double -> L.build_call fabs_func [| e' |] "fabs" builder
            | A.Abs -> L.build_call abs_func [| e' |] "abs" builder
            | A.Transpose when t = A.Matrix -> let original = L.build_load e' "tmp" builder in
                                 let m = L.array_length (L.type_of original) and n = L.array_length (L.type_of (L.build_load (L.build_gep e' [|L.const_int i32_t 0; L.const_int i32_t 0|] "tmp" builder) "indexing"  builder)) in
                                 let row_idxs = range 0 (n - 1) in
                           let col_idxs = range 0 (m-1) in
                            let matrix = L.build_malloc (array_t n m) "res" builder in
                           List.iter (fun row_idx -> List.iter (fun idx -> ignore(L.build_store (extract_element e' idx row_idx builder) (L.build_gep matrix [|L.const_int i32_t 0; L.const_int i32_t row_idx; L.const_int i32_t idx|]  "tmp" builder)  builder)) col_idxs; ) row_idxs; matrix
            | _ -> raise (Failure "internal error: semant should have rejected."))
      | SCall ("print", [e]) -> let expr_value = expr builder e in (match e with (((typ, _), _) : sexpr) ->
           (match typ with
         | A.Matrix -> let matrix = L.build_load expr_value "tmp" builder in
                                 let m = L.array_length (L.type_of matrix) and n = L.array_length (L.type_of (L.build_load (L.build_gep expr_value [|L.const_int i32_t 0; L.const_int i32_t 0|] "tmp" builder) "indexing"  builder)) in
                                 let row_idxs = range 0 (m - 1) in
                           let col_idxs = range 0 (n-1) in
                           List.iter (fun row_idx -> List.iter (fun idx -> ignore(L.build_call printf_func [| matrix_format_str ; (L.build_load (L.build_gep expr_value [|L.const_int i32_t 0; L.const_int i32_t row_idx; L.const_int i32_t idx|] "tmp" builder) "indexing"  builder) |] "printf" builder)) col_idxs; ignore(L.build_call printf_func [| return_format_str |] "printf" builder)) row_idxs; L.const_int i32_t 0
         | A.Int | A.Bool -> L.build_call printf_func [| int_format_str ; (expr_value) |] "printf" builder
         | A.Double -> L.build_call printf_func [| double_format_str ; (expr_value) |] "printf" builder
         | A.String -> L.build_call printf_func [| string_format_str; (expr builder e) |] "printf" builder
         | A.Void -> raise (Failure ("Void type cannot be printed"))))
      | SCall ("sqrt", [e]) -> L.build_call sqrt_func [| (expr builder e) |] "sqrt" builder
      | SCall ("log", [e]) -> L.build_call log_func [| (expr builder e) |] "log" builder
      | SCall ("fill", [(_, row); (_, col); value]) -> (match (row, col) with
        (SIntLit m, SIntLit n) -> let v = expr builder value in
                         let matrix = L.build_malloc (array_t m n) "res" builder in
                            let row_idxs = range 0 (m - 1) in
                           let col_idxs = range 0 (n-1) in
                           List.iter (fun row_idx -> List.iter (fun idx -> ignore(L.build_store v (L.build_gep matrix [|L.const_int i32_t 0; L.const_int i32_t row_idx; L.const_int i32_t idx|]  "tmp" builder)  builder)) col_idxs; ) row_idxs; matrix
        | _ -> raise (Failure ("The first two arguments of function fill must be integer")))
      | SCall ("det", [e]) -> let e' = expr builder e in
          let m = (get_matrix_row e' builder) and n = (get_matrix_col e' builder) in
          let matrix_pointer = L.pointer_type (array_t m n) in
          let determinant_t : L.lltype = L.function_type double_t [| matrix_pointer; i32_t; i32_t |] in
          let determinant_func : L.llvalue = L.declare_function "determinant" determinant_t the_module in
          if m <> n then raise (Failure ("Cannot compute determinant for given matrix. Row size and col size must be same."))
          else L.build_call determinant_func [| e'; (L.const_int i32_t m); (L.const_int i32_t n)|] "determinant" builder
      | SCall ("max_eigvalue", [e]) -> let e' = expr builder e in
          let m = (get_matrix_row e' builder) and n = (get_matrix_col e' builder) in
          let matrix_pointer = L.pointer_type (array_t m n) in
          let max_eigvalue_t : L.lltype = L.function_type double_t [| matrix_pointer; i32_t; i32_t |] in
          let max_eigvalue_func : L.llvalue = L.declare_function "max_eigvalue" max_eigvalue_t the_module in
          if m <> n then raise (Failure ("Codegen: Cannot compute the max eigvalue for given matrix."))
          else L.build_call max_eigvalue_func [| e'; (L.const_int i32_t m); (L.const_int i32_t n)|] "max_eigvalue" builder
      | SCall ("norm1", [e]) -> let e' = expr builder e in
          let m = (get_matrix_row e' builder) and n = (get_matrix_col e' builder) in
          let matrix_pointer = L.pointer_type (array_t m n) in
          let norm1_t : L.lltype = L.function_type double_t [| matrix_pointer; i32_t; i32_t |] in
          let norm1_func : L.llvalue = L.declare_function "norm1" norm1_t the_module in
          L.build_call norm1_func [| e'; (L.const_int i32_t m); (L.const_int i32_t n)|] "norm1" builder
      | SCall ("norm2", [e]) -> let e' = expr builder e in
          let m = (get_matrix_row e' builder) and n = (get_matrix_col e' builder) in
          let matrix_pointer = L.pointer_type (array_t m n) in
          let norm2_t : L.lltype = L.function_type double_t [| matrix_pointer; i32_t; i32_t |] in
          let norm2_func : L.llvalue = L.declare_function "norm2" norm2_t the_module in
          L.build_call norm2_func [| e'; (L.const_int i32_t m); (L.const_int i32_t n)|] "norm2" builder
      | SCall ("tr", [e1]) -> let e1' = expr builder e1  in
          let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
          let result = (L.build_alloca (array_t 1 1) "res" builder) in
            if m <> n then raise (Failure ("Cannot compute trace for given matrix. Row size and col size must be same."))
           else (
            ignore(L.build_store (L.const_float double_t 0.0) (get_element_address result 0 0 builder) builder);
            for i = 0 to m - 1 do
              ignore(L.build_store (L.build_fadd (extract_element e1' i i builder) (extract_element result 0 0 builder) "tmp" builder) (get_element_address result 0 0 builder) builder);
            done;
            (extract_element result 0 0 builder))
      | SCall ("sum_row", [e1]) -> let e1' = expr builder e1  in
          let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
          let result = (L.build_malloc (array_t m 1) "res" builder) in
            for i = 0 to m - 1 do
                ignore(L.build_store (L.const_float double_t 0.0) (get_element_address result i 0 builder) builder);
                for j = 0 to n - 1 do
                  ignore(L.build_store (L.build_fadd (extract_element e1' i j builder) (extract_element result i 0 builder) "tmp" builder) (get_element_address result i 0 builder) builder);
                done
            done;
            result
      | SCall ("sum_col", [e1]) -> let e1' = expr builder e1  in
          let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
          let result = (L.build_malloc (array_t 1 n) "res" builder) in
          for j = 0 to n - 1 do
              ignore(L.build_store (L.const_float double_t 0.0) (get_element_address result 0 j builder) builder);
              for i = 0 to m - 1 do
                ignore(L.build_store (L.build_fadd (extract_element e1' i j builder) (extract_element result 0 j builder) "tmp" builder) (get_element_address result 0 j builder) builder);
              done
          done;
          result
      | SCall ("sizeof_row", [e1]) -> let e1' = expr builder e1  in
          let m = (get_matrix_row e1' builder) in
          let result = (L.build_alloca (L.array_type (L.array_type i32_t 1) 1) "res" builder) in
          ignore(L.build_store (L.const_int i32_t m) (get_element_address result 0 0 builder) builder);
          (extract_element result 0 0 builder)
      | SCall ("sizeof_col", [e1]) -> let e1' = expr builder e1  in
          let m = (get_matrix_col e1' builder) in
          let result = (L.build_alloca (L.array_type (L.array_type i32_t 1) 1) "res" builder) in
          ignore(L.build_store (L.const_int i32_t m) (get_element_address result 0 0 builder) builder);
          (extract_element result 0 0 builder)
      | SCall ("mean_row", [e1]) -> let e1' = expr builder e1  in
          let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
          let result = (L.build_malloc (array_t m 1) "res" builder)  and len = (L.build_alloca double_t "s" builder) in
            for i = 0 to m - 1 do
                ignore(L.build_store (L.const_float double_t 0.0) (get_element_address result i 0 builder) builder);
                ignore(L.build_store (L.const_float double_t 0.0) len builder);
                for j = 0 to n - 1 do
                  ignore(L.build_store (L.build_fadd (extract_element e1' i j builder) (extract_element result i 0 builder) "tmp" builder) (get_element_address result i 0 builder) builder);
                  ignore(L.build_store (L.build_fadd (L.build_load len "len" builder) (L.const_float double_t 1.0) "add" builder) len builder);
                done;
                ignore(L.build_store (L.build_fdiv (extract_element result i 0 builder) (L.build_load len "l" builder) "div" builder) (get_element_address result i 0 builder) builder);
            done;
            result
      | SCall ("mean_col", [e1]) -> let e1' = expr builder e1  in
          let m = (get_matrix_row e1' builder) and n = (get_matrix_col e1' builder) in
          let result = (L.build_malloc (array_t 1 n) "res" builder) and len = (L.build_alloca double_t "s" builder) in
            for j = 0 to n - 1 do
                ignore(L.build_store (L.const_float double_t 0.0) (get_element_address result 0 j builder) builder);
                ignore(L.build_store (L.const_float double_t 0.0) len builder);
                for i = 0 to m - 1 do
                  ignore(L.build_store (L.build_fadd (extract_element e1' i j builder) (extract_element result 0 j builder) "tmp" builder) (get_element_address result 0 j builder) builder);
                  ignore(L.build_store (L.build_fadd (L.build_load len "len" builder) (L.const_float double_t 1.0) "add" builder) len builder);
                done;
                ignore(L.build_store (L.build_fdiv (extract_element result 0 j builder) (L.build_load len "l" builder) "tmp" builder) (get_element_address result 0 j builder) builder);
            done;
          result
      | SCall ("inv", [e]) -> let e' = expr builder e in
          let m = (get_matrix_row e' builder) and n = (get_matrix_col e' builder) in
          if m <> n then raise (Failure ("Codegen: Cannot compute inverse for given matrix. Row size and col size must be same."))
          else (
          let matrix_pointer = L.pointer_type (array_t m n) in
          let inverse_t : L.lltype = L.function_type matrix_pointer [| matrix_pointer; i32_t|] in
          let inverse_func : L.llvalue = L.declare_function "inverse" inverse_t the_module in
          L.build_call inverse_func [| e'; (L.const_int i32_t m)|] "inverse" builder)
      | SCall (f, args) ->
        let (fdef, fdecl) = StringMap.find f function_decls in
           let cast value ((t, _), _) =
              (match t with
              A.Matrix -> L.build_bitcast value (ltype_of_typ t) "tmp" builder
            |_ -> value) in
           let llargs = List.rev (List.map (expr builder) (List.rev args)) in
           let save_to_map e (t, n, _, _) = (match t with
           A.Matrix -> (let e' = expr builder e in
              let row = get_matrix_row e' builder and col = get_matrix_col e' builder in
              Hashtbl.add matrix_size (String.concat "" [fdecl.sfname; n]) (row, col))
          |_ -> ()) in
           let llargs_cast = ignore(List.iter2 save_to_map args fdecl.sargs); List.map2 cast llargs args in
           let result = (match fdecl.styp with
              A.Void -> ""
            | _ -> f ^ "_result") in
        let result = L.build_call fdef (Array.of_list llargs_cast) result builder in
        (match fdecl.styp with
        A.Matrix -> (L.build_bitcast result (L.pointer_type (array_t g_row g_col)) "tmp" builder)
        | _ -> result)
      | _ -> raise (Failure ("Codegen: expr not supported"))
    in

(* Create a map of global variables after creating each *)
  let global_var (t, n, (row_size, col_size), ((_, e) as e')) =
    let init = (match e with
    SNoexpr -> (match t with
        A.String -> L.const_string context "0"
      | A.Double -> L.const_float (ltype_of_typ t) 0.0
      | A.Matrix -> L.const_pointer_null (L.pointer_type (array_t row_size col_size))
      | _ -> L.const_int (ltype_of_typ t) 0
    )
    |_ -> expr builder e')
     in Hashtbl.add global_vars n (L.define_global n init the_module) in
  List.iter global_var globals;

let add_formal (t, n, (_, _), _) p =
      let ptr = (match t with
          A.Matrix ->  let (r, c) = Hashtbl.find matrix_size (String.concat "" [fdecl.sfname; n]) in
          (L.build_bitcast p (L.pointer_type (array_t r c)) "tmp" builder)
          | _ -> p) in
        L.set_value_name n ptr;
      let local = (match t with
       A.Matrix -> let (r, c) = Hashtbl.find matrix_size (String.concat "" [fdecl.sfname; n]) in
       L.build_alloca (L.pointer_type (array_t r c)) n builder
       | _ -> L.build_alloca (ltype_of_typ t) n builder) in
        ignore (L.build_store ptr local builder); Hashtbl.add local_vars n local in
        List.iter2 add_formal fdecl.sargs (Array.to_list (L.params the_function));

      (* Allocate space for any locally declared variables and add the
       * resulting registers to our map *)
  let add_local (t, n, (row_size, col_size), ((_, e) as e')) =
  let local_var = (match t with
       A.Matrix -> L.build_alloca (L.pointer_type (array_t row_size col_size)) n builder
       | _ -> L.build_alloca (ltype_of_typ t) n builder)
  in ignore (if e <> SNoexpr then let el = expr builder e' in
       (match t with A.String -> let m = get_matrix_row el builder in
                            let dest = L.build_bitcast local_var (L.pointer_type (L.pointer_type (L.array_type i8_t m))) "tmp" builder in
                             ignore(L.build_store el dest builder)
                          | _ -> ignore(L.build_store el local_var builder))); Hashtbl.add local_vars n local_var
      in List.iter add_local fdecl.slocals;

    (* LLVM insists each basic block end with exactly one "terminator"
       instruction that transfers control.  This function runs "instr builder"
       if the current block does not already have a terminator.  Used,
       e.g., to handle the "fall off the end of the function" case. *)
    let add_terminal builder instr =
      match L.block_terminator (L.insertion_block builder) with
  Some _ -> ()
      | None -> ignore (instr builder) in

    (* Build the code for the given statement; return the builder for
       the statement's successor (i.e., the next instruction will be built
       after the one generated by this call) *)

    let rec stmt builder = function
        SBlock sl -> List.fold_left stmt builder sl
      | SExpr e -> ignore(expr builder e); builder
      | SReturn e -> ignore(match fdecl.styp with
          (* Special "return nothing" instr *)
            A.Void -> L.build_ret_void builder
          (* Build return statement *)
          | A.Matrix -> let e' = expr builder e in
              L.build_ret (L.build_bitcast e' (ltype_of_typ A.Matrix) "tmp" builder) builder
          | _ -> L.build_ret (expr builder e) builder );
        builder
      | SIf (predicate, then_stmt, else_stmt) ->
         let bool_val = expr builder predicate in
     let merge_bb = L.append_block context "merge" the_function in
           let build_br_merge = L.build_br merge_bb in (* partial function *)

     let then_bb = L.append_block context "then" the_function in
     add_terminal (stmt (L.builder_at_end context then_bb) then_stmt)
       build_br_merge;

     let else_bb = L.append_block context "else" the_function in
     add_terminal (stmt (L.builder_at_end context else_bb) else_stmt)
       build_br_merge;

     ignore(L.build_cond_br bool_val then_bb else_bb builder);
     L.builder_at_end context merge_bb

        | SWhile (predicate, body) ->
      let pred_bb = L.append_block context "while" the_function in
      ignore(L.build_br pred_bb builder);

      let body_bb = L.append_block context "while_body" the_function in
      add_terminal (stmt (L.builder_at_end context body_bb) body)
        (L.build_br pred_bb);

      let pred_builder = L.builder_at_end context pred_bb in
      let bool_val = expr pred_builder predicate in

      let merge_bb = L.append_block context "merge" the_function in
      ignore(L.build_cond_br bool_val body_bb merge_bb pred_builder);
      L.builder_at_end context merge_bb

        (* Implement for loops as while loops *)
        | SFor (e1, e2, e3, body) -> stmt builder
        ( SBlock [SExpr e1 ; SWhile (e2, SBlock [body ; SExpr e3]) ] )
        | _ -> raise (Failure "Codegen: Other stmts not supported yet.")
      in

    (* Build the code for each statement in the function *)
    let builder = stmt builder (SBlock fdecl.sbody) in

    (* Add a return if the last block falls off the end *)
    add_terminal builder (match (fdecl.styp, fdecl.ssize) with
        (A.Void, _) -> L.build_ret_void
      | (A.Matrix, (row, col)) -> L.build_ret (L.const_pointer_null (array_t row col))
      | (t, _) -> L.build_ret (L.const_int (ltype_of_typ t) 0))
  in

  List.iter build_function_body functions;
  the_module
