(* Abstract Syntax Tree and functions for printing it *)

type op = Add | Sub | Mult | Div | Equal | Neq | Less | Leq | Greater | Geq |
          And | Or 

type uop = Neg | Not

type typ =
    Int
  | Double
  | Char
  | Bool
  | Void
  | Vector of typ
  | Job of typ
  | Struct of string
  | NoPtrStruct of string

type bind = typ * string

type expr =
    Literal of int
  | DoubleLit of float
  | BoolLit of bool
  | StringLit of string
  | Id of string
  | Binop of expr * op * expr
  | Unop of uop * expr
  | Assign of string * expr
  | Call of string * expr list
  | RemoteCall of string * expr list
  | Get of string
  | Cancel of string
  | Running of string
  | Finished of string
  | Failed of string
  | ListLiteral of expr list
  | VectorAccess of string * expr list
  | VectorAssign of string * expr list * expr
  | VectorRangeAccess of string * expr * expr
  | VectorSize of expr
  | StructFieldAccess of expr * string
  | StructFieldAssign of expr * string * expr
  | Concat of expr * expr
  | Noexpr

type stmt =
    Block of stmt list
  | Expr of expr
  | Return of expr
  | If of expr * stmt * stmt
  | For of expr * expr * expr * stmt
  | While of expr * stmt
  | VarDecl of typ * string
  | VarDeclAssign of typ * string * expr
  | VectorPushBack of expr * expr

type func_decl = {
    typ : typ;
    fname : string;
    formals : bind list;
    body : stmt list;
}

type struct_type_decl = {
    sname : string;
    blist : bind list;
}

type program = func_decl list * struct_type_decl list


(* Pretty-printing functions *)

let string_of_op = function
    Add -> "+"
  | Sub -> "-"
  | Mult -> "*"
  | Div -> "/"
  | Equal -> "=="
  | Neq -> "!="
  | Less -> "<"
  | Leq -> "<="
  | Greater -> ">"
  | Geq -> ">="
  | And -> "&&"
  | Or -> "||"

let string_of_uop = function
    Neg -> "-"
  | Not -> "!"

let rec string_of_typ = function
    Int -> "int"
  | Double -> "double"
  | Bool -> "bool"
  | Void -> "void"
  | Job(x) -> "job" ^ "<" ^ string_of_typ x ^ ">"
  | Vector(Char) -> "string"
  | Vector(x) -> "vector" ^ "<" ^ string_of_typ x ^ ">"
  | Struct(id) -> "struct " ^ id
  | _ -> "unknown type"

let map_and_concat f list = String.concat ";\n" (List.map f list)

let string_of_vdecl (t, id) = string_of_typ t ^ " " ^ id ^ ";\n"
let string_of_bind (t, id) = string_of_typ t ^ " " ^ id
let string_of_bind_list list = map_and_concat string_of_bind list

let rec string_of_expr = function
    Literal(l) -> string_of_int l
  | DoubleLit(f) ->  string_of_float f
  | StringLit(s) -> s
  | BoolLit(true) -> "true"
  | BoolLit(false) -> "false"
  | Id(s) -> s
  | Binop(e1, o, e2) ->
      string_of_expr e1 ^ " " ^ string_of_op o ^ " " ^ string_of_expr e2
  | Unop(o, e) -> string_of_uop o ^ string_of_expr e
  | Assign(v, e) -> v ^ " = " ^ string_of_expr e
  | ListLiteral(el) -> "[" ^ String.concat ", " (List.map string_of_expr el) ^ "]"
  | Call(f, el) ->
      f ^ "(" ^ String.concat ", " (List.map string_of_expr el) ^ ")"
  | RemoteCall(f, el) ->
      "remote " ^ f ^ "(" ^ String.concat ", " (List.map string_of_expr el) ^ ")"
  | Get(j) -> "get" ^ j
  | Cancel(j) -> "cancel" ^ j
  | Running(j) -> j ^ ".running"
  | Finished(j) -> j ^ ".finished"
  | Failed(j) -> j ^ ".failed"
  | VectorAccess(id, el) -> id ^ String.concat "" (List.map (fun el -> "["^(string_of_expr el)^"]") el)
  | VectorRangeAccess(id, e1, e2) -> id ^ "[" ^ string_of_expr e1 ^ ":" ^ string_of_expr e2 ^ "]"
  | VectorAssign(id, e1_list, e2) -> id ^ String.concat "" (List.map (fun el -> "["^(string_of_expr el)^"]") e1_list)^ "=" ^ string_of_expr e2
  | StructFieldAccess(s, f) -> string_of_expr s ^ "->" ^ f
  | StructFieldAssign(s, f, v) -> string_of_expr s ^ "->" ^ f ^ " = " ^ string_of_expr v
  | VectorSize(e) -> "size " ^ string_of_expr e
  | Concat(e1, e2) -> string_of_expr e1 ^ " << " ^ string_of_expr e2
  | Noexpr -> ""

let rec string_of_stmt = function
    Block(stmts) ->
      "{\n" ^ String.concat "" (List.map string_of_stmt stmts) ^ "}\n"
  | Expr(expr) -> string_of_expr expr ^ ";\n";
  | Return(expr) -> "return " ^ string_of_expr expr ^ ";\n";
  | If(e, s, Block([])) -> "if (" ^ string_of_expr e ^ ")\n" ^ string_of_stmt s
  | If(e, s1, s2) ->  "if (" ^ string_of_expr e ^ ")\n" ^
      string_of_stmt s1 ^ "else\n" ^ string_of_stmt s2
  | For(e1, e2, e3, s) ->
      "for (" ^ string_of_expr e1  ^ " ; " ^ string_of_expr e2 ^ " ; " ^
      string_of_expr e3  ^ ") " ^ string_of_stmt s
  | While(e, s) -> "while (" ^ string_of_expr e ^ ") " ^ string_of_stmt s
  | VarDecl(typ, id) -> string_of_vdecl(typ, id)
  | VarDeclAssign(typ, id, e) -> string_of_vdecl(typ, id) ^ "=" ^ string_of_expr e
  | VectorPushBack(e1, e) -> string_of_expr e1 ^ "@" ^ string_of_expr e

let string_of_fdecl fdecl =
  string_of_typ fdecl.typ ^ " " ^
  fdecl.fname ^ "(" ^ String.concat ", " (List.map snd fdecl.formals) ^
  ")\n{\n" ^
  String.concat "" (List.map string_of_stmt fdecl.body) ^
  "}\n"

let string_of_struct_type_decl sdecl =
  "struct " ^ sdecl.sname ^ " {\n"^ string_of_bind_list sdecl.blist ^ "\n}"

let string_of_program (funcs, structs) =
  String.concat "\n" (List.map string_of_fdecl funcs) ^ String.concat "\n" (List.map string_of_struct_type_decl structs)
