(* Abstract Syntax Tree and functions for printing it *) 

type op = Add | Sub | Mult | Div | Eq | Neq | Mod | Geq | Leq | 
          Greater | Less | And | Or

type uop = Not | Neg

type typ = Int | Bool | Float | String | Node of typ | List of typ | Void | Any

type bind = typ * string

type expr = 
    Id of string
  | Lit_Int of int
  | Lit_Flt of string
  | Lit_Str of string
  | Lit_Bool of bool
  | Lit_List of expr list
  | Lit_Node of expr
  | Binop of expr * op * expr
  | Unop of uop * expr
  | Assign of string * expr 
  | Call of string * expr list
  | List_Access of expr * expr
  | Attr of expr * string
  | Noexpr


type stmt =
    Block of stmt list
  | Expr of expr
  | Return of expr
  | If of expr * stmt * stmt
  | For of expr * expr * expr * stmt
  | While of expr * stmt  
  | Declare of typ * string * expr

type func_decl = {
  typ : typ;
  fname : string;
  formals : bind list;
  body : stmt list;
}

type program = bind list * func_decl list

(*Pretty Print for Debugging purposes - taken from microc ast.ml and refactored*)

let string_of_op = function
    Add -> "+"
  | Sub -> "-"
  | Mult -> "*"
  | Div -> "/"
  | Eq -> "=="
  | Neq -> "!="
  | Mod -> "%"
  | Less -> "<"
  | Leq -> "<="
  | Greater -> ">"
  | Geq -> ">="
  | And -> "&&"
  | Or -> "||"

let string_of_uop = function
  | Not -> "!"
  | Neg -> "-"

let rec string_of_expr = function
    Lit_Int(l) -> string_of_int l
  | Lit_Flt(l) -> l
  | Lit_Str(s) -> s
  | Lit_Bool(true) -> "true"
  | Lit_Bool(false) -> "false"
  | Lit_List(l) -> "[" ^ String.concat "," (List.map string_of_expr l) ^ "]"  
  | Lit_Node(e) -> "'" ^ string_of_expr e ^ "'"
  | Attr(e, a) -> string_of_expr e ^ "." ^ a
  | List_Access(l, e) -> string_of_expr l ^ "[" ^ string_of_expr e ^ "]"
  (*| Lit_Char(c) -> String.make 1 c*)
  | Id(s) -> s
  | Binop(e1, o, e2) ->
      string_of_expr e1 ^ " " ^ string_of_op o ^ " " ^ string_of_expr e2
  | Unop(o, e) -> string_of_uop o ^ string_of_expr e
  | Assign(v, e) -> v ^ " = " ^ string_of_expr e
  | Call(f, el) ->
      f ^ "(" ^ String.concat ", " (List.map string_of_expr el) ^ ")"
  | Noexpr -> ""

  let rec string_of_stmt = function
  Block(stmts) ->
    "{\n" ^ String.concat "" (List.map string_of_stmt stmts) ^ "}\n"
| Expr(expr) -> string_of_expr expr ^ ";\n";
| Return(expr) -> "return " ^ string_of_expr expr ^ ";\n";
| If(e, s, Block([])) -> "if (" ^ string_of_expr e ^ ")\n" ^ string_of_stmt s
| If(e, s1, s2) ->  "if (" ^ string_of_expr e ^ ")\n" ^
    string_of_stmt s1 ^ "else\n" ^ string_of_stmt s2
| For(e1, e2, e3, s) ->
    "for (" ^ string_of_expr e1  ^ " ; " ^ string_of_expr e2 ^ " ; " ^
    string_of_expr e3  ^ ") " ^ string_of_stmt s
| While(e, s) -> "while (" ^ string_of_expr e ^ ") " ^ string_of_stmt s
| Declare(t, id, a) -> (match a with Noexpr -> id | _ -> string_of_expr a) ^ ";\n"

let rec string_of_typ = function
    Any -> "*"
  |  Int -> "int"
  | Bool -> "bool"
  | Float -> "float"
  | String -> "string"
  | Void -> "void"
  | Node(t) -> "node <" ^ string_of_typ t ^ ">"
  | List(l) -> "list" ^ "<" ^ string_of_typ l ^ ">"


let string_of_vdecl (t, id) = string_of_typ t ^ " " ^ id ^ ";\n"

let string_of_fdecl fdecl =
  string_of_typ fdecl.typ ^ " " ^
  fdecl.fname ^ "(" ^ String.concat ", " (List.map snd fdecl.formals) ^
  ")\n{\n" ^
  String.concat "" (List.map string_of_stmt fdecl.body) ^
  "}\n"

let string_of_program (vars, funcs) =
  String.concat "" (List.map string_of_vdecl vars) ^ "\n" ^
  String.concat "\n" (List.map string_of_fdecl funcs)