open Jast

(* global environment that sticks around for type checking of all modules *)
let genv =
	ref
		{
			Sym.parent = None;
			Sym.context = Global;
			Sym.scope = [];
			Sym.funcs =
				[
				{ rettype = Types.Int;
					fname = "print";
					formals = [{ vtype = Types.Int; vname =""; vscope = Formal }];
					locals = [];
					body = []};
				{ rettype = Types.Int;
					fname = "print";
					formals = [{ vtype = Types.Mol; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Int;
					fname = "print";
					formals = [{ vtype = Types.String; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Int;
					fname = "print";
					formals = [{ vtype = Types.Boolean; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Int;
					fname = "getWeight";
					formals = [{ vtype = Types.Atom; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Tuple(Types.Atom);
					fname = "getAtoms";
					formals = [{ vtype = Types.Mol; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Tuple(Types.Atom);
					fname = "getAtoms";
					formals = [{ vtype = Types.Atom; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Tuple(Types.Bond);
					fname = "getBonds";
					formals = [{ vtype = Types.Atom; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Int;
					fname = "getIdx";
					formals = [{ vtype = Types.Atom; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Atom;
					fname = "getNbr";
					formals = [{ vtype = Types.Bond; vname =""; vscope = Formal;};
						{ vtype = Types.Atom; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Boolean;
					fname = "setMarked";
					formals = [{ vtype = Types.Atom; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Boolean;
					fname = "setMarked";
					formals = [{ vtype = Types.Bond; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Boolean;
					fname = "isMarked";
					formals = [{ vtype = Types.Atom; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Boolean;
					fname = "isMarked";
					formals = [{ vtype = Types.Bond; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Mol;
					fname = "newBond";
					formals = [{ vtype = Types.Atom; vname =""; vscope = Formal;};
						{ vtype = Types.Atom; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Mol;
					fname = "newBond";
					formals = [{ vtype = Types.Mol; vname =""; vscope = Formal;};
						{ vtype = Types.Atom; vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				{ rettype = Types.Tuple(Types.Mol);
					fname = "joinFrags";
					formals = [{ vtype = Types.Tuple(Types.Mol); vname =""; vscope = Formal;};
						{ vtype = Types.Tuple(Types.Mol); vname =""; vscope = Formal;}];
					locals = [];
					body = []};
				];
		}

let rec check_expr env =
	function
	(* primary expressions *)
	| Ast.Id id ->
			let vdecl =
				try Sym.find_env_sym env id
				with
				| Not_found ->
						raise (Types.MDLException ("Undeclared identifier: " ^ id))
			in ((Id vdecl), (vdecl.vtype))
	| Ast.IntLiteral v ->
			((IntLiteral v), Types.Int)
	| Ast.ChemLiteral s ->
			if (String.length s) = 1
			then
				((ChemLiteral s), Types.Atom)
			else
				((ChemLiteral s), Types.Mol)
	| Ast.StringLiteral s ->
			(StringLiteral s), Types.String
	| Ast.BooleanLiteral b ->
			(BooleanLiteral b), Types.Boolean
	| Ast.Call (fname, exprs) ->
			let typed_exprs = List.map (check_expr env) exprs in
			let types = snd (List.split typed_exprs) in
			let fdecl = Sym.find_env_func env fname types in
			(Call (fdecl, typed_exprs), fdecl.rettype)
	
	(* Unary operators *)
	| Ast.Unop (op, e) ->
			let typed_e = check_expr env e in
			let expr_type = snd typed_e in
			if expr_type = Types.Boolean
			then
				((Unop (op, typed_e)), Types.Boolean)
			else
				raise (Types.MDLException ("Invalid expression type for unary operator: " ^
							(Types.string_of_type expr_type)))
	
	(* Binary operators *)
	| Ast.Binop (e1, op, e2) ->
			let te1 = check_expr env e1
			and te2 = check_expr env e2 in
			let typ1 = snd te1
			and typ2 = snd te2 in
			(
				match op with
				| Types.Sub ->
						if (((typ1 = Types.Mol) || (typ1 = Types.Atom)) && (typ2 = Types.Atom))
						then
							check_expr env (Ast.Call ("newBond", [e1; e2]))
						else
							Binop (te1, op, te2), Types.Int
				| Types.Add ->
						if (typ1 = Types.Int && typ2 = Types.Int)
						then
							Binop (te1, op, te2), Types.Int
						else if (typ1 = Types.String || typ2 = Types.String)
						then
							Binop (te1, op, te2), Types.String
						else	(
							match (typ1, typ2) with
							| (Types.Mol, Types.Mol) ->
									check_expr env (Ast.Call ("joinFrags",
												[Ast.List (e1 :: []); Ast.List (e2 :: [])]))
							| (Types.Tuple(Types.Mol), Types.Mol) ->
									check_expr env (Ast.Call ("joinFrags",
												[e1; Ast.List (e2 :: [])]))
							| (Types.Mol, Types.Tuple(Types.Mol)) ->
									check_expr env (Ast.Call ("joinFrags",
												[Ast.List (e1 :: []); e2]))
							| (Types.Tuple(Types.Mol), Types.Tuple(Types.Mol)) ->
									check_expr env (Ast.Call ("joinFrags",
												[e1; e2]))
							| _ ->
									raise (Types.MDLException ("Unsupported types for + operator: " ^
												(Types.string_of_type typ1) ^ " & " ^
												(Types.string_of_type typ2)))
						)
				| Types.Mult | Types.Div ->
						(Binop (te1, op, te2), Types.Int)
				| Types.Equal | Types.Neq | Types.Less
				| Types.Leq | Types.Greater | Types.Geq ->
						(Binop (te1, op, te2), Types.Boolean)
				| _ -> raise (Types.MDLException ("Unsupported types for operator: " ^
									(Types.string_of_type typ1) ^ (Types.string_of_operator op) ^
									(Types.string_of_type typ2)))
			)
	| Ast.Covers (e1, e2) ->
			let te1 = check_expr env e1
			and te2 = check_expr env e2 in
			let typ1 = snd te1
			and typ2 = snd te2 in
			if (typ1 = Types.Tuple(Types.Mol) && typ2 = Types.Mol)
			then
				(Covers (te1, te2), Types.Boolean)
			else
				raise (Types.MDLException ("Unsupported types for covers operator: " ^
							(Types.string_of_type typ1) ^ " & " ^ (Types.string_of_type typ2)))
	
	(* Assignment operator *)
	| Ast.Assign (id, e) ->
			let vdecl = Sym.find_env_sym env id
			and typed_e = check_expr env e in
			let etype = snd typed_e in
			if vdecl.vtype = etype
			then
				(Assign (vdecl, typed_e), etype)
			else
				raise (Types.MDLException ("Incompatible types for variable " ^ id))
	
	(* List operator *)
	| Ast.List l ->
			let etype = snd (check_expr env (List.hd l)) in
			((List (List.map (check_expr env) l)), Types.Tuple(etype))

let rec check_stmt env =
	function
	| Ast.Block stmts ->
			Block (List.map (check_stmt env) stmts)
	| Ast.Expr e ->
			Expr (check_expr env e)
	| Ast.Return e ->
			Return (check_expr env e)
	| Ast.While (e, s) ->
			let checked_e = check_expr env e in
			if ((snd checked_e) = Types.Boolean)
			then
				let checked_s = check_stmt env s in
				While (checked_e, checked_s)
			else
				raise (Types.MDLException ("While condition is not of boolean type: "))
	| Ast.IfThenElse (e, s1, s2) ->
			let checked_e = check_expr env e in
			if ((snd checked_e) = Types.Boolean)
			then
				let checked_s1 = check_stmt env s1
				and checked_s2 = check_stmt env s2 in
				IfThenElse (checked_e, checked_s1, checked_s2)
			else
				raise (Types.MDLException ("If condition is not of boolean type: "))
	| Ast.IfThen (e, s1) ->
			let checked_e = check_expr env e in
			if ((snd checked_e) = Types.Boolean)
			then
				let checked_s1 = check_stmt env s1 in
				IfThen (checked_e, checked_s1)
			else
				raise (Types.MDLException ("If condition is not of boolean type: "))
	| Ast.For (t, id, e, s) ->
			let checked_e = check_expr env e in
			let e_type = Types.type_of_tuple (snd checked_e) in
			if (t <> e_type)
			then
				raise (Sym.Type_mismatch (t, e_type))
			else
				let senv = Sym.newscope ForLoop env in
				let senv = Sym.add_sym senv (t, id) in
				let var = { vtype = t; vname = id; vscope = senv.Sym.context } in
				let checked_s = check_stmt senv s in
				For (var, checked_e, checked_s)

(* check function definition and recurse down body *)
let check_and_add_fdef env fdef =
	let fenv = Sym.newscope Method env in
	let fenv = List.fold_left (Sym.add_sym) fenv fdef.Ast.formals in
	let fenv = List.fold_left (Sym.add_sym) fenv fdef.Ast.locals in
	let myformals = List.map (
		fun (t, n) -> { vtype = t; vname = n; vscope = Formal }
		) 
		fdef.Ast.formals
	and mylocals = List.map (
		fun (t, n) -> { vtype = t; vname = n; vscope = Method }
		) 
		fdef.Ast.locals in
	let self =
		{
			rettype = fdef.Ast.rettype;
			fname = fdef.Ast.fname;
			formals = myformals;
			locals = mylocals;
			body = [];
		} in
	let fenv = Sym.add_func fenv self in
	let ret_method =
		{
			rettype = fdef.Ast.rettype;
			fname = fdef.Ast.fname;
			formals = myformals;
			locals = mylocals;
			body = List.map (check_stmt fenv) fdef.Ast.body;
		} in
	Sym.add_func env ret_method

(* generate list of sast function definitions from ast function            *)
(* definitions - this is the entry point for module                        *)
let rec check_prog ((vars : (Types.mdl_type * string) list), funcs) =
	let env = List.fold_left (Sym.add_sym) (!genv) vars in
	let funcs = (List.fold_left check_and_add_fdef env (List.rev funcs)).Sym.funcs in
	JClass (env.Sym.scope, funcs)
