(* LOON ast.ml. Written by Chelci, Jack, Niles, Kyle and Habin *)
module L = Llvm

type op =
	(* numerical operators *)
	| Add | Sub | Mult | Div | Equal
	(* Relational operators *)
	| Neq | Less | Leq | Greater | Geq
	(* boolean operators *)
	| And | Or

type uop =
    | Neg | Not | Deref

type typ =
    | Int | Bool | Void | String | Pair of typ | Char | Array | Json

type bind = typ * string

type expr =
    | Literal of int
    | BoolLit of bool
    | CharLit of char
    | StringLit of string
    | PairLit of expr * expr
    | Id of string
    | Noexpr
    | Binop of expr * op * expr
    | Unop of uop * expr
	  | Assign of string * expr list * expr
    | Call of string * expr list
    | Access of string * expr list
    | ArrayLit of expr list
    | JsonLit of (expr * expr) list

type stmt =
    | Block of stmt list
    | Expr of expr
    | If of expr * stmt * stmt
    | For of expr * expr * expr * stmt
    | While of expr * stmt
    | Return of expr

type func_decl = {
    primitive	: typ;
    fname 		: string;
    formals 	:	bind list;
    locals		: bind list;
    body		: stmt list;
}

type program = bind list * func_decl list

(* Pretty-printing functions *)

let string_of_op = function
    Add -> "+"
  | Sub -> "-"
  | Mult -> "*"
  | Div -> "/"
  | Equal -> "=="
  | Neq -> "!="
  | Less -> "<"
  | Leq -> "<="
  | Greater -> ">"
  | Geq -> ">="
  | And -> "&&"
  | Or -> "||"

let string_of_uop = function
    Neg -> "-"
  | Not -> "!"
  | Deref -> "*"

let rec string_of_expr = function
    Literal(l) -> string_of_int l
  | CharLit(c) -> Char.escaped c
  | StringLit(s) -> s
  | BoolLit(true) -> "true"
  | BoolLit(false) -> "false"
  | PairLit(k, v) -> string_of_expr k ^ ", " ^ string_of_expr v
  | Id(s) -> s
  | Binop(e1, o, e2) ->
      string_of_expr e1 ^ " " ^ string_of_op o ^ " " ^ string_of_expr e2
  | Unop(o, e) -> string_of_uop o ^ string_of_expr e
  | Assign(v, lst, e) -> ignore(v, lst, e);(*v ^ "[" ^ (List.map string_of_expr lst) ^ "]" ^ " = " ^ string_of_expr e *) "nah"
  | Access(id, indx_list) -> ignore(id, indx_list);(*id ^ "[" ^ (List.map string_of_epr indx_lst) ^ "]"*) "nah2"
  | Call(f, el) ->
      f ^ "(" ^ String.concat ", " (List.map string_of_expr el) ^ ")"
  | Noexpr -> ""
  | ArrayLit(l) -> "array: [" ^ String.concat ", " (List.map string_of_expr l) ^ "]"
  | JsonLit(l) ->
          let handle_tuples (first, second) = string_of_expr first ^ ", " ^ string_of_expr second in
          "array: [" ^ String.concat ", " (List.map handle_tuples l) ^ "]"

let rec string_of_stmt = function
    Block(stmts) ->
      "{\n" ^ String.concat "" (List.map string_of_stmt stmts) ^ "}\n"
  | Expr(expr) -> string_of_expr expr ^ "\n";
  | Return(expr) -> "return " ^ string_of_expr expr ^ ";\n";
  | If(e, s, Block([])) -> "if (" ^ string_of_expr e ^ ")\n" ^ string_of_stmt s
  | If(e, s1, s2) ->  "if (" ^ string_of_expr e ^ ")\n" ^
      string_of_stmt s1 ^ "else\n" ^ string_of_stmt s2
  | For(e1, e2, e3, s) ->
      "for (" ^ string_of_expr e1  ^ " ; " ^ string_of_expr e2 ^ " ; " ^
      string_of_expr e3  ^ ") " ^ string_of_stmt s
  | While(e, s) -> "while (" ^ string_of_expr e ^ ") " ^ string_of_stmt s

let string_of_typ = function
    Int -> "int"
  | Bool -> "bool"
  | Void -> "void"
  | String -> "string"
  | Array -> "array"
  | Pair _ -> "pair"
  | Char -> "char"
  | Json -> "json"

let string_of_vdecl (t, id) = string_of_typ t ^ " " ^ id ^ "\n"

let string_of_fdecl fdecl =
  string_of_typ fdecl.primitive ^ " " ^
  fdecl.fname ^ "(" ^ String.concat ", " (List.map snd fdecl.formals) ^
  ")\n{\n" ^
  String.concat "" (List.map string_of_vdecl fdecl.locals) ^
  String.concat "" (List.map string_of_stmt fdecl.body) ^
  "}\n"

let string_of_program (vars, funcs) =
  String.concat "" (List.map string_of_vdecl vars) ^ "\n" ^
  String.concat "\n" (List.map string_of_fdecl funcs)

(* Function to return the zero-value for each type *)
let zero_of_typ = function
    Int -> Literal(0)
    | Bool -> BoolLit(false)
    | String -> StringLit("")
    | Char -> CharLit(Char.chr 0)
    (* | Pair(p_type) -> PairLit(StringLit(""), Literal(0)) *)
    | Array -> ArrayLit([Literal(0)])
    | Json -> JsonLit([(StringLit(""), Literal(0))])
    |_ -> Literal(0)

let fmt_of_lltype =  function
    "i8*" -> "%s"
  | "i8" -> "%c"
  | "i32" -> "%d"
  | _ -> "%d"

(** Wrapper around array value types *)
type val_type =
		| Val of L.lltype
		| Val_list of val_type list

(** Check if array type is value or nested array *)
(*let is_val = function
	Val(v) -> ignore(v); true
	| Val_list(v_list) -> ignore(v_list); false
	| _ -> false *)

(** Get type of val at specified indx pos.
    indx_val: list of ints specifying indx pos
	returns lltype *)
let rec get_val_type context indx_list = function
	  Val(v) -> (*ignore(print_endline("GET_VAL: Reached lltype: " ^ (L.string_of_lltype v)));*) v
	| Val_list(v_list) ->
		(* Get nth value, call function again on it *)
		if indx_list = [] then (
			ignore(print_endline("GET_VAL: Not accessing further - return i8***"));
			L.pointer_type(L.pointer_type(L.pointer_type (L.i8_type context))) )
		else (let this_indx = List.hd indx_list in
			  let next_val = List.nth v_list this_indx in (*ignore(print_endline("GET_VAL: Array, call again - list size: " ^ (string_of_int (List.length indx_list))));*)
			get_val_type context (List.tl indx_list) next_val)

(** Set type of val at specified indx pos
		types_lst: previous list of types
		new_type: val_type of new type
		indxs_int_lst: list of index positions to scan*)
let rec set_val_type context types_lst new_type indxs_int_lst =
	let cur_indx = List.hd indxs_int_lst
	and rem_indxs = List.tl indxs_int_lst in

	(* Function that mapi calls to build new list*)
	let map_func i orig_elem_type =
		(* Match current indx and cannot index any further - replace this type*)
		if (i = cur_indx && rem_indxs = []) then(
		Val(L.pointer_type (match new_type with
				| Val v -> v
				| _ -> ignore(print_endline("SET_VAL_TYPE: Error: Matched with current index, and still have more indexing to do, but value is not indexable")); L.i32_type context) ))
		(* Match current indx and can index futher - call again on remaining indexes as orig_elem_types must be list *)
		else( if i = cur_indx then (
			let true_type = (match orig_elem_type with
				| Val_list nxt_lst -> nxt_lst
				| _ -> ignore(print_endline("SET_VAL_TYPE: Error: Matched with current index, and still have more indexing to do, but value is not indexable")); []) in
		Val_list(set_val_type context true_type new_type rem_indxs) )
		(* Otherwise no match on indx, so return previous value *)
		else orig_elem_type) in
	List.mapi map_func types_lst


(** read JSON value from a string *)
(*val from_string : ?buf:Bi_outbuf.t -> ?fname:string -> ?lnum:int -> string -> json *)
(** read JSON value from a file *)
(*val from_file : ?buf:Bi_outbuf.t -> ?fname:string -> ?lnum:int -> string -> json*)
(** read JSON value from channel *)
(*val from_channel : ?buf:Bi_outbuf.t -> ?fname:string -> ?lnum:int -> in_channel -> json
val from_string  : string     -> json
val from_file    : string     -> json
val from_channel : in_channel -> json
*)
