(* Semantic Checker for the Tree++ Programming Language
   PLT Fall 2018
   Authors:
   Allison Costa
   Laura Matos
   Laura Smerling
   Jacob Penn
*)

open Ast
open Sast

module StringMap = Map.Make(String)

(* Semantic checking of a program. Returns void if successful,
   throws an exception if something is wrong.
   Check each global variable, then check each function *)

let check_program program =

  (* Raise an exception if the given list has a duplicate *)
  let report_duplicate exceptf list =
    let rec helper = function
        n1 :: n2 :: _ when n1 = n2 -> raise (Failure (exceptf n1))
      | _ :: t -> helper t
      | [] -> ()
    in helper (List.sort compare list)
  in

  (* figure out which items are statements and make a list of statements *)
  let stmt_list =
    let stmts_as_items =
      List.filter (fun x -> match x with
        Ast.Stmt(x) -> true
        | _ -> false) program
    in List.map (fun x -> match x with
        Ast.Stmt(x) -> x
        | _ -> failwith "stmt casting didn't work") stmts_as_items
  in

  (* after you figure out which items are statements, you need to go through the statements
     and figure out which ones contain the variable declarations and
     variable decl+assignment statements *)
  let globals =
    let global_list = List.filter (fun x -> match x with
        Ast.VarDec((_, x), _) -> true
      | _ -> false) stmt_list
    in List.map (fun x -> match x with
        Ast.VarDec(x, _) -> x
      | _ -> failwith "not turned into global") global_list
  in

  let functions =
      let functions_as_items = List.filter (fun x -> match x with
          Ast.Function(x) -> true
        | _ -> false) program
      in
        let all_functions_as_items = functions_as_items
        in List.map (fun x -> match x with
            Ast.Function(x) -> x
          | _ -> failwith "function casting didn't work") all_functions_as_items
  in

  (* let function_locals =
    let get_locals_from_fbody fdecl =
      let get_vdecl locals_list stmt = match stmt with
          Ast.VDecl(typ, string) -> (typ, string) :: locals_list
          | _ -> locals_list
      in
      List.fold_left get_vdecl [] fdecl.Ast.body
    in List.fold_left get_locals_from_fbody (List.hd functions) (List.tl functions)
  in *)

  let symbols = List.fold_left (fun var_map (varType, varName) -> StringMap.add varName varType var_map)
    StringMap.empty (globals)
  in

  let type_of_identifier s =
    try StringMap.find s symbols
    with Not_found -> raise (Failure ("undeclared identifier " ^ s))
  in

  (* Raise an exception of the given rvalue type cannot be assigned to
     the given lvalue type *)
  let check_assign lvaluet rvaluet err =
     if lvaluet == rvaluet then lvaluet else raise err
  in

  (* Raise an exception if a given binding is to a void type *)
  let check_not_void exceptf = function
      (Void, n) -> raise (Failure (exceptf n))
    | _ -> ()
  in

  let built_in_decls = StringMap.add "println"
      { typ = Void; fname = "println"; formals = []; body = [] }
      (StringMap.singleton "printbig"
    { typ = Int; fname = "printbig"; formals = [(Int,"x")];
      body = [] }) 
  in

  let function_decls = List.fold_left (fun m fd -> StringMap.add fd.fname fd m)
    built_in_decls functions
  in

  let function_decl s =   try StringMap.find s function_decls
       with Not_found ->  raise (Failure ("unrecognized function " ^ s))
 
  in

 (*checks to see if any library functions are defined by user - not allowed *)
  let check_function func =
    report_duplicate (fun n -> "duplicate formal " ^ n ^ " in " ^ func.fname)
      (List.map snd func.formals);

    if List.mem "print" (List.map (fun fd -> fd.fname) functions)
      then raise (Failure ("function print may not be defined")) else ();

    if List.mem "println" (List.map (fun fd -> fd.fname) functions)
      then raise (Failure ("function println may not be defined")) else ();

    if List.mem "printf" (List.map (fun fd -> fd.fname) functions)
      then raise (Failure ("function printf may not be defined")) else ();

    report_duplicate (fun n -> "duplicate function " ^ n)
      (List.map (fun fd -> fd.fname) functions);

    if List.mem "main" (List.map (fun fd -> fd.fname) functions)
      then raise (Failure ("function main may not be defined")) else ();

    List.iter (check_not_void (fun n -> "illegal null formal " ^ n ^
      " in " ^ func.fname)) func.formals;

    (* List.iter (check_not_void (fun n -> "illegal void local " ^ n ^
      " in " ^ func.fname)) func.locals; *)

    (* report_duplicate (fun n -> "duplicate local " ^ n ^ " in " ^ func.fname)
      (List.map snd func.locals); *)
  in

  let rec expr = function
      Literal l -> (Int, SLiteral l)
    | Fliteral l -> (Float, SFliteral l)
    | BoolLit l -> (Bool, SBoolLit l)
    | Sliteral l -> (String, SSliteral l)
    | Id s -> (type_of_identifier s, SId s)
    | Assign(var, e) as ex -> 
          let lt = type_of_identifier var
          and (rt, e') = expr e in
          let err = Failure("illegal assignement " ^ string_of_typ lt ^ " = " ^ 
          string_of_typ rt ^ " in " ^ string_of_expr ex)
          in (check_assign lt rt err, SAssign(var, (rt, e')))
    | Binop(e1, op, e2) as e -> let (t1, e1') = expr e1 and (t2,e2') = expr e2 in
    let ty = match op with
        Add | Sub | Mult | Div | Mod when t1 = Int && t2 = Int -> Int
        | Equal | Neq when t1 = t2 -> Bool
        | Less | Leq | Greater | Geq when t1 = Int && t2 = Int -> Bool
        | And | Or when t1 = Bool && t2 = Bool -> Bool
        | _ -> raise (Failure ("illegal binary operator " ^
            string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
            string_of_typ t2 ^ " in " ^ string_of_expr e))
               in
            (ty, SBinop((t1, e1'), op,(t2, e2')))
    | FunCall(fname, actuals) as call -> let fd = function_decl fname in
      if (fname = "print" || fname = "println")
             then
                     let sactuals  = List.map (fun e -> expr e) actuals in (fd.typ,SFunCall(fname, sactuals));
      else
          (if List.length actuals != List.length fd.formals then
         raise (Failure ("expecting " ^ string_of_int
           (List.length fd.formals) ^ " arguments in " ^ string_of_expr call))
         else
           List.iter2 (fun (ft, _) e -> let (t, et) = expr e in
              ignore (check_assign ft t
                (Failure ("illegal actual argument: found " ^ string_of_typ t ^
                " ; expected " ^ string_of_typ ft ^ " in " ^ string_of_expr e))))
           fd.formals actuals;
             let sactuals = List.map (fun e -> expr e) actuals in 
             (fd.typ,SFunCall(fname,sactuals)))  (* this is pretty sketch *)
    | Unop(op, e) as ex -> let (t, e') = expr e in
      (match op with
	    Neg when t = Int -> (Int, SUnop(op,(t, e')))
 	    | Not when t = Bool -> (Bool, SUnop(op,(t, e')))
            | _ -> raise (Failure ("illegal unary operator " ^ string_of_uop op ^
 	  		   string_of_typ t ^ " in " ^ string_of_expr ex)))
   | Noexpr -> (Void,SNoexpr)
  in 

  let check_bool_expr e = if fst (expr e) != Bool
   then raise (Failure ("expected Boolean expression in " ^ string_of_expr e))
   else expr e 
    in

    let rec check_stmt s = match s with
          Expr e -> SExpr (expr e) 
        | VarDec((t,s),e) -> SVarDec((t,s),expr e)
        | If(p, b1, b2) -> SIf(check_bool_expr p, check_stmt b1, check_stmt b2)
        | For(e1,e2,e3,s)-> SFor(expr e1, expr e2, expr e3, check_stmt s)
        | While(p, s) -> SWhile(check_bool_expr p, check_stmt s)
        | Return e -> SReturn(expr e)
          (*              let (t, e') = expr e in 
              if t = func.typ then SReturn (t, e')
                    else raise ( Failure ("return gives " ^ string_of_typ t ^ " exprected "
                    ^ string_of_typ func.typ ^ " in " ^ string_of_expr e)) *)
        | Seq sl -> let rec check_seq = function
             [Return _ as x] -> [check_stmt x]
             | Return _ :: _ -> raise (Failure "nothing may follow a return")
             | Seq sl :: ss -> check_seq (sl @ ss)
             | s :: ss -> check_stmt s :: check_seq ss
             | [] -> []
          in SSeq(check_seq sl)
  in

  let realcheck_functions func =
          { 
    styp = func.typ;
    sfname = func.fname;
    sformals = func.formals;
    sbody = (List.map check_stmt func.body);}
in

  (* Check for assignments and duplicate vdecls *)
 (* let y = (List.map check_stmt stmt_list) *)
let _ = ignore(List.iter check_function functions) in
let convert x = List.map (fun y -> SStmt(y)) x in
let transmit z = List.map (fun y -> SFunction(y)) z in
(transmit (List.map realcheck_functions functions),convert(List.map check_stmt stmt_list));

  (* ignore(List.iter check_function functions);
  ignore(List.map check_stmt stmt_list);*)
  (* List.iter stmt stmt_list; *)
 (*report_duplicate (fun n -> "Duplicate declaration or assignment for " ^ n) (List.map snd globals);*)
