open Ast
open Parser

let rec string_join delim = function
  [] -> ""
| word :: [] -> word
| word :: words -> word ^ delim ^ (string_join delim words);;

let string_of_binop op =
  match op with
  Add -> "+"
| Sub -> "-"
| Mul -> "*"
| Div -> "/"
| Mod -> "%"
(* Not implemented
| ElAdd -> ".+"
| ElSub -> ".-"
| ElMul -> ".*"
| ElDiv -> "./"
| ElMod -> ".%"
*)
| Eq -> "=="
| Neq -> "!="
| Leq -> "<="
| Geq -> ">="
| Lt -> "<"
| Gt -> ">"
| And -> "and"
| Or -> "or"
| Xor -> "xor";;
(* Not implemented
| BitAnd -> "&&"
| BitOr -> "||"
| BitXor -> "*|"
| BitLs -> "<<"
| BitRs -> ">>";;
*)

let string_of_unop o =
  match o with
  Neg -> "-"
| Not -> "not"
(* Not implemented
| BitNot -> "!"
*);;

let rec string_of_datatype = function
  Int -> "int"
| Float -> "float"
| Char -> "char"
| Bool -> "bool"
| String -> "String"
| Tup(datatype) -> "Tup<" ^ string_of_datatype datatype ^ ">"
| Fun -> "fun"
| List(datatype) -> "List<" ^ string_of_datatype datatype ^ ">"
| Matrix(datatype) -> "Matrix<" ^ string_of_datatype datatype ^ ">";;

let string_of_scalar s =
  match s with
  LitInt(i) -> string_of_int i
| LitFloat(f) -> string_of_float f;;

let rec string_of_arg a = 
  match a with
  ArgId(s, dtype) -> s ^ " : " ^ string_of_datatype dtype
(* Not implemented
| ArgFunId(f, ftype) -> f ^ " : " ^ string_of_fun_type ftype
*)

and string_of_expr e =
  match e with
  NumLit(scalar) -> string_of_scalar scalar
| CharLit(c) -> Char.escaped c
| StringLit(s) -> "\"" ^ s ^ "\""
| BoolLit(b) -> string_of_bool(b)
| TupLit(exprs) -> "(" ^ string_join "," (List.map string_of_expr exprs) ^ ")"
| ListLit(exprs) -> "{" ^ string_join "," (List.map string_of_expr exprs) ^ "}"
| FunLit(fun_type, stmts) -> string_of_fun_type fun_type ^ " {" ^ string_join "\n" (List.map string_of_stmt stmts) ^ "}"
| MatrixLit(expr_rows) -> "[" ^ string_join ";" (List.map (string_join ",") (List.map (List.map string_of_expr) expr_rows)) ^ "]"
(* Not implemented
| MatrixFunDef(num_rows, num_cols, func_item) -> "Matrix((" ^ string_of_expr num_rows ^ "," ^ string_of_expr num_cols ^ "), " ^ string_of_func_item func_item ^ ")"
*)
| MatrixInit(num_rows, num_cols) -> "Matrix(" ^ string_of_expr num_rows ^ "," ^ string_of_expr num_cols ^ ")"
| BinOp(expr_1, expr_2, bop) -> string_of_expr(expr_1) ^ " " ^ string_of_binop bop ^ " " ^ string_of_expr(expr_2)
| UnOp(expr,unop) -> string_of_unop unop ^ " " ^ string_of_expr(expr)
| Id(s) -> s
| Attribute(obj, attr) -> obj ^ "." ^ attr
| Call(s, exprs) -> s ^ string_of_expr (TupLit(exprs))
(*
| Pipe(expr_1, expr_2) -> string_of_expr(expr_1) ^ " | " ^ string_of_expr(expr_2)
*)
| MatAcc(id, ind_1, ind_2) -> id ^ "[" ^ string_of_mat_index ind_1 ^ "]" ^ "[" ^ string_of_mat_index ind_2 ^ "]"
| VecAcc(id, ind) -> id ^ "[" ^ string_of_mat_index ind ^ "]" 

and string_of_mat_index i =
  match i with
  MatIndex(expr) -> string_of_expr expr
| MatSlice(expr_1, expr_2) -> string_of_expr expr_1 ^ ":" ^ string_of_expr expr_2

and string_of_return_type t =
  match t with
  (*
    ReturnFun(f) -> string_of_fun_type f
  *)
  | ReturnData(d) -> string_of_datatype d

and string_of_fun_type t = 
  match t with
    FunType(args, ret_type) -> "(" ^ string_join "," (List.map string_of_arg args) ^ ") -> " ^ string_of_return_type ret_type

and string_of_func_item f =
  match f with
  FunId(s) -> s 
(* Not implemented
| FunItem(fun_type, stmts) -> string_of_expr (FunLit(fun_type, stmts))
*)

and string_of_stmt  s = 
  match s with
  Decl(d) -> string_of_vardecl d
| Assign(v) -> v.assign_name ^ " = " ^ string_of_expr v.new_val ^ ";"
| VecAssign(v) -> string_of_expr (VecAcc(v.vec_name, v.index)) ^ " = " ^ string_of_expr v.vec_el_val ^ ";"
| MatAssign(v) -> string_of_expr (MatAcc(v.mat_name, v.row_index, v.col_index)) ^ " = " ^ string_of_expr v.mat_el_val ^ ";"
| Return(e) -> "return " ^ string_of_expr e ^ ";"
| If(expr, stmts_1, stmts_2) -> "if (" ^ string_of_expr expr ^ ") {\n" ^
    string_join "\n" (List.map string_of_stmt stmts_1) ^ "\n} else {\n" ^
    string_join "\n" (List.map string_of_stmt stmts_2) ^ "\n}"
| For(var, it, stmts) -> "for (" ^ var ^ " in " ^ string_of_iterable it ^ ") {\n " ^
    string_join "\n" (List.map string_of_stmt stmts) ^ "\n}"

and string_of_iterable i =
  match i with
  ItId(s) -> s
| ItListLit(expr_list) -> string_join ";\n" (List.map string_of_expr expr_list)
| ItMatrixLit(m) -> string_of_expr (MatrixLit(m))
(*
| ItMatrixFunDef(expr_1, expr_2, func_item) -> "Matrix((" ^ string_of_expr expr_1 ^ "," ^ string_of_expr expr_2 ^ "), " ^ string_of_func_item func_item ^")"
*)
| ItCall(s, e_list) -> s ^ "(" ^ string_join "," (List.map string_of_expr e_list) ^ ")"
| ItAttribute(id_1, id_2) -> id_1 ^ "." ^ id_2

and string_of_vardecl v =
  string_of_datatype v.return_type ^ " " ^ v.var_name ^  " = " ^ string_of_expr v.body ^ ";";;

let string_of_program p = string_join "\n\n" (List.map string_of_stmt p);;

let string_of_token t =
  match t with
  STRING_LIT(_) -> "STRING_LIT"
| CHAR_LIT(_) -> "CHAR_LIT"
| PLUS -> "PLUS"
| MINUS -> "MINUS"
| MULTIPLY -> "MULTIPLY"
| DIVIDE -> "DIVIDE"
| MODULO -> "MODULO"
(* Not implemented
| ELEM_PLUS -> "ELEM_PLUS"
| ELEM_MINUS -> "ELEM_MINUS"
| ELEM_MULTIPLY -> "ELEM_MULTIPLY"
| ELEM_DIVIDE -> "ELEM_DIVIDE"
| ELEM_MODULO -> "ELEM_MODULO"
*)
| EQ -> "EQ"
| NEQ -> "NEQ"
| GEQ -> "GEQ"
| LEQ -> "LEQ"
| GT -> "GT"
| LT -> "LT"
| AND -> "AND"
| OR -> "OR"
| XOR -> "XOR"
| NOT -> "NOT"
(* Not supported
| BIT_AND -> "BIT_AND"
| BIT_OR -> "BIT_OR"
| BIT_XOR -> "BIT_XOR"
| BIT_LS -> "BIT_LS"
| BIT_RS -> "BIT_RS"
| BIT_NOT -> "BIT_NOT"
*)
| ASSIGN -> "ASSIGN"
| PRODUCES -> "PRODUCES"
| PIPE -> "PIPE"
| LBRACE -> "LBRACE"
| RBRACE -> "RBRACE"
| LBRACKET -> "LBRACKET"
| RBRACKET -> "RBRACKET"
| LPAREN -> "LPAREN"
| RPAREN -> "RPAREN"
| SEMI -> "SEMI"
| COLON -> "COLON"
| COMMA -> "COMMA"
| PERIOD -> "PERIOD"
| FOR -> "FOR"
| IN -> "IN"
| IF -> "IF"
| ELSE -> "ELSE"
| RETURN -> "RETURN"
| INT -> "INT"
| FLOAT -> "FLOAT"
| CHAR -> "CHAR"
| FUN -> "FUN"
| STRING -> "STRING"
| TUP -> "TUP"
| LIST -> "LIST"
| MATRIX -> "MATRIX"
| FLOAT_LIT(_) -> "FLOAT_LIT"
| INT_LIT(_) -> "INT_LIT"
| TRUE -> "TRUE"
| FALSE -> "FALSE"
| ID(_) -> "ID"
| _ -> "![INVALID TOKEN]";;