(* Semantic checking for the nodable compiler *)

(* Tasks: 
Takes an abstract syntax tree (Ast) and returns a semantically checked AST (Sast) with type checking and syntax checking. 
Goal - bind each token to a semantically checked expression and raise a failure if incorrect type
*)

open Ast 
open Sast

module StringMap = Map.Make(String)

let check (globals, functions) =

  (* Verify a list of bindings has no void types or duplicate names *)
  let check_binds (kind : string) (binds : bind list) =
    List.iter (function
	(Void, b) -> raise (Failure ("illegal void " ^ kind ^ " " ^ b))
      | _ -> ()) binds;
    let rec dups = function
        [] -> ()
      |	((_,n1) :: (_,n2) :: _) when n1 = n2 ->
	  raise (Failure ("duplicate " ^ kind ^ " " ^ n1))
      | _ :: t -> dups t
    in dups (List.sort (fun (_,a) (_,b) -> compare a b) binds)
  in
  
  check_binds "global" globals; 

  (* Collect function declarations for built-in functions: no bodies *)
  let built_in_decls = 
    let add_bind map (name, params, ret) = StringMap.add name {
      typ = ret;
      fname = name; 
      formals = params;
      body = [] } map
  in List.fold_left add_bind StringMap.empty [("print", [(Int, "x")], Void); 
                                              ("prints", [(String, "x")], Void); 
                                              ("printf", [(Float, "x")], Void); 
                                              ("printb", [(Bool, "x")], Void);
                                              ("size", [(List(Node Int), "x")], Int);
                                              ("append", [(List(Node Int), "x"); (Node Int, "y")], Void);
                                              ("update_elem", [(Node Int, "x"); (List(Node Int), "y"); (Int, "z")], List(Node Int));
                                              ("add_left", [(Node Int, "x"); (Node Int, "y")], Void);
                                              ("add_right", [(Node Int, "x"); (Node Int, "y")], Void);
                                              ("get_left", [(Node Int, "x")], Node Int);
                                              ("get_right", [(Node Int, "x")], Node Int)
                                              ]
in

  (* Add function name to symbol table *)
  let add_func map fd = 
    let built_in_err = "function " ^ fd.fname ^ " may not be defined"
    and dup_err = "duplicate function " ^ fd.fname
    and make_err er = raise (Failure er)
    and n = fd.fname (* Name of the function *)
    in match fd with (* No duplicate functions or redefinitions of built-ins *)
         _ when StringMap.mem n built_in_decls -> make_err built_in_err
       | _ when StringMap.mem n map -> make_err dup_err  
       | _ ->  StringMap.add n fd map 
  in

let function_decls = List.fold_left add_func built_in_decls functions 
in 

let find_func s = 
  try StringMap.find s function_decls
  with Not_found -> raise (Failure ("unrecognized function " ^ s))
in 

let _ = find_func "main" in (* Ensure "main" is defined *)

let check_function func =
  (* Make sure no formals or locals are void or duplicates *)
  check_binds "formal" func.formals;
  (*check_binds "local" func.locals; *)
  
    (* Raise an exception if the given rvalue type cannot be assigned to
       the given lvalue type *)
    let rec check_assign lvaluet rvaluet err =
      match (lvaluet, rvaluet) with
       (List a, List Any) -> List a
       | (Node _, Node a) -> Node a
      | (l, r) -> if l = r then l else raise (Failure err)
    in

    let rec declarations locals = function
        [] -> locals
      | hd::tl -> declarations (match hd with
           Declare (t, id, a) -> (t, id) :: locals
         | _ -> locals) tl
    in

    (* Build local symbol table of variables for this function *)
    let symbols = List.fold_left (fun m (ty, name) -> StringMap.add name ty m)
      StringMap.empty (globals @ func.formals @ (declarations [] func.body))
    in

  let type_of_identifier s = 
    try StringMap.find s symbols
    with Not_found -> raise (Failure ("undeclared identifier " ^ s))
  in 

  let expr_list lst expr =
    let rec helper typ tlist = function
        [] -> (typ, tlist)
      | hd :: _ when (match (typ, fst (expr hd)) with 
          (Node a, Node b) -> a <> b
          | (a, b) -> a <> b) ->
        raise (Failure ("Type inconsistency with list "))
      | hd :: tl -> helper typ (expr hd :: tlist) tl
    in
  let typ = match (List.length lst) with  
      0 -> Any
    | _ -> (fst (expr (List.hd lst))) in
  helper typ [] lst 
in

  let rec expr = function 
    Lit_Int l -> (Int, SLiteral l)
    | Lit_Flt l -> (Float, SFliteral l)
    | Lit_Bool l -> (Bool, SBoolLit l)
    | Lit_Str s -> (String, SStrLit s)
    | Lit_List l -> let (t, l) = expr_list l expr in (List t, SListLit l) 
    | Lit_Node n -> let (t, d) = expr n in (Node t, SNodeLit (t, d))
    | List_Access (l, e) -> 
      let (tl, _) as l' = expr l in 
      let (te, _) as e' = expr e in 
    if te != Int then raise (Failure ("list index must be an integer"))
    else (match tl with 
        List x -> (x, SListAccess (l', e')) 
      | _ -> raise (Failure ("not iterable")))
    | Noexpr     -> (Void, SNoexpr)
    | Id s -> (type_of_identifier s, SId s)
    | Attr (e, p) ->
      let (et, _) as e' = expr e in
      let pt  = (match (et, p) with 
      (Node t, "data") -> t
      | (Node t, "left") -> t
      | (Node t, "right") -> t
      | (_, _) -> raise (Failure ("no such property"))) in
      (pt, SAttr (e', p))
    | Assign(var, e) as ex -> 
    let lt = type_of_identifier var
    and (rt, e') = expr e in
    let err = "illegal assignment " ^ string_of_typ lt ^ " = " ^ 
      string_of_typ rt ^ " in " ^ string_of_expr ex
    in (check_assign lt rt err, SAssign(var, (rt, e')))
    | Unop(op, e) as ex -> 
    let (t, e') = expr e in
    let ty = match op with
      Neg when t = Int || t = Float -> t
    | Not when t = Bool -> Bool
    | _ -> raise (Failure ("illegal unary operator " ^ 
                           string_of_uop op ^ string_of_typ t ^
                           " in " ^ string_of_expr ex))
    in (ty, SUnop(op, (t, e')))
      | Binop(e1, op, e2) as e -> 
          let (t1, e1') = expr e1 
          and (t2, e2') = expr e2 in
          (* All binary operators require operands of the same type *)
          let same = t1 = t2 in
          (* Determine expression type based on operator and operand types *)
          let ty = match op with
            Add | Sub | Mult | Div | Mod when same && t1 = Int   -> Int
          | Add | Sub | Mult | Div | Mod when same && t1 = Float -> Float
          | Eq | Neq            when same               -> Bool
          | Less | Leq | Greater | Geq
                     when same && (t1 = Int || t1 = Float) -> Bool
          | And | Or when same && t1 = Bool -> Bool
          | _ -> raise (
	      Failure ("illegal binary operator " ^
                       string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
                       string_of_typ t2 ^ " in " ^ string_of_expr e))
          in (ty, SBinop((t1, e1'), op, (t2, e2')))
      | Call(fname, args) as call -> 
      let fd = find_func fname in
      let param_length = List.length fd.formals in
      if List.length args != param_length then
        raise (Failure ("expecting " ^ string_of_int param_length ^ 
                        " arguments in " ^ string_of_expr call))
      else let check_call (ft, _) e = 
        let (et, e') = expr e in 
        let err = "illegal argument found " ^ string_of_typ et ^
          " expected " ^ string_of_typ ft ^ " in " ^ string_of_expr e
        in (check_assign ft et err, e')
      in 
      let args' = List.map2 check_call fd.formals args
      in (fd.typ, SCall(fname, args'))
    in

    let check_bool_expr e = 
      let (t', e') = expr e
      and err = "expected Boolean expression in " ^ string_of_expr e
      in if t' != Bool then raise (Failure err) else (t', e')
    in 
    
    (* Return a semantically-checked statement i.e. containing sexprs *)
    let rec check_stmt = function
    Expr e -> SExpr (expr e)
    | If(p, b1, b2) -> SIf(check_bool_expr p, check_stmt b1, check_stmt b2)
    | For(e1, e2, e3, st) ->
      SFor(expr e1, check_bool_expr e2, expr e3, check_stmt st)
    | While(p, s) -> SWhile(check_bool_expr p, check_stmt s)
    | Declare (typ, id, asn)  -> 
      let a = expr asn in SDeclare(typ, id, a)
    | Return e -> let (t, e') = expr e in
      if t = func.typ then SReturn (t, e') 
        else raise 
        (Failure ("return gives " ^ string_of_typ t ^ " expected " ^ string_of_typ func.typ ^ " in " ^ string_of_expr e))

    (* A block is correct if each statement is correct and nothing
	       follows any Return statement.  Nested blocks are flattened. *)      
      | Block sl -> 
        let rec check_stmt_list = function 
        [Return _ as s] -> [check_stmt s]
        | Return _ :: _ -> raise (Failure "nothing may follow return")
        | Block sl :: ss -> check_stmt_list (sl @ ss)
        | s :: ss -> check_stmt s :: check_stmt_list ss
        | [] -> []
      in SBlock(check_stmt_list sl)
      
      in 
      { styp = func.typ;
      sfname = func.fname; 
      sformals = func.formals;
      sbody = match check_stmt (Block func.body) with
      SBlock(sl) -> sl
      | _ -> raise(Failure("internal error: block didn't become a block?"))
      } 
      in (globals, List.map check_function functions)





