(* Semantic checking for the Pixel compiler *)

open Ast
open Sast

module StringMap = Map.Make(String)

(* Semantic checking of the AST. Returns an SAST if successful,
   throws an exception if something is wrong.

   Check each global variable, then check each function *)

let check functions =

  (* Verify a list of bindings has no void types or duplicate names *)
  let check_binds (kind : string) (binds : bind list) =
    List.iter (function
	(Void, b, _) -> raise (Failure ("illegal void " ^ kind ^ " " ^ b))
      | _ -> ()) binds;
    let rec dups = function
        [] -> ()
      |	((_,n1,_) :: (_,n2,_) :: _) when n1 = n2 ->
	  raise (Failure ("duplicate " ^ kind ^ " " ^ n1))
      | _ :: t -> dups t
    in dups (List.sort (fun (_,a,_) (_,b,_) -> compare a b) binds)
  in

  (**** Check global variables ****)

  (* check_binds "global" globals; *)

  (**** Check functions ****)

  (* Collect function declarations for built-in functions: no bodies *)
  let built_in_decls = 
    let add_bind map (name, ty) = StringMap.add name {
      typ = Void;
      fname = name; 
      formals = [(ty, "x")];
      body = [] } map
    in List.fold_left add_bind StringMap.empty [ ("print", Int);
			                         ("printf", Float);
                               ("image_in", [String; String], Image);
                               ("image_out", [String; Image; String], Void);
                               ("convolute", [Matrix; Matrix], Matrix);
                               ("join", [Matrix; Matrix; Matrix], Image);
                               ("join", [Matrix], Image) ]
  in

  (* Add function name to symbol table *)
  let add_func map fd = 
    let built_in_err = "function " ^ fd.fname ^ " may not be defined"
    and dup_err = "duplicate function " ^ fd.fname
    and make_err er = raise (Failure er)
    and n = fd.fname (* Name of the function *)
    in match fd with (* No duplicate functions or redefinitions of built-ins *)
         _ when StringMap.mem n built_in_decls -> make_err built_in_err
       (* | _ when StringMap.mem n map -> make_err dup_err   *)
       | _ ->  StringMap.add n fd map 
  in

  (* Collect all function names into one symbol table *)
  let function_decls = List.fold_left add_func built_in_decls functions
  in
  
  (* Return a function from our symbol table *)
  let find_func s = 
    try StringMap.find s function_decls
    with Not_found -> raise (Failure ("unrecognized function " ^ s))
  in

  let _ = find_func "main" in (* Ensure "main" is defined *)

  let check_function func =
    (* Make sure no formals are void or duplicates *)
    check_binds "formal" func.formals;

    (* Raise an exception if the given rvalue type cannot be assigned to
       the given lvalue type *)
    let check_assign lvaluet rvaluet err =
       if lvaluet = rvaluet then lvaluet else raise (Failure err)
    in   

    let rec get_dims = function
        MatrixLit l -> List.length l :: get_dims (List.hd l)
      | _ -> []
    in
    (* Raise an exception if dimensions of Matrix are not balanced *)
    let rec flatten d = function
      [] -> []
      | MatrixLit hd::tl -> if List.length hd != List.hd d then raise (Failure("Invalid dims")) else List.append (flatten (List.tl d) hd) (flatten d tl)
      | a -> a
    in

    (* Build local symbol table of variables for this function *)
    let symbols = List.fold_left (fun m (ty, name) -> StringMap.add name ty m)
	                StringMap.empty (globals @ func.formals )
    in

    (* Return a variable from our local symbol table *)
    let type_of_identifier s =
      try StringMap.find s symbols
      with Not_found -> raise (Failure ("undeclared identifier " ^ s))
    in

    let matrix_access_type = function
        (_, _) -> Float
      | _ -> raise (Failure ("Illegal matrix access"))
    in

    (* Return a semantically-checked expression, i.e., with a type *)
    let rec expr symbols = function
        Literal  l -> (Int, SLiteral l)
      | Fliteral l -> (Float, SFliteral l)
      | SLiteral l -> (String, SSLiteral)
      | Noexpr     -> (Void, SNoexpr)
      | Id s       -> (type_of_identifier s, SId s)
      | MLiteral l -> 
          let d = get_dims (MLiteral l) in
          let rec all_match = function
            [] -> ignore()
            | hd::tl -> if tl != [] then
                          let (t1, _) = expr hd in let (t2, _) = expr (List.hd tl) in
                          if t1 = t2 then all_match tl else raise (Failure ("Data Mismatch in MLiteral: " ^ string_of_typ t1 ^ " does not match " ^ string_of_typ t2))
                        else ignore()
          in
          all_match l;
          if List.length d > 2 then (Matrix, SMLiteral ((List.map expr l), List.hd d, List.hd (List.tl d)))
          else if List.length d = 2 then (Matrix, SMLiteral ( (List.map expr (flatten (List.tl d) l)), List.hd d, List.hd (List.tl d)))
          else if List.length d = 1 then (Matrix, SMLiteral ( (List.map expr (flatten (List.tl d) l)), List.hd d, 1))
          else (Matrix, SMLiteral ( (List.map expr l), 0,0))
      | MatrixAccess(v, e1, e2) ->
          let _ = (match (expr symbols e1) with
                Int -> Int
              | _ -> raise (Failure ("Attempting to access with a non-integer type")))
            and _ = (match (expr symbols e2) with
                Int -> Int
              | _ -> raise (Failure ("Attempting to access with a non-integer type")))
            in matrix_access_type (type_of_identifier v symbols)
      | ImageRedAccess v ->
          if StringMap.mem v symbols
            then if (type_of_identifier v symbols) = Image then Matrix
            else raise (Failure ("Cannot call .red on non image datatype"))
          else raise (Failure ("Image variable does not exist"))
      | ImageGreenAccess (v) ->
          if StringMap.mem v symbols
            then if (type_of_identifier v symbols) = Image then Matrix
            else raise (Failure ("Cannot call .green on non image datatype"))
          else raise (Failure ("Image variable does not exist"))
      | ImageBlueAccess (v) ->
          if StringMap.mem v symbols
            then if (type_of_identifier v symbols) = Image then Matrix
            else raise (Failure ("Cannot call .blue on non image datatype"))
          else raise (Failure ("Image variable does not exist"))
      | ImageGrayscaleAccess (v) ->
          if StringMap.mem v symbols
            then if (type_of_identifier v symbols) = Image then Matrix
            else raise (Failure ("Cannot call .grayscale on non image datatype"))
          else raise (Failure("Image variable does not exist"))
      | MatrixRows (v) ->
          if StringMap.mem v symbols
            then if (type_of_identifier v symbols) = Matrix then Int
            else raise (Failure ("Cannot call .rowsize on non matrix datatype"))
          else raise (Failure ("Matrix variable does not exist"))
      | MatrixCols (v) ->
          if StringMap.mem v symbols
            then if (type_of_identifier v symbols) = Matrix then Int
            else raise (Failure ("Cannot call .colsize on non matrix catatype"))
          else raise (Failure ("Matrix variable does not exist"))
      | Assign(var, e) as ex -> 
          let lt = type_of_identifier var
          and (rt, e') = expr e in
          let err = "illegal assignment " ^ string_of_typ lt ^ " = " ^ 
            string_of_typ rt ^ " in " ^ string_of_expr ex
          in (check_assign lt rt err, SAssign(var, (rt, e')))
      | Unop(op, e) as ex -> 
          let (t, e') = expr e in
          let ty = match op with
            Neg when t = Int || t = Float -> t
          | Not when t = Bool -> Bool
          | _ -> raise (Failure ("illegal unary operator " ^ 
                                 string_of_uop op ^ string_of_typ t ^
                                 " in " ^ string_of_expr ex))
          in (ty, SUnop(op, (t, e')))
      | Binop(e1, op, e2) as e -> 
          let t1 = expr e1 
          and t2 = expr e2 in
          (* All binary operators require operands of the same type *)
          let same = t1 = t2 in
          let t1_dim = (match t1 with 
            (m,n) -> (m, n)
            | _ -> (-1, -1)
          ) in
          let t2_dim = (match t2 with 
            (m,n) -> (m, n)
            | _ -> (-1, -1)
          )
          (* in let t1_dim = match t1 with *)

          (* Determine expression type based on operator and operand types *)
          in (match op with
            Add | Sub | Mult | Div when same && t1 = Int   -> Int
          | Add | Sub | Mult | Div when same && t1 = Float -> Float
          | Equal | Neq            when same               -> Int
          | Less | Leq | Greater | Geq
                     when same && (t1 = Int || t1 = Float) -> Int
          | And | Or when same && t1 = Int -> Int
          (* matrix op matrix *)
          (* | Add when t1 = (hd t1_dim, tl t1_dim)
                      && t2 = (hd t2_dim, tl t2_dim)
                      if same then (hd t1_dim, tl t1_dim)
                      else raise (Failure ("Cannot add/subtract matrices with different dimensions")) *)
          | Add when same && t1 = (hd t1_dim, tl t1_dim) && t2 = (hd t2_dim, tl t2_dim) ->
                    (hd t1_dim, tl t1_dim)
                  (* else raise (Failure ("Cannot add/subtract matrices with different dimensions")) *)
          | Mult when t1 = (hd t1_dim, tl t1_dim)
                  && t2 = (hd t2_dim, tl t2_dim) -> 
                  if (tl t1_dim = hd t2_dim) then (hd t1_dim, tl t2_dim)
                  else raise (Failure ("Matrices cannot be multiplied given their dimensions"))
          (* matrix op scalar *)
          | Add | Exp | Mult when not same 
                              && t1 = (hd t1_dim, tl t1_dim) && ((t2 = Int) || (t2 = Float))
                                -> (hd t1_dim, tl t1_dim)
                              (* else raise (Failure ("Scalar ops with matrices can only use ints or floats")) *)
          (* Scalar op matrix *)
          | Add | Exp | Mult when not same
                              && t2 = hd t2_dim, tl t2_dim ->
                              if ((t1 = Int) || (t1 = Float)) then (hd t2_dim, tl t2_dim)
                              else raise (Failure ("Scalar ops with matrices can only use ints or doubles"))
          | _ -> raise (Failure ("illegal binary operator")))
	      (* Failure ("illegal binary operator " ^
                       string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
                       string_of_typ t2 ^ " in " ^ string_of_expr e))
          in (ty, SBinop((t1, e1'), op, (t2, e2'))) *)
      | Call(fname, args) as call -> 
          let fd = find_func fname in
          let param_length = List.length fd.formals in
          if List.length args != param_length then
            raise (Failure ("expecting " ^ string_of_int param_length ^ 
                            " arguments in " ^ string_of_expr call))
          else let check_call (ft, _) e = 
            let (et, e') = expr e in 
            let err = "illegal argument found " ^ string_of_typ et ^
              " expected " ^ string_of_typ ft ^ " in " ^ string_of_expr e
            in (check_assign ft et err, e')
          in 
          let args' = List.map2 check_call fd.formals args
          in (fd.typ, SCall(fname, args'))
    in

    let check_bool_expr e = 
      let (t', e') = expr e
      and err = "expected Boolean expression in " ^ string_of_expr e
      in if t' != Int then raise (Failure err) else (t', e')
    in

    (* Return a semantically-checked statement i.e. containing sexprs *)
    let rec check_stmt = function
        Expr e -> SExpr (expr e)
      | If(p, b1, b2) -> SIf(check_bool_expr p, check_stmt b1, check_stmt b2)
      | For(e1, e2, e3, st) ->
	  SFor(expr e1, check_bool_expr e2, expr e3, check_stmt st)
      | While(p, s) -> SWhile(check_bool_expr p, check_stmt s)
      | Return e -> let (t, e') = expr e in
        if t = func.typ then SReturn (t, e') 
        else raise (
	  Failure ("return gives " ^ string_of_typ t ^ " expected " ^
		   string_of_typ func.typ ^ " in " ^ string_of_expr e))
	    
	    (* A block is correct if each statement is correct and nothing
	       follows any Return statement.  Nested blocks are flattened. *)
      | Block sl -> 
          let rec check_stmt_list = function
              [Return _ as s] -> [check_stmt s]
            | Return _ :: _   -> raise (Failure "nothing may follow a return")
            | Block sl :: ss  -> check_stmt_list (sl @ ss) (* Flatten blocks *)
            | s :: ss         -> check_stmt s :: check_stmt_list ss
            | []              -> []
          in SBlock(check_stmt_list sl)

      | Variable (typ, varname, e) as call ->
          if StringMap.mem varname symbols then
            (ignore (expr symbols e) ; symbols)
          else
            let expr_type = expr symbols e
            in if expr_type = Void then
              check_stmt (StringMap.add varname expr_type symbols) call
            else if typ = Matrix then
              check_stmt (StringMap.add varname expr_type symbols) call (* Adds the dimmat *)
            else if typ = expr_type then
              check_stmt (StringMap.add varname expr_type symbols) call
            else raise (Failure ("Local var type does not match"))
      | MatrixAssign (typ, varname, e) as call ->
          if StringMap.mem varname symbols then
            (ignore (expr symbols e) ; symbols)
          else
            let expr_type = expr symbols e in
            if expr_type = Matrix && typ = Float then
              check_stmt (StringMap.add varname expr_type symbols) call
              (* SMatrixAssign(typ, varname, expr e, hd get_dims expr e, tl get_dims expr e) *)
            else raise (Failure ("Local var type does not match"))
            (* else if typ = Float then
              check_stmt (StringMap.add varname expr_type symbols) call (* Adds the dimmat *)
            else if typ = expr_type then
              check_stmt (StringMap.add varname expr_type symbols) call *)
      | MatrixAccessAssign (mat_name, idx_row, idx_col, vall) as call ->
          if StringMap.mem mat_name symbols then
          (ignore (expr symbols vall) ; symbols)
            else 
              let expr1_type = expr symbols idx_row 
              in let expr2_type = expr symbols idx_col
              in let expr3_type = expr symbols vall in
              if (expr1_type = Int && expr2_type = Int && expr3_type = Float)
              then check_stmt (StringMap.add mat_name expr3_type symbols) call
              else 
              raise (Failure ("Invalid type for Matrix Access Assignment"))

    in (* body of check_function *)
    { styp = func.typ;
      sfname = func.fname;
      sformals = func.formals;
      sbody = match check_stmt (Block func.body) with
	SBlock(sl) -> sl
      | _ -> raise (Failure ("internal error: block didn't become a block?"))
    }
  in List.map check_function functions