(* Semantic checking for the MicroC compiler *)

open Ast
open Sast

module StringMap = Map.Make(String)

type environment = {
  vars: svdecl list;
  parent: environment option;
  in_loop: bool;
}

(* Builtin functions *)
let builtins = [
  {
    fname = "print";
    rtype = Void;
    args_list = [{
      v_type = String;
      v_name = "str";
      v_val = Noexpr;
      node_val = false;
    }];
    body = []; 
  };
  {
    fname = "print_int";
    rtype = Void;
    args_list = [{
      v_type = Int;
      v_name = "int";
      v_val = Noexpr;
      node_val = false;
    }];
    body = []; 
  };
  {
    fname = "print_float";
    rtype = Void;
    args_list = [{
      v_type = Float;
      v_name = "float";
      v_val = Noexpr;
      node_val = false;
    }];
    body = []; 
  };
]

let built_in_decls = 
  let add_bind map f = StringMap.add f.fname f map
  in List.fold_left add_bind StringMap.empty builtins

let is_node n = match n with
  | Node(_) -> true
  | _ -> false

let check_types s (t1, t2) e = 
match t2 with 
| Void -> t1
| _ ->
match t1 = t2 with
| true -> t1
| false -> raise (Failure ("type error in " ^ s 
  ^ ": expected type " ^ string_of_vtype t1
  ^ " but received type " ^ string_of_vtype t2 ^ " in expr " ^ string_of_sexpr e))

(* Verify a list of bindings has no void types or duplicate names *)
let check_binds (kind : string) (types : vdecl list) =
  let type_to_name = List.map (fun vdecl -> (vdecl.v_type, vdecl.v_name)) types
  in 
  List.iter (function
    (Void, name) -> raise (Failure ("illegal void " ^ kind ^ " " ^ name))
    | _ -> ()) type_to_name;
  let rec dups = function
      [] -> ()
    |	((_,n1) :: (_,n2) :: _) when n1 = n2 ->
        raise (Failure ("duplicate " ^ kind ^ " " ^ n1))
    | _ :: t -> dups t
  in dups (List.sort (fun (_,a) (_,b) -> compare a b) type_to_name)

(* Add function name to symbol table *)
let add_func map fd = 
  let built_in_err = "function " ^ fd.fname ^ " may not be defined"
  and dup_err = "duplicate function " ^ fd.fname
  and make_err er = raise (Failure er)
  and n = fd.fname (* Name of the function *)
  in match fd with (* No duplicate functions or redefinitions of built-ins *)
        _ when StringMap.mem n built_in_decls -> make_err built_in_err
      | _ when StringMap.mem n map -> make_err dup_err  
      | _ ->  StringMap.add n fd map

(* Look for function or var *)
let find_func funcs s = 
  try StringMap.find s funcs
  with Not_found -> raise (Failure ("could not find function " ^ s))

(* Check expressions *)

let check_function func_decls global_env func =
  (* Make sure no formals or locals are void or duplicates *)
  check_binds "formal" func.args_list;
  let formals = List.map (fun v -> {
    sv_type = v.v_type;
    sv_name = v.v_name;
    sv_val = SNoexpr;
    sv_node_val = false;
  }) func.args_list in

  let local_env = { vars = formals; parent = Some(global_env); in_loop = false } in

  let rec var_exists vars name = match vars with
  | var :: rest -> if var.sv_name = name then Some(var) else var_exists rest name
  | [] -> None in

  let rec find_var env_option name = match env_option with
    | None -> raise (Failure ("variable " ^ name ^ " is not in scope"))
    | Some(env) ->
    match var_exists env.vars name with
    | Some(var) -> var
    | None -> find_var env.parent name in

  let rec check_expr (e: Ast.expr) env = match e with
      Int_lit(l) -> SInt_lit(l), Int
    | Float_lit(l) -> SFloat_lit(l), Float
    | Bool_lit(l) -> SBool_lit(l), Bool
    | Char_lit(l) -> SChar_lit(l), Char
    | String_lit(l) -> SString_lit(l), String
    | Unop(op, e1) as e ->
      let (e1', t) = check_expr e1 env in
      let ty = match op with
        Neg when t = Int || t = Float -> t
      | Not when t = Bool -> Bool
      | _ -> raise (Failure ("illegal unary operator " ^ 
                            string_of_unop op ^ string_of_vtype t ^
                            " in " ^ string_of_expr e))
      in (SUnop(op, e1', ty), ty)
    | Binop(e1, op, e2) as e -> 
            let (e1', t1) = check_expr e1 env
            and (e2', t2) = check_expr e2 env in
            (* All binary operators require operands of the same type *)
            let same = t1 = t2 in
            (* Determine expression type based on operator and operand types *)
            let ty = match op with
            | Plus | Minus | Times | Divide | Mod when same && t1 = Int   -> Int
            | Plus | Minus | Times | Divide | Mod when same && t1 = Float -> Float
            | Plus | Minus | Times | Divide | Mod when same && t1 = Char ->  Char
            | Not_Equals | Equals                 when same               -> Bool
            | Greater | Greater_Eq | Less  | Less_Eq
                      when same && (t1 = Int || t1 = Float) -> Bool
            | And | Or when same && t1 = Bool -> Bool
            | Not_Equals | Equals when (is_node t1 && t2 = Void) || (is_node t2 && t1 = Void) -> Bool
            | _ -> raise (
          Failure ("illegal binary operator " ^
                        string_of_vtype t1 ^ " " ^ string_of_op op ^ " " ^
                        string_of_vtype t2 ^ " in " ^ string_of_expr e))
            in (SBinop(e1', op, e2', t1, t2, ty), ty)
    | Call(fname, args) as call ->
        let fd = find_func func_decls fname in
        let param_length = List.length fd.args_list in
        if List.length args != param_length then
              raise (Failure ("expecting " ^ string_of_int param_length ^ 
                              " arguments in " ^ string_of_expr call))
        else let check_call vdecl e = 
                let (e', et) = check_expr e env in 
                let _ = check_types "function call" (vdecl.v_type, et) e'
                in e'
              in 
              let args' = List.map2 check_call fd.args_list args
              in SCall(fname, args', fd.rtype), fd.rtype
    | Id(s) -> let sv = find_var (Some env) s in SId(s, sv.sv_type), sv.sv_type
    | Assign(s, e) -> let sv = find_var (Some env) s in let (e', et) = check_expr e env in
        let t = check_types "assignment" (sv.sv_type, et) e' in SAssign(s, e', t), t
    | Nodeop(node_op, s) -> let sv = find_var (Some env) s in (match sv.sv_type with 
        | Node(t) -> (match node_op with
            | Get_left_child | Get_right_child -> SNodeop(node_op, s, sv.sv_type), sv.sv_type
            | Dref -> SNodeop(node_op, s, t), t
          )
        | _ -> raise (Failure ("expected Node type for var " ^ s)))
    | Node_assign(s, e) -> let sv = find_var (Some env) s in let (e', et) = check_expr e env in 
        (match sv.sv_type with
          | Node(t1) -> let t = check_types "node assign" (t1, et) e' in SNode_assign(s, e', t), t
          | _ -> raise (Failure ("expected Node type for var " ^ s 
            ^ " for node assignment in expression " ^ string_of_sexpr e'))
        )
    | Noexpr -> SNoexpr, Void in
  let to_svdecl v env = let (e', et) = check_expr v.v_val env in match v.node_val with
    | true -> (match v.v_type with
      | Node(t1) -> let _ = check_types ("node assignment" ^ v.v_name) (t1, et) e' in 
        {
          sv_type = v.v_type;
          sv_name = v.v_name;
          sv_val = e';
          sv_node_val = true
        }
      | _ -> raise (Failure ("internal parsing error"))
      )
    | false -> let t = check_types ("variable declaration " ^ v.v_name) (v.v_type, et) e' in {
        sv_type = t;
        sv_name = v.v_name;
        sv_val = e';
        sv_node_val = false
      } in

  let add_id env v = 
    match v.v_type with
    | Void -> raise (Failure ("illegal void type for variable " ^ v.v_name))
    | _ ->
    match var_exists env.vars v.v_name with
    | Some(_) -> raise (Failure ("variable " ^ v.v_name ^ " declared previously in this scope"))
    | None -> 
    match v.node_val with 
    | false -> { env with vars = to_svdecl v env :: env.vars }
    | true ->
    match v.v_type with
    | Node(_) -> { env with vars = to_svdecl v env :: env.vars }
    | _ -> raise (Failure ("illegal declaration for variable " ^ v.v_name 
        ^ ": expected node type with node assignment operator")) in 
   
  let check_bool_expr e env = 
    let (e', t') = check_expr e env
      and err = "expected Boolean expression in " ^ string_of_expr e
    in if t' != Bool then raise (Failure err) else e' in

  let add_var s curr_env = match s with
    | Variable(v) -> add_id curr_env v
    | _ -> curr_env in 

  (* Return a semantically-checked statement i.e. containing sexprs *)
  let rec check_stmt (s: Ast.stmt) env = match s with
    | If(p, b1, b2) -> SIf(check_bool_expr p env, check_stmt b1 env, check_stmt b2 env)
    | For(e1, e2, e3, st) -> 
      let new_env = { vars = []; parent = Some(env); in_loop = true } in
      SFor(fst (check_expr e1 new_env), check_bool_expr e2 new_env, fst (check_expr e3 new_env), check_stmt st new_env)
    | While(p, s) -> 
      let new_env = { vars = []; parent = Some(env); in_loop = true } in 
      SWhile(check_bool_expr p new_env, check_stmt s new_env)
    | Return e -> let (e', t) = check_expr e env in
      if t = func.rtype then SReturn (e', t) 
      else raise (
        Failure ("return gives " ^ string_of_vtype t ^ " expected " ^
          string_of_vtype func.rtype ^ " in " ^ string_of_expr e))
    | Continue -> (match env.in_loop with
      | true -> SContinue
      | false -> raise (Failure "continue statement must occur within loop"))
    | Break -> (match env.in_loop with
      | true -> SBreak
      | false -> raise (Failure "break statement must occur within loop"))
    | Block sl -> 
        let new_env = { vars = []; parent = Some(env); in_loop = env.in_loop } in
        let rec check_stmt_list s curr_env : Sast.sstmt list = (match s with
          | [Return _ as s] -> [check_stmt s curr_env]
          | Break as s :: _ -> [check_stmt s curr_env]
          | Continue as s :: _ -> [check_stmt s curr_env]
          | Return _ :: _   -> raise (Failure "nothing may follow a return")
          | Block _ :: _  -> raise (Failure "cannot nest block inside a block")
          | s :: ss         -> (check_stmt s curr_env) :: check_stmt_list ss (add_var s curr_env)
          | []              -> [])
        in SBlock(check_stmt_list sl new_env)
    | Expr e -> let (sexpr, _) = check_expr e env in SExpr(sexpr)
    | Variable(v) -> (match v.v_name with 
      | "null" -> raise (Failure ("not allowed to create variable named null")) 
      | _ -> SVariable(to_svdecl v env))
    | Node_child(n, o, e) -> let (e', et) = check_expr e env in (match et with
      | Node(_) -> let sv = (find_var (Some env) n) in 
        let t = check_types "node child operation" (sv.sv_type, et) e' in SNode_child(n, o, e', t)
      | _ -> raise (Failure ("expected node type for expression " ^ string_of_sexpr e')))

  in (* body of check_function *)
  { srtype = func.rtype;
    sfname = func.fname;
    sargs_list  = List.map (fun x -> (x.v_type, x.v_name)) func.args_list;
    sbody = match (check_stmt (Block func.body) local_env) with
      SBlock(sl) -> sl
      | _ -> raise (Failure ("internal error: block didn't become a block?"))
  }


(* Semantic checking of the AST. Returns an SAST if successful,
   throws an exception if something is wrong.

   Check each global variable, then check each function *)

let check (globals, functions) =

  (**** Check global variables ****)

  check_binds "global" globals;

  let global_svdecls = List.map (fun vdecl -> {
    sv_type = vdecl.v_type;
    sv_name = vdecl.v_name;
    sv_val = SNoexpr;
    sv_node_val = false;
  }) globals
  in 
  
  let null_var = {
    sv_type = Void;
    sv_name = "null";
    sv_val = SNoexpr;
    sv_node_val = false;
  }
  in 
  
  let global_env = { vars = null_var::global_svdecls; parent = None; in_loop = false }
  in 
  (**** Check functions ****)

  (* Collect all function names into one symbol table *)
  let function_decls = List.fold_left add_func built_in_decls functions
  in

  let _ = find_func function_decls "main" 
  in
  (global_svdecls, List.map (check_function function_decls global_env ) functions)
