(* Abstract Syntax Tree and functions for printing it *)

type op = Add | Sub | Mult | Div | Equal | Neq | Less | Leq | Greater | Geq |
          And | Or | Mod

type uop = Neg | Not

type aop = Asn | ModAsn

(* TODO: Remove String, replace with Char, do strings with Charptr *)
type typ = Int | Size_t | Bool | Char | String | Void | Struct of string | Pointer of typ
(* Array of typ * int *)

type expr =
    Literal of int
  | SizeLit of int64
  | StringLit of string
  | CharLit of char
  | BoolLit of bool
  | Id of string
  | Binop of expr * op * expr
  | Unop of uop * expr
  | Assign of expr * aop * expr
  | Call of expr * expr list
  | ArrayAccess of expr * expr
  | StructAccess of expr * string
  | StructPointerAccess of expr * string
  | BuildArray of string * expr
  | BuiltInCall of string * expr list
  | Cast of typ * expr
  | Address of expr
  | Dereference of expr
  | Sizeof of typ
  | Nullexpr
  | Noexpr

type stmt =
    Block of stmt list
  | Expr of expr
  | Return of expr
  | If of expr * stmt * stmt
  | For of expr * expr * expr * stmt
  | While of expr * stmt

(* typ - type
   string - name
   expr - value
   bool - flag for delaying evaluation *)
type bind = typ * string * expr * bool

type func_decl = {
    typ : typ;
    fname : string;
    formals : bind list;
    locals : bind list;
    body : stmt list;
  }

type struct_decl = {
    members: bind list;
    struct_name: string;
  }

type program = {
    globals: bind list;
    functions: func_decl list;
    structs: struct_decl list;
}

(* flag - is this variable lazy?
   eval - have we yet evaluated it?
   expr - the expr to evaluate *)
type lazy_record = {
  name : string;
  flag : bool;
  mutable eval : bool;
  expr : expr;
}

(* Pretty-printing functions *)
(* TODO: add pretty print for strings, pointers, structs *)

let rec string_of_typ = function
    Int -> "int"
  | Size_t -> "size_t"
  | String -> "string"
  | Bool -> "bool"
  | Char -> "char"
  | Void -> "void"
  | Struct(id) -> id
  | Pointer(t) -> string_of_typ(t) ^ "*"

let string_of_op = function
    Add -> "+"
  | Sub -> "-"
  | Mult -> "*"
  | Div -> "/"
  | Equal -> "=="
  | Neq -> "!="
  | Less -> "<"
  | Leq -> "<="
  | Greater -> ">"
  | Geq -> ">="
  | And -> "&&"
  | Or -> "||"
  | Mod -> "%"

let string_of_uop = function
    Neg -> "-"
  | Not -> "!"

let string_of_aop = function
    Asn -> "="
  | ModAsn -> "%="

let rec string_of_expr = function
    Literal(l) -> string_of_int l
  | SizeLit(n) -> Int64.to_string n
  | StringLit(s) -> s
  | CharLit(c) -> String.make 1 c
  | BoolLit(true) -> "true"
  | BoolLit(false) -> "false"
  | Id(s) -> s
  | Binop(e1, o, e2) ->
      string_of_expr e1 ^ " " ^ string_of_op o ^ " " ^ string_of_expr e2
  | Unop(o, e) -> string_of_uop o ^ string_of_expr e
  | Assign(v, o, e) -> string_of_expr v ^ string_of_aop o ^ string_of_expr e
(*  | DerefAssign(p, e) -> p ^ " = " ^ string_of_expr e *)
  | BuiltInCall(f, el) ->
    f ^ "(" ^ String.concat ", " (List.map string_of_expr el) ^ ")"
  | Call(f, el) ->
    (string_of_expr f) ^ "(" ^ String.concat ", " (List.map string_of_expr el) ^ ")"
  | ArrayAccess(a, i) ->
    (string_of_expr a) ^ "[" ^ string_of_expr i ^ "]"
  | StructAccess(s, n) ->
    (string_of_expr s) ^ "." ^ n
  | StructPointerAccess(s, n) ->
    (string_of_expr s) ^ "->" ^ n
  | BuildArray(a, n) ->
    a ^ "[" ^ string_of_expr n ^ "]"
  | Cast(t, e) -> "<" ^ (string_of_typ t) ^ "> " ^ string_of_expr e
  | Sizeof(t) -> "sizeof(" ^ (string_of_typ t) ^ ")"
  | Address(v) -> "&" ^ string_of_expr v
  | Dereference(p) -> "*" ^ string_of_expr p
  | Noexpr -> ""
  | Nullexpr -> "NULL"

let rec string_of_stmt = function
    Block(stmts) ->
      "{\n" ^ String.concat "" (List.map string_of_stmt stmts) ^ "}\n"
  | Expr(expr) -> string_of_expr expr ^ ";\n";
  | Return(expr) -> "return " ^ string_of_expr expr ^ ";\n";
  | If(e, s, Block([])) -> "if (" ^ string_of_expr e ^ ")\n" ^ string_of_stmt s
  | If(e, s1, s2) ->  "if (" ^ string_of_expr e ^ ")\n" ^
      string_of_stmt s1 ^ "else\n" ^ string_of_stmt s2
  | For(e1, e2, e3, s) ->
      "for (" ^ string_of_expr e1  ^ " ; " ^ string_of_expr e2 ^ " ; " ^
      string_of_expr e3  ^ ") " ^ string_of_stmt s
  | While(e, s) -> "while (" ^ string_of_expr e ^ ") " ^ string_of_stmt s

let string_of_vdecl (t, id, value, _) = string_of_typ t ^ " " ^ id ^ " " ^ (string_of_expr value) ^ ";\n"

let string_of_fdecl fdecl =
  string_of_typ fdecl.typ ^ " " ^
  fdecl.fname ^ "(" ^ String.concat ", " (List.map (fun (_,x,_,_) -> x) fdecl.formals) ^
  ")\n{\n" ^
  String.concat "" (List.map string_of_vdecl fdecl.locals) ^
  String.concat "" (List.map string_of_stmt fdecl.body) ^
  "}\n"

let string_of_sdecl sdecl =
  "struct" ^ sdecl.struct_name ^ String.concat "{\n" (List.map string_of_vdecl sdecl.members) ^ "\n}\n"

let string_of_program prg =
  String.concat "" (List.map string_of_vdecl prg.globals) ^ "\n" ^
  String.concat "\n" (List.map string_of_fdecl prg.functions) ^
  String.concat "\n" (List.map string_of_sdecl prg.structs)
