(* signed: Yanlin Duan, Emily Meng *)

open Ast
open Sast

module StringMap = Map.Make(String)

module StringHash = Hashtbl.Make(struct
    type t = string (* type of keys *)
    let equal x y = x = y (* use structural comparison *)
    let hash = Hashtbl.hash (* generic hash function *)
end)

(* Semantic checking of a program. Returns void if successful, throws an exception if something is wrong. Check each function *)

type symbol_table = {
  parent: symbol_table option;
  varMap: (borrowType*sexpr*typ*bool*bool*bool) StringHash.t; 
}

type env = {
  scope : symbol_table;
  return_type : typ;
  in_for : bool;
  in_while : bool;
}


let check (functions) =

  let structMap = StringHash.create 20 in

  let check_function func =
    let rec findVariable scope name = 
      try let a = StringHash.find scope.varMap name in (a,scope)
        with Not_found -> 
          match scope.parent with
            Some(parent) -> findVariable parent name
          | _            -> raise (Failure ("undeclared identifier " ^ name))

  in

  let rec checkOwnership scope name = 
    try let (_,_,_,own,_,_) = StringHash.find scope.varMap name in
      if own=false then raise (Failure ("use of moved value: " ^ name)) else ()
        with Not_found ->
          match scope.parent with
            Some(parent) -> checkOwnership parent name
          | _ -> raise (Failure ("undeclared identifier " ^ name))  
  in

  let rec findDupVar scope name = 
    try ignore(StringHash.find scope.varMap name);raise(Failure("undeclared identifier " ^ name))  
      with Not_found -> 
        match scope.parent with
          Some(parent) -> findDupVar parent name
        | _  -> ()  
  in
    
  let rec toggleOwnership scope name =
    try let (b,e',t,_,x,y) = StringHash.find scope.varMap name in StringHash.replace scope.varMap name (b,e',t,false,x,y)
      with Not_found ->
        match scope.parent with
          Some(parent) -> toggleOwnership parent name
        | _ -> raise (Failure ("undeclared identifier " ^ name))    
  in
  
  let move sexpr env = 
    match sexpr with
       SId(s,_) -> checkOwnership env.scope s; toggleOwnership env.scope s;
     | _ -> ()
  in

  (* Raise an exception if the given list has a duplicate *)
  let report_duplicate exceptf list =
    let rec helper = function
        n1 :: n2 :: _ when n1 = n2 -> raise (Failure(exceptf n1 ^ " duplicate variable"))
      | _ :: t -> helper t
      | [] -> ()
    in helper (List.sort compare list)
  in

  (* Raise an exception if a given binding is to a void type *)
  let check_not_void exceptf = function
      (n,DataT(VoidT)) -> raise (Failure (exceptf n ^ "void type"))
    | _ -> ()
  in

  (* Raise an exception if the given rvalue type cannot be assigned to the given lvalue type *)
  let check_assign lvaluet rvaluet err =
    if lvaluet = rvaluet then lvaluet 
    else raise err
  in

  (* Raise an exception if the given struct declarations do not match the definitions *)
  let check_fields formalt actualt =
    if formalt = actualt then formalt
    else raise (Failure ("illegal struct field found"))
  in

  let findMutability scope = function
      SId(n,_) -> let ((b,_,_,_,_,_),_) = findVariable scope n in (n,b)
    | _ -> raise(Failure("deal with the mutability check later!"))
  in  

  let findName  = function
      SId(s,_) -> s
    | _ -> raise(Failure("findName not implemented yet!"))
  in

  let rec toggleBorrow scope name hasImmutBorrow hasMutBorrow =
    try let (b,e',t,own,_,_) = StringHash.find scope.varMap name in 
      StringHash.replace scope.varMap name (b,e',t,own,hasImmutBorrow,hasMutBorrow)
      with Not_found ->
        match scope.parent with
            Some(parent) -> toggleBorrow parent name hasImmutBorrow hasMutBorrow
          | _ -> raise (Failure ("undeclared identifier " ^ name))
  in

  (* Raise an exception if binary operator does not support operations between operand types *)
  let check_binop_type env operand1 operator operand2 = 
    match operator with
        Add | Sub | Mult  -> 
        (match (get_type_from_sexpr operand1, get_type_from_sexpr operand2) with
            (DataT(IntT), DataT(IntT))     -> SBinop(operand1, operator, operand2, DataT(IntT))
          | (DataT(FloatT), DataT(FloatT)) -> SBinop(operand1, operator, operand2, DataT(FloatT))
          | (DataT(IntT), DataT(FloatT))   -> 
              (match operand1 with
                  SIntLit(n, _) -> SBinop(SFloatLit(float_of_int n, DataT(FloatT)), operator, operand2, DataT(FloatT))
                | _             -> SBinop(operand1, operator, operand2, DataT(FloatT))
              )
          | (DataT(FloatT), DataT(IntT))   -> 
              (match operand2 with
                  SIntLit(n, _) -> SBinop(operand1, operator, SFloatLit(float_of_int n, DataT(FloatT)), DataT(FloatT))
                | _             -> SBinop(operand1, operator, operand2, DataT(FloatT))
              )
          | _ -> raise (Failure(string_of_op operator ^ " does not support operations between " ^ string_of_sexpr operand1 ^ " and " ^  string_of_sexpr operand2)) ) 
      | Div | Mod ->  
        (match (get_type_from_sexpr operand1, get_type_from_sexpr operand2) with
            (DataT(IntT), DataT(IntT)) -> (match operand2 with
                SIntLit(0, _) -> raise (Failure("division by zero!"))
              | _                                 -> SBinop(operand1, operator, operand2, DataT(IntT)))
          | (DataT(FloatT), DataT(FloatT))
          | (DataT(IntT), DataT(FloatT)) | (DataT(FloatT), DataT(IntT)) -> 
              (match operand2 with
                SIntLit(0, _) | SFloatLit(0.0, _) -> raise (Failure("division by zero!"))
              | _                                 -> SBinop(operand1, operator, operand2, DataT(FloatT)))
          | _ -> raise (Failure(string_of_op operator ^ " does not support operations between " ^ 
                        string_of_sexpr operand1 ^ "and" ^string_of_sexpr operand2)) )
      | Equal | Neq | Less | Leq | Greater | Geq -> 
        (match (get_type_from_sexpr operand1, get_type_from_sexpr operand2) with
            (DataT(BoolT), DataT(BoolT))  -> SBinop(operand1, operator, operand2, DataT(BoolT))
          | (DataT(IntT),DataT(IntT))     -> SBinop(operand1, operator, operand2, DataT(BoolT))
          | (DataT(FloatT),DataT(FloatT)) -> SBinop(operand1, operator, operand2, DataT(BoolT))
          | _ -> raise (Failure(string_of_op operator ^ " does not support operations between these operators " ^ 
                        string_of_sexpr operand1 ^ " and " ^ string_of_sexpr operand2)))
      | And | Or ->
        (match (get_type_from_sexpr operand1, get_type_from_sexpr operand2) with
            (DataT(BoolT), DataT(BoolT)) -> SBinop(operand1, operator, operand2, DataT(BoolT))
          | _ -> raise (Failure(string_of_op operator ^ " does not support operations between these operators " ^ 
                        string_of_sexpr operand1 ^ " and " ^ string_of_sexpr operand2)) )
      | AS -> raise(Failure("AS not supported"))
      | Assign -> (match operand1 with
        | SNoexpr | SIntLit(_,_) | SBoolLit(_,_) | SFloatLit(_,_) | SCharLit(_,_) | SArrayLit(_,_,_) | SStringLit(_,_) | SCast(_,_) | SCall(_,_,_) | SBinop(_,_,_,_) |SUnop(Neg,_,_) | SUnop(Not,_,_) |SUnop(Borrow(_),_,_) -> raise (Failure(string_of_sexpr operand1 ^ "cannot be used as lvalue!"))
        | _ -> let t = get_type_from_sexpr operand1 and rt = get_type_from_sexpr operand2 in 
                ignore(check_assign t rt (Failure ("illegal assignment " ^ string_of_typ t ^ " = " ^ string_of_typ rt ^ " in " ^ 
                string_of_sexpr (SBinop(operand1,operator,operand2,get_type_from_sexpr operand2)))));
                let (s,b) = (match operand1 with
                      SId(s,_) -> let ((b,_,_,_,_,_),_) = findVariable env.scope s in (s,b)
                    | SUnop(Deref,se,_) -> findMutability env.scope se
                    | SStructAccess(_, _, _) -> ("a", Mut)
                    | _ -> raise(Failure("not supported"))) in
                
                (match b with 
                | Immut -> raise (Failure ("trying to modify an immutable value" ^ s))
                | _ -> ());
                
                (match operand2 with
                    SId(s,_) -> checkOwnership env.scope s
                | _ -> ());

                SBinop(operand1,operator,operand2,get_type_from_sexpr operand2);
              )
  in

  (**** Checking functions ****)

  (* Check overriding standard library functions *)
  if List.mem "println" (List.map (fun fd -> fd.fname) functions)
    then raise (Failure ("function println cannot be overwritten!")) 
  else ();

  (* Check duplicate function declarations *)
  report_duplicate (fun n -> "duplicate function " ^ n) (List.map (fun fd -> fd.fname) functions);

  (* Standard library function declarations *)
  let built_in_decls =  StringMap.singleton "println"
    { 
      outputType = DataT(VoidT); 
      fname = "println"; 
      formals = [("s",StringT)];
      body = [] 
    } 
  in

  (* Add standard library functions to function_decls map *)   
  let function_decls = List.fold_left (fun m fd -> StringMap.add fd.fname fd m) built_in_decls functions 
  in

  (* Check function exists in function_decls map *)
  let function_decl s = try StringMap.find s function_decls
    with Not_found -> raise (Failure ("unrecognized function " ^ s))
  in

  (* Ensure "main" is defined *)
  let _ = function_decl "main" 
  in 

  (* Verify an expression or throw an exception *)
  let rec expr_to_sexpr env = function
      Noexpr -> SNoexpr
    | IntLit(n) -> SIntLit(n,DataT(IntT))
    | BoolLit(b) -> SBoolLit(b,DataT(BoolT))
    | FloatLit(f) -> SFloatLit(f,DataT(FloatT))
    | CharLit(c) -> SCharLit(c,DataT(CharT))
    | StringLit(s) -> SStringLit(s,StringT)
    | ArrayLit(exp) -> 
      let rec iter t sel = function 
          []      -> sel, t 
        | e :: el -> 
            let se = expr_to_sexpr env e in 
              let se_t = get_type_from_sexpr se in 
                if t = se_t 
                then iter t (se :: sel) el 
                else raise (Failure ("multiple types in array " ^ string_of_expr (ArrayLit(exp))))
      in
      let se = expr_to_sexpr env (List.hd exp) in 
        let el = List.tl exp in 
          let se_t = get_type_from_sexpr se in 
            let sel, t = iter se_t ([se]) el in
              SArrayLit(List.rev sel, List.length sel, ArrayT(t,List.length sel))


    | Id(s) -> checkOwnership env.scope s; let ((_,_,t,_,_,_),_) = findVariable env.scope s in SId(s,t)

    | ArrayAccess(e1, e2) -> let e1' = expr_to_sexpr env e1 in 
      (match get_type_from_sexpr e1' with
          ArrayT(t,_) -> SArrayAccess(e1', expr_to_sexpr env e2, t)
        | _ -> raise(Failure("semant line 268, array access error")))

    | StructCreate(field_decl_list) -> 
      (* check lists of ids for duplicates *)
      let name_list = List.map fst field_decl_list in
        report_duplicate (fun n -> "field " ^ n) name_list;

      let arg_list = List.map snd field_decl_list in 
        let s_arg_list = List.map (expr_to_sexpr env) (arg_list) in

      let s_field_decl_list = List.combine name_list s_arg_list in
      SStructCreate(s_field_decl_list)

    | StructAccess(e,s) -> 
      (* check for var existance *)
      let e' = expr_to_sexpr env e in
        let et' = get_type_from_sexpr e' in
          let ((_, _, t, _,_,_),_) = findVariable env.scope (findName e') in 
            (* check arg struct type with returned type t *)
            let _ = check_fields t et' in
          
            (*check struct type has field variable name s *)
            let fld_list = try StringHash.find structMap (string_of_typ t)
              with Not_found -> raise(Failure("undeclared struct " ^ string_of_typ t)) in
              let name_fld_list = List.map fst fld_list in
                  let b = List.mem s name_fld_list in
                    if b 
                    then let fld_type = List.assoc s fld_list in 
                      SStructAccess(e',s,fld_type); 
                    else raise(Failure("no field " ^ s ^ " in struct type " ^ string_of_typ t))

    | Unop(o, e) -> let e' = expr_to_sexpr env e in
      let oldt = get_type_from_sexpr e' in 
        let newt = (match o with
          | Neg ->
            (match get_type_from_sexpr e' with 
                DataT(IntT) -> oldt
              | DataT(FloatT) -> oldt
              | _ -> raise(Failure("don't support this operator in " ^ string_of_sexpr e')))
          | Not -> 
            (match get_type_from_sexpr e' with 
                DataT(BoolT) -> oldt
              | _ -> raise(Failure("don't support not in " ^ string_of_sexpr e')))
          | Deref -> (match oldt with
              RefT(bt,t) ->  let name = findName e' in
              let ((_,_,_,_,hasImmutBorrow,hasMutBorrow),scope) = findVariable env.scope name in
                (match bt with
                    Immut -> if hasMutBorrow = true then raise(Failure("mutable and immutable borrow at the same time!")) else
                        toggleBorrow scope name true false
                  | Mut -> if hasImmutBorrow = true then raise(Failure("mutable and immutable borrow at the same time!")) else ();
                     toggleBorrow scope name false true); t
            | _ -> raise(Failure("Dereferencing a non-borrow at " ^ string_of_sexpr e')))
          | Borrow(bt) -> 
            if snd (findMutability env.scope e') = Immut && bt = Mut then raise(Failure("cannot mutably borrow a imuutable value!" ^ string_of_sexpr e')) else ();
          
              let name = findName e' in
              let ((_,_,_,_,hasImmutBorrow,hasMutBorrow),scope) = findVariable env.scope name in
                (match bt with
                    Immut -> if hasMutBorrow = true then raise(Failure("mutable and immutable borrow at the same time!")) else
                        toggleBorrow scope name true false
                  | Mut -> if hasMutBorrow = true then raise(Failure("at most one mutable borrow!")) else ();
                     if hasImmutBorrow = true then raise(Failure("mutable and immutable borrow at the same time!")) else ();
                     toggleBorrow scope name false true);
                RefT(bt,oldt)) in
      SUnop(o, e', newt)
    | Cast(e, t) -> SCast(expr_to_sexpr env e, t)
    | Binop(operand1, operator, operand2) -> 
        let so1 = expr_to_sexpr env operand1 and so2 = expr_to_sexpr env operand2 in
            check_binop_type env so1 operator so2
    | Call(fname, actuals) as call -> 
      (match fname with
        | "println" -> 
          if List.length actuals != 1 
          then raise (Failure ("expecting 1 arguments in println!"))
          else
            let actuals' = expr_to_sexpr env (List.hd actuals) in let actualsType = get_type_from_sexpr actuals' in
              (match actualsType with
                DataT(_) | StringT -> 
                  let fd = function_decl "println" in 
                  SCall(fname,List.map (fun x -> expr_to_sexpr env x) actuals, fd.outputType)
              | _ -> raise (Failure ("illegal actual argument found " ^ string_of_typ actualsType ^ " in println")))
        | _ -> let fd = function_decl fname in
          if List.length actuals <> List.length fd.formals 
          then raise (Failure ("expecting " ^ string_of_int (List.length fd.formals) ^ " arguments in " ^ string_of_expr call))
          else
            List.iter2 (fun (_, ft) e -> let e' = expr_to_sexpr env e in 
              let et = get_type_from_sexpr e' in
                ignore (check_assign ft et (Failure ("illegal actual argument found " ^ string_of_typ et ^ " expected " ^ string_of_typ ft ^ " in " ^ fname)))) fd.formals actuals;
            let sactuals = List.map (fun x -> expr_to_sexpr env x)  actuals in
            List.iter (fun x -> move x env) sactuals;
            SCall(fname,sactuals,fd.outputType))
    | _ -> SNoexpr

    in

    let check_bool_expr env e = 
      let e' = expr_to_sexpr env e in 
        let et = get_type_from_sexpr e' in  
          let boolt = get_type_from_sexpr(SBoolLit(true, DataT(BoolT))) in 
            ignore (check_assign boolt et (Failure ("expected Boolean expression in " ^ string_of_sexpr e'))) in

    (* Verify a statement or throw an exception *)
    let rec stmt_to_sstmt env = function
        Block(sl) -> 
          let new_env = 
          {
            env with scope = 
              {
                parent = Some(env.scope); 
                varMap = StringHash.create 20
              }
          } in SBlock(List.map (fun x -> stmt_to_sstmt new_env x) sl) 
      | Expr(e) -> let e' = expr_to_sexpr env e in let et = get_type_from_sexpr e' in SExpr(e',et) 
      | Return(e) -> let e' = expr_to_sexpr env e in let t = get_type_from_sexpr e' in 
        if t = func.outputType 
        then () 
        else
          raise (Failure ("return gives " ^ string_of_typ t ^ " expected " ^ string_of_typ func.outputType ^ " in " ^ string_of_expr e)); 
          SReturn(e',t)
      | If(p, b1, b2) -> check_bool_expr env p; 
        SIf(expr_to_sexpr env p,stmt_to_sstmt env b1,stmt_to_sstmt env b2)
      | While(p, s) -> check_bool_expr env p; 
        let new_env = 
        {
          env with scope = 
            {
              parent = Some(env.scope); 
              varMap = StringHash.create 20
            }; 
          in_while = true;
        } in SWhile(expr_to_sexpr new_env p, stmt_to_sstmt new_env s)
      | For(e1, e2, e3, st) -> SFor(expr_to_sexpr env e1, expr_to_sexpr env e2, expr_to_sexpr env e3, stmt_to_sstmt env st)
      | Declaration(b, (s, oldt), e) -> 
        let e' = expr_to_sexpr env e in let rt = get_type_from_sexpr e' in 
            (* If assigmment is valid, put that / replace that in varMap, and return the SDecl *)
            let t = 
              (match oldt with
                ArrayT(_,_) -> ignore(check_assign oldt rt (Failure ("illegal assignment " ^ string_of_typ oldt ^ " = " ^ string_of_typ rt ^ " in " ^ s))); oldt
              | ArrayTD(lt,ns) -> checkOwnership env.scope ns; let ((_,var,t,_,_,_),_) = findVariable env.scope ns in
                            if t <> DataT(IntT) then raise (Failure (ns ^ " should be an int!")) else ();
                            let var = (match var with
                            SIntLit(var,_) -> var
                            | _ -> raise (Failure (ns ^ " should be an int!")))
                             in let lt = ArrayT(lt,var) in
                            ignore(check_assign lt rt (Failure ("illegal assignment " ^ string_of_typ oldt ^ " = " ^ string_of_typ rt ^ " in " ^ s ))); rt
              | StructT(_) ->
                  (* check struct definition exists *) 
                  let sfields_list = try StringHash.find structMap (string_of_typ oldt)
                    with Not_found -> raise(Failure("undeclared struct " ^ string_of_typ oldt)) in
                  (* check duplicate variable name *)
                  ignore(findDupVar env.scope s);

                  (* check lists of ids *)
                  let def_name_list = List.map fst sfields_list in
                  (*let def_type_list = List.map snd sfields_list in*)
                    (match e' with 
                        SStructCreate(a) -> 
                          let name_list = List.map fst a in
                          let _ = try List.map2 check_fields def_name_list name_list
                          with Invalid_argument(_) -> raise (Failure("number of struct fields do not match!")) in ()
                      | _ -> () );

                  (* check lists of args *)
                  let def_type_list = List.map snd sfields_list in
                    (match e' with
                        SStructCreate(a) -> 
                          let arg_list = List.map snd a in
                            let type_list = List.map get_type_from_sexpr arg_list in
                            let _ = try List.map2 check_fields def_type_list type_list
                            with Invalid_argument(_) -> raise (Failure("number of struct field types do not match!")) in ()
                      | _ -> () );                  
                    oldt
              | _ -> 

              (match oldt with
                RefT(bt,_)-> if bt<>b then raise(Failure("borrow type unmatched at " ^ s )) else ()
              | _ -> ());
              ignore(check_assign oldt rt (Failure ("illegal assignment " ^ string_of_typ oldt ^ " = " ^ string_of_typ rt ^ " in " ^ s))); oldt) in
            move e' env;
            StringHash.add env.scope.varMap s (b, e', t,true, false,false);
            SDeclaration(b,(s, t),e');
      | StructDef(s, sl) -> 
        (* check for already defined struct *)
        let _ = try ignore(StringHash.find structMap s); raise(Failure("already defined struct name " ^ s))
            with Not_found -> 
              (* check for duplicate fields *)
              report_duplicate (fun n -> "field " ^ n) (List.map fst sl);
              StringHash.add structMap s sl;
            in
          SStructDef(s,sl)
      | _ -> SBreak

  and convert_fdecl_to_sfdecl env fdecl = 
  {
      sfname = fdecl.fname;
      sformals = fdecl.formals;
      soutputType = fdecl.outputType;
      sbody = List.map (fun x -> stmt_to_sstmt env x) fdecl.body;
  } in

  List.iter (check_not_void (fun n -> "illegal void formal " ^ n ^ 
    " in " ^ func.fname)) func.formals;

  report_duplicate (fun n -> "duplicate formal " ^ n ^ 
    " in " ^ func.fname) (List.map fst func.formals);

  let varMap = StringHash.create 20 in     
     List.iter (fun (s,t) -> let b = match t with
                    RefT(bt,_) -> bt
                  | _          -> Immut in StringHash.add varMap s (b, SId(s, t), t, true, false,false)) func.formals;
  let env = 
      {
        scope = {parent = None; varMap = varMap};
        return_type = func.outputType;
        in_for = false;
        in_while = false;
      } in 
      
  convert_fdecl_to_sfdecl env func; in
 
  List.map check_function functions
