(* generate code for java *)

(* TODO: Binop in expr, and the whole stmt *)

open Ast

module StringMap = Map.Make(String)

exception Fail of string

let type_of_var global_map formal_map local_map s =
  try StringMap.find s global_map
  with Not_found ->
     try StringMap.find s formal_map
     with Not_found ->
        try StringMap.find s local_map
        with Not_found ->
           raise(Fail("undefined variable " ^ s ^ "."))

let type_of_lvalue global_map formal_map local_map = function
    Id(s) -> type_of_var global_map formal_map local_map s
  | PolyElmt(s, e) -> Float

let rec typecheck_poly_literal = function
    [] -> Poly
  | hd::tl ->
      match hd with
        FloatLiteral(l) -> typecheck_poly_literal tl
      | Negate(e) ->
         (match e with
           FloatLiteral(l2) -> typecheck_poly_literal tl
         | _ -> raise(Fail("polynomial initializer can only use float literals"))
         )
      | _ -> raise(Fail("polynomial initializer can only use float literals"))

let rec type_of_expr func_map global_map formal_map local_map = function
    IntLiteral(l) -> Int
  | FloatLiteral(l) -> Float
  | BooleanLiteral(l) -> Boolean
  | StringLiteral(l) -> String
  | PolyLiteral(l) -> typecheck_poly_literal l
  | PolyInit(e) ->
      let t1 = type_of_expr func_map global_map formal_map local_map e
      in
      if t1 == Int then
         Poly
      else
         raise(Fail("PolyInit must take an expression of type int"))
  | Lvalue(lv) -> type_of_lvalue global_map formal_map local_map lv
  | Binop(e1, o, e2) ->
      let t1 = type_of_expr func_map global_map formal_map local_map e1
      in
      let t2 = type_of_expr func_map global_map formal_map local_map e2
      in
      (match o with
        Add | Sub | Mult ->
          if (t1==Poly && t2==Poly) then
             Poly
          else if (t1==Int && t2==Int) then
             Int
          else if (t1==Float && t2==Float) then
             Float
          else if (t1==Float && t2==Int) then
             Float
          else if (t1==Int && t2==Float) then
             Float
          else
             raise(Fail("type mismatch in binop: " ^ string_of_expr e1 ^ " " ^ string_of_binop o ^ " " ^ string_of_expr e2))
      | Div ->
          if (t1==Poly && (t2==Int || t2 ==Float)) then
             Poly
          else if (t1==Int && t2==Int) then
             Int
          else if (t1==Float && t2==Float) then
             Float
          else if (t1==Float && t2==Int) then
             Float
          else if (t1==Int && t2==Float) then
             Float
          else
             raise(Fail("type mismatch in binop: " ^ string_of_expr e1 ^ " " ^ string_of_binop o ^ " " ^ string_of_expr e2))
      | Lshift | Rshift ->
          if t1==Poly && t2==Int then
             Poly
          else
             raise(Fail("type mismatch in binop: " ^ string_of_expr e1 ^ " " ^ string_of_binop o ^ " " ^ string_of_expr e2))          
      | Equal | Neq ->
          if t1==t2 then
             Boolean
          else if ((t1==Int || t1==Float) && (t2==Float || t2==Int)) then
             Boolean
          else
                raise(Fail("type mismatch in binop: " ^ string_of_expr e1 ^ " " ^ string_of_binop o ^ " " ^ string_of_expr e2))
      | Less | Greater | Leq | Geq ->
          if ((t1==Int || t1==Float) && (t2==Float || t2==Int)) then
             Boolean
          else
             raise(Fail("type mismatch in binop: " ^ string_of_expr e1 ^ " " ^ string_of_binop o ^ " " ^ string_of_expr e2))
      )
  | Negate(e) ->
      let t = type_of_expr func_map global_map formal_map local_map e
      in
      (match t with
         Poly -> Poly
       | Int -> Int
       | Float -> Float
       | _ -> raise(Fail("Cannot negate expression: " ^ string_of_expr e))
      )
  | Assign(v, e) ->
      let t1 = type_of_lvalue global_map formal_map local_map v
      in
      let t2 = type_of_expr func_map global_map formal_map local_map e
      in
      if t1 != t2 then
         raise(Fail("type mismatch in Assignment: " ^ string_of_lvalue v ^ " = " ^ string_of_expr e))
      else
         t1
  | Call(f, el) ->
      (try
         let fdecl = StringMap.find f func_map
         in
         if List.length fdecl.formals != List.length el then
            raise(Fail("Argument/Formals count mismatch in function call: " ^ f))
         else
            let formal_type_list = List.map (fun v -> v.vtype) fdecl.formals
            in
            let actual_type_list = List.map (fun e -> type_of_expr func_map global_map formal_map local_map e) el
            in
            List.fold_left2
                 (fun rtype t1 t2 ->
                    if t1 != t2 then
                      raise(Fail("types of formal and actual parameters don't match at call to function: " ^ f))
                    else
                      rtype) fdecl.rtype formal_type_list actual_type_list
      with Not_found ->
         match f with
            "print" ->
               (if List.length el > 1 then
                  raise(Fail("print cannot take multiple arguments"))
               else
                  UnknownType
               )
          | "println" ->
               (if List.length el > 1 then
                  raise(Fail("print cannot take multiple arguments"))
               else
                  UnknownType
               )
          | "deg" ->
               (if List.length el > 1 then
                  raise(Fail("print cannot take multiple arguments"))
                else
                  if type_of_expr func_map global_map formal_map local_map (List.hd el) != Poly then
                     raise(Fail("deg function can only take an argument of type poly"))
                  else
                     Int
               )
          | _ -> raise(Fail("call to undefined function: " ^ f ^ ".")))
  | Noexpr -> raise(Fail("Noexpre."))

let rec typecheck_stmts func_map fdecl global_map formal_map local_map = function
    [] -> ()
  | hd::tl ->
      typecheck_stmt func_map fdecl global_map formal_map local_map hd;
      typecheck_stmts func_map fdecl global_map formal_map local_map tl

and typecheck_stmt func_map fdecl global_map formal_map local_map = function
    Block(stmts) ->
      typecheck_stmts func_map fdecl global_map formal_map local_map stmts
(*      List.fold_left (fun noth f_map f_decl g_map fv_map lv_map stmt -> let t = typecheck_stmt f_map f_decl g_map fv_map lv_map stmt in noth) () func_map fdecl global_map formal_map local_map stmts
*)
  | Expr(expr) ->
      let t = type_of_expr func_map global_map formal_map local_map expr
      in
      do_nothing t
  | Return(expr) ->
      let t = type_of_expr func_map global_map formal_map local_map expr
      in
      if t != fdecl.rtype then
         raise(Fail("declared return type and returned type mismatch in function " ^ fdecl.fname))
  | If(e, s, Block([])) ->
      let etype = type_of_expr func_map global_map formal_map local_map e
      in
      if etype != Boolean then
         raise(Fail("if condition " ^ string_of_expr e ^ " does not have boolean type."))
      else
         typecheck_stmt func_map fdecl global_map formal_map local_map s
  | If(e, s1, s2) ->
      let etype = type_of_expr func_map global_map formal_map local_map e
      in
      if etype != Boolean then
         raise(Fail("if condition " ^ string_of_expr e ^ " does not have boolean type."))
      else
         let noth = typecheck_stmt func_map fdecl global_map formal_map local_map s1
         in
         typecheck_stmt func_map fdecl global_map formal_map local_map s2
  | For(e1, e2, e3, s) ->
      let t1 = type_of_expr func_map global_map formal_map local_map e1
      in
      let t2 = type_of_expr func_map global_map formal_map local_map e2
      in
      if t2 != Boolean then
         raise(Fail("for condition " ^ string_of_expr e2 ^ " does not have boolean type."))
      else
         let t3 = type_of_expr func_map global_map formal_map local_map e3
         in
         typecheck_stmt func_map fdecl global_map formal_map local_map s
  | While(e, s) ->
      let etype = type_of_expr func_map global_map formal_map local_map e
      in
      if etype != Boolean then
         raise(Fail("while condition " ^ string_of_expr e ^ " does not have boolean type."))
      else
         typecheck_stmt func_map fdecl global_map formal_map local_map s
         
and do_nothing e =
  ()

let typecheck_fdecl global_map func_map fdecl =
  let formal_map = List.fold_left
              (fun formal_map fv_decl -> 
                if StringMap.mem fv_decl.vname formal_map then
                    raise (Fail ("formal parameter: '" ^ fv_decl.vname ^ 
                        "' in function " ^ fdecl.fname ^ " has already been defined."))
                else
                    if StringMap.mem fv_decl.vname global_map then
                        raise (Fail ("formal parameter: '" ^ fv_decl.vname ^ 
                            "' in function " ^ fdecl.fname ^ " shadows a global variable."))
                    else
                        StringMap.add fv_decl.vname fv_decl.vtype formal_map) StringMap.empty fdecl.formals
  in
  let local_map = List.fold_left
              (fun local_map lv_decl -> 
                if StringMap.mem lv_decl.vname local_map then
                    raise (Fail ("local variable: '" ^ lv_decl.vname ^ 
                        "' in function " ^ fdecl.fname ^ " has already been defined."))
                else
                    if StringMap.mem lv_decl.vname formal_map then
                        raise (Fail ("local variable: '" ^ lv_decl.vname ^ 
                            "' in function " ^ fdecl.fname ^ " shadows a formal parameter."))
                    else
                        if StringMap.mem lv_decl.vname formal_map then
                            raise (Fail ("local variable: '" ^ lv_decl.vname ^ 
                                "' in function " ^ fdecl.fname ^ " shadows a formal parameter."))
                        else
                            StringMap.add lv_decl.vname lv_decl.vtype local_map) StringMap.empty fdecl.locals
  in
  typecheck_stmts func_map fdecl global_map formal_map local_map fdecl.body
  
let rec typecheck_fdecls global_map func_map = function
    [] -> ()
  | hd::tl ->
     typecheck_fdecl global_map func_map hd;
     typecheck_fdecls global_map func_map tl
  
let typecheck_program (vars, funcs) =
  let global_map = List.fold_left
            (fun g_map gv_decl -> 
              if StringMap.mem gv_decl.vname g_map then
                  raise (Fail ("global variable: '" ^ gv_decl.vname ^ 
                      "' has already been defined."))
              else            
                  StringMap.add gv_decl.vname gv_decl.vtype g_map) StringMap.empty vars
  in
  let func_map = List.fold_left
              (fun f_map func_decl -> 
                if StringMap.mem func_decl.fname f_map then
                    raise (Fail ("function: '" ^ func_decl.fname ^ 
                        "' has already been defined."))
                else            
                    StringMap.add func_decl.fname func_decl f_map) StringMap.empty funcs
  in
  typecheck_fdecls global_map func_map funcs
  