(* Semantic checking for the GoBackwards compiler *)

open Ast

module StringMap = Map.Make(String)


(* Semantic checking of a program. Returns void if successful,
   throws an exception if something is wrong.

   Check each global variable, then check each function *)

let check (globals, functions) =

  (* Raise an exception if the given list has a duplicate *)
  let report_duplicate exceptf list =
    let rec helper = function
	n1 :: n2 :: _ when n1 = n2 -> raise (Failure (exceptf n1))
      | _ :: t -> helper t
      | [] -> ()
    in helper (List.sort compare list)
  in

  (* Raise an exception if a given binding is to a void type *)
  let check_not_void exceptf = function
      (Void,n) -> raise (Failure (exceptf n))
    | _ -> ()
  in

  (* Raise an exception of the given rvalue type cannot be assigned to
     the given lvalue type *)
  let check_assign lvaluet rvaluet err =
     if lvaluet == rvaluet then lvaluet else raise err
  in

  (**** Checking Global Variables ****)

  List.iter (check_not_void (fun n -> "illegal void global " ^ n)) globals;

  report_duplicate (fun n -> "duplicate global " ^ n) (List.map snd globals);

  (**** Checking Functions ****)

  if List.mem "print" (List.map (fun fd -> fd.fname) functions)
  then raise (Failure ("function print may not be defined")) else ();

  if List.mem "ascii" (List.map (fun fd -> fd.fname) functions)
  then raise (Failure ("function ascii may not be defined")) else ();

  report_duplicate (fun n -> "duplicate function " ^ n)
    (List.map (fun fd -> fd.fname) functions);

  (* Function declaration for a named function *)
  let built_in_decls =

      StringMap.add "print"
      {
        fname = "print";
        signature = { ret_typ = Void; formals = [( Int,"x")] };
        body = { locals = []; stmts = [] }
       }
       ( StringMap.add "println"
            { fname = "println";
              signature = { ret_typ = Void; formals = [(String,"x")] };
              body = { locals = []; stmts = [] }}

        ( StringMap.add "printb"
            { fname = "printb";
            signature = { ret_typ = Void; formals = [(Bool,"x")] };
            body = { locals = []; stmts = [] }}

      ( StringMap.add "ascii"
       { fname = "ascii";
        signature = { ret_typ = Void; formals = [(String, "x")] };
          body = { locals = []; stmts = [] }}


        StringMap.empty)))

      (*the parentheses will return a stringmap with println*)
      (*the top level stringmap will add athe stringmap with println*)

    in

    let function_decls = List.fold_left (fun m fd -> StringMap.add fd.fname {fname = fd.fname; signature = fd.signature; body=fd.body} m)
                           built_in_decls functions
  in

  let function_decl s = try StringMap.find s function_decls
       with Not_found -> raise (Failure ("unrecognized function " ^ s))
  in

  let _ = function_decl "main" in (* Ensure "main" is defined *)

  let check_function func =

    List.iter (check_not_void (fun n -> "illegal void formal " ^ n ^
      " in " ^ n)) func.body.locals;

    report_duplicate (fun n -> "duplicate formal " ^ n ^ " in " ^ func.fname)
      (List.map snd func.signature.formals);

    List.iter (check_not_void (fun n -> "illegal void local " ^ n ^
      " in " ^ func.fname)) func.body.locals;

    report_duplicate (fun n -> "duplicate local " ^ n ^ " in " ^ func.fname)
      (List.map snd func.body.locals);

    (* Type of each variable (global, formal, or local *)
    let symbols = List.fold_left (fun m (t, n) -> StringMap.add n t m)
	     StringMap.empty (globals @ func.signature.formals @ func.body.locals )
    in

    let type_of_identifier s =
      try StringMap.find s symbols
      with Not_found -> raise (Failure ("undeclared identifier " ^ s))
    in

    let check_array_access = function
      ArrayType(t,_) -> t
      | _ -> raise(Failure("illegal attempt at accessing array"))
    in

    (* Return the type of an expression or throw an exception *)
    let rec expr = function
	Literal _ -> Int
      | Strlit _ -> String
      | BoolLit _ -> Bool
      | Id s -> type_of_identifier s
      | Binop(e1, op, e2) as e -> let t1 = expr e1 and t2 = expr e2 in
	(match op with
          Add | Sub | Mult | Div when t1 = Int && t2 = Int -> Int
	| Equal | Neq when t1 = t2 -> Bool
	| Less | Leq | Greater | Geq when t1 = Int && t2 = Int -> Bool
	| And | Or when t1 = Bool && t2 = Bool -> Bool
        | _ -> raise (Failure ("illegal binary operator "

        ^
              string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
              string_of_typ t2 ^ " in " ^ string_of_expr e))
       )
      | Unop(op, e) as ex -> let t = expr e in
	 (match op with
	   Neg when t = Int -> Int
	 | Not when t = Bool -> Bool
         | _ -> raise (Failure ("illegal unary operator "
         ^ string_of_uop op ^
	  		   string_of_typ t ^ " in " ^ string_of_expr ex

           )))
      | AccessArray(t ,e2)-> let _ = (match(expr e2) with
           Int -> Int
           | _ -> raise(Failure ("arrays can only be accessed through integers")))
           in check_array_access (type_of_identifier t)
      | Noexpr -> Void
      | Assign(e1, e2) as ex -> let lt = (match e1 with
          | AccessArray(t,_)-> (match (type_of_identifier t) with
              | ArrayType(t,_) -> (match t with
                    Int -> Int
                  | Bool-> Bool
                  | String-> String
                  | _ -> raise(Failure("illegal array assignment")))
                   |_ -> raise(Failure("illegal left assignment") ) )
            | _-> expr e1)

            and rt = expr e2
            in check_assign lt rt (Failure("Illegal assignment " ^ string_of_typ lt ^ " = " ^
            string_of_typ rt ^ " in " ^ string_of_expr ex))



      | Call(fname, actuals) as call -> let fd = function_decl fname in
         if List.length actuals != List.length fd.signature.formals then
           raise (Failure ("expecting " ^ string_of_int
             (List.length fd.signature.formals) ^ " arguments in " ^ string_of_expr call

             ))
         else
           List.iter2 (fun (ft,_) e -> let et = expr e in
              ignore (check_assign ft et
                (Failure ("illegal actual argument found "

                ^ string_of_typ et ^
                " expected " ^ string_of_typ ft ^ " in " ^ string_of_expr e

                ))))
             fd.signature.formals actuals;
           fd.signature.ret_typ
    in

    let check_bool_expr e = if expr e != Bool
     then raise (Failure ("expected Boolean expression in " ^ string_of_expr e
     ))
     else () in

    (* Verify a statement or throw an exception *)
    let rec stmt = function
	Block sl -> let rec check_block = function
           [Return _ as s] -> stmt s
         | Return _ :: _ -> raise (Failure "nothing may follow a return")
         | Block sl :: ss -> check_block (sl @ ss)
         | s :: ss -> stmt s ; check_block ss
         | [] -> ()
        in check_block sl
      | Expr e -> ignore (expr e)
      | Return e -> let t = expr e in if t = func.signature.ret_typ then () else
         raise (Failure ("return gives " ^ string_of_typ t ^ " expected " ^
                         string_of_typ func.signature.ret_typ ^ " in " ^ string_of_expr e
                         ))

      | If(p, b1, b2) -> check_bool_expr p; stmt b1; stmt b2
      | For(e1, e2, e3, st) -> ignore (expr e1); check_bool_expr e2;
                               ignore (expr e3); stmt st
      | While(p, s) -> check_bool_expr p; stmt s
    in

    stmt (Block func.body.stmts)

  in
  List.iter check_function functions
