open Ast
open Symbol

exception SemanticError

let error_msg = ref ""

let v_flag = ref 0;;
let loop_flag = ref 0;;
let cur_type = ref NULL;; (*to keep track of the data type in the current expr*)

(* function used to handle error messages *)
let raise_error msg = 
	error_msg := msg;
	raise SemanticError

let get_type = function (t, v, s)-> t
let get_vector = function (t, v, s)-> (s > 0)
let get_size = function (t,v,s)-> s
let size = ref 0

(* Function to determine the data type of the current symbol *)
let dtype_eval = function
	Int -> "int"
	| Int8 -> "unsigned char"
	| Float -> "float"
	| _ -> raise_error "Invalid data type"

(* Function for type checking of two variables *)
let eval_type v1 v2 f op= 
	let a = symbol_detail v1
		in let b = symbol_detail v2
			in if(get_vector a = false && 
            get_vector b = false && 
            get_type a = get_type b) then
				"(" ^ v1 ^ " "^op^" " ^ v2 ^")"
			else if (get_type a = get_type b) then
				begin
				v_flag:=1;
				if(get_size a = -1 || get_size a = -2) 
                    then size:=get_size b
				else if(get_size b = -1 || get_size b = -2) 
                    then size:=get_size a
				else if( get_size a > get_size b) 
                    then size:= get_size b
				else size := get_size a;
				f^"(("^dtype_eval (get_type a)^
                "*)alloca("^string_of_int !size 
				^"*sizeof("^dtype_eval (get_type a)
                ^")), " ^ v1 ^ ", " 
				^ v2 ^", "^string_of_int !size^")"
				end
			else
			raise_error "data type mismatch"

(* Function to check that a scalar is not assigned a vector or if there is a mismatch of the data types *)
let check_scalar v1 v2 op= 
	let a = symbol_detail v1
		in let b = symbol_detail v2
			(*Remove the type check for allowing data type conversion or promotion*)
			in if(get_vector a = false && 
            get_vector b = false && 
            get_type a = get_type b) 
            then "(" ^ v1 ^ " "^op^" " ^ v2 ^")"
			else
			raise_error "Invalid operation or data type mismatch"

(* Function for type checking variables on the lhs and rhs of an expression *)
let check_lhs v1 = 
	let a = symbol_detail v1
		in if get_size a = -1 && !v_flag = 1 then
			raise_error "Mismatch in LHS and RHS type"
		else if get_type a <> !cur_type then
			if get_type a = Int8 && !cur_type= Int then ignore("")
			else
			raise_error "Mismatch in LHS and RHS data type"

(* Function which sets the vectro flag if the argument is a vector *)
let check_var v1 = 
	let a = symbol_detail v1
		in if get_size a > 0 then
			begin
			v_flag:=1;
			cur_type:= get_type a
			end
		else
			v_flag := 0;
			cur_type:= get_type a

(* Function for generating C++ code for printing floating point vectors *)
let print_float_array x = 
	let rec p xlist str = match xlist with
		[] -> ""
		|[x] -> (str^ (x)^"}") 
		|hd::tl -> let x1 = p tl (str^(hd)^", ") in x1;
	in
		p x "{"
		;;
		
(* Function for generating C++ code for printing integer vectors *)		
let print_int_array x = 
	let rec p xlist str = match xlist with
		[] -> ""
		|[x] -> (str^string_of_int (x)^"}")
		|hd::tl -> let x1 = p tl (str^string_of_int(hd)^", ") in x1;
	in
		p x "{"
		;;
let array = ref 0;;

let const_size = -2

(* Function to walk expressions and output relevant code*)
let rec expr_eval = function
		  ExprEmpty -> "";
		| Literal(x) -> insert_symbol_const Int (string_of_int(x)) const_size; 
                        v_flag:=0 ; cur_type:= Int; string_of_int(x); 
		| FloatL (x) -> insert_symbol_const Float x const_size; 
                        v_flag:=0;cur_type:=Float ; x;
		| String (x) -> v_flag:=0; cur_type:=StrType; ("\""^x^"\"");
		| Id(x) -> lookup_symbol(x); check_var x; x
		| IntArray (x) -> array:=1; cur_type:= Int; 
                          v_flag:= 1; let x1 = print_int_array (x) in x1
		| FloatArray (x) -> array:=1; cur_type:= Float; v_flag:= 1; 
                            let x1 = print_float_array (x) in x1
		| Uop (id, op1) -> lookup_symbol id; let a = symbol_detail id 
                           in let vf = get_vector a in 
						if vf = true then
							raise_error "Cannot use unary operators on vectors";
						if op1 = PlusP then 
							(id^"++")
							else
							(id^"--");
		| Binop (x,op,y) -> 
			let v1 = expr_eval x and v2 = expr_eval y in 
			match op with 
				 Plus -> eval_type v1 v2 "add" "+"
				|Minus -> eval_type v1 v2 "sub" "-"
				| Mul -> eval_type v1 v2 "mul" "*"
				| Div -> eval_type v1 v2 "div" "/"
				| Or -> eval_type v1 v2 "bit_or" "|"
				| And -> eval_type v1 v2 "bit_and" "&"
				| Xor -> eval_type v1 v2 "bit_xor" "^"
				| Lt -> check_scalar v1 v2 "<"
				| Gt -> check_scalar v1 v2 ">"
				| Lteq -> check_scalar v1 v2 "<="
				| Gteq -> check_scalar v1 v2 ">="
				| Eq -> check_scalar v1 v2 "=="
				| NotEq -> check_scalar v1 v2 "!="
			

	
(* Function to evaluate statements *)
let expr_to_tuple expr = (*print_string "CALL";*)
	ignore (expr_eval expr);
	if !v_flag = 1 then
		(0, !cur_type)
	else
		(-1, !cur_type)
		;;
	
	let rec stmt_eval = 
		let rec exprlist_eval = function 
		[x] -> let v1 = expr_eval x in v1
		|x::y -> let v1 = exprlist_eval y in 
                let v2 = expr_eval x in 
                let v3 = (v2 ^ "," ^ v1)in v3
		| [] -> "" in
	function
	Dummy_str(x) ->  print_endline (x); 
	| Assign (x,y1, iflag) ->  (match y1 with 
						Expr (y, _) ->
						lookup_symbol x;  v_flag := 0; 
                        array:= 0; cur_type:= NULL; 
                        let v1 = expr_eval y in
						if !v_flag = 0 then 
							begin 
							check_lhs x;
								let rec a =symbol_detail x
								in
								if (get_size a > 0) then 
									print_string ("copy(" ^ x ^ ", "^v1^","^string_of_int(get_size a)^");\n")
								else
									
									begin
									print_string(x ^" = "^ v1 );
									if iflag = 1 then
									print_string (";\n")
									end
								
							end
						else 
							begin
							check_lhs x;
							if !array = 1 then
								if !cur_type = Int then 
								let a = symbol_detail x in print_string ("{int tmp[] = " ^ 
                                        v1 ^ ";" ^"copy_vector(" ^ x ^ ", "
                                        ^string_of_int(get_size a) ^",tmp);}\n");  
								else
								let a = symbol_detail x in print_string ("{float tmp[] = " ^
                                        v1 ^ ";" ^"copy_vector(" ^ x ^ ", "
                                        ^string_of_int(get_size a) ^",tmp);}\n");  
							
							else
								begin
									let a = symbol_detail x
									in
									let b = symbol_detail v1
									in
									if ( get_size b > 0) then size:= (get_size b);
									if(!size < 1) then size:=(get_size a)
									else if( (get_size a) > (!size)) then size:= !size
									else size := (get_size a);

									print_string ("copy(" ^ x ^ ", "^v1^","^string_of_int(!size)^");\n");
								end
							end
							| FunCall (_,_) -> lookup_symbol x; let a = symbol_detail x in
												if get_type a = Int && get_vector a = false then
													begin
														print_string (x^" = "); 
														stmt_eval y1;
													end
												else
													raise_error (x^" should be of type Int");
							| _ -> raise_error ("Invalid assignment");)
	| Return (x) -> let v1 = expr_eval x in 
			if !cur_type = Int then 
				print_string ("return " ^ v1 ^";")
			else
				raise_error "Return type should be an integer"
	| FunCall (fname, expr) -> if f_find_sym fname = false then
									raise_error ("Cannot find function "^fname)
								else
									let k1 = List.map expr_to_tuple expr in 
									fun_arg_type_eval fname (List.length expr) k1; 
									let v1 = exprlist_eval expr in print_string (fname^"("^v1^");\n"); 
	| Expr(x,iflag) -> let v1 = expr_eval x in print_string (v1);
						if iflag = 1 then
								print_string (";\n")
	| If (expr1, stmt1, stmt2) -> let v1 = expr_eval expr1 in print_endline ("\nif (" ^ v1 ^ ") {"); 
										List.iter stmt_eval stmt1;
										print_endline ("}");
										if List.length stmt2 != 0 then
										begin
											print_string(" else {" );
											List.iter stmt_eval stmt2;
											print_endline ("}\n"); 
										end
	| For (stmt1, expr1, stmt2, stmt_lst) -> loop_flag:=1;print_string ("for ("); 
										stmt_eval stmt1; print_string ";";  
										let e1 = expr_eval expr1 in print_string (e1 ^ ";");
										stmt_eval stmt2; print_string (")\n {"); 
										List.iter stmt_eval stmt_lst;
										print_endline ("}\n"); loop_flag:=0;
	| While (expr1, stmt_list) -> loop_flag := 1;print_string ("while (");
								let e1 = expr_eval expr1 in print_string (e1^") {\n");
								List.iter stmt_eval stmt_list;
								print_string ("\n}\n");loop_flag:=0;
	| Print (x) -> lookup_symbol x; let x1 = symbol_detail x in 
					let x3 = string_of_int (get_size x1) in 
					if get_vector x1 = false  then
						print_string ("print("^x^");\n")
					else
						print_string ("print("^x^"," ^ x3 ^");\n")
	| Print_const (x, y) -> ignore(dtype_eval y); 
						print_string ("print("^x^");\n");
	| Print_string (x) -> print_string ("print_string (\""^x^"\");");
	| Continue -> if !loop_flag = 1 then 
					print_endline("continue;")
				  else
				  	raise_error "Misplaced continue statement. Not within a loop";
	| Break -> if !loop_flag = 1 then 
					print_endline("break;")
				  else
				  	raise_error "Misplaced break statement. Not within a loop";
	| Empty -> 			ignore();
;;

(* Function to evaluate declarations *)
let rec decl_eval = 
	let vdecl_str dtype arg = match arg with 
	Var_Decl (x, y) -> if y != -1 then
						begin
						    insert_symbol dtype x y;
							x ^ "[" ^ string_of_int(y) ^ "]";	
						end	
					else 
					begin
						insert_symbol dtype x y;
						x 
					end
	in
	let rec vdecl_eval dtype = function
	[x] -> let v1 = vdecl_str dtype x in v1
	|x::y -> let v1 = vdecl_eval dtype y in let v2 = vdecl_str dtype x in
											let v3 = (v2 ^ "," ^ v1)in v3
	| [] -> "" in
	function
	Var_list (dtype,var_list) -> let dstr = dtype_eval dtype in
								 let vstr = vdecl_eval dtype var_list in 
								 print_endline (dstr ^" "^ vstr ^ ";")

let fun_size = 0

let find_size = function
		(dtype, id, size)-> size

let fun_type_list = ref [];;

(* Function to evaluate function argument *)
let rec fun_eval_list_print = 
	let fun_eval = function
	FunArg (dtype, id, size) -> (*print_string "FUNARG ";*) let vbool = if (size > 0) then 0 
								else -1 in fun_type_list := (vbool,dtype)::!fun_type_list;
							insert_symbol dtype id size; let dstr = dtype_eval dtype in 
							if size = -1 then (dstr ^ " " ^ id ^ " ") 
							else (dstr ^ " " ^ id ^ "["^string_of_int size^"] ") in
	function
	[x] -> let v1 = fun_eval x in print_string (v1)
	|hd::tl -> let v1 = fun_eval hd in fun_eval_list_print tl; print_string ("," ^ v1)
	| [] -> print_string ("void")


(* Function to eval functions *)
let fun_put fundecl = 
  clear_symtable();  global_scope:= 0; fun_type_list := [];
  print_string ("int " ^ fundecl.fun_name ^ " ( " );
	fun_eval_list_print fundecl.arg_list;
	f_insert_sym fundecl.fun_name (List.length fundecl.arg_list) !fun_type_list; (* fun name, number of args, list of type of args *)
	print_string (") {\n "); (* arguments *)
	List.iter decl_eval fundecl.decl_list; List.iter stmt_eval fundecl.stmt_list; (*declarations  & statements*)
	print_endline ("\n}\n\n");;

(* Function to find the main function *)
let find_main count fundecl =
	if fundecl.fun_name = "main" then
		count + 1
	else
		count

(* Begin semantic analysis *)
let eval (gdecl,program) = 
	let c = 
	List.fold_left find_main 0 program in 
	if c = 0 then
		raise_error "Unable to find main function"
	else if c > 1 then
		raise_error "Cannot have more than one main function"
	else 
		begin
			global_scope:= 1; 
			print_endline "#include <stub-print.h>";
			print_endline "#include <stub.h>\n";
			List.iter decl_eval gdecl; global_scope:= 0;
			List.iter fun_put program
		end
