type op =
    Add
  | Subtract
  | Multiply
  | Divide
  | Modulo
  | Equal
  | Neq
  | Less
  | Leq
  | Greater
  | Geq
  | And
  | Or

type literal =
    Int of int
  | Float of float
  | String of string
  | Bool of bool
  | Null

type assignable =
    Id of string
  | ListIndex of expr * expr
and expr =
    Literal of literal
  | Assignable of assignable
  | Binop of expr * op * expr
  | Assign of assignable * expr
  | Call of string * expr list
  | DList of expr list
  | EmptyExpr

type stmt =
    ExprStmt of expr
  | Return of expr
  | Block of stmt list
  | If of expr * stmt * stmt
  | For of stmt * expr * stmt * stmt
  | While of expr * stmt
  | ForEach of string * expr * stmt

type func_decl = {
    name: string;
    args: string list;
    body: stmt list;
}

type kernel_arg =
    Input of string
  | Output of string
  | BasicArg of string

type kernel_decl = {
    kname: string;
    kargs: kernel_arg list;
    kbody: stmt list;
}

type program_part =
    FuncDecl of func_decl
  | KernelDecl of kernel_decl
  | Stmt of stmt
  | Channels of string list

type program = program_part list

(* Pretty printing functions *)
let string_of_op = function
    Add         -> "+"
  | Subtract    -> "-"
  | Multiply    -> "*"
  | Divide      -> "/"
  | Modulo      -> "%"
  | Equal       -> "=="
  | Neq         -> "!="
  | Less        -> "<"
  | Leq         -> "<="
  | Greater     -> ">"
  | Geq         -> ">="
  | And         -> "&&"
  | Or          -> "||"

let rec string_of_literal = function
    Int(i)      -> string_of_int i
  | Float(f)    -> string_of_float f
  | String(s)   -> "\"" ^ s ^ "\""
  | Bool(b)     -> string_of_bool b
  | Null        -> "null"

let rec string_of_expr = function
    Literal(l)          -> string_of_literal l
  | Assignable(a)       -> string_of_assignable a
  | Binop(lhs, op, rhs) ->
        string_of_expr lhs ^ " " ^ string_of_op op ^ " " ^ string_of_expr rhs
  | Assign(lhs, rhs)    ->
        string_of_assignable lhs ^ " = " ^ string_of_expr rhs
  | Call(f, args)       ->
        f ^ "(" ^
        String.concat ", " (List.map string_of_expr (List.rev args)) ^
        ")"
  | DList(exprs)        ->
        "[" ^ String.concat ", " (List.map string_of_expr exprs) ^ "]"
  | EmptyExpr           -> ""
and string_of_assignable = function
    Id(id)              -> id
  | ListIndex(lst, idx) -> string_of_expr lst ^ "[" ^ string_of_expr idx ^ "]"

let rec string_of_stmt = function
    ExprStmt(expr)                  -> string_of_expr expr ^ ";\n"
  | Return(expr)                    -> "return " ^ string_of_expr expr ^ ";\n"
  | Block(stmts)                    ->
        "{\n" ^ String.concat "" (List.map string_of_stmt stmts) ^ "}\n"
  | If(cond, then_stmt, else_stmt)  ->
        "if(" ^ string_of_expr cond ^ ")" ^ string_of_stmt then_stmt ^
        "else\n" ^ string_of_stmt else_stmt ^ "\n"
  | For(start, term, inc, body)     ->
        "for(" ^ string_of_stmt start ^ "; " ^
        string_of_expr(term) ^ "; " ^
        string_of_stmt(inc) ^ ")\n" ^
        string_of_stmt body ^ "\n"
  | While(cond, body)               ->
          "while(" ^ string_of_expr cond ^ ")\n" ^ string_of_stmt body ^ "\n"
  | ForEach(id, lst, body)          ->
          "foreach(" ^ id ^ " : " ^ string_of_expr lst ^ ")\n" ^
          string_of_stmt body ^ "\n"
and string_of_elseif (expr, stmt) =
    "elseif(" ^ string_of_expr expr ^ ")\n{" ^ string_of_stmt stmt ^ "}\n"

let string_of_func_decl f =
    "function " ^ f.name ^ "(" ^ String.concat ", " f.args ^ ")\n{\n" ^
    String.concat "" (List.map string_of_stmt f.body) ^ "}\n\n"

let string_of_kernel_decl k =
    let string_of_kernel_arg = function
        Input(a)    -> "in " ^ a
      | Output(a)   -> "out " ^ a
      | BasicArg(a) -> a
    in
    "kernel " ^ k.kname ^ "(" ^
    String.concat ", " (List.map string_of_kernel_arg k.kargs) ^ ")\n{\n" ^
    String.concat "" (List.map string_of_stmt k.kbody) ^ "}\n\n"

let string_of_program p =
    let string_of_program_part = function
        FuncDecl(f)     -> string_of_func_decl f
      | KernelDecl(k)   -> string_of_kernel_decl k
      | Stmt(s)         -> string_of_stmt s
      | Channels(c)     -> "channel " ^ (String.concat ", " c) ^ ";\n"
    in
    String.concat "" (List.map string_of_program_part p)
