(* Semantic checking for the MicroC compiler *)

(*
system functions
prints (string a)
printf (float a)
print (int a)
str_cat(string a, string b)
*)

(* author: Joseph Isaac Baker *)

open Ast
open Sast

module StringMap = Map.Make(String)
module StringSet = Set.Make(String)
type func_id = {  func_name: string; 
                  func_params: typ list; }
module FuncMap = Map.Make(struct type t = func_id let compare = compare end)
module FuncSet = Set.Make(struct type t = func_id let compare = compare end)
type oper_id = {  oper_symbol: op;
                  oper_params: typ list;}
module OperMap = Map.Make(struct type t = oper_id let compare = compare end)

module Transforms = 
  struct
    let func_id_of_func_decl func_decl = { 
      func_name = func_decl.fname; 
      func_params = (List.map fst func_decl.fparams); }

    let string_of_func_id func_id = 
      let param_string params = String.concat "," (List.map string_of_typ params) in
      func_id.func_name ^ "(" ^ param_string func_id.func_params ^ ")"

    let oper_id_of_oper_decl oper_decl = { 
      oper_symbol = oper_decl.operator; 
      oper_params = (List.map fst oper_decl.opparams); }

    let func_string_of_oper_id oper_id = 
      let param_string params = String.concat "," (List.map string_of_typ params) in
      name_of_op oper_id.oper_symbol ^ "(" ^ param_string oper_id.oper_params ^ ")"

    let string_of_oper_id oper_id = 
      let binary_oper_string binop params = (string_of_typ (List.nth params 0)) ^ " " ^ string_of_op binop ^ " " ^ (string_of_typ (List.nth params 1)) in
      let pre_unary_oper_string unop param = string_of_uop unop ^ string_of_typ param in
      let post_unary_oper_string unop param = string_of_typ param ^ string_of_uop unop in
      match oper_id.oper_symbol with
          Bop(o) -> binary_oper_string o oper_id.oper_params
        | Uop(o) -> (match o with
                      | Neg | Not -> pre_unary_oper_string o (List.hd oper_id.oper_params)
                      | Tran      -> post_unary_oper_string o (List.hd oper_id.oper_params))

    let operator_overload_name op_symbol = "_" ^ (name_of_op op_symbol)

    let func_of_oper oper = {
      fdtype  = oper.opdtype;
      fname   = operator_overload_name oper.operator;
      fparams = oper.opparams;
      fbody   = oper.opbody; }
  end
open Transforms

module AppState = 
  struct
    type app_state = {  vars: typ StringMap.t; 
                        funcs: func_decl FuncMap.t; 
                        func_names: StringSet.t;
                        func_bodies_checked: FuncSet.t;
                        func_locals: typ StringMap.t FuncMap.t;
                        opers: oper_decl OperMap.t;
                        decl_allowed: bool;
                        local_context: func_id option; }
    let empty = { vars = StringMap.empty;
                  funcs = FuncMap.empty; 
                  func_names = StringSet.empty;
                  func_bodies_checked = FuncSet.empty;
                  func_locals = FuncMap.empty;
                  opers = OperMap.empty;
                  decl_allowed = true;
                  local_context = None; }
    (* Adds the given variable to the app state and returns the new app state *)
    let add_global_var prev_state var_id var_typ = { prev_state with vars = StringMap.add var_id var_typ prev_state.vars; }
    (* Adds a local variable for a function *)
    let add_local_for_func prev_state func_id var_id var_typ = 
      let new_locals_map = StringMap.add var_id var_typ (try FuncMap.find func_id prev_state.func_locals
                                                         with Not_found -> raise (Failure ("add_local_for_func: Can't find function " ^ string_of_func_id func_id))) in
      { prev_state with func_locals = FuncMap.add func_id new_locals_map prev_state.func_locals; }
    let get_local_var state func_id var_id =  StringMap.find var_id (FuncMap.find func_id state.func_locals)
    (* Adds variable to the proper scope depending on the prev app_state *)
    let add_var prev_state var_id var_typ = 
      match prev_state.local_context with
        | Some(f) -> add_local_for_func prev_state f var_id var_typ
        | None    -> add_global_var prev_state var_id var_typ
    (* Adds the given function to the app state and returns the new app state *)
    let add_func prev_state func_sig func_decl = { prev_state with 
      funcs       = FuncMap.add func_sig func_decl prev_state.funcs;
      func_names  = StringSet.add func_sig.func_name prev_state.func_names;
      func_locals = FuncMap.add func_sig StringMap.empty prev_state.func_locals; }
    (* Adds the given operator overload to the app state and returns the new app state *)
    let add_oper prev_state oper_id oper_decl = { prev_state with opers = OperMap.add oper_id oper_decl prev_state.opers; }
    (* Adds the given func name to the func names set and returns the new app state *)
    let add_func_name prev_state func_name = { prev_state with func_names = StringSet.add func_name prev_state.func_names; }
    (* Adds the given func sig to the bodies checked let and returns the new app state *)
    let add_func_body_checked prev_state func_sig = { prev_state with func_bodies_checked = FuncSet.add func_sig prev_state.func_bodies_checked; }
    (* Returns an appstate based on prevstate with declarations prohibited *)
    let prohibit_decl prev_state = { prev_state with decl_allowed = false; }
    let enable_decl prev_state = { prev_state with decl_allowed = true; }
    (* Set whether var declarations should be global or local *)
    let set_local_context prev_state func_id = { prev_state with local_context = func_id; }
    let print state = 
      let param_string params = String.concat "_" (List.map string_of_typ params) in
      let string_of_func_id f = f.func_name ^ "(" ^ param_string f.func_params ^ ")" in
      let func_string_of_oper_id o = name_of_op o.oper_symbol ^ "(" ^ param_string o.oper_params ^ ")" in
      let var_string = StringMap.fold (fun k v a -> a ^ string_of_typ v ^ " " ^ k ^ "; " ) state.vars "" in
      let local_context_string = 
        match state.local_context with
          | Some(f) -> string_of_func_id f
          | None    -> "None"
      in
      let func_string = FuncMap.fold (fun k _ a -> a ^ string_of_func_id k ^ "; ") state.funcs "" in
      let oper_string = OperMap.fold (fun k _ a -> a ^ func_string_of_oper_id k ^ "; ") state.opers "" in
      let locals_string = 
        let local_vars_string locals_map = StringMap.fold (fun k v a -> a ^ string_of_typ v ^ " " ^ k ^ ", ") locals_map "" in
      FuncMap.fold (fun k lcls a -> a ^ "\t" ^ string_of_func_id k ^ ": " ^ local_vars_string lcls ^ "\n") state.func_locals "" in
      let checked_string = FuncSet.fold (fun f a -> a ^ string_of_func_id f ^ "; " ) state.func_bodies_checked "" in

      "Vars: " ^ var_string ^ "\n" ^ "Funcs: " ^ func_string ^ "\n" ^ "local_context: " ^ local_context_string ^ "\n" ^
      "Declarations Allowed: " ^ string_of_bool state.decl_allowed ^ "\n" ^
      "Function Locals: \n" ^ locals_string  ^ "Checked Functions: " ^ checked_string ^ "\n" ^
      "Known Operators: \n" ^ oper_string
  end
open AppState

module ExceptionMessages = 
  struct
    (* Exception statement functions *)
    let except_void_var n               = "Illegal void variable " ^ n ^ "!" 
    let except_dupe_var n               = "Duplicate variable " ^ n ^ " for this scope!"
    let except_dupe_func n              = "Duplicate function " ^ n ^ "!"
    let except_dupe_oper n              = "Duplicate operator definition " ^ n ^ "!"
    let except_undeclared_var n         = "Use of undeclared variable " ^ n ^ "!"
    let except_undeclared_func n        = "Use of undeclared function " ^ n ^ "!"
    let except_unknown_operator n       = "Use of unknown operator " ^ n ^ "!"
    let except_declaration_prohibited n = "Declaration of variable " ^ n ^ " prohibited in this context!"
    let except_type_binary_unknown n    = "Don't recognize types for arithemetic operator " ^ n ^ "!"
    let except_type_unary_unknown n     = "Don't recognize types for unary operator " ^ n ^ "!"
    let except_type_req_type t n        = t ^ " type expected from " ^ n ^ "!"
    let except_type_mat_init t          = "Only number type allowed in matrix init block. " ^ t ^ " used!"
    let except_jagged_mat t             = "Matrix initialization cannot be jagged! " ^ t
    let except_stmt_after_return        = "No statements are allowed after return inside a block!"
  end
open ExceptionMessages

module TypeCh = 
  struct
    let check_types global_vars func_decls app_state = 
      let global_var_map = List.fold_left (fun map (t, i) -> StringMap.add i t map) StringMap.empty global_vars in
      let func_map = List.fold_left (fun map func_decl -> StringMap.add func_decl.cname func_decl map) StringMap.empty func_decls in

      (* Confirm no global, local, or parameter variables are marked as 'void' type *)
      let check_void_vars var_list = List.iter (fun (t, i) -> if t = VoidTyp then raise (Failure (except_void_var i))) var_list in
      check_void_vars global_vars;
      List.iter (fun func_decl -> check_void_vars func_decl.clocal_vars; check_void_vars func_decl.cparams ) func_decls;


      let confirm_typ t target_typ e =
        if t != target_typ then raise (Failure (except_type_req_type (string_of_typ target_typ) (string_of_expr e)))
      in
      (* Performs various type checks on the expression and returns the expression's type *)
      let rec check_expr_typ avail_vars = function
          Number(_,_,_)       ->  NumberTyp
        | Str(_)              ->  StringTyp
        | Noexpr              ->  VoidTyp
        | Id(i)               ->  (try StringMap.find i avail_vars with Not_found -> raise (Failure ("Can't find " ^ i)))
        (* Ensure all the expressions have the same type*)
        | MatInit(ell)        ->  check_expr_list_list_typ avail_vars ell NumberTyp;
                                  MatrixTyp
        (* Ensure both expressions are numbers *)
        | MatEmptyInit(e1,e2) ->  confirm_typ (check_expr_typ avail_vars e1) NumberTyp e1;
                                  confirm_typ (check_expr_typ avail_vars e2) NumberTyp e2;
                                  MatrixTyp;
        (* Ensure both expressions are numbers *)
        | MatAcc(_, e1, e2)   ->  confirm_typ (check_expr_typ avail_vars e1) NumberTyp e1;
                                  confirm_typ (check_expr_typ avail_vars e2) NumberTyp e2;
                                  NumberTyp;
        (* Simply getting their type determines if we recognize the types of the operands *)
        | Binop(e1,op,e2)     ->  let oper_id = { oper_symbol=Bop(op); 
                                                  oper_params=[ check_expr_typ avail_vars e1; 
                                                                check_expr_typ avail_vars e2]} in
                                  (try (OperMap.find oper_id app_state.opers).opdtype
                                   with Not_found -> raise (Failure ("check_expr_typ: Can't find operator " ^ string_of_oper_id oper_id)))
        (* Simply getting its type determines if we recognize the types of the operands *)
        | Unop(op, e)         ->  let oper_id = { oper_symbol=Uop(op); 
                                                  oper_params=[check_expr_typ avail_vars e;]} in
                                  (try (OperMap.find oper_id app_state.opers).opdtype
                                  with Not_found -> raise (Failure ("check_expr_typ: Can't find operator " ^ string_of_oper_id oper_id)))
        (* Confirm that the result of the assignment expression matches what is being assigned to *)
        | Assign(l, e)        ->  check_assignment avail_vars l e;
        (* Simply getting the function's type determines if the parameters are for a valid function *)
        | Func(s, es)         ->  ignore(List.map (check_expr_typ avail_vars) es);
                                  try (StringMap.find s func_map).cdtype
                                  with Not_found -> raise (Failure ("check_expr_typ: Can't find " ^ s))

      and check_expr_list_list_typ avail_vars ell target_typ= 
        let check_expr_list_typ el =
          List.iter (fun e -> confirm_typ (check_expr_typ avail_vars e) target_typ e) el
        in
        List.iter check_expr_list_typ ell
      and l_value_target_typ avail_vars = function
          VarDecl (t, _)      ->  t
        | IdAsn(i)            ->  (try StringMap.find i avail_vars 
                                   with Not_found -> raise (Failure ("check_assignment: Can't find " ^ i)))
        | MatCellAsn(_,e1,e2) ->  confirm_typ (check_expr_typ avail_vars e1) NumberTyp e1;
                                  confirm_typ (check_expr_typ avail_vars e2) NumberTyp e2;
                                  NumberTyp;
      and check_assignment avail_vars lv expr = 
        let target_typ = l_value_target_typ avail_vars lv in
        let expr_typ = check_expr_typ avail_vars expr in
        confirm_typ expr_typ target_typ expr; expr_typ
      in

      let confirm_return_block_end stmts = 
        let len = List.length stmts in
        let check i = function 
            Return(e) -> if (i != (len - 1)) then raise (Failure except_stmt_after_return)
          | _         -> ()
        in
        List.iteri check stmts
      in

      let rec check_stmt func_target_typ avail_vars = function
          Block(stmts)      ->  confirm_return_block_end stmts;
                                List.iter (check_stmt func_target_typ avail_vars) stmts
        | Expr(e)           ->  ignore(check_expr_typ avail_vars e)
        | Return(e)         ->  confirm_typ (check_expr_typ avail_vars e) func_target_typ e
        | If(e, s1, s2)     ->  confirm_typ (check_expr_typ avail_vars e) NumberTyp e;
                                check_stmt func_target_typ avail_vars s1; check_stmt func_target_typ avail_vars s2;
        | While(e, s)       ->  confirm_typ (check_expr_typ avail_vars e) NumberTyp e;
                                check_stmt func_target_typ avail_vars s;
        | For(lv,e,s,_)     ->  confirm_typ (l_value_target_typ avail_vars lv) NumberTyp (Assign(lv, Noexpr));
                                confirm_typ (check_expr_typ avail_vars e) MatrixTyp e;
                                check_stmt func_target_typ avail_vars s;
        | _                 -> ()
      in

      let check_function func_decl = 
        let avail_vars = List.fold_left (fun map (t, i) -> StringMap.add i t map) global_var_map func_decl.cparams in
        let avail_vars = List.fold_left (fun map (t, i) -> StringMap.add i t map) avail_vars func_decl.clocal_vars in
        check_stmt func_decl.cdtype avail_vars func_decl.cbody
      in

      List.iter check_function func_decls
  end

module DeclarationCh = 
  struct
    let add_system_func_decls app_state = 
      let stub_func name param_bindings return_typ = { fdtype = return_typ; fname = name; fparams = param_bindings; fbody = Block([]); } in
      let add_func prev_state (name, params, return_typ) = 
        let func_decl = stub_func name params return_typ in
        AppState.add_func prev_state (func_id_of_func_decl func_decl) func_decl  in
      let system_funcs = 
        [("main",  [], VoidTyp);
        ("print",  [ (NumberTyp, "n") ], VoidTyp); 
        ("print",  [ (StringTyp, "s") ], VoidTyp);
        ("printnl",  [ (NumberTyp, "n") ], VoidTyp); 
        ("printnl",  [ (StringTyp, "s") ], VoidTyp);
        ("shape",  [ (MatrixTyp, "m") ], MatrixTyp);
        ("len",    [ (MatrixTyp, "m") ], NumberTyp);
        ("strcat", [ (StringTyp, "s"); (StringTyp, "s2"); ], StringTyp); ]
      in List.fold_left add_func app_state system_funcs

    let add_system_oper_decls app_state = 
      let stub_oper op param_bindings return_typ = 
      { opdtype = return_typ; operator = op; opparams = param_bindings; opbody = Block([]); } in
      let add_oper prev_state (op, params, return_typ) = 
        let oper_decl = stub_oper op params return_typ in
        AppState.add_oper prev_state (oper_id_of_oper_decl oper_decl) oper_decl  in
      let system_opers = 
        [(Bop(Add),     [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Sub),     [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Mul),     [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Div),     [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Mod),     [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Equal),   [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Neq),     [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Less),    [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Leq),     [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Greater), [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Geq),     [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(And),     [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Bop(Or),      [(NumberTyp, "n"); NumberTyp, "n2"], NumberTyp);
         (Uop(Not),     [(NumberTyp, "n")], NumberTyp);
         (Uop(Neg),     [(NumberTyp, "n")], NumberTyp); ]
      in List.fold_left add_oper app_state system_opers

    let check_func_param_types app_state func_id = 
      if FuncMap.mem func_id app_state.funcs then app_state
      else raise (Failure (except_undeclared_func (string_of_func_id func_id)))

    (*  Focused search of the expression types search just for the types of the parameters
        of the function in question *)
    let rec func_id_of_name_params app_state func_name func_params = 
      let get_var_typ var_name = 
        let is_global_var = StringMap.mem var_name app_state.vars in
        let is_local_var = 
          match app_state.local_context with
            | Some(f) ->  (try StringMap.mem var_name (FuncMap.find f app_state.func_locals)
                           with Not_found -> raise (Failure ("get_var_typ: Can't find local context")))
            | None    -> false
        in
        if is_local_var then 
          ( let func_context = let (Some f) = app_state.local_context in f in
            try StringMap.find var_name ( try FuncMap.find func_context app_state.func_locals
                                          with Not_found -> raise (Failure ("get_var_typ: Can't find local context")))
            with Not_found -> raise (Failure ("get_var_typ: Can't find variable in local context")))
        else 
          if is_global_var then try StringMap.find var_name app_state.vars
                                with Not_found -> raise (Failure ("get_var_typ: Can't find global variable"))
          else raise (Failure (except_undeclared_var var_name))
      in
      let get_oper_typ oper params = 
        let oper_id = { oper_symbol=oper; oper_params=params } in
        if OperMap.mem oper_id app_state.opers
        then try (OperMap.find oper_id app_state.opers).opdtype
             with Not_found -> raise (Failure ("get_oper_typ: Can't find operator " ^ string_of_oper_id oper_id))
        else raise (Failure (except_unknown_operator (string_of_oper_id oper_id)))
      in
      let rec get_expr_typ = function
        (* cases not using a variable or function return simple type *)
          Number(_,_,_)       ->  NumberTyp
        | Str(_)              ->  StringTyp
        | Noexpr              ->  VoidTyp
        (* id has the type that it was declared as *)
        | Id(i)               ->  get_var_typ i
        (* recurse into the first expression and use its type as the inner type for the matrix *)
        | MatInit(ell)        ->  MatrixTyp
        (* empty matrices have an inner type of NumberTyp by default *)
        | MatEmptyInit(_,_)   ->  MatrixTyp
        (* matrix access has the same type as the inner type of the matrix *)
        | MatAcc(i, _, _)     ->  NumberTyp
        (* determine return type of binary op based on the types of the expressions *)
        | Binop(e1,op,e2)     ->  get_oper_typ (Bop(op)) [(get_expr_typ e1); (get_expr_typ e2)]
        (* determine return type of unary op based on the types of the expression *)
        | Unop(op, e)         ->  get_oper_typ (Uop(op)) [(get_expr_typ e)]
        (* assign returns the same type as the expression *)
        | Assign(_, e)        ->  get_expr_typ e
        (* func has the same type as the return type from the function declaration *) 
        | Func(s, es)         ->  let func_id = func_id_of_name_params app_state s es in
                                  (try (FuncMap.find func_id app_state.funcs).fdtype
                                   with Not_found -> raise (Failure ("get_expr_typ: Can't find function " ^ string_of_func_id func_id)))
      in
      { func_name=func_name; func_params=(List.map (get_expr_typ) func_params) }

    let func_id_of_symbol_params app_state oper_symbol oper_params =
      let possible_func_id = func_id_of_name_params app_state (operator_overload_name oper_symbol) oper_params in
      let oper_id = { oper_symbol=oper_symbol; oper_params=possible_func_id.func_params; } in
      let valid_operator = 
        if OperMap.mem oper_id app_state.opers then true
        else raise (Failure (except_unknown_operator (string_of_oper_id oper_id)))
      in
      if (valid_operator && FuncMap.mem possible_func_id app_state.funcs)
      then Some(possible_func_id)
      else None (* Is just a system function that has no corresponding overload function *)

    (*  Iterate over the global statements, as if we're evaluating them.
        Build a application declarations objects, ensuring that whenever we 
        try to use a variable, function, or oper overload, it has been 
        declared. *)  
    let check_declarations (global_statements) = 
      let confirm_variable_declared app_state var_name = 
        let is_global_var = StringMap.mem var_name app_state.vars in
        let is_local_var = 
          match app_state.local_context with
            | Some(f) -> StringMap.mem var_name (try FuncMap.find f app_state.func_locals
                                                 with Not_found -> raise (Failure ("confirm_variable_declared: Can't find local context")))
            | None    -> false in
        if (is_local_var || is_global_var) then app_state
        else raise (Failure (except_undeclared_var var_name))          
      in
      let confirm_func_declared app_state func_name = 
        if StringSet.mem func_name app_state.func_names then app_state
        else raise (Failure (except_undeclared_func func_name))
      in
      let try_add_var_decl app_state var_name var_typ = 
        if app_state.decl_allowed 
        then 
          let is_global_var = StringMap.mem var_name app_state.vars in
          let is_in_local_context = app_state.local_context != None in
          let is_local_var = 
            match app_state.local_context with
              | Some(f) -> StringMap.mem var_name (try FuncMap.find f app_state.func_locals
                                                   with Not_found -> raise (Failure ("try_add_var_decl: Can't find local context")))
              | None    -> false
          in
          let will_add_locally = is_in_local_context && not(is_local_var) in
          let will_add_globally = not(is_in_local_context) && not(is_global_var) in

          if will_add_locally || will_add_globally 
          then AppState.add_var app_state var_name var_typ
          else raise (Failure (except_declaration_prohibited var_name))
        else raise (Failure (except_declaration_prohibited var_name))
      in
      let try_add_func_decl app_state func_sig func_decl = 
        if not (FuncMap.mem func_sig app_state.funcs)
        then AppState.add_func app_state func_sig func_decl
        else raise (Failure (except_dupe_func (string_of_func_id func_sig)))
      in
      let try_add_oper_decl app_state oper_sig oper_decl = 
        if not (OperMap.mem oper_sig app_state.opers)
        then AppState.add_oper app_state oper_sig oper_decl
        else raise (Failure (except_dupe_oper (func_string_of_oper_id oper_sig)))
      in
      let handle_mat_init app_state ell = 
        let row_count = List.length ell in
        let col_count = List.length (List.hd ell) in
        let mat_is_rect = 
          let check_row r = if ((List.length r) != col_count)
                            then raise (Failure (except_jagged_mat (string_of_expr_list_list ell))) in
        List.iter check_row ell 
        in
        let fname = "_mat_init_" ^ string_of_int row_count ^ "_by_" ^ string_of_int col_count in
        let mat_id = "tmp_mat" in
        let flattened_el = List.fold_left (fun a l -> a @ l) [] ell in
        let el_types = 
          let param_types = (func_id_of_name_params app_state "" flattened_el).func_params in
          let confirm_typ t = if (t != NumberTyp) then raise (Failure (except_type_mat_init (string_of_typ t))) in
          List.iter confirm_typ param_types
        in
        let params = 
          let rec gen_list acc = function 0 -> acc | _ as l  -> gen_list (l::acc) (l - 1) in
          let len = row_count * col_count in
          List.map (fun e -> (NumberTyp, "c" ^ string_of_int e)) (gen_list [] len)
        in
        let assigns = List.mapi (fun i (_, id) -> Expr(Assign(
            MatCellAsn(mat_id, Number(IntTyp, (i + 1), 0.0), Number(IntTyp, 0, 0.0)), Id(id)))) params in
        let return = [Return(Id(mat_id))] in
        let body = Block([
          Expr(Assign(VarDecl((MatrixTyp), mat_id), MatEmptyInit(Number(IntTyp, row_count, 0.0), Number(IntTyp, col_count, 0.0))));
        ] @ assigns @ return) in
        let func_decl = { fdtype = MatrixTyp; fname = fname; fparams = params; fbody = body; } in
        let func_id = 
          let func_params = List.map (fun _ -> Number(IntTyp, 0, 0.0)) params in
          func_id_of_name_params app_state fname func_params in
        AppState.add_func app_state func_id func_decl 
      in

      let handle_for_loop app_state id = 
        let temp_var_name = "_tmp_for_i" ^ id in
        let mat_len_name = "_tmp_mat_len" ^ id in
        AppState.add_var (AppState.add_var app_state temp_var_name NumberTyp) mat_len_name NumberTyp
      in
      let rec check_expression app_state = function
        (* cases not using a variable or function *)
          Number(_,_,_) | Str(_) | Noexpr  -> app_state
        (* confirm variable is declared *)
        | Id(i)               ->  confirm_variable_declared app_state i
        (* recurse into expressions,
           because we turn off declaration for inside the expressions, return the previous state*)
        | MatInit(ell)        ->  ignore(check_expression_list_list (AppState.prohibit_decl app_state) ell);
                                  handle_mat_init app_state ell;
        (* recurse into expressions,
           because we turn off declaration for inside the expressions, return the previous state*)
        | MatEmptyInit(e1,e2) ->  ignore(check_expression (check_expression (AppState.prohibit_decl app_state) e1) e2);
                                  app_state;
        (* confirm variable is declared, recurse expressions,
           because we turn off declaration for inside the expressions, return the previous state *)
        | MatAcc(i, e1, e2)   ->  ignore(confirm_variable_declared app_state i);
                                  ignore(check_expression (check_expression (AppState.prohibit_decl app_state) e1) e2);
                                  app_state;
        (* recurse expressions, confirm that we have operator overload if appropriate *)
        | Binop(e1,op,e2)      ->  check_operator app_state (Bop(op)) [e1; e2]
        (* recurse expression, confirm that we have operator overload if appropriate *)
        | Unop(op, e)          ->  check_operator app_state (Uop(op)) [e]
        (* check lvalue, recurse expression *)
        | Assign(l, e)        ->  check_lvalue (check_expression app_state e) l
        (* confirm function is declared then check function body *) 
        | Func(s, es)         ->  ignore(check_expression_list app_state es);        
                                  ignore(confirm_func_declared app_state s);
                                  check_func_body_by_name app_state s es

      and check_expression_list app_state el = 
        List.fold_left check_expression app_state el 
      and check_expression_list_list app_state ell = 
        List.fold_left check_expression_list app_state ell

      and check_lvalue app_state = function
        (* add variable declaration *)
          VarDecl((t, i))       ->  try_add_var_decl app_state i t
        (* check variable declared *)
        | IdAsn(i)              ->  confirm_variable_declared app_state i
        (* check variable declared and recurse expressions *)
        | MatCellAsn(i, e1, e2) ->  ignore(confirm_variable_declared app_state i);
                                    ignore(check_expression (check_expression (AppState.prohibit_decl app_state) e1) e2);
                                    app_state;
      and check_operator app_state oper params = 
        let app_state = List.fold_left check_expression app_state params in
        let oper_func_id = func_id_of_symbol_params app_state oper params in
        match oper_func_id with
          | Some(func_id) -> (let func_decl = (try FuncMap.find func_id app_state.funcs 
                                               with Not_found -> raise (Failure ("check_operator: Can't find function " ^ string_of_func_id func_id))) in
                              check_func_body app_state func_id func_decl)
          | None          -> app_state

      (* Check the body statements of a function and confirm everything is declared before use *)
      and check_func_body_by_name app_state func_name func_params = 
        let func_id = func_id_of_name_params app_state func_name func_params in
        let func_decl = 
          ignore(check_func_param_types app_state func_id);
          (try FuncMap.find func_id app_state.funcs 
           with Not_found -> raise (Failure ("check_func_body_by_name: Can't find function " ^ string_of_func_id func_id))) in    
        check_func_body app_state func_id func_decl

      and check_func_body app_state func_id func_decl = 
        let add_params_to_state app_state params = 
          List.fold_left (fun a p -> AppState.add_var a (snd p) (fst p)) app_state params in

        if not (FuncSet.mem (func_id_of_func_decl func_decl) app_state.func_bodies_checked)
        then 
          ( let func_app_state = AppState.add_func_body_checked app_state func_id in
            let func_app_state = AppState.set_local_context func_app_state (Some(func_id)) in
            let func_app_state = add_params_to_state func_app_state func_decl.fparams in 
            let func_app_state = d_check_statement func_app_state func_decl.fbody in
            (* Restore the previous context *)
            AppState.set_local_context func_app_state app_state.local_context)
        else app_state

      (* debug check statement call, prints the application state after processing the passed in state *)
      and d_check_statement app_state s = check_statement app_state s

      and check_statement app_state = function
        (* check each statement in the block *)
          Block(stmts)      ->  List.fold_left d_check_statement app_state stmts
        (* check the expression *)
        | Expr(e)           ->  check_expression app_state e
        (* add the variable declared to the app_state *)
        | VDecl((t, i))     ->  try_add_var_decl app_state i t
        (* check the expression, since decl is prohibited inside the expression return the original state *)
        | Return(e)         ->  check_expression app_state e
        (* check the expression and both statements, 
           note decl is prohibited inside the expression but not inside the statement blocks *)
        | If(e, s1, s2)     ->  ignore(check_expression (AppState.prohibit_decl app_state) e);
                                d_check_statement (d_check_statement app_state s1) s2;
        (* check the expression and statement, 
           note decl is prohibited inside the expression but not inside the statement block *)
        | While(e, s)       ->  ignore(check_expression (AppState.prohibit_decl app_state) e);
                                d_check_statement app_state s;
        (* replace with final for syntax, will need to check variable declaration, matrix expression, and statement *)
        | For(lv,e,s,id)    ->  ignore(check_expression (AppState.prohibit_decl app_state) e);
                                d_check_statement (handle_for_loop (check_lvalue app_state lv) id) s;
      in

      let add_func_decl app_state func_decl = 
        try_add_func_decl app_state (func_id_of_func_decl func_decl) func_decl 
      in

      (* Add the operator to the operator map and also add it as a function entry *)
      let add_oper_decl app_state oper_decl = 
        let fix_unary_minus oper_decl = 
          if (oper_decl.operator = Bop(Sub) && (List.length oper_decl.opparams) = 1)
          then { oper_decl with operator=Uop(Neg); }
          else oper_decl
        in
        let oper_decl = fix_unary_minus oper_decl in
        let func_decl = (func_of_oper oper_decl) in
        let new_app_state = add_func_decl app_state func_decl in
        try_add_oper_decl new_app_state (oper_id_of_oper_decl oper_decl) oper_decl
      in

      let check_global_statement app_state = function
          Stmt(st)    ->  d_check_statement app_state st
        | FuncDecl(f) ->  add_func_decl app_state f
        | OperDecl(o) ->  add_oper_decl app_state o
      in 

      let check_remaining_funcs app_state = 
        let remaining_func_ids = FuncMap.filter (fun id _ -> not (FuncSet.mem id app_state.func_bodies_checked)) app_state.funcs in
        FuncMap.fold (fun func_id func_decl app_state -> check_func_body app_state func_id func_decl) remaining_func_ids app_state
      in

      let initial_app_state = (add_system_oper_decls 
                              (add_system_func_decls AppState.empty)) in
      (* Check variable and function declarations based on invocation ordering *)
      let checked_app_state = List.fold_left check_global_statement initial_app_state global_statements in
      (* Check any functions that were declared but not checked by previous step *)
      let checked_app_state = check_remaining_funcs checked_app_state in

      checked_app_state
  end

module SASTGenerator = 
  struct
    let generate_sast app_state global_statements = 
      let rec gen_sast_func_name func_context name exprs = 
        let func_app_state = AppState.set_local_context app_state (Some(func_context)) in
        let func_id = DeclarationCh.func_id_of_name_params func_app_state name exprs in 
        gen_sast_func_name_of_func_id func_id

      and gen_sast_func_name_of_func_id func_id = 
        let param_string = String.concat "_" (List.map string_of_typ func_id.func_params) in
        func_id.func_name ^ "_" ^ param_string
      in


      let gen_sast_oper func_context op_symbol exprs orig_exprs =
        let func_app_state = AppState.set_local_context app_state (Some(func_context)) in
        let func_id = DeclarationCh.func_id_of_symbol_params func_app_state op_symbol orig_exprs in 
        match func_id with
          | Some(f) ->  Func(gen_sast_func_name_of_func_id f, exprs)
          | None    ->  match op_symbol with
                          | Bop(op) ->  Binop((List.nth exprs 0), op, (List.nth exprs 1))
                          | Uop(op) ->  Unop(op, (List.hd exprs))       
      in
      let rec gen_sast_expr_list func_context el = List.map (gen_sast_expr func_context) el 

      and gen_sast_expr func_context = function 
          MatAcc(s,e1,e2)     ->  MatAcc(s, gen_sast_expr func_context e1, gen_sast_expr func_context e2)
        | MatInit(ell)        ->  gen_sast_mat_init func_context ell
        | MatEmptyInit(e1,e2) ->  MatEmptyInit(gen_sast_expr func_context e1, gen_sast_expr func_context e2)
        | Binop(e1,op,e2)     ->  gen_sast_oper func_context (Bop(op)) [gen_sast_expr func_context e1; gen_sast_expr func_context e2] [e1; e2]
        | Unop(op, e)         ->  gen_sast_oper func_context (Uop(op)) [gen_sast_expr func_context e] [e]
        | Assign(l, e)        ->  Assign(gen_sast_lvalue func_context l, gen_sast_expr func_context e)
        | Func(s, es)         ->  Func(gen_sast_func_name func_context s es, gen_sast_expr_list func_context es)
        | _ as e              ->  e

      and gen_sast_lvalue func_context = function
        (* No var decls in SAST because the declaration are all in global and local var lists *)
          VarDecl(_,i)          ->  IdAsn(i)
        | IdAsn(i)              ->  IdAsn(i)
        | MatCellAsn(s,e1,e2)   ->  MatCellAsn(s, gen_sast_expr func_context e1, gen_sast_expr func_context e2)

      and gen_sast_mat_init func_context ell = 
        let row_count = List.length ell in
        let col_count = List.length (List.hd ell) in
        let fname = "_mat_init_" ^ string_of_int row_count ^ "_by_" ^ string_of_int col_count in
        let params = 
          let rec gen_list acc = function 0 -> acc | _ as l  -> gen_list (l::acc) (l - 1) in
          let len = row_count * col_count in
          List.map (fun _ -> NumberTyp) (gen_list [] len)
        in
        let func_id = { func_name=fname; func_params=params } in
        let func_name = gen_sast_func_name_of_func_id func_id in
        let flattened_el = List.fold_left (fun a l -> a @ l) [] ell in
        Func(func_name, gen_sast_expr_list func_context flattened_el)
      in
      
      let rec gen_sast_stmt func_context = function
          Block(stmts)      ->  Block(List.map (gen_sast_stmt func_context) (remove_vdecl_stmts stmts))
        | Expr(e)           ->  Expr(gen_sast_expr func_context e)
        | Return(e)         ->  Return(gen_sast_expr func_context e)
        | If(e, s1, s2)     ->  If(gen_sast_expr func_context e, gen_sast_stmt func_context s1, gen_sast_stmt func_context s2)
        | While(e, s)       ->  While(gen_sast_expr func_context e, gen_sast_stmt func_context s)
        | For(lv,e,s,id)    ->  gen_sast_for_statement func_context lv e s id
        | _ as s            ->  s

      and remove_vdecl_stmts stmts = 
        let is_not_vdecl = function VDecl(_) -> false | _ -> true in
        List.filter is_not_vdecl stmts

      and gen_sast_for_statement func_context lvalue expr statement id = 
        let lvalue = gen_sast_lvalue func_context lvalue in
        let mat_id = let (Id s) = expr in s in
        let temp_var_name = "_tmp_for_i" ^ id in
        let mat_len_name = "_tmp_mat_len" ^ id in
        let num n = Number(IntTyp, n, 0.0) in
        let new_statement_block = 
          Block([
            Expr(Assign(VarDecl(NumberTyp, temp_var_name), num 1));
            Expr(Assign(VarDecl(NumberTyp, mat_len_name), Func("len", [expr])));
            While(Binop(Id(temp_var_name), Leq, Id(mat_len_name)),
              Block([
                Expr(Assign(lvalue, MatAcc(mat_id, Id(temp_var_name), num 0)));
                statement;
                Expr(Assign(IdAsn(temp_var_name), Binop(Id(temp_var_name), Add, num 1)));
              ]))
          ]) in
        gen_sast_stmt func_context new_statement_block
        

      in
      
      let remove_func_decls glbl_statements =
        (* count both FuncDecl and OperDecl as function declarations *)
        let is_not_fdecl = function Stmt(_) -> true | _ -> false in
        let get_stmt = function Stmt(s) -> s | _ -> Expr(Noexpr) in
        List.map get_stmt (List.filter is_not_fdecl glbl_statements)
      in

      let remove_params_from_locals locals params = 
        let is_not_param p = not(List.mem p params) in
        List.filter is_not_param locals
      in

      let var_bindings_of_stringmap map = 
        StringMap.fold (fun i t l -> (t, i)::l) map []
      in

      let main = { cdtype=VoidTyp; cname="main"; 
        cparams=[]; clocal_vars=[]; cbody=gen_sast_stmt {func_name="main"; func_params=[]} (Block(remove_func_decls global_statements)) }
      in

      let global_vars = 
        var_bindings_of_stringmap app_state.vars;
      in

      let func_decls = 
        let gen_sast_func_decl func_id func_decl local_vars =
          let fname = gen_sast_func_name_of_func_id func_id in
          { cdtype=func_decl.fdtype; 
            cname=fname; 
            cparams=func_decl.fparams; 
            clocal_vars=remove_params_from_locals (var_bindings_of_stringmap local_vars) func_decl.fparams;
            cbody=gen_sast_stmt func_id func_decl.fbody; }
        in
        let local_vars func_id = (try FuncMap.find func_id app_state.func_locals 
                                  with Not_found -> raise (Failure ("SAST func_decls: Can't find function " ^ string_of_func_id func_id))) in
        let process_func func_id func_decl l = (gen_sast_func_decl func_id func_decl (local_vars func_id))::l in
        FuncMap.fold process_func app_state.funcs []
      in

      { global_vars=global_vars; func_decls=main::func_decls }

    let clean_up_sast program = 
      let remove_sys_functions functions = 
        let is_not_sys_function f = 
          if  (f.cname = "main_") || 
              (f.cname = "print_number") || 
              (f.cname = "print_string") ||
              (f.cname = "printnl_string") ||
              (f.cname = "printnl_number") ||
              (f.cname = "strcat_string_string") ||
              (f.cname = "shape_matrix") ||
              (f.cname = "len_matrix") then false
          else true
        in
        List.filter is_not_sys_function functions
      in
      { program with func_decls=(remove_sys_functions program.func_decls); }

  end

(* check globals *)
let check (global_statements) =   
  let checked_app_state = DeclarationCh.check_declarations global_statements in
  let sast = SASTGenerator.generate_sast checked_app_state global_statements in
  TypeCh.check_types sast.global_vars sast.func_decls checked_app_state;
  SASTGenerator.clean_up_sast sast
