open Ast

module StringMap = Map.Make(String)
module CounterHash = Hashtbl.Make(struct
  type t = string
  let equal x y = x = y
  let hash = Hashtbl.hash
end)

type function_info = {
	function_index : int;
	parameter_count : int;
  }

type condition = CGreaterThan | CLessThan | CGreaterEqual | CLessEqual | CEqual | CNotEqual

type var_type = IntType | ArrayType

type instruction =
	Label of string
  |	Move of string * string
  | CondMove of condition * string * string
  | Jump of string
  | CondJump of string
  |	Push of string
  | Pop of string
  | Invoke of string
  | AddLiteral of int * string
  | AddValues of string * string
  | SubValues of string * string
  | MulValues of string * string
  | DivValues of string
  | Sarl of string * string
  | Cmp of string * string
  | FuncHeader of string
  | Ret
  | None

let string_of_asm = function
	Label(s) -> s ^ ":\n"
  | Move(s, t) -> "    movl " ^ s ^ ", " ^ t ^ "\n"
  | CondMove(c, s, t) -> "    " ^
      (match c with
        CGreaterThan -> "cmovg"
      | CLessThan -> "cmovl"
      | CGreaterEqual -> "cmovge"
      | CLessEqual -> "cmovle"
      | CEqual -> "cmove"
      | CNotEqual -> "cmovne")
      ^ " " ^ s ^ ", " ^ t ^ "\n"
  | Jump(s) -> "    jmp " ^ s ^ "\n"
  | CondJump(s) -> "    je " ^ s ^ "\n"
  |	Push(a) -> "    pushl " ^ a ^ "\n"
  | Pop(a) -> "    popl " ^ a ^ "\n"
  | Invoke(n) -> "    call " ^ n ^ "\n"
  | AddLiteral(l, a) -> "    addl $" ^ string_of_int l ^ ", %" ^ a ^ "\n"
  | AddValues(s, d) -> "    addl " ^ s ^ ", " ^ d ^ "\n"
  | SubValues(s, d) -> "    subl " ^ s ^ ", " ^ d ^ "\n"
  | MulValues(s, d) -> "    imul " ^ s ^ ", " ^ d ^ "\n"
  | DivValues(s) -> "    idivl " ^ s ^ "\n"
  | Sarl(b, t) -> "    sarl " ^ b ^ ", " ^ t ^ "\n"
  | Cmp(s, d) -> "    cmp " ^ s ^ ", " ^ d ^ "\n"
  | FuncHeader(n) -> ".type " ^ n ^ ", @function\n" ^ n ^ ":\n"
  | Ret -> "    ret\n"
  | None -> ""

let accumulate_statements a statement_list =
  a @ statement_list

let translate (globals, functions) =
  let counters = CounterHash.create 1 in
  CounterHash.add counters "loop_label" 0;
  CounterHash.add counters "cond_label" 0;
  let alloc_functions (functions, count) fdecl =
	if StringMap.mem fdecl.fname functions
	then raise (Failure (fdecl.fname ^ " used twice as a function name."))
	else (StringMap.add
			fdecl.fname
			{ function_index = count; parameter_count = List.length fdecl.formals }
			functions,
		  count + 1) in
  let (func_info, _) =
	List.fold_left alloc_functions
	  (StringMap.add "printf" { function_index = -1; parameter_count = 2 }
		 StringMap.empty, 0)
	  functions in
  let alloc_vars start stride (vars, count) = function
	  Int(n) -> if StringMap.mem n vars
		  then raise (Failure (n ^ " used as a variable twice in the same scope."))
		  else (StringMap.add n (IntType, (count * stride) + start) vars, count + 1)
	| Array(n, l) -> if StringMap.mem n vars
		  then raise (Failure (n ^ " used as a variable twice in the same scope."))
		  else if l < 1 then raise (Failure ("Arrays must be of size greater than 1."))
		  else (StringMap.add n (ArrayType, (count * stride) + start) vars, count + l) in
  let (bss_vars, bss_words) =
	List.fold_left (alloc_vars 1 1) (StringMap.empty, 0) globals in
  let translate_cond_expression cond =
      [ Cmp("%ebx", "%eax");
        Move("$0", "%eax");
        Move("$1", "%ebx");
        CondMove(cond, "%ebx", "%eax") ] in
  let rec translate_expr env = function
      Literal(l) -> [ Move("$" ^ string_of_int l, "%eax") ]
    | VarRef(v) ->
		(match v with
		  IntRef(s) ->
			(try let (t, o) = (StringMap.find s env) in
			if ArrayType = t then
			  raise (Failure ("attempt to use array variable " ^ s ^ " as scalar"))
			else [ Move(string_of_int (-4 * o) ^ "(%ebp)", "%eax") ]
			with Not_found ->
			  (try let (t, o) = (StringMap.find s bss_vars) in
			  if ArrayType = t then
				raise (Failure ("attempt to use array variable " ^ s ^ " as scalar"))
			  else [ Move("$global_buffer", "%edi");
					 Move(string_of_int (4 * o) ^ "(%edi)", "%eax") ]
			  with Not_found -> raise (Failure ("undefined variable " ^ s))))
		  | ArrayRef(s, e) ->
			(try let (t, o) = (StringMap.find s env) in
			if IntType = t then
			  raise (Failure ("attempt to use scalar variable " ^ s ^ " as an array"))
			else (translate_expr env e) @
			  [ MulValues("$-4", "%eax");
				AddValues("%ebp", "%eax");
				AddValues("$" ^ string_of_int (-4 * o), "%eax");
				Move("%eax", "%ebx");
				Move("0(%ebx)", "%eax"); ]
			with Not_found ->
			  (try let (t, o) = (StringMap.find s bss_vars) in
			  if IntType = t then
				raise (Failure ("attempt to use scalar variable " ^ s ^ " as an array"))
			  else
				(translate_expr env e) @
				[ MulValues("$4", "%eax");
				  Move("$global_buffer", "%ebx");
				  AddValues("%ebx", "%eax");
				  AddValues("$" ^ string_of_int (4 * o), "%eax");
				  Move("%eax", "%ebx");
				  Move("0(%ebx)", "%eax"); ]
			  with Not_found -> raise (Failure ("undefined variable " ^ s)))))
    | Binop(e1, o, e2) ->
        (translate_expr env e2) @
        Push("%eax") ::
        (translate_expr env e1) @
        Pop("%ebx") ::
        (match o with
          Add -> [ AddValues("%ebx", "%eax") ]
        | Sub -> [ SubValues("%ebx", "%eax") ]
        | Mult -> [ MulValues("%ebx", "%eax") ]
        | Div -> [ Move("%eax", "%edx"); Sarl("$31", "%edx"); DivValues("%ebx") ]
        | Equal -> translate_cond_expression CEqual
        | NotEqual -> translate_cond_expression CNotEqual
        | LessThan -> translate_cond_expression CLessThan
        | GreaterThan -> translate_cond_expression CGreaterThan
        | LessEqual -> translate_cond_expression CLessEqual
        | GreaterEqual -> translate_cond_expression CGreaterEqual)
    | Assign(v, e) ->
		(translate_expr env e) @
		(match v with
		  IntRef(s) ->
			(try let (t, o) = (StringMap.find s env) in
			if ArrayType = t then
			  raise (Failure ("attempt to use array variable " ^ s ^ " as a scalar"))
			else [ Move("%eax", string_of_int (-4 * o) ^ "(%ebp)") ]
			with Not_found ->
			  (try let (t, o) = (StringMap.find s bss_vars) in
			  if ArrayType = t then
				raise (Failure ("attempt to use array variable " ^ s ^ " as scalar"))
			  else [ Move("$global_buffer", "%edi");
					 Move("%eax", string_of_int (4 * o) ^ "(%edi)" ) ]
			  with Not_found -> raise (Failure ("undefined variable " ^ s))))
		| ArrayRef(s, e) ->
			Push("%eax") :: (translate_expr env e) @
			(try let (t, o) = (StringMap.find s env) in
			if IntType = t then
			  raise (Failure ("attempt to use scalar variable " ^ s ^ " as an array"))
			else
			  [ MulValues("$-4", "%eax");
				AddValues("%ebp", "%eax");
				AddValues("$" ^ string_of_int (-4 * o), "%eax");
				Move("%eax", "%ebx");
				Pop("%eax");
				Move("%eax", "0(%ebx)"); ]
			with Not_found ->
			  (try let (t, o) = (StringMap.find s bss_vars) in
			  if IntType = t then
				raise (Failure ("attempt to use scalar variable " ^ s ^ " as an array"))
			  else
				[ MulValues("$4", "%eax");
				  Move("$global_buffer", "%ebx");
				  AddValues("%ebx", "%eax");
				  AddValues("$" ^ string_of_int (4 * o), "%eax");
				  Move("%eax", "%ebx");
				  Pop("%eax");
				  Move("%eax", "0(%ebx)"); ]
			  with Not_found -> raise (Failure ("undefined variable " ^ s)))))
    | Call(f, el) ->
		(try let finfo = StringMap.find f func_info in
		if finfo.parameter_count == List.length el then
		  (* Evaluate and push the arguments, and do it in reverse order. *)
		  List.fold_left accumulate_statements []
			(List.map (fun exp -> exp @ [Push("%eax")])
			   (List.map (translate_expr env) (List.rev el))) @
		  (* Call the function. *)
		  [ Invoke(f);
			(* Pop the stack for the number of arguments pushed. *)
			AddLiteral(4 * List.length el, "esp"); ]
		else raise (Failure ("call to function " ^ f ^ " which takes " ^
							 string_of_int finfo.parameter_count ^ " args with " ^
							 string_of_int (List.length el) ^ " args"))
		with Not_found -> raise (Failure ("call to undeclared function " ^ f)))
	| Data(s) -> [ Move(s, "%eax") ]
    | Noexpr -> [] in
  let rec translate_stmt env = function
      Block(stmts) -> translate_stmts env stmts
    | Expression(expr) -> translate_expr env expr
    | Return(expr) ->
		translate_expr env expr @
		[ Move("%ebp", "%esp"); Pop("%ebp"); Ret ]
    | If(e, s1, s2) ->
		let num = CounterHash.find counters "cond_label" in
		let false_label = "cond_label_" ^ string_of_int num in
		let true_label = "cond_label_" ^ string_of_int (num + 1) in
		CounterHash.replace counters "cond_label" (num + 2);
		translate_expr env e @
		[ Cmp("$0", "%eax") ;
		  CondJump(false_label); ] @
		translate_stmt env s1 @
		Jump(true_label) ::
		Label(false_label) ::
		translate_stmt env s2 @
		[ Label(true_label) ]
    | While(e, s) ->
		let num = CounterHash.find counters "loop_label" in
		let start_label = "loop_label_" ^ string_of_int num in
		let done_label = "loop_label_" ^ string_of_int (num + 1) in
		CounterHash.replace counters "loop_label" (num + 2);
		Label(start_label) ::
		translate_expr env e @
		[ Cmp("$0", "%eax");
		  CondJump(done_label) ] @
		translate_stmt env s @
		[ Jump(start_label);
		  Label(done_label) ]
  and translate_stmts env stmts =
	List.fold_left (fun a l -> a @ l) [] (List.map (translate_stmt env) stmts) in
  let translate_fdecl fdecl =
	let (parameter_vars, _) =
	  List.fold_left (alloc_vars (-2) (-1)) (StringMap.empty, 0) fdecl.formals in
	let (local_vars, local_words) =
	  List.fold_left (alloc_vars 1 1) (parameter_vars, 0) fdecl.locals in
	FuncHeader(fdecl.fname) ::
	(* Save the stack pointer in %ebp, for accessing data, and preserve the old
	   value of %ebp *)
	Push("%ebp") :: Move("%esp", "%ebp") ::
	(* Add space for local vars. *)
	AddLiteral(-4 * local_words, "esp") ::
	translate_stmts local_vars fdecl.body @
    (* Restore the %ebp and stack pointer, then return. *)
    [ Move("%ebp", "%esp"); Pop("%ebp"); Ret ] in
  let func_defs = List.map translate_fdecl functions in
  let asm = List.fold_left accumulate_statements [] func_defs in
  let text = List.map string_of_asm asm in
  ".section .data\n" ^
  "output:\n" ^
  "    .asciz \"%d\\n\"\n" ^
  (if bss_words > 0 then
	".section .bss\n" ^
	"    .lcomm global_buffer, " ^ string_of_int (4 * bss_words) ^ "\n"
  else
	"") ^
  ".section .text\n" ^
  ".globl _start\n" ^
  "_start:\n" ^
  "    call main\n" ^
  "    pushl $0\n" ^
  "    call exit\n\n" ^
  String.concat "" text
