(* Semantic checking for the MicroC compiler *)

open Ast

module StringMap = Map.Make(String)

(* Semantic checking of a program. Returns void if successful,
   throws an exception if something is wrong.

   Check each global variable, then check each function *)

let check program =

  (* Raise an exception if the given list has a duplicate *)
  let report_duplicate exceptf list =
    let rec helper = function
	n1 :: n2 :: _ when n1 = n2 -> raise (Failure (exceptf n1))
      | _ :: t -> helper t
      | [] -> ()
    in helper (List.sort compare list)
  in

  (* Raise an exception if a given binding is to a void type *)
  let check_not_void exceptf = function
      (Void, n, _, _) -> raise (Failure (exceptf n))
    | _ -> ()
  in

  (* Reveal the internal type of a pointer type, otherwise just return input *)
  let strip p = match p with
      Pointer(typ) -> typ
    | _ -> p
  in

  (* Raise an exception of the given rvalue type cannot be assigned to
     the given lvalue type *)
  let check_assign lvaluet rvaluet err =
    if (Pervasives.(=) lvaluet rvaluet) then lvaluet else raise err
  in

  (**** Checking Global Variables ****)

  List.iter (check_not_void (fun n -> "illegal void global " ^ n)) program.globals;

  report_duplicate (fun n -> "duplicate global " ^ n) (List.map (fun (_,x,_,_) -> x) program.globals);

  (**** Checking program.functions ****)

  if List.mem "print" (List.map (fun fd -> fd.fname) program.functions)
  then raise (Failure ("function print may not be defined")) else ();

  report_duplicate (fun n -> "duplicate function " ^ n)
    (List.map (fun fd -> fd.fname) program.functions);

  (* Function declaration for a named function *)
  let built_in_decls =  StringMap.add "print"
     { typ = Void; fname = "print"; formals = [(Int, "x", Noexpr, false)];
       locals = []; body = [] } (StringMap.add "printb"
     { typ = Void; fname = "printb"; formals = [(Bool, "x", Noexpr, false)];
       locals = []; body = [] } (StringMap.add "printf"
     { typ = Void; fname = "printf"; formals = [(String, "x", Noexpr, false)];
       locals = []; body = [] } (StringMap.add "malloc"
     { typ = Pointer(Void); fname = "malloc"; formals = [(Size_t, "x", Noexpr, false)];
       locals = []; body = [] } (StringMap.add "free"
     { typ = Void; fname = "free"; formals = [(Pointer(Void), "x", Noexpr, false)];
       locals = []; body = [] } (StringMap.add "atoi"
     { typ = Pointer(Char); fname = "atoi"; formals = [(Pointer(Char), "x", Noexpr, false)];
       locals = []; body = []; } (StringMap.add "strdup"
     { typ = Pointer(Char); fname = "strdup"; formals = [(Pointer(Char), "x", Noexpr, false)];
       locals = []; body = []; } (StringMap.singleton "printbig"
     { typ = Void; fname = "printbig"; formals = [(Int, "x", Noexpr, false)];
       locals = []; body = [] } )))))))
   in

  let struct_decls = List.fold_left (fun m sd -> StringMap.add sd.struct_name sd m)
                        StringMap.empty program.structs
  in

  let struct_decl s = try StringMap.find s struct_decls
      with Not_found -> raise (Failure ("SEMANT: unrecognized struct " ^ s))
  in

  let function_decls = List.fold_left (fun m fd -> StringMap.add fd.fname fd m)
                         built_in_decls program.functions
  in

  let function_decl s = try StringMap.find s function_decls
       with Not_found -> raise (Failure ("unrecognized function " ^ s))
  in

  let _ = function_decl "main" in (* Ensure "main" is defined *)

  let check_function func =

    List.iter (check_not_void (fun n -> "illegal void formal " ^ n ^
      " in " ^ func.fname)) func.formals;

    report_duplicate (fun n -> "duplicate formal " ^ n ^ " in " ^ func.fname)
      (List.map (fun (_,x,_,_) -> x) func.formals);

    List.iter (check_not_void (fun n -> "illegal void local " ^ n ^
      " in " ^ func.fname)) func.locals;

    report_duplicate (fun n -> "duplicate local " ^ n ^ " in " ^ func.fname)
      (List.map (fun (_,x,_,_) -> x) func.locals);

    (* Type of each variable (global, formal, or local *)
    let symbols = List.fold_left (fun m (t, n, _,_) -> StringMap.add n t m)
	     StringMap.empty (program.globals @ func.formals @ func.locals )
    in

    let type_of_identifier s =
      try StringMap.find s symbols
      with Not_found -> raise (Failure ("undeclared identifier " ^ s))
    in

    (* Return the type of an expression or throw an exception *)
    let rec expr = function
        Literal _ -> Int
      | CharLit _ -> Char
      | SizeLit _ -> Size_t
      | StringLit _ -> String
      | BoolLit _ -> Bool
      | Id s -> type_of_identifier s
      | Sizeof _ -> Size_t
      | Binop(e1, op, e2) as e ->
        let t1 = expr e1 and t2 = expr e2 in
        let err = (Failure ("illegal binary operator " ^
                            string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
                            string_of_typ t2 ^ " in " ^ string_of_expr e)) in
	          (match op with
              Add | Sub when t2 = Int ->
                  (match t1 with
                     Pointer(t) -> Pointer(t)
                   | Int -> Int
                   | _ -> raise err
                )
            | Mult | Div | Mod  when t1 = Int && t2 = Int -> Int
            | Equal | Neq when ((t1 = t2) || (e2 = Nullexpr)) -> Bool
	          | Less | Leq | Greater | Geq when t1 = Int && t2 = Int -> Bool
	          | And | Or when t1 = Bool && t2 = Bool -> Bool
            | _ -> raise err
            )
      | Unop(op, e) as ex -> let t = expr e in
	           (match op with
	             Neg when t = Int -> Int
             | Not when t = Bool -> Bool
             | _ -> raise (Failure ("illegal unary operator " ^ string_of_uop op ^
               string_of_typ t ^ " in " ^ string_of_expr ex)))
      | Address(e) -> let t = expr e in (match t with
            Int -> Pointer(Int)
          | Bool -> Pointer(Bool)
          | Pointer(Int) -> Pointer(Pointer(Int))
          | Pointer(Bool) -> Pointer(Pointer(Bool))
          | _ -> raise (Failure ("trying to get address of expr " ^ string_of_expr e ^ " but it has illegal type " ^ string_of_typ t))
        )
      | Dereference(e) -> let t = expr e in (match t with
            Pointer(Int) -> Int
          | Pointer(Bool) -> Bool
          | Pointer(Char) -> Char
          | Pointer(Pointer(t)) -> if t = Char then String else Pointer(t)
          | Pointer(Struct(n)) -> Struct(n)
          | _ -> raise (Failure ("trying to get dereference of non-pointer " ^ string_of_expr e ^ " of type " ^ string_of_typ t))
        )
      | Noexpr -> Void
      | Nullexpr -> Void
      | Assign(lhs, op, rhs) as ex ->
        let rt = if (expr rhs) = Pointer(Char) then String else (expr rhs)
        and lt = if (expr lhs) = Pointer(Char) then String else (expr lhs) in
        let err a b = (Failure ("illegal assignment " ^ string_of_typ a ^ string_of_aop op ^ string_of_typ b ^ " in " ^ string_of_expr ex)) in
        (match op with
           Asn -> (match lhs with
               ArrayAccess(_,_) -> check_assign (strip lt) rt (err (strip lt) rt)
             | _ -> (match rhs with
                   ArrayAccess(_,_) -> check_assign lt (strip rt) (err lt (strip rt))
                 | _ -> check_assign lt rt (err lt rt)
               )
           )
         | ModAsn -> if (lt = Bool) || (rt = Bool) then raise (err lt rt) else check_assign lt rt (err lt rt))

      | BuiltInCall(fname, actuals) as bcall ->
        let fd = function_decl fname in
        if List.length actuals != List.length fd.formals then
          raise (Failure ("expecting " ^ string_of_int
                            (List.length fd.formals) ^ " arguments in " ^ string_of_expr bcall))
        else  (match fname with
              "free" -> (* HACK: For the built-in function "free" we need to coerce the pointers to void* *)
                  (match expr (List.hd actuals) with
                     Pointer(_) -> Void
                   | t -> raise (Failure ("built-in function free(): attempting to free non-pointer " ^ string_of_typ t ^ "."))
                ) (* end match expr *)
            | "malloc" -> let err = (Failure ("SEMANT: improper arguments to malloc-call " ^ string_of_expr bcall ^ ".")) in
              (match (List.hd actuals) with
                 Binop(e1, Mult, e2) -> (match (expr e1, expr e2) with
                     ((Int, Size_t) | (Size_t, Int)) -> Pointer(Void)
                   | _ -> raise err) (* FIXME: need to extract actual pointer type *)
                | Sizeof(t) -> Pointer(t)
                | _ -> raise err) (* end match expr *)
            | "atoi" as func ->
              let input = expr (List.hd actuals) in
              if input = Pointer(Char) then Int else raise (Failure ("SEMANT: improper input type " ^ string_of_typ input ^ " to " ^ func))
            | "strdup" as func ->
              let input = expr (List.hd actuals) in
              if input = Pointer(Char) then input else raise (Failure ("SEMANT: improper input type " ^ string_of_typ input ^ " to " ^ func))
        (* FIXME: Add ANY Semantic checking for the rest of the builtin functions *)
            | "print" -> Void
            | "printb" -> Void
            | "printf" -> Void
            | "printbig" -> Void
            | n -> raise (Failure("SEMANT: Illegal built-in name " ^ n ^ "."))
          ) (* end match fname *)
      | Call(fname, actuals) as call ->
        (* FIXME: A.string_of_expr will only work for non-compound functions, will break for function pointers *)
         let fname' = string_of_expr fname in
         let fd = function_decl fname' in
         if List.length actuals != List.length fd.formals then
           raise (Failure ("expecting " ^ string_of_int
             (List.length fd.formals) ^ " arguments in " ^ string_of_expr call))
         else
           List.iter2 (fun (ft, _, _, _) e -> let et = expr e in
              ignore (check_assign ft et
                (Failure ("illegal actual argument found " ^ string_of_typ et ^
                " expected " ^ string_of_typ ft ^ " in " ^ string_of_expr e))))
             fd.formals actuals;
         fd.typ
      | ArrayAccess(arr, _) ->
            (* This is C, we are fortunate not to have to do boundary checking *)
            (* FIXME: Add some actual semantic checking here *)
            expr arr
      | StructAccess(s, m) as sacc ->
        let sd = struct_decl (string_of_typ (strip (strip (expr s))))in
          let members = List.fold_left (fun m (t,n,_,_) -> StringMap.add n t m) StringMap.empty sd.members in
          (* Iterate through the members of the struct; if name found, return its type, else fail *)
          (try StringMap.find m members
           with Not_found -> raise (Failure ("illegal member " ^ m ^ " of struct " ^ string_of_expr sacc)))
      | StructPointerAccess(s, m) as spacc ->
        let sd = (match s with
            | ArrayAccess(a,_) -> struct_decl (string_of_typ (strip (expr (Dereference a))))
            | _ -> struct_decl (string_of_typ (strip (expr s)))
          ) in
          let members = List.fold_left (fun m (t,n,_,_) -> StringMap.add n t m) StringMap.empty sd.members in
          (* Iterate through the members of the struct; if name found, return its type, else fail *)
          (try StringMap.find m members
           with Not_found -> raise (Failure ("illegal member " ^ m ^ " of struct " ^ string_of_expr spacc)))
      | BuildArray(_,_) -> Pointer(Void)
      | Cast(t, e) -> ignore( expr e ) ; t
      (* | _ as ex -> raise (Failure ("unknown expression " ^ (string_of_expr ex) ^ ".")) *)

    in (* END expr CHECKER *)

    let check_bool_expr e = if expr e != Bool
     then raise (Failure ("expected Boolean expression in " ^ string_of_expr e))
     else () in

    (* Verify a statement or throw an exception *)
    let rec stmt = function
	Block sl -> let rec check_block = function
           [Return _ as s] -> stmt s
         | Return _ :: _ -> raise (Failure "nothing may follow a return")
         | Block sl :: ss -> check_block (sl @ ss)
         | s :: ss -> stmt s ; check_block ss
         | [] -> ()
        in check_block sl
      | Expr e -> ignore (expr e)
      | Return e -> let t = expr e in if t = func.typ then () else
         raise (Failure ("return gives " ^ string_of_typ t ^ " expected " ^
                         string_of_typ func.typ ^ " in " ^ string_of_expr e))

      | If(p, b1, b2) -> check_bool_expr p; stmt b1; stmt b2
      | For(e1, e2, e3, st) -> ignore (expr e1); check_bool_expr e2;
                               ignore (expr e3); stmt st
      | While(p, s) -> check_bool_expr p; stmt s
    in

    stmt (Block func.body)

  in
  List.iter check_function program.functions
