(*
Author: Mingye Chen
Copyright 2018, MathLight
*)

type operator = Add | Sub | Mult | Div | Pow
              | Dotmul | Dotdiv | Dotpow
              | Greater | Geq | Leq | Less
              | Neq | Equal
              | And | Or
              | Comma | Semi

type unary_operator = Not | Neg | Abs | Transpose

type data_type = Int | Double | String | Void | Bool | Matrix

type expr =
    Binop of expr * operator * expr
  | Unop of unary_operator * expr
  | IntLit of int
  | DoubleLit of float
  | StrLit of string
  | BoolLit of bool
  | Range of expr * expr
  | MatrixLit of float array array
  | MatrixOp of string * operator * string
  | Matrix1DElement of string * expr
  | Matrix2DElement of string * expr * expr
  | Matrix1DModify of string * expr * expr
  | Matrix2DModify of string * (expr * expr) * expr
  | Id of string
  | Assign of string * expr
  | Call of string * expr list
  | Noexpr

type bind = data_type * string * (int * int) * expr

type statement =
	  Block of statement list
  | Expr of expr
  | If of expr * statement * statement
  | For of expr * expr * expr * statement
  | ForRange of string * expr * statement
  | While of expr * statement
  | Continue of expr
  | Break of expr
  | Return of expr

type function_declare = {
  data_type : data_type;
	function_name : string;
	arguments : bind list;
  local_vars : bind list;
	body : statement list;
}

type program = bind list * function_declare list

(* Parser helper functions *)

let check_size_matrix_return_bind data_type variable_name matrix_size expr =
  match data_type with
    Matrix -> (data_type, variable_name, matrix_size, expr)
    | _ -> failwith("only matrix type can declare its size")

let check_size_normal_return_bind data_type variable_name expr =
  match data_type with
    Matrix -> failwith("should assign the size for matrix type")
    | _ -> (data_type, variable_name, (-1, -1), expr)

(* Pretty-printing functions *)

let string_of_op = function
    Add -> "+"
  | Sub -> "-"
  | Mult -> "*"
  | Div -> "/"
  | Equal -> "=="
  | Neq -> "!="
  | Less -> "<"
  | Leq -> "<="
  | Greater -> ">"
  | Geq -> ">="
  | And -> "&&"
  | Or -> "||"
  | Comma -> ","
  | Semi -> ":"
  | Pow -> "^"
  | Dotmul -> ".*"
  | Dotdiv -> "./"
  | Dotpow -> ".^"

let string_of_uop = function
    Neg -> "-"
  | Not -> "!"
  | Abs -> "| |(Abs)"
  | Transpose -> "'(Transpose)"

let rec string_of_expr = function
    IntLit(l) -> string_of_int l
  | DoubleLit(l) -> string_of_float l
  | BoolLit(true) -> "true"
  | BoolLit(false) -> "false"
  | StrLit(s) -> s
  | Id(s) -> s
  | Binop(e1, o, e2) ->
      string_of_expr e1 ^ " " ^ string_of_op o ^ " " ^ string_of_expr e2
  | Unop(o, e) -> string_of_uop o ^ string_of_expr e
  | Assign(v, e) -> v ^ " = " ^ string_of_expr e
  | Call(f, el) ->
      f ^ "(" ^ String.concat ", " (List.map string_of_expr el) ^ ")"
  | Noexpr -> ""
  | MatrixLit(_) -> "MatrixLit"
  | Matrix1DElement(s, e) -> "Matrix1DElement " ^ s ^ "[" ^ string_of_expr(e) ^ "]"
  | Matrix2DElement(s, e1, e2) ->
      "SMatrix2DElement " ^ s ^ "[" ^ string_of_expr(e1) ^ ", " ^ string_of_expr(e2) ^ "]"
  | Range(e1, e2) -> "Range: " ^ string_of_expr(e1) ^ ": " ^ string_of_expr(e2)
  | MatrixOp(m1, o, m2) -> "MatrixOp " ^ m1 ^ string_of_op(o) ^ m2
  | Matrix1DModify(s, e1, e2) ->
      s ^ "[" ^ string_of_expr(e1) ^ "]" ^ " = " ^ string_of_expr(e2)
  | Matrix2DModify(s, (e1,e2), e3) ->
      s ^ "[" ^ string_of_expr(e1) ^ ", " ^ string_of_expr(e2) ^ "]" ^ " = " ^ string_of_expr(e3)

let rec string_of_stmt = function
    Block(stmts) ->
      "{\n" ^ String.concat "" (List.map string_of_stmt stmts) ^ "}\n"
  | Expr(expr) -> string_of_expr expr ^ ";\n";
  | Return(expr) -> "return " ^ string_of_expr expr ^ ";\n";
  | If(e, s, Block([])) -> "if (" ^ string_of_expr e ^ ")\n" ^ string_of_stmt s
  | If(e, s1, s2) ->  "if (" ^ string_of_expr e ^ ")\n" ^
      string_of_stmt s1 ^ "else\n" ^ string_of_stmt s2
  | For(e1, e2, e3, s) ->
      "for (" ^ string_of_expr e1  ^ " ; " ^ string_of_expr e2 ^ " ; " ^
      string_of_expr e3  ^ ") " ^ string_of_stmt s
  | While(e, s) -> "while (" ^ string_of_expr e ^ ") " ^ string_of_stmt s
  | _ -> "other stmt"

let string_of_typ = function
    Int -> "int"
  | Bool -> "bool"
  | Double -> "double"
  | Void -> "void"
  | String -> "string"
  | Matrix -> "matrix"

let string_of_vdecl (t, id, _, expr) = string_of_typ t ^ " " ^ id ^ " " ^ string_of_expr expr ^";\n"

let string_of_fdecl fdecl =
  let sndOfQuadruple = fun (_, y, _, _) -> y in
  string_of_typ fdecl.data_type ^ " " ^
  fdecl.function_name ^ "(" ^ String.concat ", " (List.map sndOfQuadruple fdecl.arguments) ^
  ")\n{\n" ^
  String.concat "" (List.map string_of_vdecl fdecl.local_vars) ^
  String.concat "" (List.map string_of_stmt fdecl.body) ^
  "}\n"

let string_of_program (vars, funcs) =
  String.concat "" (List.map string_of_vdecl vars) ^ "\n" ^
  String.concat "\n" (List.map string_of_fdecl funcs)

let string_of_size (s) =
  let (r, c) = s in
 "(" ^ string_of_int(r) ^ ", " ^ string_of_int(c) ^ ")"
