open Ast
open Utility
open Class
open Environment
open Lists


type sast_environment = {
    sast_typed  : literal NameMap.t;
    sast_untyped: literal NameMap.t;
    call_list   : string list;
  }

(* compare types to see if they can be used interchangably.  Otherwise, this is a semantic error *) 
let rec types_compatible lit1 lit2 =
	match lit1, lit2 with
		IntLiteral(lit1), IntLiteral(lit2) -> true
		| BoolLiteral(lit1), BoolLiteral(lit2) -> true
		| StringLiteral(lit1), StringLiteral(lit2) -> true
		| DoubleLiteral(lit1), DoubleLiteral(lit2) -> true
		| ClassLiteral(cls_map1), ClassLiteral(cls_map2) -> 
			get_class_name (ClassLiteral(cls_map1)) = get_class_name (ClassLiteral(cls_map2))
		| ListLiteral(items1), ListLiteral(items2) ->  
			if (List.length items1) == 0 then
				true
			else if (List.length items2) == 0 then
				true
			else 
				types_compatible (List.hd items1) (List.hd items2)
		| KVPLiteral(key1, lit1), KVPLiteral(key2, lit2) -> types_compatible lit1 lit2
		| (LhsLiteral(var, lit), (_ as second)) -> types_compatible lit second
		| ((_ as first), LhsLiteral(var, lit)) -> types_compatible first lit
		| _, _ -> false

(* Typed functions must end in a return *)
let rec statment_end_in_return = function
    Block(stmts) ->
		if (List.length stmts > 0) then
			statment_end_in_return (List.nth stmts ((List.length stmts) - 1))
		else
			false
  | Return(expr) -> true
  | If(e, s1, s2) ->  (statment_end_in_return s1) && (statment_end_in_return s2)
  | _ -> false


let eval_actual_list (actuals, env, sast_env, eval) = 
	List.fold_left
		(fun (actuals, env, sast_env) actual -> 
			let v, (env, sast_env) = eval (env, sast_env) actual in 
			v :: actuals, env, sast_env)
		([], env, sast_env) (List.rev actuals)

(* ensure function return types match up *)
let check_fdecl (check, fdecl, actuals, env, sast_env) =
	let sast_env_pre_call_list = sast_env.call_list in
	let sast_env = { sast_env with call_list = fdecl.fname::sast_env.call_list } in
	let (env, sast_env) = (check fdecl actuals env sast_env) in
	let sast_env = {sast_env with call_list = sast_env_pre_call_list} in
	let env = { env with env_locals = NameMap.empty } in
	if NameMap.mem fdecl.fname sast_env.sast_typed then
(* let test = print_endline ("TYPED: " ^ fdecl.fname ^ ": " ^ (string_of_literal (NameMap.find fdecl.fname sast_env.sast_typed))) in *)
		let statement_size = List.length fdecl.body in
		if statement_size > 0 then
			let last_statement = (List.nth fdecl.body (statement_size - 1)) in
(* let test = print_endline (string_of_stmt last_statement) in *)
			if (statment_end_in_return last_statement) then
				(NameMap.find fdecl.fname sast_env.sast_typed), (env, sast_env)
			else
				raise(SemanticFailure("Function contains return statements, but does not end in one: " ^ fdecl.fname))
		else
			(NameMap.find fdecl.fname sast_env.sast_typed), (env, sast_env)
	else if NameMap.mem fdecl.fname sast_env.sast_untyped then
(* let test = print_endline ("UNTYPED: " ^ fdecl.fname) in *)
		(NameMap.find fdecl.fname sast_env.sast_untyped), (env, sast_env)
	else
(* let test = print_endline ("UNKNOWN: " ^ fdecl.fname) in *)
		let sast_env = { sast_env with sast_untyped = NameMap.add (fdecl.fname) (IntLiteral(0)) sast_env.sast_untyped } in
		IntLiteral(0), (env, sast_env)

let cls_check_func (check, env, func_decls, func_name, actuals, cls_lit, call_fdecl, sast_env) =
	let fdecl = find_fdecl (func_name, func_decls) in
	let cls_context = ((get_lhs_varname cls_lit), (get_class_value_map cls_lit)) in
	let call_env = {env with env_locals = NameMap.empty; env_context = cls_context} in
	let call_lit, (post_call_env, sast_env) = check_fdecl (check, fdecl, actuals, call_env, sast_env) in
	let err_msg = "Unknown context: " ^ (fst call_env.env_context) in
	let return_env = update_env_class_call (env, post_call_env, err_msg) in 
	call_lit, (return_env, sast_env)


(* 
	Semantic Check:
	all return statements in a functions return same type
	assignments are of correct type
	all list items are same type
*)

let semantic_check (vars, funcs, classes) =
	let func_decls = map_funcs (funcs, classes) in
	(* Check by functions and return an updated global symbol table *)
	let rec check fdecl actuals call_env sast_env =
		(* Evaluate an expression and return (value, updated environment) *)
		let rec eval (env, sast_env) = function
			Literal(i) -> i, (env, sast_env)
			| Noexpr -> IntLiteral(1), (env, sast_env)
			| Id(var) -> 
				let ret_lit, env = id_variable_lookup (var, env, "ID: undeclared identifier ") in
				ret_lit, (env, sast_env)
			| Cast(v_type, e1) ->
				let v1, (env, sast_env) = eval (env, sast_env) e1 in
				(cast_of_literal (v_type, v1)), (env, sast_env) 
			| Negate(e1) ->
				let v1, (env, sast_env) = eval (env, sast_env) e1 in
				(negate_of_literal v1), (env, sast_env)
			| Binop(e1, op, e2) ->
				let v1, (env, sast_env) = eval (env, sast_env) e1 in
				let v2, (env, sast_env) = eval (env, sast_env) e2 in
				let bool_compare_exception = (fun x y -> raise (Failure("Boolean greater/less than operators not supported"))) in
				(match op with
					Add -> 
						if ((is_list_literal v1) && (is_list_literal v2)) then
							list_concat (v1, v2)
						else
							do_literal_operation (+) (+.) (^) (v1,v2)
					| Sub -> do_literal_operation (-) (-.) (fun x y -> raise (Failure("String subtract not supported"))) (v1,v2)
					| Mult -> do_literal_operation ( * ) ( *. ) (fun x y -> raise (Failure("String multiply not supported"))) (v1,v2)
					| Div -> do_literal_operation ( / ) ( /. ) (fun x y -> raise (Failure("String divide not supported"))) (v1,v2)
					| Equal -> do_literal_compare (==) (=) (=) (==) (v1,v2)
					| Neq -> do_literal_compare (!=) (<>) (<>) (!=) (v1,v2)
					| Less -> do_literal_compare (<) (<) (<) (bool_compare_exception) (v1,v2)
					| Leq -> do_literal_compare (<=) (<=) (<=) (bool_compare_exception) (v1,v2)
					| Greater -> do_literal_compare (>) (>) (>) (bool_compare_exception) (v1,v2)
					| Geq -> do_literal_compare (>=) (>=) (>=) (bool_compare_exception) (v1,v2)), (env, sast_env)
			| ListItems(items) -> 
				let check_list_literal item  =
					let v1, (env, sast_env) = eval (env, sast_env) item.lvalue in
					if item.lkey = "_" then
						v1, (env, sast_env)
					else
						KVPLiteral(item.lkey, v1), (env, sast_env)
				in
				if (List.length items) = 0 then
					ListLiteral([]), (env, sast_env)
				else
					let ret_lit, (env, sast_env) = check_list_literal (List.hd items) in
					ListLiteral([ret_lit]), (env, sast_env)
			| Assign(e_var, e_val) ->
				let v_var, (env, sast_env) = eval (env, sast_env) e_var in
				let v_val, (env, sast_env) = eval (env, sast_env) e_val in
				if is_lhs_literal v_var then
					let var_name = (get_lhs_varname v_var) in
					if types_compatible v_var v_val then
						let ret_lit, env = assign_variable_lookup (var_name, v_val, env, "Assign: undeclared identifier ") in
						ret_lit, (env, sast_env)
					else
						raise(SemanticFailure("Cannot assign variable to an incompatible type: " ^ var_name))
				else
					raise(SemanticFailure("Cannot assign to the expression: " ^ (string_of_expr e_var)))
			| Call("print", [e]) ->
				let v, (env, sast_env) = eval (env, sast_env) e in
				IntLiteral(0), (env, sast_env)
			| Call("cos", [e]) ->
				let v, (env, sast_env) = eval (env, sast_env) e in
				DoubleLiteral(cos (double_of_literal v)), (env, sast_env)
			| Call("sin", [e]) ->
				let v, (env, sast_env) = eval (env, sast_env) e in
				DoubleLiteral(sin (double_of_literal v)), (env, sast_env)
			| Call("sqrt", [e]) ->
				let v, (env, sast_env) = eval (env, sast_env) e in
				DoubleLiteral(sqrt (double_of_literal v)), (env, sast_env)
			| ClassCall (e, cls_func, actuals) ->
				let cls_lit, (env, sast_env) = eval (env, sast_env) e in
				let actuals, env, sast_env = eval_actual_list (actuals, env, sast_env, eval) in
				if (is_list_literal cls_lit) = true then
					let ret_lit, env = check_list_func (cls_lit, cls_func, actuals, env) in
					ret_lit, (env, sast_env)
				else
					let class_name = (get_class_name cls_lit) in  
					let func_name = class_name ^ "." ^ cls_func in
					cls_check_func (check, env, func_decls, func_name, actuals, cls_lit, check_fdecl, sast_env)
			| Access (e, member) ->
				let cls_lit, (env, sast_env) = eval (env, sast_env) e in
				if is_lhs_literal cls_lit then
					id_cls_variable_lookup ((get_lhs_varname cls_lit), member, env, "undeclared access identifier: "), (env, sast_env)
				else
					raise(SemanticFailure("Cannot use access operator on expression: " ^ (string_of_expr e)))
			| Call(func, actuals) ->
				let get_context_cls_func =
					if NameMap.cardinal (snd env.env_context) = 0 then
						""
					else
						let cls_name = (get_class_name (ClassLiteral(snd env.env_context))) in
						let cls_func_name = cls_name ^ "." ^ func in
						if NameMap.mem cls_func_name func_decls then
							cls_func_name
						else
							"" in
				let cls_func_name = get_context_cls_func in	
				if String.length cls_func_name > 0 then
					let cls_lit = LhsLiteral((fst env.env_context), ClassLiteral(snd env.env_context)) in
					let actuals, env, sast_env = eval_actual_list (actuals, env, sast_env, eval) in
					cls_check_func (check, env, func_decls, cls_func_name, actuals, cls_lit, check_fdecl, sast_env)
			  	else 
					let actuals, env, sast_env = eval_actual_list (actuals, env, sast_env, eval) in	
					
					try
						let fdecl = find_fdecl (func, func_decls) in
						let call_env = { env with env_locals = NameMap.empty; env_context = get_def_env_context } in
						let call_lit, (post_call_env, sast_env) = check_fdecl (check, fdecl, actuals, call_env, sast_env) in
						call_lit, ({env with env_globals = post_call_env.env_globals}, sast_env) 
					with FunctionNotFoundException s -> 
						try
							let ret_lit, env = try_func_as_class_constructor (actuals, env, classes, func, init_var) in
							ret_lit, (env, sast_env)
						with Failure s1 ->
							raise (FunctionNotFoundException(s))
		in
		(* Execute a statement and return an updated environment *)
		let rec exec (env, sast_env) = function
			Block(stmts) -> List.fold_left exec (env, sast_env) stmts
			| Expr(e) -> let _, (env, sast_env) = eval (env, sast_env) e in (env, sast_env)
			| If(e, s1, s2) ->
				let v, (env, sast_env) = eval (env, sast_env) e in
				let (env, sast_env) = exec (env, sast_env) s1 in
				exec (env, sast_env) s2
			| While(e, s) ->
				let rec loop env =
					let v, (env, sast_env) = eval (env, sast_env) e in
					if (bool_of_literal v) then 
						(exec (env, sast_env) s) 
					else (env, sast_env)
				in loop env
			| For(e1, e2, e3, s) ->
				let _, (env, sast_env) = eval (env, sast_env) e1 in
				let rec loop env =
					let v, (env, sast_env) = eval (env, sast_env) e2 in
					if (bool_of_literal v) then
					let (env, sast_env) = exec (env, sast_env) s in
					let _, (env, sast_env) = eval (env, sast_env) e3 in
						(env, sast_env)
					else
						(env, sast_env)
				in loop env
			| Return(e) ->
				let current_func = (List.hd sast_env.call_list) in
				let matches = List.find_all (fun name -> current_func = name) sast_env.call_list in
				if ((List.length matches) < 3) then
					let v, (env, sast_env) = eval (env, sast_env) e in
					if (NameMap.mem current_func sast_env.sast_typed) then
						if types_compatible v (NameMap.find current_func sast_env.sast_typed) then
							(env, sast_env)
						else
							raise (SemanticFailure("Incompatible return types found for function: " ^current_func))
					else
						let sast_env = {sast_env with sast_typed = NameMap.add current_func v sast_env.sast_typed } in
						(env, sast_env)
				else
					(env, sast_env)
(*				let v, env = eval env e in
				raise (ReturnException(v, env)) *)
		in

		(* Enter the function: bind actual values to formal arguments *)
		let locals =
			try List.fold_left2
				(fun locals formal actual -> NameMap.add formal.vname actual locals)
				NameMap.empty fdecl.formals actuals
			with Invalid_argument(_) -> raise (Failure ("wrong number of arguments passed to " ^ fdecl.fname))
		in
		(* Initialize local variables to 0 *)
		let locals = List.fold_left
			(fun accum local -> NameMap.add local.vname (init_var classes local.vtype) accum) 
			locals fdecl.locals
		in
		let locals = NameMap.fold
			(fun key value accum -> NameMap.add key value accum) 
			call_env.env_locals locals
		in
		(* Execute each statement in sequence, return updated global symbol table *)
		(List.fold_left exec ({ call_env with env_locals = locals }, sast_env) fdecl.body)

	(* Run a program: initialize global variables to 0, find and run "main" *)
	in let globals = List.fold_left
		(fun globals vdecl -> NameMap.add vdecl.vname (init_var classes vdecl.vtype) globals) 
		NameMap.empty vars
	in try
		check (NameMap.find "main" func_decls) [] 
			{ env_globals = globals; 
			  env_locals = NameMap.empty; 
			  env_context = get_def_env_context }
			{ sast_typed = NameMap.empty;
    		  sast_untyped = NameMap.empty;
			  call_list = [] }
	with Not_found -> raise (Failure ("did not find the main() function"))
