(*
    Semantic checking for Casper
    Based on MicroC
    File: semant.ml
    Michael Makris, mm3443
    PLT Fall 2018
*)

open Ast
open Sast

module StringMap = Map.Make(String)

(* Semantic checking of the AST. Returns an SAST if successful,
   throws an exception if something is wrong.

   Check each global variable, then check each function *)

let check (globals, functions) =

  (* Verify a list of bindings has no void types or duplicate names *)
  let check_binds (kind : string) (binds : casperVariable list) =
    List.iter (function
    (Void, b) -> raise (Failure ("illegal void " ^ kind ^ " " ^ b))
      | _ -> ()) binds;
    let rec dups = function
        [] -> ()
      | ((_,n1) :: (_,n2) :: _) when n1 = n2 ->
    raise (Failure ("duplicate " ^ kind ^ " " ^ n1))
      | _ :: t -> dups t
    in dups (List.sort (fun (_,a) (_,b) -> compare a b) binds)
  in

  (**** Check global variables ****)

  check_binds "global" globals;

  (**** Check functions ****)

  (* Collect function declarations for built-in functions: no bodies *)
  let built_in_decls =
    let add_bind map (name, ty, fty) = StringMap.add name {
      functionType = fty;
      functionName = name;
      functionFormals = [(ty, "x")];
      functionLocals = []; functionStatements = [] } map
    in List.fold_left add_bind StringMap.empty [ ("printi", Int, Void);
    ("printinl", Int, Void); ("printfnl", Float, Void); ("printf", Float, Void);
    ("prints", String, Void); ("printsnl", String, Void); ("printc", Char, Void); 
    ("printcnl", Char, Void);("printbig", Int, Void);
    ("printbasi", Bool, Void); ("printb", Bool, Void); ("printbnl", Bool, Void);
    ("sin", Float, Float); ("cos", Float, Float); ("tan", Float, Float);
    ("asin", Float, Float); ("acos", Float, Float); ("atan", Float, Float);
    ("sinh", Float, Float); ("cosh", Float, Float); ("tanh", Float, Float);
    ("exp", Float, Float); ("log", Float, Float); ("log10", Float, Float);
    ("sqrt", Float, Float); ("floor", Float, Float); ("ceil", Float, Float);
    ("abs", Float, Float); ("srand", Int, Void); ("strlen", String, Int);
    ("strerror", Int, String); ("isalnum", Char, Int); ("isalpha", Char, Int);
    ("iscntrl", Char, Int); ("isdigit", Char, Int); ("isgraph", Char, Int);
    ("islower", Char, Int); ("isprint", Char, Int); ("ispunct", Char, Int);
    ("isspace", Char, Int); ("isupper", Char, Int); ("isxdigit", Char, Int); ]
  in

  let built_in_decls =
    let add_bind map (name, fty) = StringMap.add name {
      functionType = fty;
      functionName = name;
      functionFormals = [];
      functionLocals = []; functionStatements = [] } map
    in List.fold_left add_bind built_in_decls [ ("rand", Int)]
  in

  let built_in_decls =
    let add_bind map (name, ty1, ty2, fty) = StringMap.add name {
      functionType = fty;
      functionName = name;
      functionFormals = [(ty1, "x"); (ty2, "x")];
      functionLocals = []; functionStatements = [] } map
    in List.fold_left add_bind built_in_decls [ ("fmod", Float, Float, Float);
       ("strcpy", String, String, String); ("strcat", String, String, String);
       ("strcmp", String, String, Int); ("strchr", String, String, String);
       ("strrchr", String, String, String); ("strspn", String, String, Int);
       ("strcspn", String, String, Int); ("strpbrk", String, String, String);
       ("strstr", String, String, String); ("strtok", String, String, String)]
  in

  let built_in_decls =
    let add_bind map (name, ty1, ty2, ty3, fty) = StringMap.add name {
      functionType = fty;
      functionName = name;
      functionFormals = [(ty1, "x"); (ty2, "x"); (ty3, "x")];
      functionLocals = []; functionStatements = [] } map
    in List.fold_left add_bind built_in_decls [
      ("strncpy", String, String, Int, String);
      ("strncat", String, String, Int, String);
      ("strncmp", String, String, Int, Int) ]
  in

  (* Add function name to symbol table *)
  let add_func map fd =
    let built_in_err = "function " ^ fd.functionName ^ " may not be defined"
    and dup_err = "duplicate function " ^ fd.functionName
    and make_err er = raise (Failure er)
    and n = fd.functionName (* Name of the function *)
    in match fd with (* No duplicate functions or redefinitions of built-ins *)
         _ when StringMap.mem n built_in_decls -> make_err built_in_err
       | _ when StringMap.mem n map -> make_err dup_err
       | _ ->  StringMap.add n fd map
  in

  (* Collect all function names into one symbol table *)
  let function_decls = List.fold_left add_func built_in_decls functions
  in

  (* Return a function from our symbol table *)
  let find_func s =
    try StringMap.find s function_decls
    with Not_found -> raise (Failure ("unrecognized function " ^ s))
  in

  let _ = find_func "main" in (* Ensure "main" is defined *)

  let check_function func =
    (* Make sure no formals or locals are void or duplicates *)
    check_binds "formal" func.functionFormals;
    check_binds "local" func.functionLocals;

    (* Raise an exception if the given rvalue type cannot be assigned to the given lvalue type *)
    let check_assign lvaluet rvaluet err =
       if lvaluet = rvaluet then lvaluet else raise (Failure err)
    in

    (* Build local symbol table of variables for this function *)
    let symbols = List.fold_left (fun m (ty, name) -> StringMap.add name ty m)
                StringMap.empty (globals @ func.functionFormals @ func.functionLocals )
    in

    (* Return a variable from our local symbol table *)
    let type_of_identifier s =
      try StringMap.find s symbols
      with Not_found -> raise (Failure ("undeclared identifier " ^ s))
    in

    (* validate Arrays  
    let validateArray lst el =
        let rec helper typ tlist = function
            [] -> (typ, tlist)
          | hd :: _ when typ <> fst (el hd) -> raise (Failure ("Type inconsistency with array"))
          | hd :: tl -> helper typ (el hd :: tlist) tl
        in
      helper (fst (el (List.hd lst))) [] lst 
    in*)
    
    (* Return a semantically-checked expression, i.e., with a type *)
    let rec casperExpression = function
        Epsilon       -> (Void, SEpsilon)
      | IntLIT i      -> (Int, SIntLIT i)
      | FltLIT f      -> (Float, SFltLIT f)
      | StrLIT s      -> (String, SStrLIT s)
      | ChrLIT c      -> (Char, SChrLIT c)
      | BoolLIT b     -> (Bool, SBoolLIT b)
      | VoidLIT       -> (Void, SVoidLIT)
      | NullLIT       -> (Void, SNullLIT)
      | Identifier s  -> (type_of_identifier s, SIdentifier s)
      | BinOP(e1, op, e2) as e ->
          let (t1, e1') = casperExpression e1
          and (t2, e2') = casperExpression e2 in
          (* All binary operators require operands of the same type *)
          let same = t1 = t2 in
          (* Determine expression type based on operator and operand types *)
          let ty = match op with
            Add | Sub | Mul | Div | Mod when same && t1 = Int -> Int
          | Add | Sub | Mul | Div | Exp | Mod when same && t1 = Float -> Float
          | Con when same && t1 = String -> String
          | Con when t1 = String && t2 = Int -> String
          | Con when t1 = String && t2 = Float -> String
          | Con when t1 = String && t2 = Bool -> String
          | Cat when t1 = String && t2 = Int -> String
          | Eql | Neq when same -> Bool
          | Grt | Gre | Lst | Lse when same && (t1 = Int || t1 = Float || t1 = String) -> Bool
          | And | Or when same && t1 = Bool -> Bool
          | _ -> raise (Failure ("illegal binary operator " ^ string_of_casperType t1 ^ " " ^ string_of_casperBinaryOperator op ^ " " ^ string_of_casperType t2 ^ " in " ^ string_of_casperExpression e))
          in (ty, SBinOP((t1, e1'), op, (t2, e2')))
      | UnrOP(op, e) as ex ->
          let (t, e') = casperExpression e in
          let ty = match op with
            Neg when t = Int || t = Float -> t
          | Not when t = Bool -> Bool
          | ItoF when t = Int -> Float
          | ItoF when t = Float -> Int
          | _ -> raise (Failure ("illegal unary operator " ^ string_of_casperUnaryOperator op ^ string_of_casperType t ^ " in " ^ string_of_casperExpression ex))
          in (ty, SUnrOP(op, (t, e')))
      | AsgnOP(var, op, e) as ex -> 
          let lt = type_of_identifier var 
          and (rt, e') = casperExpression e in
          let same = lt = rt in
          let ty = match op with
            ConAsgn when same && lt = String -> String
          | ConAsgn when lt = String && rt = Int -> String
          | ConAsgn when lt = String && rt = Float -> String
          | ConAsgn when lt = String && rt = Bool -> String   
          | Asgn when same -> lt
          | AddAsgn when same && (lt = Int || lt = Float) -> lt
          | SubAsgn when same && (lt = Int || lt = Float) -> lt          
          | _ -> raise (Failure ("illegal assignment " ^ string_of_casperType lt ^ string_of_casperAssignment op ^ string_of_casperType rt ^ " in " ^ string_of_casperExpression ex))
          in (ty, SAsgnOP(var, op, (rt, e')))
      | FunctionCall(fname, args) as call ->
          let fd = find_func fname in
          let param_length = List.length fd.functionFormals in
          if List.length args != param_length then
            raise (Failure ("expecting " ^ string_of_int param_length ^ " arguments in " ^ string_of_casperExpression call))
          else let check_call (ft, _) e =
            let (et, e') = casperExpression e in
            let err = "illegal argument found " ^ string_of_casperType et ^ " expected " ^ string_of_casperType ft ^ " in " ^ string_of_casperExpression e
            in (check_assign ft et err, e')
          in
          let args' = List.map2 check_call fd.functionFormals args
          in (fd.functionType, SFunctionCall(fname, args'))
    in

    let check_bool_expr e =
      let (t', e') = casperExpression e
      and err = "expected Boolean expression in " ^ string_of_casperExpression e
      in if t' != Bool then raise (Failure err) else (t', e')
    in

    (* Return a semantically-checked statement i.e. containing sexprs *)
    let rec check_stmt = function
        Expression e -> SExpression (casperExpression e)
      | IfCondition(p, b1, b2) -> SIfCondition(check_bool_expr p, check_stmt b1, check_stmt b2)
      | ForLoop(e1, e2, e3, st) -> SForLoop(casperExpression e1, check_bool_expr e2, casperExpression e3, check_stmt st)
      | WhileLoop(p, s) -> SWhileLoop(check_bool_expr p, check_stmt s)
      | DoUntilLoop(s, p) -> SDoUntilLoop(check_stmt s, check_bool_expr p)
      | DoWhileLoop(s, p) -> SDoWhileLoop(check_stmt s, check_bool_expr p)
      | Break -> SBreak
      | Continue -> SContinue
      | Return e -> let lt = func.functionType and (rt, e') = casperExpression e in
            let err = "return gives " ^ string_of_casperType rt ^ " expected " ^ string_of_casperType func.functionType ^ " in " ^ string_of_casperExpression e
            in (SReturn (check_assign lt rt err, e') )
      (*
        if t = func.functionType then SReturn (t, e')
        else raise (Failure ("return gives " ^ string_of_casperType t ^ " expected " ^ string_of_casperType func.functionType ^ " in " ^ string_of_casperExpression e))
      *)
        (* A block is correct if each statement is correct and nothing
        follows any Return statement.  Nested blocks are flattened. *)
      | StatementBlock sl ->
          let rec check_stmt_list = function
              [Break as s] -> [check_stmt s]
            | Break :: _   -> raise (Failure "nothing may follow a break")
            | [Continue as s] -> [check_stmt s]
            | Continue :: _   -> raise (Failure "nothing may follow a continue")
            | [Return _ as s] -> [check_stmt s]
            | Return _ :: _   -> raise (Failure "nothing may follow a return")
            | StatementBlock sl :: ss  -> check_stmt_list (sl @ ss) (* Flatten blocks *)
            | s :: ss         -> check_stmt s :: check_stmt_list ss
            | []              -> []
          in SStatementBlock(check_stmt_list sl)

    in (* body of check_function *)
    { sFunctionType = func.functionType;
      sFunctionName = func.functionName;
      sFunctionFormals = func.functionFormals;
      sFunctionLocals  = func.functionLocals;
      sFunctionStatements = match check_stmt (StatementBlock func.functionStatements) with
      SStatementBlock(sl) -> sl
      | _ -> raise (Failure ("internal error: block didn't become a block?"))
    }
  in (globals, List.map check_function functions)
