(*
File: SEMANT.ML
Description: Semantically checks the AST
*)

open Ast

module StringMap = Map.Make(String)

let check (functions) =
  let image_row_size = 50
  and image_col_size = 50 in

  (* Verify a list of bindings has no void types or duplicate names *)
  let check_binds (kind : string) (binds : bind list) =
    List.iter (function
        (Void, b) -> raise (Failure ("illegal void " ^ kind ^ " " ^ b)) (* can't have void args*)
      | _ -> ()) binds;                                             (* when iterating over list of binds*)
    let rec dups = function
        [] -> ()
      | ((_,n1) :: (_,n2) :: _) when n1 = n2 ->
          raise (Failure ("duplicate " ^ kind ^ " " ^ n1))
      | _ :: t -> dups t
      (* sorts the binds using compare, then checks for dups*)
      in dups (List.sort (fun (_,a) (_,b) -> compare a b) binds)
  in

  (* Collect function declarations for built-in functions: no bodies *)
  let built_in_decls =
    let add_bind map func_def = StringMap.add func_def.fname func_def map
    in List.fold_left add_bind StringMap.empty [
      { typ = Void; fname = "print"; formals = [(String, "x")]; body = [] };
      { typ = Image; fname = "load"; formals = [(String, "x")]; body = [] };
      { typ = Image; fname = "blur"; formals = [(String, "x")]; body = [] };
      { typ = Image; fname = "grayscale"; formals = [(String, "x")]; body = [] };
      { typ = Image; fname = "brighten"; formals = [(String, "x")]; body = [] };
      { typ = Image; fname = "edgedetect"; formals = [(String, "x")]; body = [] };
      { typ = Matrix; fname = "dim"; formals = [(String, "x")]; body = [] };
      { typ = Int; fname = "row_size"; formals = [(String, "filename")]; body = [] };
      { typ = Double; fname = "int2double"; formals = [(Int, "i")]; body = [] };
      { typ = Double; fname = "dbl_arr"; formals = []; body = [] };
      { typ = Bool; fname = "save"; formals = [(Image, "x")]; body = [] };
    ]
  in

  (* Add function name to symbol table *)
  let add_func map fd =
    match fd.typ with
      Matrix -> raise (Failure ("Function cannot return matrix type"))
    | Image -> raise (Failure ("Function cannot return image type"))
    | _ ->     let built_in_err = "function " ^ fd.fname ^ " may not be defined"
        and dup_err = "duplicate function " ^ fd.fname
        and make_err er = raise (Failure ("Adding func failed with " ^ er))
        and n = fd.fname (* Name of the function *)
        in match fd with (* No duplicate functions or redefinitions of built-ins *)
             _ when StringMap.mem n built_in_decls -> make_err built_in_err
           | _ when StringMap.mem n map -> make_err dup_err
           | _ ->  StringMap.add n fd map
  in

  (* Collect all function names into one symbol table *)
  let function_decls = List.fold_left add_func built_in_decls functions
  in

  (* Return a function from our symbol table *)
  let find_func s =
    try StringMap.find s function_decls
    with Not_found -> raise (Failure ("unrecognized function " ^ s))
  in

  (* Ensure "main" is defined *)
  let _ = find_func "main" in

  let check_function func =
    (* Make sure no formals are void or duplicates *)
    check_binds "formal" func.formals;

    (* Raise exception if the given rvalue type cant be assigned to
       the given LLValue type *)
    let check_assign lvaluet rvaluet err =
       if lvaluet = rvaluet then lvaluet
       else raise (Failure ("Check_assign failed with " ^ err))
    in

    (* Build symbol table of formal variables for this function *)
    let formals = List.fold_left (fun m (ty, name) ->
        StringMap.add name ty m) StringMap.empty ( func.formals )
    in

    (* Return a variable from our local symbol table *)
    let type_of_identifier s symbols =
      try StringMap.find s symbols
      with Not_found -> raise (Failure ("undeclared identifier " ^ s))
    in

    let matrix_access_type = function
        DimMatrix(_, _) -> Double
      | _ -> raise (Failure ("Illegal matrix access"))
    in

    (* Return a semantically-checked expression, i.e., with a type *)
    let rec check_expr symbols = function
        IntLit _ -> Int
      | StrLit _ -> String
      | DblLit _  -> Double
      | BoolLit _ -> Bool
      | Binop (e1, op, e2) ->
          let t1 = check_expr symbols e1
          and t2 = check_expr symbols e2
          in let same = (t1 = t2)
          (* check if matrix, then check dimensions are the same *)
          in let t1_dim = (match t1 with
              DimMatrix(m, n) -> (m, n)
            | _ -> (-1, -1) (* for non - matrices *)
          )
          in let t2_dim = (match t2 with
              DimMatrix(m, n) -> (m, n)
            | _ -> (-1,-1)
          )
          in (match op with
              Add | Sub | Mult | Div | Mod when same && t1 = Int -> Int
            | Add | Sub | Mult | Div | Mod when same && t1 = Double -> Double
            | Equal | Neq when same -> Bool
            | Less | Leq | Greater | Geq when same -> Bool
            | And | Or when same && t1 = Bool -> Bool
            (* matrix op matrix *)
            | Add | Sub when t1 = DimMatrix(fst t1_dim, snd t1_dim)
                          && t2 = DimMatrix(fst t2_dim, snd t2_dim) ->
                          if same then DimMatrix(fst t1_dim, snd t1_dim)
                          else raise (Failure ("Cannot add/subtract matrices with different dimensions"))
            | Mult when t1 = DimMatrix(fst t1_dim, snd t1_dim)
                     && t2 = DimMatrix(fst t2_dim, snd t2_dim) ->
                     if (snd t1_dim = fst t2_dim) then DimMatrix(fst t1_dim, snd t2_dim)
                     else raise (Failure ("Matrices cannot be multiplied given their dimensions"))
            (* matrix op scalar *)
            | Add | Sub | Mult when not same
                                    && t1 = DimMatrix(fst t1_dim, snd t1_dim) ->
                                    if ((t2 = Int) || (t2 = Double))
                                        then DimMatrix(fst t1_dim, snd t1_dim)
                                    else raise (Failure ("Scalar ops with matrices can only use ints or doubles"))
            (* scalar op matrix *)
            | Add | Sub | Mult when not same
                                    && t2 = DimMatrix(fst t2_dim, snd t2_dim) ->
                                    if ((t1 = Int) || (t1 = Double)) then DimMatrix(fst t2_dim, snd t2_dim)
                                    else raise (Failure ("Scalar ops with matrices can only use ints or doubles"))
            | _ -> raise (Failure ("illegal binary operator")))
      | Unop (op, e) ->
            let list_acceptable = [Int; Double]
            and typ_e = (check_expr symbols e)
            in if op = Neg then
              if List.mem typ_e list_acceptable then typ_e
              else raise (Failure ("Cannot negate a non-int/double"))
            else
              if typ_e = Bool then Bool (* if op = Not *)
              else raise (Failure ("Cannot NOT a non-bool"))
      | Assign (varname, e) ->
            let lt = type_of_identifier varname symbols
            and rt = check_expr symbols e
            and non_mat_or_img = [Int; Bool; Double; String] (* Because matrices and images have both DimImg and DimMat *)
            in if List.mem rt non_mat_or_img (* If the variable is not a matrix then make sure the two sides are same *)
              then check_assign lt rt "illegal assignment"
              else lt
      | Call (fname, actuals) -> let func_call = find_func fname in (* make sure function exists *)
            if List.length actuals != List.length func_call.formals
              then raise (Failure ("expecting " ^ string_of_int (List.length func_call.formals) ^ " arguments in " ^ fname))
            else
                (* Do not conduct type check on argument for "print" method *)
              if not (String.equal fname "print")
                then
                  List.iter2
                  (fun (ft, _) e ->
                    let et = check_expr symbols e
                    in ignore (check_assign ft et "illegal actual argument"))
                  func_call.formals actuals;
                  func_call.typ
      | Noexpr -> Void
      | Noassign (t) -> t
      | MatLit m -> (* type checking done in parser/scanner *)
            let rec len_check = function
                []                  -> true (* empty list of lists *)
              | _ :: []             -> true (* only one list *)
              (* two lists, check their lengths *)
              | fst :: snd :: []    -> List.length fst = List.length snd
              (* more than two lists, recursively compare first with second & second with rest *)
              | fst :: snd :: tail  -> len_check (fst::[snd]) && len_check (snd::tail)
            in if len_check m then
              let rows = (List.length m) and cols = (List.length (List.hd m)) in
              (* if true then raise (Failure ((string_of_int rows) ^ " by " ^ (string_of_int cols))) *)
              (* it registers empty matrix as 1 by 0? *)
              if rows = 1 && cols = 0 then DimMatrix(0, 0)
              else DimMatrix(rows, cols)
            else raise (Failure ("Not all rows in matrix are the same length"))
      | MatAccess (mname, row, col) ->
            let _ = (match (check_expr symbols row) with
                Int -> Int
              | _ -> raise (Failure ("Attempting to access with a non-integer type")))
            and _ = (match (check_expr symbols col) with
                Int -> Int
              | _ -> raise (Failure ("Attempting to access with a non-integer type")))
            in matrix_access_type (type_of_identifier mname symbols)
      | Cast (typ, s1) ->
            let typ_s1 = (check_expr symbols s1) in
            if typ = Int
              then if typ_s1 = Double then Int
              else raise (Failure ("Cannot cast non-double to int"))
            else if typ = Double
              then if typ_s1 = Int then Double
              else raise (Failure ("Cannot cast non-int to a double"))
            else raise (Failure("Cannot cast to that type"))
      | Id varname ->
            if StringMap.mem varname symbols then type_of_identifier varname symbols
            else raise (Failure ("Variable not found: " ^ varname))
      | ImageLit (m1, m2, m3) ->  (* need to check that all matrices are same size*)
          let m1t = (type_of_identifier m1 symbols)
             and m2t = (type_of_identifier m2 symbols)
             and m3t = (type_of_identifier m3 symbols)
             in
             if m1t = m2t
               then if m2t = m3t
                 then if m3t = DimMatrix(image_row_size, image_col_size)
                   then Image
                 else raise(Failure ("Can't create an image with the wrong dimensions"))
               else raise(Failure ("Matrix sizes don't match"))
             else raise(Failure ("Matrix sizes don't match"))
      | ImageRedAccess varname ->
            if StringMap.mem varname symbols
              then if (type_of_identifier varname symbols) = Image then Matrix
              else raise (Failure ("Cannot call .red on non image datatype"))
            else raise (Failure ("Image variable does not exist"))
      | ImageGreenAccess (varname) ->
            if StringMap.mem varname symbols
              then if (type_of_identifier varname symbols) = Image then Matrix
              else raise (Failure ("Cannot call .green on non image datatype"))
            else raise (Failure ("Image variable does not exist"))
      | ImageBlueAccess (varname) ->
            if StringMap.mem varname symbols
              then if (type_of_identifier varname symbols) = Image then Matrix
              else raise (Failure ("Cannot call .blue on non image datatype"))
            else raise (Failure ("Image variable does not exist"))
      | MatrixRowSize (varname) ->
            if StringMap.mem varname symbols
              then if (type_of_identifier varname symbols) = Matrix then Int
              else raise (Failure ("Cannot call .rowsize on non matrix datatype"))
            else raise (Failure ("Matrix variable does not exist"))
      | MatrixColSize (varname) ->
            if StringMap.mem varname symbols
              then if (type_of_identifier varname symbols) = Matrix then Int
              else raise (Failure ("Cannot call .colsize on non matrix catatype"))
            else raise (Failure ("Matrix variable does not exist"))
    in

    let check_bool_check_expr symbols e = if check_expr symbols e != Bool
        then raise (Failure ("expected Boolean expression")) else ()
    in

    (* Return a semantically-checked statement i.e. containing sexprs *)
    let rec check_stmt symbols = function
        Expr e -> ignore (check_expr symbols e) ; symbols
      | If(cond, b1, b2) ->
            check_bool_check_expr symbols cond;
            ignore(check_stmt symbols b1); check_stmt symbols b2
      | For(e1, cond, e2, st) ->
            ignore (check_expr symbols e1) ;
            check_bool_check_expr symbols cond ;
            ignore (check_expr symbols e2) ;
            check_stmt symbols st
      | While(cond, st) -> check_bool_check_expr symbols cond; check_stmt symbols st
      | Return e ->
            let t = check_expr symbols e
            in (match t with
                DimMatrix(_, _) -> raise (Failure ("Cannot return matrix type (don't know how to allocate result)"))
              | _ -> if t = func.typ then symbols else raise (Failure ("Return type does not match method signature")))
      | Block sl ->
            let rec check_stmt_list symbols_rec = function
                [Return _ as s] -> check_stmt symbols_rec s
              | Return _ :: _   -> raise (Failure "nothing may follow a return")
              | Block sl :: ss  -> check_stmt_list symbols_rec (sl @ ss) (* Flatten blocks *)
              | s :: ss         -> let symbols_updated = check_stmt symbols_rec s
                                    in check_stmt_list symbols_updated ss
              | []              -> symbols_rec
            in check_stmt_list symbols sl
      | Local (typ, varname, e) as call ->
            if StringMap.mem varname symbols then
              (ignore (check_expr symbols e) ; symbols)
  	        else
              let expr_type = check_expr symbols e
              in if expr_type = Void then
                check_stmt (StringMap.add varname expr_type symbols) call
              else if typ = Matrix then
                check_stmt (StringMap.add varname expr_type symbols) call (* Adds the dimmat *)
              else if typ = expr_type then
                check_stmt (StringMap.add varname expr_type symbols) call
              else raise (Failure ("Local var type does not match"))
  in check_stmt formals (Block func.body)
in (List.map check_function functions)
