(* Semantic checking for the M/s compiler *)

open Ast

module StringMap = Map.Make(String)


let check (functions, struct_decls) =

  (* Raise an exception if the given list has a duplicate *)
  let report_duplicate exceptf list =
    let rec helper = function
        n1 :: n2 :: _ when n1 = n2 -> raise (Failure (exceptf n1))
      | _ :: t -> helper t
      | [] -> ()
    in helper (List.sort compare list)
  in

  (*let get_vector_element_type = function
      Int -> Int
    | Double -> Double
    | Bool -> Bool
    | Void -> Void
    | Job(x) -> Job(x)
    | Vector(t) -> Vector(t)
  in *)

let rec create_struct_map map  = function
      [] -> StringMap.empty
    | [(typ, name)] -> StringMap.add name typ map
    | h::t -> let m = create_struct_map map  [h] in create_struct_map m t
  in
  let rec create_global_map global_map = function
      [] -> StringMap.empty
    |  [x] -> StringMap.add x.sname (create_struct_map StringMap.empty x.blist) global_map
    | h::t -> let m = create_global_map global_map [h] in create_global_map m t
  in

  let global_map = create_global_map StringMap.empty struct_decls
in

  (* Raise an exception if a given binding is to a void type *)
  let check_not_void exceptf = function
      (Void, n) -> raise (Failure (exceptf n))
    | _ -> ()
  in

  (* Raise an exception of the given rvalue type cannot be assigned to
     the given lvalue type *)
  let check_assign lvaluet rvaluet err =
     if lvaluet = rvaluet then lvaluet else raise err
  in
  if List.mem "print" (List.map (fun fd -> fd.fname) functions)
  then raise (Failure ("function print may not be defined")) else 
  if List.mem "printd" (List.map (fun fd -> fd.fname) functions)
  then raise (Failure ("function printd may not be defined")) else 
  if List.mem "printb" (List.map (fun fd -> fd.fname) functions)
  then raise (Failure ("function printb may not be defined")) else
  if List.mem "prints" (List.map (fun fd -> fd.fname) functions)
  then raise (Failure ("function prints may not be defined")) else ();

  report_duplicate (fun n -> "duplicate function " ^ n)
    (List.map (fun fd -> fd.fname) functions);


  (* Function declaration for a named function *)

  let built_in_decls = List.fold_left
   (fun map (name, t) -> StringMap.add name t map)
   StringMap.empty
    [ ("print", { typ = Void; fname = "print"; formals = [(Int, "x")]; body = [] });
      ("printb", { typ = Void; fname = "printb"; formals = [(Bool, "x")]; body = [] });
      ("printd", { typ = Void; fname = "printd"; formals = [(Double, "x")]; body = [] });
      ("prints", { typ = Void; fname = "prints"; formals = [(Vector(Char), "x")]; body = [] });
    ]
  in

  (* adds (fname, AST.fdecl) to map function_decls *)
  let function_decls = List.fold_left (fun m fd -> StringMap.add fd.fname fd m)
                        built_in_decls functions
  in

  let function_decl s = try StringMap.find s function_decls
       with Not_found -> raise (Failure ("unrecognized function " ^ s))
  in

  let _ = function_decl "master" in (*Ensure "master" is defined*)
  (*let expr_type = Hashtbl.create 5000 in*)
  let check_function func =
    List.iter (check_not_void (fun n -> "illegal void formal " ^ n ^
      " in " ^ func.fname)) func.formals;

    report_duplicate (fun n -> "duplicate formal " ^ n ^ " in " ^ func.fname)
      (List.map snd func.formals);

  (* Type of each variable (global, formal, or local *)
  let symbols = Hashtbl.create 5000 in
  let _ = List.iter (fun (t, n) -> Hashtbl.add symbols n t)
    func.formals in

  let type_of_identifier s =
    try Hashtbl.find symbols s
    with Not_found -> raise (Failure ("undeclared identifier " ^ s))
  in

  let struct_decl s = 
    try StringMap.find s global_map
    with Not_found -> raise (Failure ("undeclared struct " ^ s))
  in

  (* Return the type of an expression or throw an exception *)
  let rec expr e = match e with
      Literal _ -> Int
    | DoubleLit _ -> Double
    | StringLit _ -> Vector(Char)
    | BoolLit _ -> Bool
    | Id s -> type_of_identifier s
    | Binop(e1, op, e2) as e -> let t1 = expr e1 and t2 = expr e2 in
        (match op with
            Add | Sub | Mult | Div when t1 = Int && t2 = Int -> Int
          | Add | Sub | Mult | Div when t1 = Double && t2 = Double -> Double
          | Equal | Neq when t1 = t2 -> Bool
          | Less | Leq | Greater | Geq when t1 = Int && t2 = Int || t1 = Double && t2 = Double -> Bool
          | And | Or  when t1 = Bool && t2 = Bool -> Bool
          | _ -> raise (Failure ("illegal binary operator " ^
            string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
          string_of_typ t2 ^ " in " ^ string_of_expr e))
        )
    | Unop(op, e) as ex -> let t = expr e in
       (match op with
         Neg when t = Int -> Int
       | Neg when t = Double -> Double
       | Not when t = Bool -> Bool
             | _ -> raise (Failure ("illegal unary operator " ^ string_of_uop op ^
               string_of_typ t ^ " in " ^ string_of_expr ex)))
    | Noexpr -> Void
    | Assign(var, e) as ex -> let lt = type_of_identifier var
                              and rt = expr e in
        check_assign lt rt (Failure ("illegal assignment " ^ string_of_typ lt ^
           " = " ^ string_of_typ rt ^ " in " ^
           string_of_expr ex))
    | RemoteCall (fname, actuals) as call-> let fd = function_decl fname in
       if List.length actuals != List.length fd.formals then
         raise (Failure ("expecting " ^ string_of_int
           (List.length fd.formals) ^ " arguments in " ^ string_of_expr call))
       else
         List.iter2 (fun (ft, _) e -> let et = expr e in
            ignore (check_assign ft et
              (Failure ("illegal actual argument found " ^ string_of_typ et ^
              " expected " ^ string_of_typ ft ^ " in " ^ string_of_expr e))))
           fd.formals actuals;
           Job((function_decl fname).typ)
    | Call(fname, actuals) as call -> let fd = function_decl fname in
       if List.length actuals != List.length fd.formals then
         raise (Failure ("expecting " ^ string_of_int
           (List.length fd.formals) ^ " arguments in " ^ string_of_expr call))
       else
         List.iter2 (fun (ft, _) e -> let et = expr e in
            ignore (check_assign ft et
              (Failure ("illegal actual argument found " ^ string_of_typ et ^
              " expected " ^ string_of_typ ft ^ " in " ^ string_of_expr e))))
           fd.formals actuals;
         fd.typ

    | Get(job) -> (match (type_of_identifier job) with
            Job(typ) -> typ
            | _ -> raise(Failure("Getting something not a job")) )

    | Cancel(job) -> (match (type_of_identifier job) with
            Job(_) -> Void
            | _ -> raise(Failure("Cancelling something not a job")) )
    | Running(job) -> (match (type_of_identifier job) with
            Job(_) -> Bool
            | _ -> raise(Failure("Checking running of something not a job")) )
    | Finished(job) -> (match (type_of_identifier job) with
            Job(_) -> Bool
            | _ -> raise(Failure("Checking finished of something not a job")) )
    | Failed(job) -> (match (type_of_identifier job) with
            Job(_) -> Bool
            | _ -> raise(Failure("Checking failed of something not a job")) )

    | VectorAccess(e, e1) -> let rec step = function
                               (typ, []) -> typ
                             | (Vector(typ), _::l) -> step (typ, l)
                             | _ -> raise(Failure("Vector access of something not a vector")) in
                        step ((type_of_identifier e), e1)
    | VectorAssign(id, e1, e2) -> let ltyp = expr (VectorAccess(id, e1)) in
                                  let rtyp = expr e2 in
                                  ignore (check_assign ltyp rtyp
                                         (Failure ("illegal vector assignment found "
                                                   ^ string_of_typ rtyp ^ " expected "
                                                   ^ string_of_typ ltyp))) ; ltyp
    | VectorSize(_) -> Int
    | StructFieldAccess(e, fieldname) -> (match (expr e) with
      | Struct(struct_name) ->
          let struct_def_map = struct_decl struct_name in
          if StringMap.mem fieldname struct_def_map
            then StringMap.find fieldname struct_def_map
          else raise(Failure("cannot find fieldname: " ^ fieldname))
      | _ -> raise(Failure("accessing field of something not a struct")) )

    | StructFieldAssign(e1, fieldname, e) -> let ltyp = expr (StructFieldAccess(e1, fieldname)) in
      let rtyp = expr e in ignore(check_assign ltyp rtyp (Failure ("illegal struct field assignment"))); ltyp
    | Concat(e1, e2) -> let ltyp = expr e1 in let rtyp = expr e2 in ignore(check_assign ltyp rtyp (Failure ("illegal string concat"))); ltyp
    | _ -> raise(Failure("unrecognized expr"))

  in

  let check_bool_expr e = if expr e != Bool
     then raise (Failure ("expected Boolean expression in " ^ string_of_expr e))
     else ()
  in

  (* Verify a statement or throw an exception *)
  let rec stmt inside_loop = function
    | Block sl -> let rec check_block = function
         [Return _ as s] -> stmt inside_loop s
       | Return _ :: _ -> raise (Failure "nothing may follow a return")
       | Block sl :: ss -> check_block (sl @ ss)
       | s :: ss -> stmt inside_loop s ; check_block ss
       | [] -> ()
      in check_block sl

    | VarDecl(typ, id) ->
          if inside_loop then (match typ with
            Vector(_) -> raise (Failure "cannot declare vectors or strings inside for loop")
            | _ -> ();
          );
          (match typ with 
            | Struct(name) -> ignore(struct_decl name); ignore(Hashtbl.add symbols id typ);
            check_not_void (fun n -> "illegal void local " ^ n ^ " in " ^ func.fname) (typ, id)
            | _ -> if Hashtbl.mem symbols id then raise (Failure "duplicate local variable declarations") else ignore(Hashtbl.add symbols id typ) ;
            check_not_void (fun n -> "illegal void local " ^ n ^ " in " ^ func.fname) (typ, id)
          )
    | VarDeclAssign(typ, id, e) ->
          if inside_loop then (match typ with
            Vector(_) -> raise (Failure "cannot declare vectors or strings inside for loop")
            | _ -> ();
          );
          if Hashtbl.mem symbols id then raise (Failure "duplicate local variable declarations") else  ignore(Hashtbl.add symbols id typ) ;
          check_not_void (fun n -> "illegal void local " ^ n ^ " in " ^ func.fname) (typ, id) ;
          let rt = expr e in
          ignore (check_assign typ rt (Failure ("illegal assignment " ^ string_of_typ typ ^ " = " ^ string_of_typ rt)))

    | Expr e -> ignore (expr e)
    | Return e -> let t = expr e in if t = func.typ then () else
       raise (Failure ("return gives " ^ string_of_typ t ^ " expected " ^
                       string_of_typ func.typ ^ " in " ^ string_of_expr e))

    | If(p, b1, b2) -> check_bool_expr p; stmt inside_loop b1; stmt inside_loop b2
    | For(e1, e2, e3, st) -> ignore (expr e1); check_bool_expr e2;
                             ignore (expr e3); stmt true st
    | While(p, s) -> check_bool_expr p; stmt true s
    | VectorPushBack(e1, e) -> (match (expr e1) with
                Vector(ltyp) -> let rtyp = expr e in
                    ignore (check_assign ltyp rtyp (Failure ("Illegal vector pushback found "
                    ^ string_of_typ rtyp^ " expected " ^ string_of_typ ltyp)))
              | _ -> raise(Failure("pushing back to something not a vector")))
  in

  stmt false (Block func.body)

  in
  List.iter check_function functions
