(* Written by Zachary Salzbank *)

open Ast
open Sast

type symbol_table = {
  parent : symbol_table option;
  variables : Sast.v_decl list;
  functions: Sast.func_decl list;
}

type trans_env = {
  scope : symbol_table;
}

let rec find_variable (scope : symbol_table) name =
  try
    List.find (fun v -> v.vvname = name) scope.variables
  with Not_found ->
    match scope.parent with
    Some(parent) -> find_variable parent name
    | _ -> raise (Failure("variable " ^ name ^ " not defined"))

let var_exists scope name =
    try
      let _ = find_variable scope name
      in true
    with Failure(_) ->
      false

let integer_check i =
  if (i > 32767 || i < -32768) then
    raise (Failure("invalid value for integer"))
  else
    i

let id_check name = 
  if String.length name > 16 then
    raise (Failure("identifiers must be 16 characters or less"))
  else
    name

let rec find_function (scope : symbol_table) name =
  try
    List.find (fun f -> f.ffname = name) scope.functions
  with Not_found ->
    match scope.parent with
    Some(parent) -> find_function parent name
    | _ -> raise (Failure("function " ^ name ^ " not defined"))

let rec find_print (scope : symbol_table) t =
  try
    List.find (fun f -> (f.ffname = "print" && (List.hd f.fformals).vvtype = t)) scope.functions
  with Not_found ->
    match scope.parent with
    Some(parent) -> find_print parent t
    | _ -> raise (Failure("function print(" ^ string_of_obj_type t ^ ") not defined"))

let func_exists scope name =
    List.exists (fun f -> f.ffname = name) scope.functions

let assign_allowed lt rt = match lt with
    NodeType(t) -> (lt = rt) || (rt = NullType)
  | _ -> lt = rt

let rec can_assign lt rval = 
  let (_, rt) = rval in
    if assign_allowed lt rt then
      rval
    else
      raise (Failure("type " ^ string_of_obj_type rt ^ " cannot be put into type " ^ string_of_obj_type lt))

let inner_type t =
  match t with
      NodeType(it) -> it
    | _ -> raise (Failure("accessor cannot be used on " ^ string_of_obj_type t))


let is_node = function
    NodeType(_) -> true
  | _ -> false

let can_op lval op rval = 
  let (_, lt) = lval
  and (_, rt) = rval in
  let type_match = (lt = rt) in
  let int_or_char = (lt = IntType || lt = CharType) in
  let node = ((is_node lt) && rt == NullType) || (lt == NullType && (is_node
  rt)) || (type_match && (is_node lt)) in
  let result = match op with
      Ast.Add     -> (type_match && int_or_char), lt
    | Ast.Sub     -> (type_match && int_or_char), lt
    | Ast.Mult    -> (type_match && lt = IntType), lt
    | Ast.Div     -> (type_match && lt = IntType), lt
    | Ast.Equal   -> (type_match && (int_or_char || lt = BooleanType)) || node, BooleanType
    | Ast.Neq     -> (type_match && (int_or_char || lt = BooleanType)) || node, BooleanType
    | Ast.Less    -> (type_match && int_or_char), BooleanType
    | Ast.Leq     -> (type_match && int_or_char), BooleanType
    | Ast.Greater -> (type_match && int_or_char), BooleanType
    | Ast.Geq     -> (type_match && int_or_char), BooleanType
    | Ast.BoolAnd -> (type_match && lt == BooleanType), BooleanType
    | Ast.BoolOr  -> (type_match && lt == BooleanType), BooleanType
  in if fst result then
    snd result
  else
    raise (Failure("operator " ^ string_of_op op ^ " cannot be used on types " ^
      string_of_obj_type lt ^ " and " ^ string_of_obj_type rt))

let translate (globals, funcs) =
  let rec trans_lval env = function
      Ast.Id(n) -> let vdecl = (find_variable env.scope n) in
                     Sast.Id(vdecl), vdecl.vvtype
    | Ast.Unop(lval, op) -> let l, t = trans_lval env lval in
                            let inner = inner_type t in
                            let newt = match op with
                                Ast.Child(_) -> t
                              | Ast.ValueOf  -> inner
                            in Sast.Unop((l, t), trans_unop env op), newt
  and trans_unop env = function
      Ast.Child(e) -> let e, t = (trans_expr env e) in
                      if (t == IntType) then
                        Sast.Child(e, t)
                      else
                        raise (Failure("index must be of type int"))
    | Ast.ValueOf  -> Sast.ValueOf
  and trans_expr env = function
      Ast.Literal(l) -> (match l with
          Integer(i) -> Literal(Integer(integer_check i)), IntType
        | Character(c) -> Literal(Character(c)), CharType
        | Boolean(b) -> Literal(Boolean(b)), BooleanType
        | Null -> Literal(Null), NullType
      )
    | Ast.Node(e) -> 
        let e, t = trans_expr env e
        in Sast.Node(e), NodeType(t)
    | Ast.LValue(l) -> 
        let lv, t = trans_lval env l
        in Sast.LValue(lv, t), t 
    | Ast.Binop(e1, op, e2) ->
        let e1 = trans_expr env e1
        and e2 = trans_expr env e2 
        in let rtype = can_op e1 op e2 in
        Sast.Binop(e1, op, e2), rtype
    | Ast.Call(n, a) -> 
        let args = 
          List.map (fun s -> (trans_expr env s)) a in
        let fdecl = if n = "print" then
          (find_print env.scope (snd (List.hd args)))
        else
          (find_function env.scope n)
        in let types =
          List.rev (List.map (fun v -> v.vvtype) (List.rev fdecl.fformals)) in
        let checked_args = try
            List.map2 can_assign types args
          with Invalid_argument(x) ->
            raise (Failure("invalid number of arguments")) in
        Sast.Call(fdecl, checked_args), fdecl.fftype 
    | Ast.Assign(lv, e) ->
        let lval, t = (trans_lval env lv) in
        let aval = (trans_expr env e) in
        Sast.Assign((lval, t), (can_assign t aval)), t
    | Ast.Neg(e) ->
        let e, t = (trans_expr env e) in
        if t = IntType then
          Sast.Neg(e, t), t
        else
          raise (Failure("cannot negate type " ^ string_of_obj_type t))
    | Ast.Bang(e) ->
        let e, t = (trans_expr env e) in
        if t = BooleanType then
          Sast.Bang(e, t), t
        else
          raise (Failure("cannot get logical opposite of type " ^ string_of_obj_type t))
    | Ast.Noexpr ->
        Sast.Noexpr, VoidType
  in let add_local env v = 
    let evalue = match (var_exists env.scope (id_check v.vname)) with
        true -> raise (Failure("redeclaration of " ^ v.vname))
      | false -> match v.vdefault with
                     None -> None
                   | Some(e) -> Some(can_assign v.vtype (trans_expr env e))
    in let new_v = {
            vvname = v.vname;
            vvtype = v.vtype;
            vvdefault = evalue;
      }
    in let vars = new_v :: env.scope.variables
    in let scope' = {env.scope with variables = vars}
    in {(*env with*) scope = scope'}
  in let rec trans_stmt env = function
      Ast.Block(v, s) -> 
        let scope' = {parent = Some(env.scope); variables = []; functions = []}
        in let env' = {(*env with*) scope = scope'}
        in let block_env = List.fold_left add_local env' (List.rev v)
        in let s' = List.map (fun s -> trans_stmt block_env s) s
        in Sast.Block(block_env.scope.variables, s')
    | Ast.Expr(e) ->
        Sast.Expr(trans_expr env e)
    | Ast.Return(e) ->
        Sast.Return(trans_expr env e)
    | Ast.If (e, s1, s2) ->
        let e' = trans_expr env e
        in Sast.If(can_assign BooleanType e', trans_stmt env s1, trans_stmt env s2)
    | Ast.While (e, s) ->
        let e' = trans_expr env e
        in Sast.While(can_assign BooleanType e', trans_stmt env s)
  in let add_func env f =
    let new_f = match ((var_exists env.scope f.fname) || (func_exists env.scope f.fname)) with
      true -> raise (Failure("redeclaration of " ^ f.fname))
    | false ->
        let scope' = {parent = Some(env.scope); variables = []; functions = []}
        in let env' = {(*env with*) scope = scope'}
        in let env' = List.fold_left add_local env' (List.rev f.formals)
        in {
          fftype = f.ftype;
          ffname = id_check f.fname;
          fformals = env'.scope.variables;
          flocals = [];
          fbody = [];
          parsed = false;
        }
    in let funcs = new_f :: env.scope.functions
    in let scope' = {env.scope with functions = funcs}
    in {(*env with*) scope = scope'}
  in let trans_func env (f : Ast.func_decl) =  
    let sf = find_function env.scope (f.fname)
    in let functions' = List.filter (fun f -> f.ffname != sf.ffname) env.scope.functions
    in let scope' = {parent = Some(env.scope); variables = sf.fformals; functions = []}
    in let env' = {(*env with*) scope = scope'}
    in let formals' = env'.scope.variables
    in let env' = List.fold_left add_local env' (f.locals)
    in let remove v =
      not (List.exists (fun fv -> fv.vvname = v.vvname) formals')
    in let locals' = List.filter remove env'.scope.variables 
    in let body' = List.map (fun f -> trans_stmt env' f) (f.body)
    in let new_f = {
      sf with 
      fformals = formals';
      flocals = locals';
      fbody = body';
      parsed = true;
    }
    in let funcs = new_f :: functions'
    in let scope' = {env.scope with functions = funcs}
    in {(*env with*) scope = scope'}
  in let validate_func f =
    let is_return = function
        Sast.Return(e) -> true
      | _ -> false
    in let valid_return = function
        Sast.Return(e) -> if assign_allowed f.fftype (snd e) then
                            true
                          else
                             raise (Failure(  f.ffname ^ " must return type " ^
                              string_of_obj_type f.fftype ^
                              ", not " ^ string_of_obj_type (snd e)
                             ))
      | _ -> false
    in let returns = List.filter is_return f.fbody
    in let _ = List.for_all valid_return returns
    in let return_count = List.length returns
    in if (return_count = 0 && f.fftype != VoidType) then
      raise (Failure(f.ffname ^ " must have a return type of " ^ string_of_obj_type f.fftype))
    else if List.length f.fformals > 8 then
      raise (Failure(f.ffname ^ " must have less than 8 formals"))
    else
      f
  in let make_print t = 
    {
      fftype = VoidType;
      ffname = "print";
      fformals = [{
        vvname = "val";
        vvtype = t;
        vvdefault = None;
      }];
      flocals = [];
      fbody = [];
      parsed = false;
    }
  in let global_scope = {
    parent = None;
    variables = [];
    functions = List.map make_print [IntType; CharType; BooleanType];
  }
  in let genv = {
    scope = global_scope;
  }
  in let genv = List.fold_left add_local genv (List.rev globals)
  in let genv = List.fold_left add_func genv (List.rev funcs)
  in let genv = List.fold_left trans_func genv (List.rev funcs)
  in if func_exists genv.scope "root" then
    (genv.scope.variables, List.map validate_func genv.scope.functions)
  else
    raise (Failure("no root function defined"))
