(* LOON Compiler Semantic Checking *)
(* Authors: Kyle Hughes *)

open Ast

module StringMap = Map.Make(String)

let check (globals, functions) =

  (* Raise an exception if the given list has a duplicate *)
  let report_duplicate exceptf list =
      let rec helper = function
          n1 :: n2 :: _ when n1 = n2 -> raise (Failure (exceptf n1))
          | _ :: t -> helper t
          | [] -> ()
       in helper (List.sort compare list)
  in

  (* Raise an exception if a given binding is to a void type *)
  let check_not_void exceptf = function
      (Void, n) -> raise (Failure (exceptf n))
     | _ -> ()
  in

  (* Raise an exception of the given rvalue type cannot be assigned to
     the given lvalue type *)
  let check_assign lvaluet rvaluet err =
      if lvaluet == rvaluet then lvaluet else raise err
  in

  (**** Checking Globals ****)
  List.iter (check_not_void (fun n -> "illegal void global " ^ n)) globals;

  report_duplicate (fun n -> "duplicate global " ^ n) (List.map snd globals);

  (**** Checking Functions ****)
  if List.mem "print" (List.map (fun fd -> fd.fname) functions)
  then raise (Failure ("function print may not be defined")) else ();

  if List.mem "printJSON" (List.map (fun fd -> fd.fname) functions)
  then raise (Failure ("function printJSON may not be defined")) else ();

  (* CHecks for other LOON library functions here *)
  report_duplicate (fun n -> "duplicate function " ^ n)
      (List.map (fun fd -> fd.fname) functions);

  (* Function declaration for a named function *)
  let built_in_decls =  List.fold_left (fun map (name, attr) -> StringMap.add
  name attr map) StringMap.empty [
      ("printJSON", { primitive = Void; fname = "printJSON"; formals = []; locals = []; body = [] });
      ("loon_scanf", {primitive = Void; fname = "loon_scanf"; formals = []; locals = []; body = []});
      ]
  in

  let function_decls = List.fold_left (fun m fd -> StringMap.add fd.fname fd m)
      built_in_decls functions
  in

  let function_decl s = try StringMap.find s function_decls
      with Not_found -> if s = "main" then raise (Failure ("main function must be defined"))
      else raise (Failure ("function " ^ s ^ " unrecognized!"))
  in

  let _ = function_decl "main" in

  let check_function func =
      List.iter (check_not_void (fun n -> "illegal void formal " ^ n ^ " in " ^ func.fname)) func.formals;

      report_duplicate (fun n -> "duplicate formal " ^ n ^ " in " ^ func.fname)
          (List.map snd func.formals);

      List.iter (check_not_void (fun n -> "illegal void local " ^ n ^ " in " ^ func.fname)) func.locals;

      report_duplicate (fun n -> "duplicate local " ^ n ^ " in " ^ func.fname)
          (List.map snd func.locals);

    (* Variable types *)
    let symbols = List.fold_left (fun m (t, n) -> StringMap.add n t m)
        StringMap.empty (globals @ func.formals @ func.locals )
    in

    let type_of_identifier s =
        try StringMap.find s symbols
        with Not_found -> raise (Failure ("undeclared identifier " ^ s))
    in

    (* Return the type of an expression or throw an exception *)
    let rec expr = function
        Literal _ -> Int
      | BoolLit _ -> Bool
      | CharLit _ -> Char
      | StringLit _ -> String
      | PairLit (_, e) -> Pair (expr e)
      | JsonLit _ -> Json
      | ArrayLit _ -> Array
      | Id s -> type_of_identifier s
      | Binop(e1, op, e2) as e -> let t1 = expr  e1 and t2 = expr  e2 in
      begin match op with
        Add ->
          begin match t1, t2 with
          | Int, Int -> Int
          | String, String -> String (* Concatenat  ion Operator *)
          | Array, Array -> Array
          | Pair _, Pair _ -> Json
          | Pair _, Json -> Json
          | Json, Pair _ -> Json
          | _ -> raise (Failure ("illegal binary operator " ^
              string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
              string_of_typ t2 ^ " in " ^ string_of_expr e))
          end
       | Sub | Mult | Div ->
          begin match t1, t2 with
          | Int, Int -> Int
          | String, String -> String
          | _ -> raise (Failure ("illegal binary operator " ^
              string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
              string_of_typ t2 ^ " in " ^ string_of_expr e))
          end
        | Equal | Neq when t1 = t2 -> Bool
        | Less | Leq | Greater | Geq when t1 = Int && t2 = Int -> Bool
        | And | Or when t1 = Bool && t2 = Bool -> Bool
        | _ -> raise (Failure ("illegal binary operator " ^
          string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
          string_of_typ t2 ^ " in " ^ string_of_expr e))
      end
      | Unop(op, e) as ex -> let t = expr  e in
      begin match op with
        Neg when t = Int -> Int
        | Not when t = Bool -> Bool
        | Deref -> expr e
        | _ -> raise (Failure ("illegal unary operator " ^ string_of_uop op ^
          	string_of_typ t ^ " in " ^ string_of_expr ex))
      end
      | Noexpr -> Void
      | Assign(_, _, e) -> expr e
      (* codegen's hands *)
      | Access(_,[]) -> raise (Failure "Trying to access on empty array")
      | Access(_,x::_) -> expr x
      | Call(fname, actuals) as call -> let fd = function_decl fname in
          if fname = "printJSON" then
              let _ = List.iter (fun e -> ignore(expr e)) actuals in Void
          else
              if List.length actuals != List.length fd.formals
                  then raise (Failure ("Expecting " ^ string_of_int (List.length fd.formals) ^
                  "arguments in " ^ string_of_expr call))
              else
                  let _ = List.iter2 (fun (ft, _) e -> let et = expr e in
                      ignore (check_assign ft et (Failure ("Illegal actual argument found "
                      ^ string_of_typ et ^ " expected " ^ string_of_typ ft ^ " in " ^ string_of_expr call))))
              fd.formals actuals
              in
          fd.primitive
    in

    let check_bool_expr e = if expr e != Bool
        then raise (Failure ("expected Bool expression in " ^ string_of_expr e)) else ()
    in


    let rec stmt in_loop = function
      Block sl -> let rec check_block = function
         [Return _ as s] -> stmt in_loop s
         | Return _ :: _ -> raise (Failure "nothing may follow a return")
         | Block sl :: ss -> check_block (sl @ ss)
         | s :: ss -> stmt in_loop s ; check_block ss
         | [] -> ()
       in check_block sl
      | Expr e -> ignore (expr e)
      | Return e -> let t = expr e in if t = func.primitive then () else
          raise (Failure ("return gives " ^ string_of_typ t ^ " expected " ^
                         string_of_typ func.primitive ^ " in " ^ string_of_expr e))
      | If(p, b1, b2) -> check_bool_expr p; stmt false b1; stmt false b2
      | For(e1, e2, e3, st) -> ignore (expr e1); check_bool_expr e2;
                               ignore (expr e3); stmt true st
      | While(p, s) -> check_bool_expr p; stmt true s
    in

  stmt false (Block func.body)

  in
  List.iter check_function functions
