(*
File: AST.ML
Description: Creates AST from parser
*)
type op = Add | Sub | Mult | Div | Equal | Neq | Less | Leq | Greater | Geq |
          And | Or | Mod

type uop = Neg | Not


type typ = Int | Bool | Void | Double | String | Image
          | Matrix | DimMatrix of int * int

type bind = typ * string

type expr =
    IntLit of int
  | StrLit of string
  | DblLit of float
  | BoolLit of bool
  | Id of string
  | Binop of expr * op * expr
  | Unop of uop * expr
  | Assign of string * expr
  | Call of string * expr list
  | Noexpr
  | Noassign of typ
  | MatLit of expr list list
  | MatAccess of string * expr * expr
  | ImageLit of string * string * string
  | ImageRedAccess of string (* Can only imageaccess on a image id *)
  | ImageGreenAccess of string
  | ImageBlueAccess of string
  | MatrixRowSize of string
  | MatrixColSize of string
  | Cast of typ * expr

type stmt =
    Block of stmt list
  | Expr of expr
  | Return of expr
  | If of expr * stmt * stmt
  | For of expr * expr * expr * stmt
  | While of expr * stmt
  | Local of typ * string * expr

type func_decl = {
    typ : typ;
    fname : string;
    formals : bind list;
    body : stmt list;
  }

type program = func_decl list

(* Pretty-printing functions *)

let string_of_op = function
    Add -> "+"
  | Sub -> "-"
  | Mult -> "*"
  | Div -> "/"
  | Equal -> "=="
  | Neq -> "!="
  | Less -> "<"
  | Leq -> "<="
  | Greater -> ">"
  | Geq -> ">="
  | And -> "&&"
  | Or -> "||"
  | Mod -> "%"


let string_of_uop = function
    Neg -> "-"
  | Not -> "!"

let string_of_typ = function
    Int -> "int"
  | Bool -> "bool"
  | Void -> "void"
  | Double -> "double"
  | String -> "string"
  | Image -> "image"
  | Matrix -> "matrix"
  | DimMatrix(l1, l2)-> "matrix" ^ "[" ^ string_of_int l1 ^ ", " ^ string_of_int l2 ^ "]"

let rec string_of_expr = function
    IntLit(l) -> string_of_int l
  | StrLit(s) -> s
  | DblLit(l) -> string_of_float l
  | BoolLit(true) -> "true"
  | BoolLit(false) -> "false"
  | Id(s) -> s
  | Binop(e1, o, e2) ->
      string_of_expr e1 ^ " " ^ string_of_op o ^ " " ^ string_of_expr e2
  | Unop(o, e) -> string_of_uop o ^ string_of_expr e
  | Assign(v, e) -> v ^ " = " ^ string_of_expr e
  | Call(f, el) ->
      f ^ "(" ^ String.concat ", " (List.map string_of_expr el) ^ ")"
  | Noexpr -> ""
  | Noassign(t) -> string_of_typ t
  | MatLit(ell) -> (* TEST *)
      "[" ^ String.concat "; "
      (List.map (fun d -> String.concat ", " (List.map (* string_of_float *) string_of_expr d)) ell) ^ "]"
  | MatAccess(s, l1, l2) -> s ^ "[" ^ string_of_expr l1 ^ ", " ^ string_of_expr l2 ^ "]"
  | ImageLit(m1, m2, m3) -> "image(" ^ m1 ^ ", " ^ m2 ^ ", " ^ m3 ^ ")"
  | ImageRedAccess(s) -> s ^ ".red"
  | ImageGreenAccess(s) -> s ^ ".green"
  | ImageBlueAccess(s) ->  s ^ ".blue"
  | MatrixRowSize(s) -> s ^ ".rowsize"
  | MatrixColSize(s) -> s ^ ".colsize"
  | Cast(t, e) -> "(" ^ string_of_typ t ^ ")" ^ string_of_expr e


let rec string_of_stmt = function
    Block(stmts) ->
      "{\n" ^ String.concat "" (List.map string_of_stmt stmts) ^ "}\n"
  | Expr(expr) -> string_of_expr expr ^ ";\n";
  | Return(expr) -> "return " ^ string_of_expr expr ^ ";\n";
  | If(e, s, Block([])) -> "if (" ^ string_of_expr e ^ ")\n" ^ string_of_stmt s
  | If(e, s1, s2) ->  "if (" ^ string_of_expr e ^ ")\n" ^
      string_of_stmt s1 ^ "else\n" ^ string_of_stmt s2
  | For(e1, e2, e3, s) ->
      "for (" ^ string_of_expr e1  ^ " ; " ^ string_of_expr e2 ^ " ; " ^
      string_of_expr e3  ^ ") " ^ string_of_stmt s
  | While(e, s) -> "while (" ^ string_of_expr e ^ ") " ^ string_of_stmt s
  | Local(t, s, e) -> if string_of_expr e = ""
      then string_of_typ t ^ " " ^ s ^ ";\n"(* Noassign case *)
      else string_of_typ t ^ " " ^ s ^ " = " ^ string_of_expr e ^ ";\n" (* Local assign case *)

let string_of_fdecl fdecl =
  "func " ^ fdecl.fname ^ "(" ^ String.concat ", " (List.map snd fdecl.formals) ^
  ") -> " ^ string_of_typ fdecl.typ
  ^ "\n{\n" ^
  String.concat "" (List.map string_of_stmt fdecl.body) ^
  "}\n"

let string_of_program funcs =
  String.concat "\n" (List.map string_of_fdecl funcs)
