module L = Llvm
module Fcmp = Llvm.Fcmp
module A = Ast
module StringMap = Map.Make(String)

let translate ((_, globals), (functions, classes)) =
 
  let context = L.global_context () in
  let the_module = L.create_module context "Scolkam"
      and i32_t  = L.i32_type context
      and i8_t   = L.i8_type context
      and i1_t   = L.i1_type context
      and flt_t  = L.double_type context
      and str_t  = L.pointer_type (L.i8_type context)
      and void_t = L.void_type context in

  (* Declare printf(), which the print built-in function will call *)
  let printf_t = L.var_arg_function_type i32_t [| L.pointer_type i8_t |] in
  let printf_func = L.declare_function "printf" printf_t the_module in
  let prints_t = L.var_arg_function_type str_t [| L.pointer_type i8_t |] in
  let prints_func = L.declare_function "puts" prints_t the_module in

  let int_format_str    b = L.build_global_stringptr "%d\n" "fmt" b 
  and float_format_str  b = L.build_global_stringptr "%f\n" "fmt" b 
  and string_format_str b = L.build_global_stringptr "%s\n" "fmt" b in

  (* For the implementation of 'break' and 'continue' *)
  let (after_block) = ref (L.block_of_value (L.const_int i32_t 0))
  and (before_block) = ref (L.block_of_value (L.const_int i32_t 0)) in
  
  (* Global variables *)
  let global_vars = ref (StringMap.empty) in

  (* Current function and local variables *)
  let (local_vars) = ref StringMap.empty in
  let currentf = ref (List.hd functions) in 

  (* Return the value or the type for a variable or formal argument *)
  (* All the tables have the structure (type, llvalue) *)
  let name_to_llval n : L.llvalue = 
    try (snd (StringMap.find n !local_vars))
    with Not_found -> (snd (StringMap.find n !global_vars))
  in

  let name_to_type n : A.typ =
    try (fst (StringMap.find n !local_vars))
    with Not_found -> (fst (StringMap.find n !global_vars)) in

  (* LLVM types *)
  let rec ltype_of_typ = function
      A.Int         -> i32_t
    | A.Float       -> flt_t
    | A.Bool        -> i1_t
    | A.String      -> str_t
    | A.None        -> void_t 
    | A.Tuple(t, _) -> L.pointer_type (ltype_of_typ t)
 
  and gen_type = function
      A.IntLit _          -> A.Int
    | A.FloatLit _        -> A.Float
    | A.BoolLit _         -> A.Bool
    | A.StringLit _       -> A.String
    | A.TupleLit x        -> gen_type (List.hd x)
    | A.Element (e,_)     -> gen_type (A.Id(e))
    | A.Id s              -> (match (name_to_type s) with
                                A.Tuple(t,_) -> t 
                              |  _ as ty -> ty)
    | A.Call(s,_)       -> let fdecl = 
                              List.find (fun x -> x.A.fname = s) functions in
                              (match fdecl.A.typ with
                                A.Tuple(t,_) -> t 
                              |  _ as ty -> ty)
    | A.Binop(e1, _, _)  -> gen_type e1
    | A.Unop(_, e1)     -> gen_type e1
    | A.Assign(s, _)    -> gen_type (A.Id(s))
    | A.Noexpr          -> raise (Failure "corrupted tree - Noexpr as a statement")

  and lreturn_type ty = match ty with
      A.Tuple (t, _) -> L.pointer_type (ltype_of_typ t)
    | _ -> ltype_of_typ ty

  and find_type (t, _) = ltype_of_typ t

  and format_str x_type builder =
      let b = builder in 
        match x_type with
          A.Int    -> int_format_str b
        | A.Float  -> float_format_str b
        | A.String -> string_format_str b
        | _ -> raise (Failure "Invalid printf type")

  in

  (* Define each function (arguments and return type) so we can call it *)
  let function_decls =
    let function_decl m fdecl =
      let name         = fdecl.A.fname
      and formal_types = Array.of_list (List.map find_type fdecl.A.formals)
    in 
      let ftype = L.function_type (lreturn_type fdecl.A.typ) formal_types in
        StringMap.add name (L.define_function name ftype the_module, fdecl) m 
    
    in List.fold_left function_decl StringMap.empty functions 
  in 

    (* Invoke "f builder" if the current block doesn't already
       have a terminal (e.g., a branch). *)
  let rec add_terminal builder f =
      match L.block_terminator (L.insertion_block builder) with
        Some _ -> ()
      | None -> ignore (f builder)

  (* Return the code of an expression to be built in LLVM *)
  and expr builder =  let b = builder in
      function 
      A.IntLit i -> L.const_int i32_t i
    | A.FloatLit f -> L.const_float flt_t f
    | A.BoolLit b -> L.const_int i1_t (if b then 1 else 0)
    | A.StringLit sl -> L.build_global_stringptr sl "string" b
    | A.TupleLit elts ->  
        let sizeva = (List.length elts) + 1 in
        let size   = L.const_int i32_t sizeva in
        let ty     = ltype_of_typ 
                      (A.Tuple(gen_type(List.hd elts), sizeva)) in
        let arr    = L.build_array_malloc ty size "init1" b in
        let arr    = L.build_pointercast arr ty "init2" b in 
        let _      = L.build_bitcast size ty "init3" b in
        let values = List.map (expr b) elts in
        let buildf i v = 
          (let arr_ptr = 
            L.build_gep arr [| (L.const_int i32_t (i+1)) |] "init4" b in
            ignore(L.build_store v arr_ptr b);) 
        in
        List.iteri buildf values; arr
    | A.Element (s, e) -> 
        let idx = expr b e in
        let idx = L.build_add idx (L.const_int i32_t 1) "access1" b in
        let arr = expr b (A.Id(s)) in
        let res = L.build_gep arr [| idx |] "access2" b in
        L.build_load res "access3" b
    | A.Noexpr -> L.const_int i32_t 0
    | A.Id s -> let llval = (name_to_llval s) in
                  L.build_load llval s b
    | A.Binop (e1, op, e2) ->              
        let e1' = expr b e1 
        and e2' = expr b e2
        and float_ops = (match op with
            A.Add     -> L.build_fadd
          | A.Sub     -> L.build_fsub
          | A.Mult    -> L.build_fmul
          | A.Div     -> L.build_fdiv
          | A.Mod     -> L.build_frem
          | A.And     -> L.build_and
          | A.Or      -> L.build_or
          | A.Equal   -> L.build_fcmp L.Fcmp.Oeq
          | A.Neq     -> L.build_fcmp L.Fcmp.One
          | A.Less    -> L.build_fcmp L.Fcmp.Olt
          | A.Leq     -> L.build_fcmp L.Fcmp.Ole
          | A.Greater -> L.build_fcmp L.Fcmp.Ogt
          | A.Geq     -> L.build_fcmp L.Fcmp.Oge
        ) 
        and int_ops = match op with
            A.Add     -> L.build_add
          | A.Sub     -> L.build_sub
          | A.Mult    -> L.build_mul
          | A.Div     -> L.build_sdiv
          | A.Mod     -> L.build_urem
          | A.And     -> L.build_and
          | A.Or      -> L.build_or
          | A.Equal   -> L.build_icmp L.Icmp.Eq
          | A.Neq     -> L.build_icmp L.Icmp.Ne
          | A.Less    -> L.build_icmp L.Icmp.Slt
          | A.Leq     -> L.build_icmp L.Icmp.Sle
          | A.Greater -> L.build_icmp L.Icmp.Sgt
          | A.Geq     -> L.build_icmp L.Icmp.Sge 
        and str_ops = match op with
          | A.Add     -> expr b (A.StringLit((A.string_of_expr e1) ^ (A.string_of_expr e2)))
          | _ -> (L.const_int i32_t 0)
        in
        if (L.type_of e1' = flt_t || L.type_of e2' = flt_t) then float_ops e1' e2' "tmp" builder
        else if ((L.type_of e1' = str_t) && (L.type_of e2' = str_t)) then str_ops else int_ops e1' e2' "tmp" builder    
    | A.Unop(op, e) ->
        let e' = expr b e in
        (match op with
          A.Neg     -> if (L.type_of e' = flt_t) then L.build_fneg else L.build_neg
        | A.Not     -> L.build_not) e' "tmp" b
    | A.Assign (s, e) -> let e' = expr b e in
                     ignore (L.build_store e' (name_to_llval s) b); e'
    | A.Call ("print", [e]) | A.Call ("printb", [e]) -> 
      let e' = expr b e in
      let e_type = gen_type e in
      L.build_call printf_func [| ( format_str e_type b) ; e' |]
        "printf" b
    | A.Call ("prints", [e]) -> 
        L.build_call prints_func [| (expr b e) |]
        "puts" b
    | A.Call ("toint", [e]) ->
        let e' = expr b e in
        let e_type = L.string_of_lltype (L.type_of e') in
          if e_type = L.string_of_lltype i32_t
            then raise (Failure "You converted int to int")
          else if e_type = L.string_of_lltype flt_t
            then L.build_fptosi e' i32_t "cast" b 
          (* else if e_type = L.string_of_lltype i1_t
            then L.build_bitcast e' i32_t "castbool" b *)
          else raise (Failure ("You cannot convert from this type to int"))
    | A.Call ("tofloat", [e]) ->
        let e' = expr b e in 
        let e_type = L.string_of_lltype (L.type_of e') in
          if e_type = L.string_of_lltype flt_t
            then raise (Failure "You converted float to float")
          else if e_type = L.string_of_lltype i32_t
            then L.build_sitofp e' flt_t "castint" b 
          (* else if e_type = L.string_of_lltype i1_t
            then L.build_bitcast e' flt_t "castbool" b *)
          else raise (Failure ("You cannot convert from this type to int"))
    | A.Call (f, act) ->
        let (fdef, f_decl) = StringMap.find f function_decls in
        let actuals = List.rev (List.map (expr b) (List.rev act)) in
        let result = (match f_decl.A.typ with A.None -> ""
                                          | _ -> f ^ "_result") in
      L.build_call fdef (Array.of_list actuals) result b

  (* Build the code for the given statement; return the b for
     the statement's successor *)
  and stmt builder = let b = builder in 
      let (the_function, _) = StringMap.find !currentf.A.fname function_decls in
      function
      A.Block sl -> List.fold_left stmt b sl
    | A.Expr e ->  ignore (expr b e); b
    | A.Return e ->  ignore (match !currentf.A.typ with
        A.None -> L.build_ret_void b
      | _ -> L.build_ret (expr b e) b); b
    | A.If (predicate, then_stmts, else_if_stmts, else_stmts) ->

      (* Removing the elseifs by recursively replacing the else_statements *)
      let rec remove_elif (_, _, else_if_stmts, else_stmts) =
      (match else_if_stmts with
         A.Block(hd::tl) -> 
           let new_predicate, new_then = 
             (match hd with
               A.Elif(condition,stmts) -> condition, stmts
             | _ -> raise (Failure "Corrupted tree - Elseif problem")) in
           let new_else_ifs = A.Block(tl) in
           let new_else = remove_elif (new_predicate, new_then, new_else_ifs, else_stmts) in
           A.If(new_predicate, new_then, new_else_ifs, new_else)
         | A.Block([]) -> else_stmts
         | _ -> else_stmts) in
      let new_else_stmts = remove_elif (predicate, then_stmts, else_if_stmts, else_stmts) in

      let bool_val = expr b predicate in
      let merge_bb = L.append_block context "merge" the_function in

      (* Emit 'then' value. *)
      let then_bb = L.append_block context "then" the_function in
      let then_code = (stmt (L.builder_at_end context then_bb) then_stmts) in
      add_terminal then_code (L.build_br merge_bb);

      (* Emit 'else' value. *)
      let else_bb = L.append_block context "else" the_function in
      let else_code = (stmt (L.builder_at_end context else_bb) new_else_stmts) in
      add_terminal else_code (L.build_br merge_bb);

      (* Add the conditional branch. *)
      ignore (L.build_cond_br bool_val then_bb else_bb b);
      L.builder_at_end context merge_bb
    | A.While (predicate, body) -> stmt b (A.HiddenWhile(predicate, body, A.Block([A.Nostmt])))
    | A.HiddenWhile (predicate, body, increment) ->
        let pred_bb = L.append_block context "while" the_function in
        ignore (L.build_br pred_bb b);

        let body_bb = L.append_block context "while_body" the_function in

        let pred_b = L.builder_at_end context pred_bb in
        let bool_val = expr pred_b predicate in 

        let increment_bb = L.append_block context "increment" the_function in
        let increment_b = L.builder_at_end context increment_bb in
        let merge_bb = L.append_block context "merge" the_function in

        ignore(before_block := increment_bb);
        ignore(after_block  := merge_bb);
        add_terminal (stmt (L.builder_at_end context body_bb) body)
          (L.build_br increment_bb);
        add_terminal (stmt increment_b increment)
          (L.build_br pred_bb);
        ignore (L.build_cond_br bool_val body_bb merge_bb pred_b);
        L.builder_at_end context merge_bb
    | A.For (e1, e2, e3, body) -> stmt b
        ( A.Block [A.Expr e1 ; A.HiddenWhile (e2, A.Block [body], A.Expr e3) ] )
    | A.Nostmt -> ignore (0); b
    | A.Continue -> 
      let block = fun () -> !before_block in
      ignore (L.build_br (block ()) b); b
    | A.Break -> 
      let block = fun () -> !after_block in
      ignore (L.build_br (block ()) b); b
    | A.Declaration _ -> raise (Failure "Corrupted Tree")
    | A.Elif (_,_) | A.ForIn (_,_,_) | A.In (_) -> 
        raise (Failure "Corrupted Tree")

  and global_var m (t, n, e) =
    let (f,_) = StringMap.find "main" function_decls in
    let builder = L.builder_at_end context (L.entry_block f) in
    (* Build the first initialization of the variables *)
    let rec init t e = match e with
          A.IntLit _ | A.FloatLit _ | A.BoolLit _ | A.StringLit _ -> expr builder e
        | A.TupleLit(_) -> 
            let ty     = ltype_of_typ t in
            L.const_ptrtoint (L.const_int i32_t 0) ty
        | _ -> 
              match t with
              A.Int   -> expr builder (A.IntLit(0))
            | A.Float -> expr builder (A.FloatLit(0.0))
            | A.String -> expr builder (A.StringLit(""))
            | A.Bool   -> expr builder (A.BoolLit(true))
            | A.Tuple(_, _) -> (init t (A.TupleLit([])))
            | A.None -> expr builder (A.Noexpr)
    in
    let tuple = (t, (L.define_global n (init t e) the_module)) in
    StringMap.add n tuple m 

  (* Fill in the body of the given function *)
  and build_function_body fdecl =
    let (the_function, _) = StringMap.find fdecl.A.fname function_decls in
    let builder = L.builder_at_end context (L.entry_block the_function) in

    currentf := fdecl;
    
    (* Construct the function's "locals": formal arguments and locally
       declared variables.  Allocate each on the stack, initialize their
       value, if appropriate, and remember their values in the "locals" map *)
    let add_formal m (t, n) p = 
      L.set_value_name n p;
      let local = L.build_alloca (find_type (t, n)) n builder in
        ignore (L.build_store p local builder);
      StringMap.add n (t, local) m 
    in

    let add_local m (t, n) =
      let local_var = L.build_alloca (find_type (t, n)) n builder in        
      StringMap.add n (t, local_var) m 
    in

    let formals = List.fold_left2 add_formal StringMap.empty fdecl.A.formals
                  (Array.to_list (L.params the_function)) in

    local_vars := List.fold_left add_local formals 
                  (List.map (fun (t, n, _)-> (t, n)) fdecl.A.fbody.A.f_vdecls);

    let assign_variable (_, n, e) = 
      let e' = expr builder e in ignore (L.build_store e' (name_to_llval n) builder); e'
    in
    (* Build the code for each statement in the function *)
    let builder = ignore (List.map assign_variable fdecl.A.fbody.A.f_vdecls);
      stmt builder (A.Block fdecl.A.fbody.A.f_stmts);
    in


    (* Add a return if the last block falls off the end *)
    add_terminal builder (match fdecl.A.typ with
        A.None -> L.build_ret_void
      | t -> L.build_ret (L.const_int (ltype_of_typ t) 0))
  
  in
  
  let build_class_body _ = 
    
    ()
  in

  (* Declare each global variable; remember its value in a map *)
  let globals = List.rev globals in 
  List.iter (fun k -> global_vars := global_var !global_vars k) globals;
  List.iter build_class_body classes;
  List.iter build_function_body functions;
  the_module
