(* Code generation: translate takes a semantically checked AST and
produces LLVM IR

LLVM tutorial: Make sure to read the OCaml version of the tutorial

http://llvm.org/docs/tutorial/index.html

Detailed documentation on the OCaml LLVM library:

http://llvm.moe/
http://llvm.moe/ocaml/

*)

module L = Llvm
module A = Ast
open Sast 

module StringMap = Map.Make(String)

(* translate : Sast.ssmt -> Llvm.module *)
let translate (sstmt) =
  let context    = L.global_context () in
  
  (* Create the LLVM compilation module into which
     we will generate code *)
  let the_module = L.create_module context "MQL" in

  (* Get types from the context *)
  let i32_t      = L.i32_type    context
  and i8_t       = L.i8_type     context
  and i1_t       = L.i1_type     context
  and float_t    = L.float_type context
  and str_t      = L.pointer_type (L.i8_type context)
  and void_t     = L.void_type   context in
  
  let array_t typ length = L.array_type typ length in
  let size_of_primitive = function
      A.Int   -> 4
    | A.Bool  -> 4
    | A.Float -> 4
    | A.String -> 8
    | _ -> 0
  in

  (* Return the LLVM type for a MQL type *)
  let rec ltype_of_typ = function
      A.Int   -> i32_t
    | A.Bool  -> i1_t
    | A.Float -> float_t
    | A.Void  -> void_t
    | A.String -> str_t
    | A.Table decl ->  
      let t = List.map fst decl in
      table_t t
    | A.Column -> str_t
    and table_t t = 
      let lltype_arr = Array.of_list((List.map ltype_of_typ t)) in 
      let ptr = L.pointer_type (L.struct_type context lltype_arr) in
      L.struct_type context [|ptr ; i32_t|]
  in
  let printf_t : L.lltype = 
      L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let printf_func : L.llvalue = 
      L.declare_function "printf" printf_t the_module in
  let string_concat_t : L.lltype = 
      L.var_arg_function_type str_t [| str_t; str_t |] in
  let string_concat_func : L.llvalue = 
      L.declare_function "string_concat" string_concat_t the_module 
  in
  let fty = L.function_type i32_t [| |] in
  let f = L.define_function "main" fty the_module in
  let builder = L.builder_at_end context (L.entry_block f) in

  let float_format_str = L.build_global_stringptr "%f\n" "fmt" builder in
  let int_format_str = L.build_global_stringptr "%d\n" "fmt" builder in

  let cell builder ((e) : scell) = match e with
    SIntLit(e) ->  L.const_int i32_t e
  | SStrLit(e) ->  L.build_global_stringptr e "" builder
  | SBoolLit(e) -> L.const_int i1_t (if e then 1 else 0)
  | SFloatLit(e) -> L.const_float float_t e
  in

  let lookup n m = 
      try StringMap.find n m
      with Not_found -> raise(Failure("Variable never assigned value"))
  in

  let rec calculate_offset current_offset old_cols new_col = 
    match old_cols with
    | [] -> raise(Failure("Column not exist!"))
    | _ -> 
      let head = (List.hd old_cols)
      and tail = (List.tl old_cols) in
      let offset = size_of_primitive(fst head) + current_offset in
      if((snd head) == (snd new_col)) then current_offset
      else calculate_offset offset tail new_col
  in

  let add_terminal builder instr =
    match L.block_terminator (L.insertion_block builder) with
    Some _ -> ()
    | None -> ignore (instr builder) 
  in

  let create_table_t t: L.lltype =
    L.function_type (table_t t) [| i32_t |] in
  let create_table_func t: L.llvalue =
    L.declare_function "create_table" (create_table_t t) the_module in

  let import_table_t t len: L.lltype =
    L.function_type (table_t t) [| str_t; i32_t; L.pointer_type(array_t (L.pointer_type i8_t) len)|] in
  let import_table_func t len: L.llvalue =
    L.declare_function "import_table" (import_table_t t len) the_module in

  let select_table_t t t' len: L.lltype =
    let offset_arr = L.pointer_type (array_t i32_t len) in
    let column_type_arr = L.pointer_type (array_t (L.pointer_type i8_t) len) in
    L.function_type (table_t t') [| (table_t t); offset_arr; column_type_arr; i32_t|] in
  let select_table_func t t' len: L.llvalue =
    L.declare_function "select_table" (select_table_t t t' len) the_module in

  let insert_table_t t len: L.lltype =
    let str_arr = L.pointer_type (array_t (L.pointer_type i8_t) len) in
    L.function_type (table_t t) [| (table_t t); str_arr; str_arr; i32_t|] in
  let insert_table_func t len: L.llvalue =
    L.declare_function "insert_table" (insert_table_t t len) the_module in

  let distinct_table_t t t' len: L.lltype =
    let offset_arr = L.pointer_type (array_t i32_t len) in
    let column_type_arr = L.pointer_type (array_t (L.pointer_type i8_t) len) in
    L.function_type (table_t t') [| (table_t t); offset_arr; column_type_arr; i32_t|] in
  let distinct_table_func t t' len: L.llvalue =
    L.declare_function "distinct_table" (distinct_table_t t t' len) the_module in
  
  let print_table_func t len: L.lltype =
    L.function_type i32_t [| (table_t t); L.pointer_type(array_t (L.pointer_type i8_t) len); i32_t |] in
  let print_table_func t len: L.llvalue =
    L.declare_function "print_table" (print_table_func t len) the_module in

  let where_table_t t: L.lltype =
    L.function_type (table_t t) [| (table_t t); i32_t; str_t; str_t; str_t|] in
  let where_table_func t: L.llvalue =
    L.declare_function "where_table" (where_table_t t) the_module in
  let delete_table_t t: L.lltype =
    L.function_type (table_t t) [| (table_t t); i32_t; str_t; str_t; str_t|] in
  let delete_table_func t: L.llvalue =
    L.declare_function "delete_table" (delete_table_t t) the_module in

  let int_to_string_t: L.lltype =
    L.function_type str_t [| i32_t|] in
  let int_to_string_func: L.llvalue =
    L.declare_function "int_to_string" int_to_string_t the_module in
  let bool_to_string_t: L.lltype =
    L.function_type str_t [| i1_t|] in
  let bool_to_string_func: L.llvalue =
    L.declare_function "bool_to_string" bool_to_string_t the_module in
  let bool_to_int_t: L.lltype =
    L.function_type i32_t [| i1_t|] in
  let bool_to_int_func: L.llvalue =
    L.declare_function "bool_to_int" bool_to_int_t the_module in
  
  let float_to_string_t: L.lltype =
    L.function_type str_t [| float_t |] in
  let float_to_string_func: L.llvalue =
    L.declare_function "float_to_string" float_to_string_t the_module in
   
  let rec table_expr builder m (st: stable) = (
    let (_,e) = st in
    match e with
  | STableLit(l) -> L.build_load (lookup l m) l builder
  
  | SWhere((t_decls, t), (col_dec, op, se)) -> 
    let tt = table_expr builder m (t_decls, t) in
    let decls = (match t_decls with
      | A.Table t -> List.map fst t
      | _ -> raise(Failure("Where can only be applied to a table"))
    ) in 
    let names = (match t_decls with
      | A.Table t -> List.map snd t
      | _ -> raise(Failure("Where can only be applied to a table"))
    ) in
    let old_decls = (List.rev (List.combine decls names)) in
    let offset = calculate_offset 0 old_decls col_dec in
    let op_str = A.string_of_op op in
    let op_str' = L.build_global_stringptr op_str "op_str" builder in 
    let (se_t, _) = expr builder m se in 
    let typ = A.string_of_typ (fst col_dec) in
    let typ_str = L.build_global_stringptr typ "typ_str" builder in 
    let se_t' = 
    (match fst col_dec with
    | A.String -> se_t
    | A.Int -> L.build_call int_to_string_func [| se_t|] "int_to_string" builder
    | A.Bool -> L.build_call bool_to_string_func [| se_t|]  "boolean_to_string" builder
    | A.Float -> L.build_call float_to_string_func [| se_t|] "float_to_string" builder
    | _ -> raise(Failure("Only string, int, bool and float are supported in table elements."))
    )in 
    L.build_call (where_table_func decls) [| tt; L.const_int i32_t offset; op_str'; se_t'; typ_str |] "where_table" builder

  | SDelete((t_decls, t), (col_dec, op, se)) -> 
    let tt = table_expr builder m (t_decls, t) in
    let decls = (match t_decls with
      | A.Table t -> List.map fst t
      | _ -> raise(Failure("Select can only be applied to a table"))
    ) in 
    let names = (match t_decls with
      | A.Table t -> List.map snd t
      | _ -> raise(Failure("Select can only be applied to a table"))
    ) in
    let old_decls = (List.rev (List.combine decls names)) in
    let offset = calculate_offset 0 old_decls col_dec in
    let op_str = A.string_of_op op in
    let op_str' = L.build_global_stringptr op_str "op_str" builder in 
    let (se_t, _) = expr builder m se in 
    let typ = A.string_of_typ (fst col_dec) in
    let typ_str = L.build_global_stringptr typ "typ_str" builder in 
    let se_t' = 
    (match fst col_dec with
    | A.String -> se_t
    | A.Int -> L.build_call int_to_string_func [| se_t|] "int_to_string" builder
    | A.Bool -> L.build_call bool_to_string_func [| se_t|]  "int_to_string" builder
    | A.Float -> L.build_call float_to_string_func [| se_t|] "float_to_string" builder
    | _ -> raise(Failure("Only string, int, bool and float are supported in table elements."))
    )in 
    L.build_call (delete_table_func decls) [| tt; L.const_int i32_t offset; op_str'; se_t'; typ_str |] "delete_table" builder

  
  | SDistinct((t_decls, t), columns) ->
    let tt = table_expr builder m (t_decls, t) in
    let to_string_ptr s = L.build_global_stringptr s s builder in
    let decls = (match t_decls with
      | A.Table t -> List.map fst t
      | _ -> raise(Failure("Distinct can only be applied to a table"))
    ) in 
    let names = (match t_decls with
      | A.Table t -> List.map snd t
      | _ -> raise(Failure("Distinct can only be applied to a table"))
    ) in 
    let new_decls = (List.map fst columns) in
    let columns_type = List.rev (List.map A.string_of_typ new_decls) in 

    let const_array = L.const_array (L.pointer_type i8_t) (Array.of_list (List.map to_string_ptr columns_type)) in
    let len = (List.length columns_type) in
    let pointer = L.build_alloca (array_t (L.pointer_type i8_t) len) "arr_p" builder in
    ignore(L.build_store const_array pointer builder);

    let old_decls = (List.rev (List.combine decls names)) in
    let offset_ls = List.rev (List.map (calculate_offset 0 old_decls) columns) in
    let offset_arr = Array.of_list (List.map (L.const_int i32_t) offset_ls) in
    let const_array_offset = L.const_array i32_t offset_arr in
    
    let pointer_offset = L.build_alloca (array_t i32_t len) "offset_p" builder in
    ignore(L.build_store const_array_offset pointer_offset builder);
    L.build_call (distinct_table_func decls new_decls len) [| tt; pointer_offset; pointer; (L.const_int i32_t len)|] "distinct_table" builder
  
  | SSelect((t_decls, t), columns) ->
    let tt = table_expr builder m (t_decls, t) in
    let to_string_ptr s = L.build_global_stringptr s s builder in
    let decls = (match t_decls with
      | A.Table t -> List.map fst t
      | _ -> raise(Failure("Select can only be applied to a table"))
    ) in 
    let names = (match t_decls with
      | A.Table t -> List.map snd t
      | _ -> raise(Failure("Select can only be applied to a table"))
    ) in 
    let new_decls = (List.map fst columns) in
    let columns_type = List.rev (List.map A.string_of_typ new_decls) in 

    let const_array = L.const_array (L.pointer_type i8_t) (Array.of_list (List.map to_string_ptr columns_type)) in
    let len = (List.length columns_type) in
    let pointer = L.build_alloca (array_t (L.pointer_type i8_t) len) "arr_p" builder in
    ignore(L.build_store const_array pointer builder);

    let old_decls = (List.rev (List.combine decls names)) in
    let offset_ls = List.rev (List.map (calculate_offset 0 old_decls) columns) in
    let offset_arr = Array.of_list (List.map (L.const_int i32_t) offset_ls) in
    let const_array_offset = L.const_array i32_t offset_arr in
    
    let pointer_offset = L.build_alloca (array_t i32_t len) "offset_p" builder in
    ignore(L.build_store const_array_offset pointer_offset builder);
    L.build_call (select_table_func decls new_decls len) [| tt; pointer_offset; pointer; (L.const_int i32_t len)|] "select_table" builder
  
  | SReadFile(f, decls) -> 
    let file_name = L.build_global_stringptr f "" builder in 
    let to_string_ptr s = L.build_global_stringptr s s builder in
    let table_type = List.rev (List.map A.string_of_typ decls) in 
    let const_array = L.const_array (L.pointer_type i8_t) (Array.of_list (List.map to_string_ptr table_type)) in
    let len = (List.length decls) in
    let pointer = L.build_alloca (array_t (L.pointer_type i8_t) len) "arr_p" builder in
    ignore(L.build_store const_array pointer builder);
    L.build_call (import_table_func decls len) [| file_name; (L.const_int i32_t len); pointer |] "import_table" builder
  
  | SCreate(decls) -> 
    let table_type = (List.map fst decls) in 
    let struct_member_size = List.map L.size_of (List.map ltype_of_typ table_type) in
    let struct_size = List.fold_left L.const_add (L.const_int i32_t 0) struct_member_size in
    L.build_call (create_table_func table_type) [| L.const_intcast struct_size i32_t ~is_signed:false|] "create_table" builder
  | SInsert((t_decls, t), sc_lst) ->(
    let tt = table_expr builder m (t_decls, t) in
    let translate_sc_lst sc = 
      let (t, _) = sc in
      let (v, _) = expr builder m sc in 
      (match t with
      | A.String -> v
      | A.Int -> L.build_call int_to_string_func [| v |] "" builder
      | A.Bool -> 
        let i = L.build_call bool_to_int_func [| v |]  "" builder in
        L.build_call int_to_string_func [| i |]  "" builder
      | A.Float -> L.build_call float_to_string_func [| v|] "" builder
      | _ -> raise(Failure("Only string, int, bool and float are supported in table elements."))
      )
    in
    let decls = (match t_decls with
      | A.Table t -> List.map fst t
      | _ -> raise(Failure("Select can only be applied to a table"))
    ) in 
    let to_string_ptr s = L.build_global_stringptr s s builder in
    let len = (List.length sc_lst) in

    let table_type = List.rev (List.map A.string_of_typ (List.map fst sc_lst)) in 
    let table_type_array = L.const_array (L.pointer_type i8_t) (Array.of_list (List.map to_string_ptr table_type)) in
    let table_type_array_ptr = L.build_alloca (array_t (L.pointer_type i8_t) len) "type_p" builder in
    ignore(L.build_store table_type_array table_type_array_ptr builder);
    
    (* let value = (List.map translate_sc_lst (List.rev sc_lst)) in *)
    let note_ptr = L.build_malloc (array_t str_t len) "note_ptr" builder in

    let rec iter lst count = 
      (
        match count with
        | 0 -> ()
        | _ -> 
          let pitch_ptr = L.build_struct_gep note_ptr (count-1) "pitch_ptr" builder in
          ignore(L.build_store (translate_sc_lst (List.hd lst)) pitch_ptr builder);
          iter (List.tl lst) (count-1);
      )
    in
    iter sc_lst len;
    (* let value_array = L.const_array i32_t [|L.build_call bool_to_int_func [| L.const_int i1_t 1 |] "" builder|] in
    let value_pointer = L.build_alloca (array_t i32_t len) "value_p" builder in *)
    (*ignore(L.build_store value_array value_pointer builder); *) 
    L.build_call (insert_table_func decls len) [| tt; note_ptr; table_type_array_ptr; L.const_int i32_t len |] "insert_table" builder
    )
  | _ -> raise (Failure ("TO DO"));

  )

  and expr builder m ((_, e) : sexpr) = 
    match e with
  | SCell(e) -> (cell builder e, m)
  | STableExpr(e) -> (table_expr builder m e, m)
  | SBinop ((A.Float,_ ) as e1, op, e2) ->
    let (e1',m) = expr builder m e1 in
    let (e2',m) = expr builder m e2 in
    let v = 
    (match op with 
      A.Add     -> L.build_fadd
    | A.Sub     -> L.build_fsub
    | A.Mul     -> L.build_fmul
    | A.Div     -> L.build_fdiv
    | A.Equ     -> L.build_fcmp L.Fcmp.Oeq
    | A.Neq     -> L.build_fcmp L.Fcmp.Ueq
    | A.Lt      -> L.build_fcmp L.Fcmp.Olt
    | A.Lteq    -> L.build_fcmp L.Fcmp.Ole
    | A.Gt      -> L.build_fcmp L.Fcmp.Ogt
    | A.Gteq    -> L.build_fcmp L.Fcmp.Oge
    | _         -> raise(Failure("error: operation is illegal"))
    ) e1' e2' "tmp" builder in
    (v, m)

  | SBinop ((A.Bool,_)as e1, op, e2) ->
    let (e1',m) = expr builder m e1 in
    let (e2',m) = expr builder m e2 in
    let v = 
    (match op with
      A.Equ      -> L.build_icmp L.Icmp.Eq
    | A.Neq      -> L.build_icmp L.Icmp.Ne
    | A.Lt       -> L.build_icmp L.Icmp.Slt
    | A.Lteq     -> L.build_icmp L.Icmp.Sle
    | A.Gt       -> L.build_icmp L.Icmp.Sgt
    | A.Gteq     -> L.build_icmp L.Icmp.Sge
    | _          -> raise(Failure("error: operation is illegal"))
    ) e1' e2' "tmp" builder in
    (v, m)
  | SBinop ((A.String,_) as e1, op, e2) ->
    let (e1',m) = expr builder m e1 in
    let (e2',m) = expr builder m e2 in
    let v = 
    (match op with
      A.Add      -> L.build_call string_concat_func [| e1' ; e2' |] "string_concat" builder
    | _          -> raise(Failure("error: operation is illegal"))
    ) in (v, m)

  | SBinop (e1, op, e2) ->
	  let (e1',m) = expr builder m e1 in
    let (e2',m) = expr builder m e2 in
    let v = 
	  (match op with
	    A.And       -> L.build_and
	  | A.Or        -> L.build_or
    | A.Add       -> L.build_add
    | A.Sub       -> L.build_sub
    | A.Mul       -> L.build_mul
    | A.Div       -> L.build_sdiv
    | A.Equ       -> L.build_icmp L.Icmp.Eq
    | A.Neq       -> L.build_icmp L.Icmp.Ne
    | A.Lt        -> L.build_icmp L.Icmp.Slt
    | A.Lteq      -> L.build_icmp L.Icmp.Sle
    | A.Gt        -> L.build_icmp L.Icmp.Sgt
    | A.Gteq      -> L.build_icmp L.Icmp.Sge
    ) e1' e2' "tmp" builder in
    (v, m)

  | SVal(v) -> (L.build_load (lookup v m) v builder, m)

  | SAssign(t, v, e) ->
    let (new_v, m) = expr builder m e 
    in
    let init = match t with
      A.Float -> L.const_float (ltype_of_typ t) 0.0
      | A.Int -> L.const_int (ltype_of_typ t) 0
      | A.Bool -> L.const_int (ltype_of_typ t) 0
      | A.String -> L.build_global_stringptr v "str" builder
      | A.Table decl -> 
        let tt = List.map fst decl in
        L.const_inttoptr (L.const_int i32_t 0) (table_t tt)
      | _ -> L.const_int (ltype_of_typ t) 0
    in
    let llval = L.define_global v init the_module
    in
    let m = StringMap.add v llval m
    in
    ignore(L.build_store new_v llval builder); 
    (new_v, m)
  | SReassign(v, e) -> 
    let (new_v, m) = expr builder m e 
    in 
    (match (lookup v m) with
    llval -> ignore(L.build_store new_v llval builder); (new_v, m))
  

  | SNot e -> 
    let (t, _) = e in
    let (e', m) = expr builder m e in
    let v = 
    (match t with 
      A.Bool -> L.build_not
    | _ -> raise(Failure("error: operation is illegal"))) e' "tmp" builder
    in
    (v, m)
  | _-> raise(Failure("NOT IMPLEMENTED"))
  in

  let rec translate_stmt m builder (s) =
  (match s with
  | SPrint(e) -> 
    ignore(match e with
      | (A.Int, _) -> L.build_call printf_func [| int_format_str ; (fst (expr builder m e))|] "printf" builder
      | (A.String, _) -> L.build_call printf_func [| (fst (expr builder m e)) |] "printf" builder
      | (A.Bool, _) -> 
        let e' = L.build_call bool_to_string_func [| (fst (expr builder m e)) |] "bool_to_string" builder in
        L.build_call printf_func [| e' |] "printf" builder
      | (A.Float, _) -> 
        let str = L.build_call float_to_string_func [| (fst (expr builder m e))|] "printf" builder in
        L.build_call printf_func [| str |] "printf" builder
      | (A.Table decls, _) -> 
        let len = List.length decls in
        let t = List.map fst decls in
        let to_string_ptr s = L.build_global_stringptr s s builder in
        let table_type = List.rev (List.map A.string_of_typ t) in 
        let const_array = L.const_array (L.pointer_type i8_t) (Array.of_list (List.map to_string_ptr table_type)) in
        let pointer = L.build_alloca (array_t (L.pointer_type i8_t) len) "arr_p" builder in
        ignore(L.build_store const_array pointer builder);
        L.build_call (print_table_func t len) [| fst (expr builder m e); pointer; (L.const_int i32_t len) |] "printTable" builder
      | _ -> raise (Failure ("TO DO"));
    ); (builder, m)
  | SExpr(e) -> 
    let (_, m) = expr builder m e in
    (builder, m)
  | SSemi(e1, e2) -> 
    let (builder, m) = translate_stmt m builder e1 
    in translate_stmt m builder e2
  | SSemi1(e) -> translate_stmt m builder e
  | SWhile(e, s) -> 
    let pred_bb = L.append_block context "while" f in
    ignore(L.build_br pred_bb builder);

    let body_bb = L.append_block context "while_body" f in
    let (new_builder, m) = translate_stmt m (L.builder_at_end context body_bb) s in
    add_terminal new_builder (L.build_br pred_bb);

    let pred_builder = L.builder_at_end context pred_bb in
    let (bool_val, m) = expr pred_builder m e in

    let merge_bb = L.append_block context "merge" f in
    ignore(L.build_cond_br bool_val body_bb merge_bb pred_builder);
    let new_builder = L.builder_at_end context merge_bb in
    (new_builder, m)
  | SCondition(b, s) -> 
      let (bool_val,m) = expr builder m b in
      let merge_bb = L.append_block context "merge" f in
      let branch_instr = L.build_br merge_bb in
      let then_bb = L.append_block context "then" f in
      let (then_builder,m) = 
        translate_stmt m (L.builder_at_end context then_bb) s in
      let () = add_terminal then_builder branch_instr in
      let else_bb = L.append_block context "else" f in
      let else_builder = (L.builder_at_end context else_bb) in 
      let () = add_terminal else_builder branch_instr in
      let _ = L.build_cond_br bool_val then_bb else_bb builder in
      let new_builder = L.builder_at_end context merge_bb in
      (new_builder, m)
  | SConditionWithElse(b, s1, s2) -> 
      let (bool_val,m) = expr builder m b in
      let merge_bb = L.append_block context "merge" f in
      let branch_instr = L.build_br merge_bb in
      let then_bb = L.append_block context "then" f in
      let (then_builder,m) = 
        translate_stmt m (L.builder_at_end context then_bb) s1 in
      let () = add_terminal then_builder branch_instr in
      let else_bb = L.append_block context "else" f in
      let (else_builder,m) = 
        translate_stmt m (L.builder_at_end context else_bb) s2 in 
      let () = add_terminal else_builder branch_instr in
      let _ = L.build_cond_br bool_val then_bb else_bb builder in
      let new_builder = L.builder_at_end context merge_bb in
      (new_builder, m)
  | _ -> raise (Failure ("TO DO"));
  )
  
  in
  let (builder, _) = translate_stmt StringMap.empty builder sstmt in 
  add_terminal builder (L.build_ret (L.const_int i32_t 0));
  the_module