open Ast
open To_string

module StringMap = Map.Make(String)

(* Define more complete type classes than in Ast (with support for parametrized function types) *)
type scalar_type = 
  IntT
| FloatT;;

type twister_type =
  ScalarT of scalar_type
| CharT
| BoolT
| TupT of twister_type
| ListT of twister_type
| MatrixT of twister_type
| FunT of twister_type list * twister_type
| MatchT;;

(* Stores a map from locally defined names to their types at each level *)
type symbol_table = {
  parent: symbol_table option;
  vars: twister_type StringMap.t;
};;

(* Stores the symbol table and return type for a function body *)
type translation_environment = {
  scope: symbol_table;
  loc_ret_type: twister_type;
};;

(* Pretty printing of types for error messages *)
let string_of_scalar_type t = 
  match t with
  IntT -> "int"
| FloatT -> "float";;

let rec string_of_twister_type t = 
  match t with
  ScalarT(s) -> string_of_scalar_type s
| CharT -> "Char"
| BoolT -> "Bool"
| TupT(entry_type) -> "Tup<" ^ (string_of_twister_type entry_type) ^ ">"
| ListT(entry_type) -> "List<" ^ (string_of_twister_type entry_type) ^ ">"
| MatrixT(entry_type) -> "Matrix<" ^ (string_of_twister_type entry_type) ^ ">"
| FunT(arg_t_list, ret_t) -> "Fun (" ^ (To_string.string_join "," (List.map string_of_twister_type arg_t_list)) ^ 
  ") -> " ^ string_of_twister_type ret_t
| MatchT -> "Flex_Match_Type";;

(* Check if a name is in the symbol table *)
let rec var_is_defined (scope: symbol_table) name =
  if StringMap.mem name scope.vars then true
  else match scope.parent with
      Some(parent) -> var_is_defined parent name
    | _ -> false;;

(* Check if a name is in the symbol table with a function type *)
let rec fun_is_defined (scope: symbol_table) name =
  if StringMap.mem name scope.vars then
    begin
      let t = StringMap.find name scope.vars in
        match t with
          FunT(_,_) -> true
        | _ -> false
    end
  else match scope.parent with
      Some(parent) -> fun_is_defined parent name
    | _ -> false;;

(* Get the type of a variable from the symbol table *)
let rec type_of_var_name (scope: symbol_table) name = 
  if StringMap.mem name scope.vars then StringMap.find name scope.vars
  else match scope.parent with
      Some(parent) -> type_of_var_name parent name
    | _ -> let var_err = "Attempt to access to undeclared variable " ^ name in
          raise (Invalid_argument var_err);;

(* Get the type signature of a function from the symbol table *)
let rec type_of_fun_name (scope: symbol_table) name = 
  if StringMap.mem name scope.vars then
    begin
      let t = StringMap.find name scope.vars in
        match t with
          FunT(_,_) -> t
        | _ -> raise (Invalid_argument (name ^ " is not a function, but has type " ^ string_of_twister_type t))
    end
  else match scope.parent with
      Some(parent) -> type_of_fun_name parent name
    | _ -> let var_err = "Attempt to call undeclared function " ^ name in
          raise (Invalid_argument var_err);;

(* Check if all the elements in a list are identical *)
let all_equal l =
  match l with
  [] -> true
| h :: t -> List.for_all (fun(x) -> x = h) t;;

(* only allow scalar-type Tuples *)
let type_of_tuple types =
  match types with
  [] -> TupT(MatchT)
| h :: _ -> match h with 
  ScalarT(_) -> TupT(h)
  | _ -> let error_details = " invalid Tuple entry type " ^ string_of_twister_type h in
         raise (Invalid_argument error_details);;

(* Determine the type of a list literal *)
let type_of_list types =
  match types with
  [] -> ListT(MatchT)
| h :: _ -> ListT(h);;

(* Determine the type of a Matrix literal *)
let type_of_matrix types =
  match types with
  [] -> MatrixT(MatchT)
| h :: _ -> match h with
    ScalarT(_) -> MatrixT(h)
  | TupT(_) -> MatrixT(h)
  | _ -> let error_details = " invalid Matrix entry type " ^ string_of_twister_type h in
    raise (Invalid_argument error_details);;

(* Map scalar literal types from Ast to Semant datatypes *)
let type_of_ast_scalar s =
  match s with
  LitInt(_) -> IntT
| LitFloat(_) -> FloatT;;

(* Map complex datatypes from Ast to Semant datatypes *)
let rec type_of_ast_dtype t =
  match t with
  Int -> ScalarT(IntT)
| Float -> ScalarT(FloatT)
| Char -> CharT
| Bool -> BoolT
| String -> ListT(CharT)
| Tup(dtype) -> TupT(type_of_ast_dtype dtype)
| List(dtype) -> ListT(type_of_ast_dtype dtype)
| Matrix(dtype) -> MatrixT(type_of_ast_dtype dtype)
| Fun -> raise (Invalid_argument "Ast.Fun type in unexpected location");;

(* Determine the type of a function argument *)
let rec type_of_fun_arg a = 
  match a with
  ArgId(_, dtype) -> type_of_ast_dtype dtype
(* Not implemented
| ArgFunId(_, funtype) -> type_of_fun_type funtype
*)

(* Determine the type of a function return type *)
and type_of_fun_return t = 
  match t with
(*
  ReturnFun(f) -> type_of_fun_type f
*)
| ReturnData(d) -> type_of_ast_dtype d

and return_type_of_fun_type t = 
  match t with
  FunType(_, fun_ret_type) -> type_of_fun_return fun_ret_type

(* Determine the type signature of a function from its declaration *)
and type_of_fun_type t = 
  match t with
  FunType(args, fun_ret_type) -> FunT(List.map type_of_fun_arg args, type_of_fun_return fun_ret_type);;

(* Determine the type signature of an anonymous function from its in-line declaration
   or of a named function from a call to it *)
let type_of_func_item (scope: symbol_table) f =
  match f with
  FunId(x) -> type_of_fun_name scope x;;
(*
| FunItem(fun_type, _) -> type_of_fun_type fun_type;;
*)

(* Check if a type is an integer type *)
let is_type_int t = 
  match t with
  ScalarT(s) -> (s = IntT)
| _ -> false;;

(* Does a function-style Matrix constructor call use a valid function type? *)
let matrix_fun_args_ok arg_types = 
  match arg_types with
  _ :: _ :: [] -> List.for_all is_type_int arg_types
| _ -> false;;

(* Determine the type of a Matrix created using the function-style constructor *)
let matrix_construct_type (scope: symbol_table) f = 
  let ftype = type_of_func_item scope f in
  match ftype with
    FunT(arg_types, fun_ret_type) -> if matrix_fun_args_ok arg_types then fun_ret_type
      else raise (Invalid_argument
        ("Non-scalar matrix constructor function arguments: " ^ To_string.string_of_func_item f))
  | _ -> raise (Invalid_argument
      ("Bad matrix constructor function type: " ^ To_string.string_of_func_item f));;

(* Determine the type of the output of a unary operation *)
let rec type_of_unop t op expr =  
  match t with
  ScalarT(_) -> t
| MatrixT(_) -> t
| TupT(e) -> TupT(type_of_unop e op expr)
| _ -> let op_string = To_string.string_of_unop op in
       let type_string = string_of_twister_type t in 
       let error_details = "Unary operator " ^ op_string ^ " cannot be applied to type " ^ type_string ^ 
          " but expression " ^ To_string.string_of_expr expr ^ " has this type." in
       raise (Invalid_argument error_details);;

(* Determine the type of the output of a binary operation on scalar types *)
let scalar_binop_type op t1 t2 =
  if List.mem op (And :: Or :: Xor :: []) then
  begin
    let error_details = "Cannot use boolean operator " ^ To_string.string_of_binop op ^ " on scalars." in
    raise (Invalid_argument error_details)
  end
  else if List.mem op (Eq :: Neq :: Geq :: Leq :: Lt :: Gt :: []) then BoolT else
  if not (t1 = t2) then
  begin
    let error_details = "Cannot use operator " ^ To_string.string_of_binop op ^ " on mismatched types " ^
    string_of_twister_type t1 ^ " and " ^ string_of_twister_type t2 in
    raise (Invalid_argument error_details)
  end
  else t1

(* Generate an error message for invalid binary operations found in expressions *)
let complain_binop t1 t2 op e1 e2 =
  let op_err = "Cannot perform operation " ^ To_string.string_of_binop op ^ " on values " in
    let args = To_string.string_of_expr e1 ^ " and " ^ To_string.string_of_expr e2 in
    let types = " of types " ^ string_of_twister_type t1 ^ " and " ^ string_of_twister_type t2 in
    let error_details = op_err ^ args ^ types in
    raise (Invalid_argument  error_details);;

(* If a non-boolean op was used on bools, error; else report boolean type return value *)
let bool_binop_type op =
  if not (List.mem op (Eq :: Neq :: And :: Or :: Xor :: [])) then
  begin
    let error_details = "Cannot use non-boolean operator " ^ To_string.string_of_binop op ^ " on booleans." in
    raise (Invalid_argument  error_details)
  end
else BoolT;;

(* Map elementwise operations to their regular counterparts *)
(* Not implemented 
let elop_map op =
  match op with
  ElAdd -> Add
| ElSub -> Sub
| ElMul -> Mul
| ElDiv -> Div
| ElMod -> Mod
| _ -> raise (Invalid_argument ("Operator " ^ To_string.string_of_binop op ^ " is not elementwise"));;
*)

(* Determine the output type of a binary operation on two expressions *)
let rec type_of_binop t1 t2 op e1 e2 =
  (*)
  let elop = List.mem op (ElAdd :: ElSub :: ElMul :: ElDiv :: ElMod :: []) in
  let op = (if elop then 
    begin
      match t1, t2 with
      | MatrixT(_), MatrixT(_) -> elop_map op
      | _, _ -> complain_binop t1 t2 op e1 e2
    end
    else op)
  in
*)
  match t1, t2 with
  ScalarT(_), ScalarT(_) -> scalar_binop_type op t1 t2
(* Complex types not supported
| MatrixT(s1), MatrixT(s2) -> if not (elop || (op = Mul)) then complain_binop t1 t2 op e1 e2
    else 
    if s1 = s2 then MatrixT(type_of_binop s1 s2 op e1 e2) else
    (match s1, s2 with
      | ScalarT(_), ScalarT(_) -> MatrixT(scalar_binop_type op s1 s2)
      | TupT(s3), TupT(s4) -> MatrixT(TupT(type_of_binop s3 s4 op e1 e2))
      | _ -> complain_binop t1 t2 op e1 e2
| MatrixT(s1), TupT(s2) -> (match s1 with
  | TupT(_) -> MatrixT(TupT((type_of_binop s1 s2 op e1 e2)))
  | _ -> complain_binop t1 t2 op e1 e2
)
| TupT(_), MatrixT(_) -> type_of_binop t2 t1 op e1 e2
| MatrixT(s1), s2 -> MatrixT(type_of_binop s1 s2 op e1 e2)
| s1, MatrixT(s2) -> MatrixT(type_of_binop s1 s2 op e1 e2)
| TupT(s1), TupT(s2) -> TupT(type_of_binop s1 s2 op e1 e2)
| TupT(s1), s2 -> TupT(type_of_binop s1 s2 op e1 e2)
| s1, TupT(s2) -> TupT(type_of_binop s1 s2 op e1 e2)
| ListT(CharT), ListT(CharT) -> if op = Add then ListT(CharT) else complain_binop t1 t2 op e1 e2
*)
| BoolT, BoolT -> bool_binop_type op
| _, _ -> complain_binop t1 t2 op e1 e2;;

(* Determine the type of an attribute such as Matrix.rows or Matrix.cols; complain about invalid
   attribute access *)
let type_of_attr scope obj_name attr_name =
  if not (var_is_defined scope obj_name) then 
    begin 
      let error_details = "Cannot access attribute " ^ attr_name ^ " of object " ^ obj_name ^ " because " ^ obj_name ^
        " is not declared" in
        raise (Invalid_argument error_details)
    end
  else let obj_type = type_of_var_name scope obj_name in
  match obj_type with
  MatrixT(_) -> if List.mem attr_name ["num_rows"; "num_cols"] then ScalarT(IntT) else
    begin
      let error_details = "Cannot access attribute " ^ attr_name ^ " of object " ^ obj_name ^ " because " ^ attr_name ^
        " is not a valid attribute" in
        raise (Invalid_argument error_details)
    end
| ListT(_) -> if attr_name = "len" then ScalarT(IntT) else
    let error_details = "Cannot access attribute " ^ attr_name ^ " of object " ^ obj_name ^ " because " ^ attr_name ^
      " is not a valid attribute" in
      raise (Invalid_argument error_details)
| TupT(_) -> if attr_name = "len" then ScalarT(IntT) else
    let error_details = "Cannot access attribute " ^ attr_name ^ " of object " ^ obj_name ^ " because " ^ attr_name ^
      " is not a valid attribute" in
      raise (Invalid_argument error_details)
| _ -> let type_str = string_of_twister_type obj_type in
        let error_details = "Cannot access attribute " ^ attr_name ^ " of object " ^ obj_name ^ " because " ^ obj_name ^
        " has type " ^ type_str ^ " which has no attributes." in
        raise (Invalid_argument error_details);;

(* Get first n elts of list *)
let rec first n elts =
  match elts with
  [] -> []
| h :: t -> if n = 0 then [] else h :: first (n-1) t

(* Drop first n elts of list *)
let rec drop_first n elts =
  match elts with
  [] -> []
| _ :: t -> if n = 0 then elts else drop_first (n-1) t

(* Determine the output type of a partial function call (with Currying) *)
let partial_call_type args_tail ret_type =
  match args_tail with
  [] -> ret_type
| _ -> FunT(args_tail, ret_type)

(* Determine the type of the elements contained in a complex type with multiple elements *)
let elem_type complex_type =
  match complex_type with
  ListT(etype) -> etype
| MatrixT(etype) -> etype
| TupT(etype) -> etype
| _ -> let type_str = string_of_twister_type complex_type in 
  let error_details = "Invalid complex type for element type unpacking: " ^ type_str in
  raise (Invalid_argument error_details);;

let type_list_mismatch_err types expr = 
  let types_str = "[" ^ To_string.string_join "," (List.map string_of_twister_type types) ^ "]" in
  let error_details = "Inconsistent element types " ^ types_str ^ " in: " ^ To_string.string_of_expr expr in
  raise (Invalid_argument error_details);;

(* Check matrix index literal format *)
let rec check_index_lit (scope: symbol_table) index_lit =
  let t = type_of_expr scope index_lit in
    if t = ScalarT(IntT) then true else
    begin
      let t_str = string_of_twister_type t in
      let type_desc = "Expression " ^ (To_string.string_of_expr index_lit) ^ " has type " ^ t_str in
      let error_details = type_desc ^ " and cannot be used as an index." in
      raise (Invalid_argument error_details)
    end

(* Check if matrix access is a slice or an index; also validate format *)
and is_slice (scope: symbol_table) index = 
  match index with
  MatIndex(e) -> let _ = check_index_lit scope e in false
| MatSlice(e_1, e_2) -> let _ = (check_index_lit scope e_1) && (check_index_lit scope e_2) in true

(* Determine type of return value of vector access *)
and type_of_vec_acc (scope: symbol_table) id index =
  let slice = is_slice scope index in
  let id_type = type_of_var_name scope id in
  match id_type with
| ListT(etype) -> if slice then id_type else etype
| TupT(etype) -> if slice then id_type else etype
| _ -> let act_str = "Variable " ^ id ^ " is of type " ^ (string_of_twister_type id_type) in
      let error_details = act_str ^ " and cannot be accessed with one index." in
      raise (Invalid_argument error_details)

(* Determine type of return value of matrix access *)
and type_of_mat_acc (scope: symbol_table) id index_1 index_2 =
  let row_slice = is_slice scope index_1 in
  let col_slice = is_slice scope index_2 in
  let id_type = type_of_var_name scope id in
  match id_type with
  MatrixT(etype) -> if row_slice || col_slice then id_type else etype
| _ -> let act_str = "Variable " ^ id ^ " is of type " ^ (string_of_twister_type id_type) in
      let error_details = act_str ^ ", not of type Matrix<element_type>, and cannot be indexed into." in
      raise (Invalid_argument error_details)

and reduce_call_type (scope: symbol_table) args = 
  match args with
  [] -> raise (Invalid_argument "Built-in function reduce called without arguments.")
| _ :: [] -> raise (Invalid_argument "Built-in function reduce called with only one argument.")
| red_fun :: red_arg :: [] -> let arg_type = type_of_expr scope red_arg in
  let red_fun_type = type_of_expr scope red_fun in
  let red_fun_name = 
    (match red_fun with
      Id(name) -> name
    | _ -> To_string.string_of_expr red_fun
    ) in 
    (match arg_type with
    ListT(e_type) -> check_fun_call_type red_fun_name red_fun_type (e_type :: [])
  | MatrixT(e_type) -> check_fun_call_type red_fun_name red_fun_type (e_type :: [])
  | TupT(e_type) -> check_fun_call_type red_fun_name red_fun_type (e_type :: [])
  | _ -> raise (Invalid_argument ("Reduce cannot be applied to type " ^ string_of_twister_type arg_type))
  )
| _ -> raise (Invalid_argument "Built-in function reduce called with more than two arguments.")

and map_call_type (scope: symbol_table) args = 
  match args with
  [] -> raise (Invalid_argument "Built-in function map called without arguments.")
| _ :: [] -> raise (Invalid_argument "Built-in function map called with only one argument.")
| map_fun :: map_arg :: [] -> let arg_type = type_of_expr scope map_arg in
  let map_fun_type = type_of_expr scope map_fun in
  let map_fun_name = 
    (match map_fun with
      Id(name) -> name
    | _ -> To_string.string_of_expr map_fun
    ) in 
    (match arg_type with
    ListT(e_type) -> ListT(check_fun_call_type map_fun_name map_fun_type (e_type :: []))
  | MatrixT(e_type) -> MatrixT(check_fun_call_type map_fun_name map_fun_type (e_type :: []))
  | TupT(e_type) -> TupT(check_fun_call_type map_fun_name map_fun_type (e_type :: []))
  | _ -> raise (Invalid_argument ("Map cannot be applied to type " ^ string_of_twister_type arg_type))
  )
| _ -> raise (Invalid_argument "Built-in function map called with more than two arguments.")
      
(* Determine the output type of function call *)
and call_type (scope: symbol_table) fun_name args = 
  match fun_name with
  "reduce" -> reduce_call_type scope args
| "map" -> map_call_type scope args
| _ -> let arg_types = List.map (type_of_expr scope) args in
  let ftype = type_of_fun_name scope fun_name in
  check_fun_call_type fun_name ftype arg_types

and check_fun_call_type fun_name ftype arg_types =
  match ftype with
  FunT(arg_t_list, ret_t) -> (* Currying not implemented
    let arg_count = List.length arg_types in
    let types_head = first arg_count arg_t_list in
    let types_tail = drop_first arg_count arg_t_list in
    if types_head = arg_types then partial_call_type types_tail ret_t else
    *)
    if arg_t_list = arg_types then ret_t else
    begin
      let expected_type = To_string.string_join "," (List.map string_of_twister_type arg_t_list) in
      let received_type = To_string.string_join "," (List.map string_of_twister_type arg_types) in
      let error_details = " Invalid argument types (" ^ received_type ^ ") for " ^ fun_name ^ 
        " which expects (" ^ expected_type ^ ")" in
      raise (Invalid_argument error_details)
    end
| _ -> let error_details = "Unexpected function type " ^ string_of_twister_type ftype ^ " for " ^ fun_name in
  raise (Invalid_argument error_details)
(* Determine the type of the value produced after evaluating an expression *)
and type_of_expr (scope: symbol_table) expr =
  match expr with
  NumLit(s) -> ScalarT(type_of_ast_scalar s)
| CharLit(_) -> CharT
| StringLit(_) -> ListT(CharT)
| BoolLit(_) -> BoolT
| TupLit(e_list) -> let types = List.map (type_of_expr scope) e_list in
  if all_equal types then type_of_tuple types else type_list_mismatch_err types expr
| ListLit(e_list) -> let types = List.map (type_of_expr scope) e_list in
  if all_equal types then type_of_list types else type_list_mismatch_err types expr
| FunLit(fun_type, _) -> type_of_fun_type fun_type
| MatrixLit(entries) -> let types = List.map (type_of_expr scope) (List.flatten entries) in
  if all_equal types then type_of_matrix types else type_list_mismatch_err types expr
| MatrixInit(num_rows, num_cols) -> let (t1, t2) = (type_of_expr scope num_rows, type_of_expr scope num_cols) in
  if is_type_int t1 && is_type_int t2
    then MatrixT(ScalarT(IntT)) else
      begin
        let error_details = "Non-integer matrix dimensions " ^ To_string.string_of_expr num_rows ^
          " and " ^ To_string.string_of_expr num_cols ^ " of types " ^ string_of_twister_type t1 ^
          " and " ^ string_of_twister_type t2
        in raise (Invalid_argument error_details)
      end
(* Not implemented
| MatrixFunDef(num_rows, num_cols, func_item) -> let (t1, t2) = (type_of_expr scope num_rows, type_of_expr scope num_cols) in
  if is_type_int t1 && is_type_int t2
    then matrix_construct_type scope func_item else
      begin
        let error_details = "Non-integer matrix dimensions " ^ To_string.string_of_expr num_rows ^
          " and " ^ To_string.string_of_expr num_cols ^ " of types " ^ string_of_twister_type t1 ^
          " and " ^ string_of_twister_type t2
        in raise (Invalid_argument error_details)
      end
*)
| BinOp(expr_1, expr_2, binop) -> let t1 = type_of_expr scope expr_1 in
  let t2 = type_of_expr scope expr_2 in type_of_binop t1 t2 binop expr_1 expr_2
| UnOp(expr, unop) -> let expr_type = type_of_expr scope expr in type_of_unop expr_type unop expr
| Id(name) -> type_of_var_name scope name
| Attribute(obj_name, attr_name) -> type_of_attr scope obj_name attr_name
| Call(fun_name, args) -> call_type scope fun_name args
(* Not implemented
| Pipe(expr_1, expr_2) -> let _ = type_of_expr scope expr_1 in type_of_expr scope expr_2
*)
| MatAcc(id, index_1, index_2) -> type_of_mat_acc scope id index_1 index_2
| VecAcc(id, index) -> type_of_vec_acc scope id index;;

(* Determine the type of an iterable used in a loop header *)
let type_of_iterable (scope: symbol_table) it =
  match it with
  ItId(name) -> let n_type = type_of_var_name scope name in elem_type n_type
| ItListLit(e_list) -> let ltype = type_of_expr scope (ListLit(e_list)) in elem_type ltype
| ItMatrixLit(e_list) -> let mtype = type_of_expr scope (MatrixLit(e_list)) in elem_type mtype
(* Not implemented
| ItMatrixFunDef(width, height, func_item) -> let mtype = type_of_expr scope (MatrixFunDef(width, height, func_item)) in
    elem_type mtype
*)
| ItCall(fun_name, args) -> let ctype = type_of_expr scope (Call(fun_name, args)) in elem_type ctype
| ItAttribute(id, attr) -> let t = type_of_var_name scope id in
    match t with
    MatrixT(_) -> if (attr = "rows" || attr = "cols") then ScalarT(IntT) else
        raise (Invalid_argument ("Attempt to access invalid iterable attribute " ^ attr))
  | _ -> let type_str = string_of_twister_type t in 
        let error_details = "Invalid attempt to access iterable " ^ attr ^ " on type " ^ type_str in
        raise (Invalid_argument error_details);;

(* "Declare" a variable by mapping its name to the appropriate type in the current symbol table *)
let declare (scope: symbol_table) dtype d_name = 
  {parent = scope.parent; vars = StringMap.add d_name dtype scope.vars};;

let match_types type_1 type_2 = 
  match type_1, type_2 with
  TupT(t1), TupT(t2) -> t1 = t2 || t2 = MatchT
| ListT(t1), ListT(t2) -> t1 = t2 || t2 = MatchT
| MatrixT(t1), MatrixT(t2) -> t1 = t2 || t2 = MatchT
| t1, t2 -> t1 = t2;;

(* Check a statement assigning a new value to an already defined variable name *)
let assign_val (env: translation_environment) assign_rec =
  let name = assign_rec.assign_name and value = assign_rec.new_val in
  let val_type = type_of_expr env.scope value in
    if var_is_defined env.scope name then
      begin
        let name_type = type_of_var_name env.scope name in
        if match_types name_type val_type then env else
          begin
            let name_type_str = string_of_twister_type name_type in
            let val_type_str = string_of_twister_type val_type in
            let val_str = To_string.string_of_expr value in
            let error_details = "Value: ''\n" ^ val_str ^ "\n''\n of type " ^ val_type_str ^ 
              " cannot be assigned to name " ^ name ^ " of type " ^ name_type_str in
              raise (Invalid_argument error_details)
          end
      end
    else
      begin
        let val_type_str = string_of_twister_type val_type in
        let val_str = To_string.string_of_expr value in
        let error_details = "Value: ''\n" ^ val_str ^ "\n''\n of type " ^ val_type_str ^ 
          " cannot be assigned to name " ^ name ^ " which has no declared type." in
          raise (Invalid_argument error_details)
      end;;

(* Check a statement assigning a new value to an element of an already defined vector type *)
let vec_assign_val (env: translation_environment) assign_rec =
  let name = assign_rec.vec_name and value = assign_rec.vec_el_val and index = assign_rec.index in
  let val_type = type_of_expr env.scope value in
    if var_is_defined env.scope name then
      begin
        let name_type = type_of_vec_acc env.scope name index in
        if match_types name_type val_type then env else
          begin
            let name_type_str = string_of_twister_type name_type in
            let val_type_str = string_of_twister_type val_type in
            let val_str = To_string.string_of_expr value in
            let error_details = "Value: ''\n" ^ val_str ^ "\n''\n of type " ^ val_type_str ^ 
              " cannot be assigned to name " ^ name ^ " of type " ^ name_type_str in
              raise (Invalid_argument error_details)
          end
      end
    else
      begin
        let val_type_str = string_of_twister_type val_type in
        let val_str = To_string.string_of_expr value in
        let error_details = "Value: ''\n" ^ val_str ^ "\n''\n of type " ^ val_type_str ^ 
          " cannot be assigned to name " ^ name ^ " which has no declared type." in
          raise (Invalid_argument error_details)
      end;;

(* Check a statement assigning a new value to an element of an already defined matrix type *)
let mat_assign_val (env: translation_environment) assign_rec =
  let name = assign_rec.mat_name and value = assign_rec.mat_el_val
  and row_index = assign_rec.row_index and col_index = assign_rec.col_index in
  let val_type = type_of_expr env.scope value in
    if var_is_defined env.scope name then
      begin
        let name_type = type_of_mat_acc env.scope name row_index col_index in
        if match_types name_type val_type then env else
          begin
            let name_type_str = string_of_twister_type name_type in
            let val_type_str = string_of_twister_type val_type in
            let val_str = To_string.string_of_expr value in
            let error_details = "Value: ''\n" ^ val_str ^ "\n''\n of type " ^ val_type_str ^ 
              " cannot be assigned to name " ^ name ^ " of type " ^ name_type_str in
              raise (Invalid_argument error_details)
          end
      end
    else
      begin
        let val_type_str = string_of_twister_type val_type in
        let val_str = To_string.string_of_expr value in
        let error_details = "Value: ''\n" ^ val_str ^ "\n''\n of type " ^ val_type_str ^ 
          " cannot be assigned to name " ^ name ^ " which has no declared type." in
          raise (Invalid_argument error_details)
      end;;

(* Generate an error message for attempting to declare a new variable with an already existing name 
   TODO: Should we allow this at different stack levels?? *)
let name_reuse_err scope name = 
  let type_str = string_of_twister_type (type_of_var_name scope name) in
  let error_details = "Variable name " ^ name ^ " already used with type " ^ type_str in
  raise (Invalid_argument error_details);;

(* Generate an error message for attempting to default a value of the wrong type for the argument
   to which it is supposed to be the default value in a function declaration *)
let default_type_err id expr etype dtype = 
  let exp_type_str = string_of_twister_type dtype in
  let act_type_str = string_of_twister_type etype in
  let expr_str = To_string.string_of_expr expr in
  let error_details = "Cannot default expression ''\n" ^ expr_str ^ "\n''\n of type " ^ act_type_str ^
    " for variable " ^ id ^ " of type " ^ exp_type_str in
    raise (Invalid_argument error_details);;

(* Add an argument name to the symbol table *)
let declare_an_arg (scope: symbol_table) arg =
  let local_scope = {parent = None; vars = scope.vars} in
  match arg with
  ArgId(name, dtype) -> if var_is_defined local_scope name then name_reuse_err local_scope name else
    let argtype = type_of_ast_dtype dtype in declare scope argtype name;;
(*
| ArgFunId(name, ftype) -> if var_is_defined local_scope name then name_reuse_err local_scope name else
  let argtype = type_of_fun_type ftype in declare scope argtype name;;
*)

(* Add a list of argument names to the symbol table *)
let rec declare_args (scope: symbol_table) args =
  match args with
  [] -> scope
| h :: t -> declare_args (declare_an_arg scope h) t;;

(* Find all argument names from a function declaration and get them added the symbol table *)
let declare_fun_args (scope: symbol_table) fun_type =
  let new_scope: symbol_table = {parent = Some scope; vars = StringMap.empty} in
  match fun_type with
  FunType(args, _) -> declare_args new_scope args;;

(* Wrap an error with context about symbol table - can turn on for verbose debugging *)
let wrap_err on (scope: symbol_table) err = 
  if on then 
  begin
    let scope_var_strs = List.map fst (StringMap.bindings (scope.vars)) in
    let err_new = "SCOPE: [" ^ (To_string.string_join "," scope_var_strs) ^ "]" ^ "; ERR: " ^ err in
      raise (Invalid_argument err_new)
  end
  else raise (Invalid_argument err);;

(* Determine a statement's effects on the translation environment, and return a new environment if appropriate *)
let rec add_statement (env: translation_environment) stmt = 
  match stmt with
  Decl(d) -> let dtype = decl_type env.scope d in
    let new_scope = declare env.scope dtype d.var_name in
    {scope = new_scope; loc_ret_type = env.loc_ret_type}
| Assign(v) -> assign_val env v
| VecAssign(v) -> vec_assign_val env v
| MatAssign(v) -> mat_assign_val env v
| Return(ret_expr) -> let rtype = type_of_expr env.scope ret_expr in
    if rtype = env.loc_ret_type then env else 
      begin
        let rtype_str = string_of_twister_type rtype in
        let loc_type_str = string_of_twister_type env.loc_ret_type in
        let stmt_str = To_string.string_of_stmt stmt in
        let error_details = "Statement ''\n" ^ stmt_str ^ "\n''\n would return invalid type " ^ rtype_str ^
            " where return values should be of type " ^ loc_type_str in
          raise (Invalid_argument error_details)
      end
| If(condition, if_block, else_block) -> let cond_type = type_of_expr env.scope condition in
    if cond_type = BoolT then 
      begin
        if (check_statements env if_block && check_statements env else_block) then env else
        begin
          let error_details = "Invalid statements in ''\n" ^ To_string.string_of_stmt stmt ^ "\n''\n" in
          raise (Invalid_argument error_details)
        end
      end
    else 
      begin 
        let cond_str = To_string.string_of_expr condition in
        let error_details = "Condition ''\n" ^ cond_str ^ "\n''\n is not a valid Boolean expression." in
        raise (Invalid_argument error_details)
      end
| For(loop_var, iterable, stmts) -> let local_scope = {parent = None; vars = env.scope.vars} in
    if var_is_defined local_scope loop_var then name_reuse_err local_scope loop_var else
      let it_type = type_of_iterable env.scope iterable in
      let new_scope = {parent = Some env.scope; vars = StringMap.empty} in
      let loop_scope = declare new_scope it_type loop_var in
      let new_env = {scope = loop_scope; loc_ret_type = env.loc_ret_type} in
      if check_statements new_env stmts then env else
        begin
          let stmt_str = To_string.string_of_stmt stmt in 
          raise (Invalid_argument ("Invalid syntax found in ''\n" ^ stmt_str ^ "\n''\n"))
        end

and process_statement env stmt = 
  try add_statement env stmt
  with Invalid_argument err -> let err_new = "Error: \"\n" ^ err ^ "\n\" @ statement: " ^ To_string.string_of_stmt stmt in
    raise (Invalid_argument err_new)

(* Check a list of statements for semantic correctness*)
and check_statements (env: translation_environment) stmts =
  match stmts with
  [] -> true
| h :: t -> try check_statements (process_statement env h) t with
    Invalid_argument err -> wrap_err false env.scope err

(* Determine the type of a function declaration *)
and fun_decl_type scope d =
  if var_is_defined scope d.var_name then
    begin
      let type_str = string_of_twister_type (type_of_var_name scope d.var_name) in
      let error_details = "Function name " ^ d.var_name ^ " is already declared with type " ^ type_str in
      raise (Invalid_argument error_details)
    end
  else
    begin
      match d.body with
      FunLit(ftype, stmts) -> let fun_ret_type = return_type_of_fun_type ftype in
        let fun_full_type = type_of_fun_type ftype in
        let new_scope = declare_fun_args scope ftype in
        let new_env: translation_environment = {scope= new_scope; loc_ret_type = fun_ret_type} in
        if check_statements new_env stmts then fun_full_type
          else 
            begin
              let fun_str = "Unexpected error in " ^ d.var_name in
              let body_str = " function body ''\n" ^ To_string.string_of_expr d.body ^ "\n''\n" in
              let error_details = fun_str ^ body_str in
              raise (Invalid_argument error_details)
            end
    | _ -> let body_str = To_string.string_of_expr d.body in
        let body_type_str = string_of_twister_type (type_of_expr scope d.body) in
        let error_details = "Expression ''\n" ^ body_str ^ "\n''\n of type " ^ body_type_str ^
          " is not a valid function declaration body." in
          raise (Invalid_argument error_details)
    end
(* Determine the type of any variable declaration *)
and decl_type (scope: symbol_table) d =
  let local_scope = {parent = None; vars = scope.vars} in
  if var_is_defined local_scope d.var_name then
    begin
      let type_str = string_of_twister_type (type_of_var_name local_scope d.var_name) in
      let error_details = "Name " ^ d.var_name ^ " is already declared with type " ^ type_str in
      raise (Invalid_argument error_details)
    end
  else 
    begin
      if d.return_type = Fun then fun_decl_type scope d else
        begin
          let val_type = type_of_expr scope d.body in
          if not (match_types (type_of_ast_dtype d.return_type) val_type) then
            begin
              let val_str = To_string.string_of_expr d.body in
              let val_type_str = string_of_twister_type val_type in
              let type_str = To_string.string_of_datatype d.return_type in
              let error_details = "Type mismatch: Could not assign expression ''\n" ^ val_str ^
                "\n''\n of type " ^ val_type_str ^ " to name " ^ d.var_name ^ " of type " ^ type_str in
              raise (Invalid_argument error_details)
            end
          else val_type
        end
    end;;

(* Returns initial scope with built-in functions declared*)
let get_init_scope =
  let blank_scope: symbol_table = {parent = None; vars = StringMap.empty} in
  let tofl_scope = declare blank_scope (FunT(ScalarT(IntT) :: [], ScalarT(FloatT))) "tofl" in
  let toint_scope = declare tofl_scope (FunT(ScalarT(FloatT) :: [], ScalarT(IntT))) "toint" in
  let print_scope = declare toint_scope (FunT(ListT(CharT) :: [], ScalarT(IntT))) "print" in
  let print_ch_scope  = declare print_scope (FunT(CharT::[], ScalarT(IntT))) "print_c" in
  let println_scope = declare print_ch_scope (FunT(ListT(CharT) :: [], ScalarT(IntT))) "println" in
  let print_int_scope = declare println_scope (FunT(ScalarT(IntT) :: [], ScalarT(IntT))) "print_int" in
  let println_int_scope = declare print_int_scope (FunT(ScalarT(IntT) :: [], ScalarT(IntT))) "println_int" in
  let print_fl_scope = declare println_int_scope (FunT(ScalarT(FloatT) :: [], ScalarT(IntT))) "print_fl" in
  let fread_scope = declare print_fl_scope (FunT(ListT(CharT) :: [], BoolT)) "fread" in
  let fwrite_scope = declare fread_scope (FunT(ListT(CharT) :: [], BoolT)) "fwrite" in
  declare fwrite_scope (FunT(ScalarT(IntT) :: ScalarT(IntT) :: [], ListT(ScalarT(IntT)))) "range"


(* Check a whole program, statement by statement, for semantic correctness *)
let valid_program p =
  let global_scope = get_init_scope in
  let global_env: translation_environment = {scope= global_scope; loc_ret_type = BoolT} in
  check_statements global_env p;;
