open Ast
open Sast

module StringMap = Map.Make(String)

let check (globals, functions) =
  (* Verify a list of bindings has no void types or duplicate names *)
  let check_binds (kind : string) (binds : bind list) =
    List.iter (function
	(Void, b, _, _) -> raise (Failure ("illegal void " ^ kind ^ " " ^ b))
      | _ -> ()) binds;
    let rec dups = function
        [] -> ()
      |	((_,n1,_,_) :: (_,n2,_,_) :: _) when n1 = n2 ->
	  raise (Failure ("duplicate " ^ kind ^ " " ^ n1))
      | _ :: t -> dups t
    in dups (List.sort (fun (_,a,_, _) (_,b,_, _) -> compare a b) binds)
  in

  (**** Check global variables ****)

  check_binds "global" globals;

  (* Collect function declarations for built-in functions: no bodies *)
  let built_in_decls =
    let add_bind map (data_ty, name, args) = StringMap.add name {
      data_type = data_ty;
      function_name = name;
      arguments = args;
      local_vars = []; body = [] } map
    in List.fold_left add_bind StringMap.empty [ (Void, "print", [(String, "x", (-1, -1), Noexpr)]);
                                                 (Double, "sqrt", [(Double, "x", (-1, -1), Noexpr)]);
                                                 (Double, "log", [(Double, "x", (-1, -1), Noexpr)]);
                                                 (Matrix, "fill", [(Int, "r", (-1, -1), Noexpr); (Int, "c", (-1, -1), Noexpr); (Double, "num", (-1, -1), Noexpr);]);
                                                 (Matrix, "inv", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Double, "det", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Double, "tr", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Double, "max_eigvalue", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Double, "norm2", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Int, "sizeof_row", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Int, "sizeof_col", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Double, "norm1", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Matrix, "sum_row", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Matrix, "sum_col", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Matrix, "mean_row", [(Matrix, "x", (-1, -1), Noexpr)]);
                                                 (Matrix, "mean_col", [(Matrix, "x", (-1, -1), Noexpr)])];
  in

  (* Add function name to symbol table *)
  let add_func map fd =
    let built_in_err = "function " ^ fd.function_name ^ " may not be defined"
    and dup_err = "duplicate function " ^ fd.function_name
    and make_err er = raise (Failure er)
    and n = fd.function_name (* Name of the function *)
    in match fd with (* No duplicate functions or redefinitions of built-ins *)
         _ when StringMap.mem n built_in_decls -> make_err built_in_err
       | _ when StringMap.mem n map -> make_err dup_err
       | _ ->  StringMap.add n fd map
  in

  (* Collect all function names into one symbol table *)
  let function_decls = List.fold_left add_func built_in_decls functions
  in

  (* Return a function from our symbol table *)
  let find_func s =
    try StringMap.find s function_decls
    with Not_found -> raise (Failure ("unrecognized function " ^ s))
  in

  let _ = find_func "main" in (* Ensure "main" is defined *)
  let is_number typ =
    if typ = Int || typ = Double || typ = Bool then true
    else false
  in
  let higher_type t1 t2 =
    if t1 = Double || t2 = Double then Double
    else if t1 = Int || t2 = Int then Int
    else Bool
  in
  (* Raise an exception if the given rvalue type cannot be assigned to
       the given lvalue type *)
  let check_assign lvaluet rvaluet err overload =
    let (t1,s1) = lvaluet and (t2,s2) = rvaluet in
    let (r1,c1) = s1 and (r2,c2) = s2 in
    if t1 = t2 && t1 != Matrix then rvaluet
    else if t1 = t2 && t1 = Matrix && s1=s2 then lvaluet
    else if t1 = t2 && t1 = Matrix && ((r1 = r2 || r1 = -1 || r2 = -1) && (c1 = c2 || c1 = -1 || c2 = -1)) then (t1,((max r1 r2),(max c2 c2)))
    else if (is_number t1) = true && (is_number t2) = true then ((higher_type t1 t2), s2)
    else if overload = true then rvaluet
    else raise (Failure err)
  in

  let type_of_identifier symbols s =
    try StringMap.find s symbols
    with Not_found -> raise (Failure ("undeclared identifier " ^ s))
  in
  
  let get_num_from_expr e =
    match e with
      SIntLit l -> l
    | _ -> raise (Failure("IntLit expeceted"))
  in
  let get_builtin_size fname args =
    match fname with
      "size" -> (true, (1, 2))
    | "inv" | "det" | "tr" -> let (( _, s), _) = List.hd args in (true, s)
    | "sum_col" | "mean_col" -> let ((_, (_, c)), _) = List.hd args in (true, (1, c))
    | "sum_row" | "mean_row" -> let ((_, (r, _)), _) = List.hd args in (true, (r, 1))
    | "fill" -> let (_, e1) = List.hd args and (_, e2) = List.hd (List.tl args) in (true, (get_num_from_expr e1, get_num_from_expr e2))
    | _ -> (false, (-1, -1))
  in
  (* Return a semantically-checked expression, i.e., with a type *)
  let rec expr symbols = function
      IntLit  l -> ((Int, (-1, -1)), SIntLit l)
    | DoubleLit l -> ((Double, (-1, -1)), SDoubleLit l)
    | BoolLit l  -> ((Bool, (-1, -1)), SBoolLit l)
    | StrLit  l -> ((String, (-1, -1)), SStrLit l)
    | MatrixLit l ->
        let row_size = Array.length(l) in
        let col_size = Array.length(l.(0)) in
        ((Matrix, (row_size, col_size)), SMatrixLit l)
    | Noexpr     -> ((Void, (-1, -1)), SNoexpr)
    | Id s       -> (type_of_identifier symbols s, SId s)
    | Matrix1DElement(m, i) as ex ->
        let (t, s) = type_of_identifier symbols m in
        let ((t2, s2), e2') = expr symbols i in
        if t2 != Int then
          raise (Failure ("illegal 1D matrix index type " ^ string_of_typ(t2) ^
                          "for matrix " ^ m))
        else if t != Matrix then
          raise (Failure ("illegal 1D matrix operation, matrix type expected, get "
                          ^ string_of_typ(t)))
        else
          let (_, c) = s in
          let ty = match e2' with
          SIntLit l when c != -1 && l >= c -> raise (Failure ("expression " ^ string_of_expr ex ^
                                                        " out of boundary, matrix size: (" ^ string_of_int(c) ^ ")"))
          | _ -> (Double, (-1, -1))
          in (ty, SMatrix1DElement(m, ((t2, s2), e2')))
    | Matrix2DElement(m, i1, i2) as ex ->
        let (t, s) = type_of_identifier symbols m in
        let ((t1, s1), e1') = expr symbols i1 in
        let ((t2, s2), e2') = expr symbols i2 in
        if t1 != Int || t2 != Int then
          raise (Failure ("illegal 2D matrix index type [" ^ string_of_typ(t1) ^ ", " ^
                          string_of_typ(t2) ^ "] for matrix " ^ m))
        else if t != Matrix then
          raise (Failure ("illegal 2D matrix operation, matrix type expected, get " ^ string_of_typ(t)))
        else
          let (r, c) = s in
          let _ = match e1' with
          SIntLit l when r != -1 && l >= r -> raise (Failure ("expression " ^ string_of_expr ex ^
                                                        " out of boundary, matrix size: (" ^ string_of_int(r) ^ ", " ^ string_of_int(c)))
          | _ -> ""
          and _ = match e2' with
          SIntLit l when c != -1 && l >= c -> raise (Failure ("expression " ^ string_of_expr ex ^
                                                        " out of boundary, matrix size: (" ^ string_of_int(r) ^ ", " ^ string_of_int(c)))
          | _ -> ""
          in ((Double, (-1, -1)), SMatrix2DElement(m, ((t1, s1), e1'), ((t2, s2), e2')))
    | Matrix1DModify(m, i, e) as ex ->
        let (t, s) = type_of_identifier symbols m in
        let ((t1, s1), e1') = expr symbols i in
        let ((t2, s2), e2') = expr symbols e in
        if t1 != Int then
          raise (Failure ("illegal 1D matrix index type " ^ string_of_typ(t1) ^
                          "for matrix " ^ m))
        else if t != Matrix then
          raise (Failure ("illegal 1D matrix operation, matrix type expected, get "
                            ^ string_of_typ(t)))
        else if is_number(t2) = false then
          raise (Failure ("illegal 1D matrix assignment, number expected, get " ^ string_of_typ(t2) ))
        else
          let (_, c) = s in
          let ty = match e1' with
          SIntLit l when c != -1 && l >= c -> raise (Failure ("expression " ^ string_of_expr ex ^
                                                          " out of boundary, matrix size: (" ^  string_of_int(c) ^ ")"))
          | _ -> (Double, (-1, -1))
          in (ty, SMatrix1DModify(m, ((t1, s1), e1'), ((t2, s2), e2')))
    | Matrix2DModify(m, (i1, i2), e) as ex ->
        let (t, s) = type_of_identifier symbols m in
        let ((t1, s1), e1') = expr symbols i1 in
        let ((t2, s2), e2') = expr symbols i2 in
        let ((t3, s3), e3') = expr symbols e in
        if t1 != Int || t2 != Int then
          raise (Failure ("illegal 2D matrix index type [" ^ string_of_typ(t1) ^ ", " ^
                          string_of_typ(t2) ^ "] for matrix " ^ m))
        else if t != Matrix then
          raise (Failure ("illegal 2D matrix operation, matrix type expected, get " ^ string_of_typ(t)))
        else if is_number(t3) = false then
          raise (Failure ("illegal 2D matrix assignment, number expected, get " ^ string_of_typ(t3) ))
        else
          let (r, c) = s in
          let _ = match e1' with
          SIntLit l when r != -1 && l >= r -> raise (Failure ("expression " ^ string_of_expr ex ^
                                                        " out of boundary, matrix size: (" ^ string_of_int(r) ^ ", " ^ string_of_int(c)))
          | _ -> ""
          and _ = match e2' with
          SIntLit l when c != -1 && l >= c -> raise (Failure ("expression " ^ string_of_expr ex ^
                                                        " out of boundary, matrix size: (" ^ string_of_int(r) ^ ", " ^ string_of_int(c)))
          | _ -> ""
          in ((Double, (-1, -1)), SMatrix2DModify(m, (((t1, s1), e1'), ((t2, s2), e2')), ((t3, s3), e3')))
    | Range(e1, e2) as ex ->   (*TODO*)
          let ((t1, s1), e1') = expr symbols e1
          and ((t2, s2), e2') = expr symbols e2 in
          if t1 != Int || t2 != Int then
            raise (Failure ("illegal Range type, int : int expected, get "
                            ^ string_of_typ(t1) ^ " : " ^ string_of_typ(t2) ^ " in " ^
                            string_of_expr(e1) ^ " : " ^ string_of_expr(e2)))
          else
            let a1 = match e1' with
            SIntLit l -> l
            | _ -> 0
            and a2 = match e2' with
            SIntLit l -> l
            | _ -> 1
            in
            if a1 >= a2 then raise( Failure("Invalid argument in " ^ string_of_expr(ex)))
            else ((Int, (-1, -1)), SRange(((t1, s1), e1'), ((t2, s2), e2')))
    | Assign(var, e) as ex ->
        let (lt,s1) = type_of_identifier symbols var
        and ((rt,s2), e') = expr symbols e in
        let err = "illegal assignment " ^ string_of_typ lt ^ string_of_size(s1) ^ " = " ^
          string_of_typ rt^ string_of_size(s2) ^ " in " ^ string_of_expr ex in
        let check_res = check_assign (lt,s1) (rt,s2) err false in
        let _ = StringMap.add var check_res symbols
        in (check_res, SAssign(var, ((rt,s2), e')))
    | Unop(op, e) as ex ->
        let ((t, s), e') = expr symbols e in
        let (r, c) = s in
        let ty = match op with
          Neg when t = Int || t = Double -> (t, s)
        | Not when t = Bool -> (Bool, s)
        | Abs when t = Int || t = Double || t = Matrix -> (t, s)
        | Transpose when t = Matrix -> (t, (c, r))
        | _ -> raise (Failure ("illegal unary operator: " ^
                               string_of_uop op ^ string_of_typ t ^
                               " in " ^ string_of_expr ex))
        in (ty, SUnop(op, ((t,s), e')))
    | Binop(e1, op, e2) as e ->
        let ((t1, s1), e1') = expr symbols e1
        and ((t2, s2), e2') = expr symbols e2 in
        let (r1, c1) = s1 and (r2, c2) = s2 in
        (* All binary operators require operands of the same type *)
        let same = t1 = t2 in
        let sameMsize = t1 = Matrix && t2 = Matrix && (s1 = s2 || s1 = (-1, -1) || s2 = (-1, -1)) in
        let sameRTsize = if s1 = (-1, -1) || s2 = (-1, -1) then (-1, -1) else s1 in
        let matchMsize = t1 = Matrix && t2 = Matrix && (c1 = r2 || s1 = (-1, -1) || s2 = (-1, -1)) in
        let matchRTsize = if s1 = (-1, -1) || s2 = (-1, -1) then (-1, -1) else (r1, c2) in
        (* Determine expression type based on operator and operand types *)
        let ty = match op with
          Add | Sub | Mult | Div            when is_number t1 && is_number t2 -> ((higher_type t1 t2), s1)
        | Add | Sub                       when same && t1 = Matrix && sameMsize -> (Matrix, sameRTsize)
        | Mult                           when same && t1 = Matrix && matchMsize -> (Matrix, matchRTsize)
        | Add | Sub | Mult        when (t1 = Double || t1 = Int) && t2 = Matrix -> (Matrix, s2)
        | Add | Sub | Mult | Div  when t1 = Matrix && (t2 = Double || t2 = Int) -> (Matrix, s1)
        | Pow                        when (t1 = Int || t1 = Double) && (t2 = Int || t2 = Double) -> (higher_type t1 t2, s1)
        | Dotmul | Dotdiv | Dotpow when t1 = Matrix && t2 = Matrix && sameMsize -> (Matrix, sameRTsize)
        | Equal | Neq            when same                       -> (Bool, s1)
        | Less | Leq | Greater | Geq
                   when same && (t1 = Int || t1 = Double)        -> (Bool, s1)
        | And | Or when same && t1 = Bool -> (Bool, s1)
        | _ -> raise (
                  Failure ("illegal binary operator: " ^
                     string_of_typ t1 ^ " of size (" ^ string_of_int(r1) ^ "," ^string_of_int(c1) ^ ") " ^ string_of_op op ^ " " ^
                     string_of_typ t2 ^ " of size (" ^ string_of_int(r2) ^ "," ^string_of_int(c2) ^ ") in " ^ string_of_expr e))
        in (ty, SBinop(((t1,s1), e1'), op, ((t2,s2), e2')))
    | MatrixOp(e1, op, e2) as e ->
        let (t1, (r1, c1)) = type_of_identifier symbols e1
        and (t2, (r2, c2)) = type_of_identifier symbols e2 in
        let getnum n1 n2 = if n1 = -1 || n2 = -1 then -1 else n1 in
        let getsum n1 n2 = if n1 = -1 || n2 = -1 then -1 else n1 + n2 in
        let ty = match op with
          Comma when t1 = Matrix && t2 = Matrix && (r1 = r2 || r1 = -1 || r2 = -1) -> (Matrix, ((getnum r1 r2), (getsum c1 c2)))
        | Semi  when t1 = Matrix && t2 = Matrix && (c1 = c2 || c1 = -1 || c2 = -1) -> (Matrix, ((getsum r1 r2), (getnum c1 c2)))
        | _ -> raise( Failure("illegal Matrix Concat operator: " ^
                        string_of_typ t1 ^ " of size (" ^ string_of_int(r1) ^ "," ^string_of_int(c1) ^ ") " ^ string_of_op op ^ " " ^
                        string_of_typ t2 ^ " of size (" ^ string_of_int(r2) ^ "," ^string_of_int(c2) ^ ") in " ^ string_of_expr e))
        in (ty, SMatrixOp(e1, op, e2))
    (* | Call(fname, args) as call ->  *)
    | Call(fname, args) as call ->
        let fd = find_func fname in
        let param_length = List.length fd.arguments in
        if List.length args != param_length then
          (* raise (Failure ("expecting " ^ string_of_int param_length )) *)
          raise (Failure ("expecting " ^ string_of_int param_length ^
                          " arguments in " ^ string_of_expr call))
        else let check_call (ft, _, (row_size, col_size), _) e =
          let ((et,es), e') = expr symbols e in
          let err = "illegal argument found " ^ string_of_typ et ^ " " ^ string_of_expr e ^
            " expected " ^ string_of_typ ft ^ " in " ^ string_of_expr call in
          let overload = (fname = "print")
          in (check_assign (ft,(row_size,col_size)) (et,es) err overload, e')
        in
        let args' = List.map2 check_call fd.arguments args in
        let (is_builtin, builtin_size) = get_builtin_size fname args' in
        let ret_size =
        if is_builtin then builtin_size
        else
          if fd.data_type = Matrix then
            let get_new_args formal arg =
              let ((t, size), _) = arg in
              let (_, name, _, e) = formal in
              (t, name, size, e)
            in
            let get_new_f func args =
              {
                data_type = func.data_type;
                function_name = func.function_name;
                arguments = List.map2 get_new_args func.arguments args;
                local_vars = func.local_vars;
                body = func.body;
              } in
            let new_f = (get_new_f fd args') in
            let sf = (check_function true new_f) in
            sf.ssize
          else (-1, -1)
        in
        ((fd.data_type, ret_size), SCall(fname, args'))
  and
  check_bool_expr symbols e =
    let ((t', s), e') = expr symbols e
    and err = "expected Boolean expression in " ^ string_of_expr e
    in if t' != Bool then raise (Failure err) else ((t', s), e')
  and
  (* Return a semantically-checked statement i.e. containing sexprs *)
  get_assign_expr symbols b =
    let (typ, name, size, e) = b in
    let e' =
    if e = Noexpr then ((Void, (-1, -1)), SNoexpr)
    else let (_, e_assign) = expr symbols (Assign (name, e)) in
    match e_assign with
      SAssign(_, e') -> e'
      | _ -> raise (Failure ("internal error: declare and assign"))
    in
    (typ, name, size, e')
  and
  (**** Check functions ****)
  check_function infer_type func =
    (* Make sure no formals or locals are void or duplicates *)
    check_binds "arguments" func.arguments;
    check_binds "local_vars" func.local_vars;

    (* Build local symbol table of variables for this function *)
    let symbols = List.fold_left (fun m (ty, name, (row_size, col_size), _) -> StringMap.add name (ty, (row_size, col_size)) m)
	                StringMap.empty (globals @ func.arguments @ func.local_vars )
    in
    let rec check_stmt symbols = function
      Expr e -> SExpr (expr symbols e)
    | If(p, b1, b2) -> SIf(check_bool_expr symbols p, check_stmt symbols b1, check_stmt symbols b2)
    | For(e1, e2, e3, st) ->
      SFor(expr symbols e1, check_bool_expr symbols e2, expr symbols e3, check_stmt symbols st)
    | ForRange(var, e, st) ->     (*TODO*)
      SForRange(var, expr symbols e, check_stmt symbols st)
    | While(p, s) -> SWhile(check_bool_expr symbols p, check_stmt symbols s)
    | Return e ->
        let ((t, s), e') = expr symbols e in
        if t = func.data_type then SReturn ((t, s), e')
        else raise (
            Failure ("return gives " ^ string_of_typ t ^ " expected " ^
            string_of_typ func.data_type ^ " in " ^ string_of_expr e))

    (* A block is correct if each statement is correct and nothing
      follows any Return statement.  Nested blocks are flattened. *)
    | Block sl ->
        let rec check_stmt_list symbols = function
            [Return _ as s] -> [check_stmt symbols s]
          | Return _ :: _   -> raise (Failure "nothing may follow a return")
          | Block sl :: ss  -> check_stmt_list symbols (sl @ ss) (* Flatten blocks *)
          | s :: ss         -> check_stmt symbols s :: check_stmt_list symbols ss
          | []              -> []
        in SBlock(check_stmt_list symbols sl)
    | Break e ->
        let ((t, s), e') = expr symbols e in
        if e' != SNoexpr then raise (Failure("break stmt should include Noexpr, " ^
                                              string_of_expr e ^ " found"))
        else SBreak((t, s), e')
    | Continue e ->
        let ((t, s), e') = expr symbols e in
        if e' != SNoexpr then raise (Failure("continue stmt should include Noexpr, " ^
                                              string_of_expr e ^ " found"))
        else SContinue((t, s), e')
    in
    (* Return a variable from our local symbol table *)
    { styp = func.data_type;
      sfname = func.function_name;
      sargs = List.map (get_assign_expr symbols) func.arguments;
      slocals  = List.map (get_assign_expr symbols) func.local_vars;
      sbody =
      if infer_type = false then
        let no_return = (List.length func.body) = 0 ||
          match (List.hd (List.rev func.body)) with
            Return(_) -> false
            | _ -> true
        in
        if func.data_type != Void && no_return then
          raise( Failure("function " ^ func.function_name ^ " has no return"))
        else
        match check_stmt symbols (Block func.body) with
          SBlock(sl) -> sl
          | _ -> raise (Failure ("internal error: block didn't become a block?"))
      else [];
      ssize =
      if infer_type = true && func.data_type = Matrix then
        let sblock =
          match check_stmt symbols (Block func.body) with
            SBlock(sl) -> sl
            | _ -> raise (Failure ("internal error: block didn't become a block?"))
        in
        let last_stmt = (List.hd (List.rev sblock)) in
        match last_stmt with
          SReturn((_, s), _) -> s
          | _ -> raise (Failure ("last stmt in func is not return in func " ^ func.function_name))
      else (-1, -1);
    }
  in
  let global_symbols = List.fold_left (fun m (ty, name, (row_size, col_size), _) -> 
    StringMap.add name (ty, (row_size, col_size)) m) StringMap.empty (globals)
  in (List.map (get_assign_expr global_symbols) globals, List.map (check_function false) functions)
