(* Semantic checking *)

open Ast

exception Error of string

module StringMap = Map.Make(String)

(* Semantic checking of a program. Returns NoneType if successful,
    throws an exception if something is wrong.

    Check each global variable, then check each function *)

let check ((_, globals), (functions, classes)) =

  (* Raise an exception if the given list has a duplicate *)
  let report_duplicate exceptf list =
    let rec helper = function
        n1 :: n2 :: _ when n1 = n2 -> raise (Failure (exceptf n1))
      | _ :: t -> helper t
      | [] -> ()
    in helper (List.sort compare list)
  in

  (* Raise an exception if a given binding is to a nonetype type *)      
  let check_not_none exceptf = function
      (Ast.None, n, _) -> raise (Failure (exceptf n))
    | _ -> ()
  in

  let check_not_none_formals exceptf = function
      (Ast.None, n) -> raise (Failure (exceptf n))
    | _ -> ()
  in
  
  (* Raise an exception of the given rvalue type cannot be assigned to
      the given lvalue type *)
  let check_assign lvaluet rvaluet err =
      if lvaluet = rvaluet then lvaluet else raise err
  in

  (**** Checking Global Variables ****)
  List.iter (check_not_none (fun n -> "Illegal NoneType global variable " ^ n)) globals;
   
  report_duplicate (fun n -> "Duplicate global " ^ n) (List.map (fun (_, n, _) -> n) globals);

  (* TODO: Check assign on global variables *)

  

  (* Function declaration for a named function *)
  let built_in_decls =  StringMap.add "print"
     { typ = Ast.None; fname = "print"; formals = [(Int, "x")]; 
      fbody = { f_vdecls = []; f_stmts = [] } } (StringMap.singleton "prints"
    { typ = Ast.None; fname = "prints"; formals = [(String, "x")];
      fbody = { f_vdecls = []; f_stmts = [] } } ) in
  let built_in_decls = StringMap.add "toint" { typ = Ast.Int; fname = "toint"; formals = [(Float, "x")]; fbody = { f_vdecls = []; f_stmts = [] } } built_in_decls in
  let built_in_decls = StringMap.add "tofloat" { typ = Ast.Float; fname = "tofloat"; formals = [(Int, "x")]; fbody = { f_vdecls = []; f_stmts = [] } }  built_in_decls
    (*(StringMap.singleton "printb" { typ = Ast.None; fname = "printb"; formals = [(Bool, "x")];
      fbody = { f_vdecls = []; f_stmts = [] } })*)
  in
     
  let function_decls = List.fold_left (fun m fd -> StringMap.add fd.fname fd m)
                          built_in_decls functions
  in

  let function_decl s = try StringMap.find s function_decls
        with Not_found -> raise (Failure ("Unrecognized function " ^ s))
  in

  (****    Check functions    ****)
  let check_function func =

    report_duplicate (fun n -> "Duplicate function " ^ n)
      (List.map (fun fd -> fd.fname) functions);
  
    List.iter (check_not_none_formals (fun n -> "Illegal nonetype formal " ^ n ^
      " in " ^ func.fname )) func.formals;

    report_duplicate (fun n -> "Duplicate formal " ^ n ^ " in " ^ func.fname)
      (List.map snd func.formals);

    List.iter (check_not_none (fun n -> "Illegal NoneType local variable " ^ n ^
      " in " ^ func.fname)) func.fbody.f_vdecls;

    report_duplicate (fun n -> "Duplicate local variable " ^ n ^ " in " ^ func.fname)
      (List.map (fun (_, n, _) -> n) func.fbody.f_vdecls);

    (* Type of each variable (global, formal, or local *)
    let symbols = List.fold_left (fun m (t, n) -> StringMap.add n t m)
      StringMap.empty ((List.map (fun (t, n, _) -> (t, n)) globals) @ func.formals @ 
      (List.map (fun (t, n, _) -> (t, n)) func.fbody.f_vdecls))
    in

    let type_of_identifier s =
      try StringMap.find s symbols
      with Not_found -> raise (Failure ("Undeclared identifier " ^ s))
    in

    (* Return the type of an expression or throw an exception *)
    let rec expr = function
        IntLit _ -> Int
      | FloatLit _ -> Float
      | BoolLit _ -> Bool
      | StringLit _ -> String
      | Id s -> type_of_identifier s
      | TupleLit l -> 
        let first_el = List.hd l in
        let type_el = expr first_el in
        let length = List.length l in
        Tuple (type_el, length)
      | Binop(e1, op, e2) as e -> let t1 = expr e1 and t2 = expr e2 in  
        (match op with
            Add | Sub | Mult | Div | Mod when t1 = Int && t2 = Int -> Int
          | Add | Sub | Mult | Div | Mod when t1 = Float && t2 = Float -> Float
          | Add when t1 = String && t2 = String ->
                (match (e1, e2) with 
                   (Ast.StringLit(_), Ast.StringLit(_)) -> String
                 | _ -> raise (Failure ("Only raw strings can be concatenated"))) 
          | Equal | Neq when t1 = t2 -> Bool 
          | Less | Leq | Greater | Geq when t1 = t2 -> Bool
          | And | Or when t1 = Bool && t2 = Bool -> Bool
          | _ -> raise (Failure ("Illegal binary operator " ^ string_of_typ t1 ^ 
                  " " ^ string_of_op op ^ " " ^ string_of_typ t2 ^ " in " ^ 
                  string_of_expr e))
          )    
      | Unop(op, e) as ex -> let t = expr e in   
        (match op with
            Neg when t = Int -> Int
          | Neg when t = Float -> Float
          | Not when t = Bool -> Bool
          | _ -> raise (Failure ("Illegal unary operator " ^ string_of_uop op ^
                    string_of_typ t ^ " in " ^ string_of_expr ex)))
      | Assign(var, e) as ex -> let lt = type_of_identifier var
                                and rt = expr e in
                  check_assign (type_of_identifier var) (expr e)
                  (Failure ("Illegal assignment " ^ string_of_typ lt ^ " = " ^
                            string_of_typ rt ^ " in " ^ string_of_expr ex))
      | Call(fname, actuals) as call -> 
          if fname <> "print" then (* The print function can have different 
                                      types of arguments *)
          (let fd = function_decl fname in
           if List.length actuals != List.length fd.formals then
             raise (Failure ("Expecting " ^ string_of_int
              (List.length fd.formals) ^ " arguments in " ^ string_of_expr call))
           else
             List.iter2 (fun (ft, _) e -> let et = (match expr e with 
                                              (* Since there is a call to
                                              another function, we don't care
                                              about the length *)
                            Ast.Tuple(ty, _) -> Ast.Tuple(ty, 0)
                          |_ -> expr e) in
               ignore (check_assign ft et
                 (Failure ("Illegal actual argument found " ^ string_of_typ et ^
                 " expected " ^ string_of_typ ft ^ " in " ^ string_of_expr e))))
               fd.formals actuals;
           fd.typ) else String
      | Element(_, el) as element -> 
        if expr el != Int then
        raise (Failure ("Invalid element access in " ^ string_of_expr element))
        else Int
      | Noexpr -> Ast.None
    in

    let check_variable_assign (t, n, ex) = 
      let rt = expr ex in
        ignore(check_assign (type_of_identifier n) (expr ex)
                (Failure ("Illegal assignment " ^ string_of_typ t ^ " := " ^ string_of_typ rt ^ 
                  " in \'" ^ string_of_typ t ^ " " ^ n ^ " = " ^ string_of_expr ex ^ "\'"))); ()
    in

    let check_bool_expr e = if expr e != Bool
      then raise (Failure ("Expected Boolean expression in " ^ string_of_expr e))
      else () 
    in

    (* Verify a statement or throw an exception *)
    let rec stmt = function
        Block sl -> let rec check_block = function
           [Return _ as s] -> stmt s
         | Return _ :: _ -> raise (Failure "nothing may follow a return")
         | Block sl :: ss -> check_block (sl @ ss)
         | s :: ss -> stmt s ; check_block ss
         | [] -> ()
          in check_block sl
      | Expr e -> ignore (expr e)
      | Return e -> let t = (match expr e with (* Since there is a call to
                                              another function, we don't care
                                              about the length *)
                            Ast.Tuple(ty, _) -> Ast.Tuple(ty, 0)
                          |_ -> expr e) in if t = func.typ then () else
          raise (Failure (func.fname ^ " returns " ^ string_of_typ t ^ " but " ^
                          string_of_typ func.typ ^ " was expected."))         
      | If(p, b1, b2, b3) -> check_bool_expr p; stmt b1; stmt b2; stmt b3
      | Elif(p, b1) -> check_bool_expr p; stmt b1
      | For(e1, e2, e3, st) ->  ignore (expr e1); check_bool_expr e2;
                                ignore (expr e3); stmt st
      | ForIn(_, e2, st) -> (* Handle nonetype identifier e1 *)
                if (match expr e2 with
                      Tuple(_, _) -> false
                    | _ -> true) && expr e2 != String
                then raise (Failure ("Expected a tuple or a str in " ^ string_of_expr e2));
                stmt st
      | While(p, s) -> check_bool_expr p; stmt s
      | HiddenWhile(e, s1, s2) -> ignore(expr e); stmt s1; stmt s2
      | In(e1, e2) -> if (match expr e2 with
                      Tuple(_, _) -> false
                    | _ -> true) then 
              raise (Failure ("Expected a tuple in " ^ string_of_expr e1 ^ " " ^ string_of_expr e2))
              else ()
      | Break -> ignore (0)
      | Continue -> ignore (0)
      | Nostmt -> ignore (0)
      | Declaration(_) -> ignore (0)
    in
      ignore (List.map check_variable_assign func.fbody.f_vdecls); 
      stmt (Block func.fbody.f_stmts); ()
  in

  (****    Check classes    ****) 
  let check_class c =
    
    List.iter (check_not_none (fun n -> "Illegal NoneType class local " ^ n ^
      " in " ^ c.cname)) c.cbody.vdecls;

    report_duplicate (fun n -> "Duplicate class local " ^ n ^ " in " ^ c.cname)
      (List.map (fun (_, n, _) -> n) c.cbody.vdecls);

    List.iter check_function c.cbody.funcs;
in    
  List.iter check_class classes;
  List.iter check_function functions;
  report_duplicate (fun n -> "Duplicate class " ^ n) (List.map (fun cd -> cd.cname) classes)