open Ast
open Parser

let rec string_join delim = function
  [] -> ""
| word :: [] -> word
| word :: words -> word ^ delim ^ (string_join delim words);;

let string_of_binop op =
  match op with
  Add -> "+"
| Sub -> "-"
| Mul -> "*"
| Div -> "/"
| Mod -> "%"
| ElAdd -> ".+"
| ElSub -> ".-"
| ElMul -> ".*"
| ElDiv -> "./"
| ElMod -> ".%"
| Eq -> "=="
| Neq -> "!="
| Leq -> "<="
| Geq -> ">="
| And -> "and"
| Or -> "or"
| Xor -> "xor"
| BitAnd -> "&&"
| BitOr -> "||"
| BitXor -> "*|"
| BitLs -> "<<"
| BitRs -> ">>";;

let string_of_unop o =
  match o with
  Neg -> "-"
| Not -> "not"
| BitNot -> "!"
| ToFl -> "tofl"
| ToInt -> "toint"
| Print -> "print";;

let rec string_of_datatype = function
  Int -> "Int"
| Float -> "Float"
| Char -> "Char"
| Bool -> "Bool"
| String -> "String"
| Struct -> "Struct"
| Tup -> "Tup"
| Fun -> "Fun"
| List(datatype) -> "List<" ^ string_of_datatype datatype ^ ">"
| Matrix(datatype) -> "Matrix<" ^ string_of_datatype datatype ^ ">"
| NoType -> "NoType";;

let string_of_scalar s =
  match s with
  LitInt(i) -> string_of_int i
| LitFloat(f) -> string_of_float f;;

let rec string_of_arg a = 
  match a with
  ArgId(s, dtype) -> s ^ " : " ^ string_of_datatype dtype
| ArgAssign(expr_1, dtype, expr_2) -> string_of_expr expr_1 ^ " = " ^ string_of_expr expr_2 ^ " : " ^ string_of_datatype dtype
and string_of_struct_entry e =
  match e with
  IdItem(s) -> s
| IdAssign(expr_1, expr_2) -> string_of_expr(expr_1) ^ " = " ^ string_of_expr(expr_2)
and string_of_expr e =
  match e with
  NumLit(scalar) -> string_of_scalar scalar
| CharLit(c) -> Char.escaped c
| StringLit(s) -> s
| BoolLit(b) -> string_of_bool(b)
| StructLit(entry_list) -> "{" ^ string_join "\n" (List.map string_of_struct_entry entry_list) ^ "}"
| TupLit(exprs) -> "(" ^ string_join "," (List.map string_of_expr exprs) ^ ")"
| ListLit(exprs) -> "[" ^ string_join "," (List.map string_of_expr exprs) ^ "]"
| FunLit(args, datatype, stmts) -> string_join "," (List.map string_of_arg args) ^ " -> " ^ string_of_datatype datatype ^ string_join "\n" (List.map string_of_stmt stmts)
| MatrixLit(expr_rows) -> "[" ^ string_join ";" (List.map (string_join ",") (List.map (List.map string_of_expr) expr_rows)) ^ "]"
| MatrixFunDef(num_rows, num_cols, func_item) -> "Matrix(" ^ string_of_scalar num_rows ^ "," ^ string_of_scalar num_cols ^ "," ^ string_of_func_item func_item ^ ")"
| BinOp(expr_1, expr_2, bop) -> string_of_expr(expr_1) ^ " " ^ string_of_binop bop ^ " " ^ string_of_expr(expr_2)
| UnOp(expr,unop) -> string_of_unop unop ^ " " ^ string_of_expr(expr)
| Id(s) -> s
| Assign(expr_1, expr_2) -> string_of_expr(expr_1) ^ " = " ^ string_of_expr(expr_2)
| Attribute(expr, s) -> string_of_expr(expr) ^ "." ^ s
| Call(s, exprs) -> s ^ string_of_expr (TupLit(exprs))
| Pipe(expr_1, expr_2) -> string_of_expr(expr_1) ^ " | " ^ string_of_expr(expr_2)
and string_of_func_item f =
  match f with
  FunId(s) -> s
| FunItem(args, datatype, stmts) -> string_of_expr (FunLit(args, datatype, stmts))
and string_of_stmt  s = 
  match s with
  Expr(e) -> string_of_expr e
| Return(e) -> "return " ^ string_of_expr e
| If(expr, stmt_1, stmt_2) -> "if (" ^ string_of_expr expr ^ ") {\n" ^ string_of_stmt stmt_1 ^ "\n} else {\n" ^ string_of_stmt stmt_2 ^ "\n}"
| For(var, it, stmt) -> "for " ^ var ^ " in " ^ string_of_iterable it ^ ": {\n " ^ string_of_stmt stmt ^ "\n}"
| Block(stmts) -> string_join "\n" (List.map string_of_stmt stmts)
and string_of_iterable i =
  match i with
  ItId(s) -> s
| ItListLit(expr_list) -> string_join ";\n" (List.map string_of_expr expr_list)
| ItMatrixLit(m) -> string_of_expr (MatrixLit(m))
| ItMatrixFunDef(scalar_1, scalar_2, func_item) -> "Matrix((" ^ string_of_scalar scalar_1 ^ "," ^ string_of_scalar scalar_2 ^ "), " ^ string_of_func_item func_item ^")"
| ItCall(s, e_list) -> s ^ "(" ^ string_join "," (List.map string_of_expr e_list) ^ ")";;


let string_of_vardecl v =
  v.var_name ^ " (" ^ string_of_datatype v.return_type ^ ") " ^ " = " ^ string_of_expr v.body;;

let string_of_program p = string_join "\n\n" (List.map string_of_vardecl p);;

let string_of_token t =
  match t with
  STRING_LIT(_) -> "STRING_LIT"
| CHAR_LIT(_) -> "CHAR_LIT"
| PLUS -> "PLUS"
| MINUS -> "MINUS"
| MULTIPLY -> "MULTIPLY"
| DIVIDE -> "DIVIDE"
| MODULO -> "MODULO"
| ELEM_PLUS -> "ELEM_PLUS"
| ELEM_MINUS -> "ELEM_MINUS"
| ELEM_MULTIPLY -> "ELEM_MULTIPLY"
| ELEM_DIVIDE -> "ELEM_DIVIDE"
| ELEM_MODULO -> "ELEM_MODULO"
| EQ -> "EQ"
| NEQ -> "NEQ"
| GEQ -> "GEQ"
| LEQ -> "LEQ"
| AND -> "AND"
| OR -> "OR"
| XOR -> "XOR"
| NOT -> "NOT"
| BIT_AND -> "BIT_AND"
| BIT_OR -> "BIT_OR"
| BIT_XOR -> "BIT_XOR"
| BIT_LS -> "BIT_LS"
| BIT_RS -> "BIT_RS"
| BIT_NOT -> "BIT_NOT"
| TOFL -> "TOFL"
| TOINT -> "TOINT"
| PRINT -> "PRINT"
| ASSIGN -> "ASSIGN"
| PRODUCES -> "PRODUCES"
| PIPE -> "PIPE"
| LBRACE -> "LBRACE"
| RBRACE -> "RBRACE"
| LBRACKET -> "LBRACKET"
| RBRACKET -> "RBRACKET"
| LPAREN -> "LPAREN"
| RPAREN -> "RPAREN"
| SEMI -> "SEMI"
| COLON -> "COLON"
| COMMA -> "COMMA"
| PERIOD -> "PERIOD"
| FOR -> "FOR"
| IN -> "IN"
| IF -> "IF"
| ELSE -> "ELSE"
| RETURN -> "RETURN"
| INT -> "INT"
| FLOAT -> "FLOAT"
| CHAR -> "CHAR"
| FUN -> "FUN"
| STRUCT -> "STRUCT"
| STRING -> "STRING"
| TUP -> "TUP"
| LIST -> "LIST"
| MATRIX -> "MATRIX"
| FLOAT_LIT(_) -> "FLOAT_LIT"
| INT_LIT(_) -> "INT_LIT"
| TRUE -> "TRUE"
| FALSE -> "FALSE"
| ID(_) -> "ID"
| MAP -> "MAP"
| FREAD -> "FREAD"
| FWRITE -> "FWRITE"
| RANGE -> "RANGE"
| EOF -> "EOF"
| PRINT -> "PRINT"
| _ -> "![INVALID TOKEN]"



