(* simplification pass
  convert checked AST to flat intermediate representation
 	-simplify expressions
	-flatten blocks
	-replace loops with labels and branches
*)

open Ast
open Check

type simple_type = Ast.var_type
type simple_var = Ast.var_decl
type simple_fdecl = Ast.func_decl

type simple_lit =
	StrLit of string
	| NumLit of int

type simple_expr =
	Bin of simple_var * Ast.bop * simple_var * simple_var (* a = b op c *)
	| Un of simple_var * Ast.uop * simple_var (* a = op b *)
	| Call of simple_var * simple_fdecl * simple_var list (* a = b(c..d..) *)
	| Lit of simple_var * simple_lit (* a = "b" *)
	| Deref of simple_var * simple_var * simple_var (* a = b[c] *)
	| Alias of simple_var * simple_var * simple_var (* a[b] = c *)

type simple_stmt =
	| Decl of simple_var
	| If of simple_var * string
	| Jmp of string
	| Label of string
	| Ret of simple_var
	| Expr of simple_expr

and simple_func = {
	args: simple_var list;
	header: simple_fdecl;
	code: simple_stmt list;
}

type simple_program = {
	gvars: simple_var list;
	fdecls: simple_fdecl list;
	funcs: simple_func list;
	blocks: int;
}

let tmp_reg_id = ref 0
let label_id = ref 0

let gen_tmp_var t =
	let x = tmp_reg_id.contents in 
	let prefix = (match t with
		Simple(Str) -> "__reg_str_" | Simple(Num) -> "__reg_num_" |
		Map(_,_) -> "__reg_map_" | _ -> raise(Failure("unsupported type"))) in
	tmp_reg_id := x + 1; (prefix ^ (string_of_int x), t, -1)

let gen_tmp_label (s:unit) =
	let x = label_id.contents in
	label_id := x + 1; "__LABEL_" ^ (string_of_int x)

let is_vdecl (s:simple_stmt) =
	match s with
	Decl(_) -> true
	| _ -> false

let is_not_vdecl (s:simple_stmt) =
	not (is_vdecl s)

let rec simplify_rvalue (t:simple_type) (l:c_lvalue) =
	let(decl, e) = l in
	if e = NoExpr then
		([], decl)
	else
		let (se, r) = simplify_expr e in
		let tmp = gen_tmp_var t in
		([Decl(tmp)] @ se @ [Expr(Deref(tmp, decl, r))], tmp)
		(* side-effects are that passing map expression always is done using a temporary *)

and simplify_binop (t:simple_type) (e1:c_expr) (e2:c_expr) (op:bop) =
	let (se1, r1) = simplify_expr e1 in
	let (se2, r2) = simplify_expr e2 in
	let tmp = gen_tmp_var t in
	([Decl(tmp)] @ se1 @ se2 @ [Expr(Bin(tmp, op, r1, r2))], tmp)

and simplify_unop (t:simple_type) (e1:c_expr) (op:uop) =
	let (se, r) = simplify_expr e1 in
	let tmp = gen_tmp_var t in
	([Decl(tmp)] @ se @ [Expr(Un(tmp, op, r))], tmp)

and simplify_assign (t:simple_type) (l:c_lvalue) (e:c_expr) =
	let (se, r) = simplify_expr e in
	let (decl, l_expr) = l in
	if l_expr = NoExpr
		then (se @ [Expr(Alias(decl, ("", Simple(None), 1), r))], r)
	else
		let (le, lr) = simplify_expr l_expr in
		(se @ le @ [Expr(Alias(decl,lr,r))], r)

and simplify_call (fdecl:simple_fdecl) (el:c_expr list)
	(rl:simple_var list) (sl:simple_stmt list) =
	match el with
	[] ->
		let (_,t,_,_) = fdecl in
		let tmp = (match t with
			Simple(None) -> ("__none", t, -1) | _ -> gen_tmp_var t) in
			let c = Call(tmp, fdecl, (List.rev rl)) in (* reverse the list of results as it was constructed right-to-left *)
			([Decl(tmp)] @ sl @ [Expr(c)], tmp)
	| head :: tail ->
		let (se, r) = simplify_expr head in
		(* tack on the result to the list of results, the intermediate statements to the list of statements *)
		simplify_call fdecl tail (r :: rl) (se @ sl)

and simplify_expr (e:c_expr) =
	match e with
	StrLiteral(s) ->
		let tmp = gen_tmp_var (Simple(Str)) in
		([Decl(tmp); Expr(Lit(tmp, StrLit(s)))], tmp)
	| NumLiteral(n) ->
		let tmp = gen_tmp_var (Simple(Num)) in
		([Decl(tmp); Expr(Lit(tmp, NumLit(n)))], tmp)
	| NoExpr ->
		([], ("none", Simple(None), -1))
	| Rvalue(t,l) ->
		simplify_rvalue t l
	| Binop(t, e1, op, e2) ->
		simplify_binop t e1 e2 op
	| Unop(t, e1, op) ->
		simplify_unop t e1 op
	| Assign(t, l, e1) ->
		simplify_assign t l e1
	| FuncCall(fdecl, el) ->
		simplify_call fdecl el [] []

let gen_default_ret (t:simple_type) =
	if t = Simple(None) then [Ret(("none", t, -1))]
	else let tmp = gen_tmp_var t in
		Decl(tmp) :: [Ret(tmp)]

let rec simplify_stmt (s:c_stmt) =
	match s with
	CodeBlock(b) -> simplify_block b
	| Conditional(e, b1, b2) ->
		let (se, r) = simplify_expr e in
		let sb1 = simplify_block b1 in
		let sb2 = simplify_block b2 in
		let startlabel = gen_tmp_label () in
		let endlabel = gen_tmp_label () in
		se @ [If(r, startlabel)] @ sb2 @ [Jmp(endlabel); Label(startlabel)] @ sb1 @ [Label(endlabel)]
	| Loop(e, b) ->
		let (se, r) = simplify_expr e in
		let sb = simplify_block b in
		let startlabel = gen_tmp_label () in
		let endlabel = gen_tmp_label () in
		[Jmp(endlabel); Label(startlabel)] @ sb @ [Label(endlabel)] @ se @ [If(r, startlabel)]
	| Return(e) ->
		let (se, r) = simplify_expr e in
		se @ [Ret(r)]
	| Expression(e) -> (* only need simplified statements, not final tmp register *)
		let (se, r) = simplify_expr e in
		se

and simplify_stmtlist (slist:c_stmt list) =
	match slist with
	[] -> []
	| head :: tail -> simplify_stmt head @ simplify_stmtlist tail

and simplify_block (b:c_block) =
	let decls = List.map (fun e -> Decl(e)) b.c_locals in
	decls @ (simplify_stmtlist b.c_statements)

and simplify_fdecls = function
	[] -> []
	| head :: tail ->
		head.c_header :: simplify_fdecls tail

and simplify_func (f:c_func) =
	let body = simplify_block f.c_body in
	let ret_type = Check.get_ret_of_fdecl f.c_header in
	let body = body @ (gen_default_ret ret_type) in
	let vdecls = List.filter is_vdecl body in
	let stmts = List.filter is_not_vdecl body in
	{header = f.c_header; args = f.c_formals; code = vdecls @ stmts}

and simplify_funclist (flist:c_func list) =
	match flist with
	[] -> []
	| head :: tail -> simplify_func head :: simplify_funclist tail

let rec simplify_program (p:c_program) =
	{ gvars = p.c_globals; fdecls = simplify_fdecls p.c_functions; funcs = simplify_funclist p.c_functions; blocks = p.c_block_count}