(* COMS W4115, COAL, Eliot Scull, CUID: C000056091 *)

open Types
open Sym
open Printf

let debug = false

(* global environment that sticks around for type checking of all modules *)
let genv = ref {parent=None; context="*GLOBAL*"; scope=[]}

let addbuiltin name t = 
    ignore(addsym (!genv) name t)
      

(* Based on Algorithm 6.16, "Type inference for polymorphic functions", page 393,
    and Algorithm 6.19 "Unification of a pair of nodes in a type graph", page 397, from 
    Golden Dragon Book.  Instead of returning true or false, return unit type and
    throw exception on failure. *)
let rec unify s t =
    
    (* This method is used to check for cycles before one type variable 
             is assigned to another.   This is important for when recursion is
             used in COAL.  Use of this method was prompted by cycles discovered
             during the semantic check of the gcd function in the recursion tests. *)
    let rec contains a b = 
      if debug then printf "does %s contain %s?\n" (string_of_type a) (string_of_type b);
      let res = 
        (a==b) ||
        match a with 
          Var({contents=inner_a}) -> contains inner_a b
        | _ -> (a==b)
      in
      if debug then printf "%s.\n" (if res then "yes" else "no");
      res
      
    in
    if debug then printf "unify %s %s\n" (string_of_type s) (string_of_type t);
    match s, t with
    
      (* physically the same type node or one contains the other *)
      _, _ when (s==t) -> ()
    
      (* same basic type*)
    | Types.Num, Types.Num 
    | Types.NumArr, Types.NumArr -> () 
    
      (* function "op-node" *)
    | Types.Func(rt1, at1), Types.Func(rt2, at2) -> 
        (* the union of these Func type operators will happen implicitly through the union
                         of their children *)
                         
        (* unify argument lists *)
        (
          try
            List.iter2 unify at1 at2
          with Invalid_argument(_) -> raise (Type_mismatch(Types.Func(rt1, at1), Types.Func(rt2, at2)))
        );
                 
        (* unify return types *)
        unify rt1 rt2;

      (* s or t represents a variable *)
    | Types.Var({contents=Types.Tbd} as var_s), _ ->
        (* union where s becomes t *)
        if (not (contains t s)) then var_s := t 

    | _, Types.Var({contents=Types.Tbd} as var_t) ->
        (* union where t becomes s *)
        if (not (contains s t)) then var_t := s
        
        
      (* s or t has been defined.  Simply strip off Var and unify. *)
    | Types.Var({contents=inner_s}), _ -> unify inner_s t
    | _, Types.Var({contents=inner_t}) -> unify s inner_t
    
      (* cannot unify *)
    | _, _ -> 
    raise (Type_mismatch(s, t))


(* check expressions *)
let rec check_exp env = function

  (* leaf expression nodes *)
    Ast.Real(fps, isint) -> 
    if debug then printf "%s\n" fps;
    Sast.Real(fps, isint), Types.Num
    
  | Ast.Imag(fps) -> 
    if debug then printf "%s\n" fps;
    Sast.Imag(fps), Types.Num
    
  | Ast.Id(name) -> 
	if debug then printf "Id(%s)\n" name;
    Sast.Id(name), findtyp env name
    
  | Ast.Negate(e) ->
    if debug then printf "Negate\n";
    let se = check_exp env e in
    unify Types.Num (snd se);
    Sast.Negate(se), Types.Num
  
  (* a + b *)
  | Ast.Binop(e1, op, e2) ->
    if debug then printf "Binop\n";
    let se1 = check_exp env e1
    and se2 = check_exp env e2 in
    unify Types.Num (snd se1);
    unify Types.Num (snd se2);
    Sast.Binop(se1, op, se2), Types.Num
    
  (* a <- 1 *)
  | Ast.Assign(id, e) ->
    if debug then printf "Assign\n";
    (* be sure to check rhs first so we can check for illegal usages of *)
    (* assigned variable before it is assigned *)
    let se = check_exp env e
    and idt =
        (* if this symbol hasn't been defined yet, them add it to environment *)
        try 
            findtyp env id 
        with Symbol_not_found(_) -> addsym env id (Types.fresh ())
    in
    unify idt (snd se);
    Sast.Assign(id, se), (snd se)
    
  (* a[1] *)
  | Ast.GetElem(e1, e2) ->
    if debug then printf "GetElem\n";
    let se1 = check_exp env e1
    and se2 = check_exp env e2 in
    unify Types.NumArr (snd se1);
    unify Types.Num (snd se2);
    Sast.GetElem(se1, se2), Types.Num
    
  (* a[1] <- 5 *)
  | Ast.PutElem(e1, e2, e3) ->
    if debug then printf "PutElem\n";
    let se1 = check_exp env e1
    and se2 = check_exp env e2
    and se3 = check_exp env e3 in
    unify Types.NumArr (snd se1);
    unify Types.Num (snd se2);
    unify Types.Num (snd se3);
    Sast.PutElem(se1, se2, se3), Types.Num
    
  (* f{a} *)
  | Ast.Map(a_callee, e) ->
    if debug then printf "Map\n";
    let (id, s_callee) = handle_lambda env a_callee in
    let idt = findtyp env id in
    let se = check_exp env e in
    unify (Types.Func(Types.Num, [Types.Num])) idt;
    unify Types.NumArr (snd se);
    Sast.Map(s_callee, se), Types.NumArr
    
  (* a..b\c *)
  | Ast.Range(e1, e2, e3) ->
    if debug then printf "Range\n";
    let se1 = check_exp env e1
    and se2 = check_exp env e2
    and se3 = check_exp env e3 in
    unify Types.Num (snd se1);
    unify Types.Num (snd se2);
    unify Types.Num (snd se3);
    Sast.Range(se1, se2, se3), Types.NumArr
    
  (* f(1,2) *)
  | Ast.Invoke(a_callee, actual_elist) ->
    if debug then printf "Invoke\n";
    let (id, s_callee) = handle_lambda env a_callee in
    (* check actual argument expressions *)
    let actual_selist = List.map (fun e -> check_exp env e) actual_elist in
    let rt = (Types.fresh ()) in
    let invoke_ft = Types.Func(rt, List.map (fun se -> snd se) actual_selist) in
    (* retrieve formal argument types and compare with actuals *)
    let def_ft = findtyp env id in
    unify def_ft invoke_ft;
    Sast.Invoke(s_callee, actual_selist), rt
    
  (* f{0, a} *)
  | Ast.Reduce(a_callee, e1, e2) ->
    if debug then printf "Reduce\n";
    let (id, s_callee) = handle_lambda env a_callee in
    let idt = findtyp env id
    and accum_t = (Types.fresh ())
    and se1 = check_exp env e1
    and se2 = check_exp env e2 in
    unify (Types.Func(accum_t, [accum_t; Types.Num])) idt;
    unify accum_t (snd se1);
    unify Types.NumArr (snd se2);
    Sast.Reduce(s_callee, se1, se2), accum_t
    
  (* if a then b else c *)
  | Ast.IfThenElse(e1, e2, e3) ->
    if debug then printf "IfThenElse\n";
    let se1 = check_exp env e1
    and se2 = check_exp env e2
    and se3 = check_exp env e3 in
    if debug then printf "--unify condition\n";
    unify Types.Num (snd se1);
    if debug then printf "--unify condition done\n";
    if debug then printf "--unify then/else\n";
    unify (snd se2) (snd se3);
    if debug then printf "--unify then/else done\n";
    Sast.IfThenElse(se1, se2, se3), (snd se3)
    
  (* a;b;c;d... *)
  | Ast.Sequence(e1, e2) ->
    if debug then printf "Sequence\n";
    let se1 = check_exp env e1
    and se2 = check_exp env e2 in
    Sast.Sequence(se1, se2), (snd se2)
  
(* check function definition and recurse down body *)
and check_func_def env ast_fdef =
	if debug then printf "Func %s\n" ast_fdef.Ast.fname;
    (* function argument name/type pairs *)
    let n_t_pairs = List.fold_left (fun pairs arg -> (arg, (Types.fresh ())) :: pairs) [] ast_fdef.Ast.fargs in
    let n_t_pairs = List.rev n_t_pairs in
    let rt = (Types.fresh ()) in
    let ft = Types.Func(rt, List.map (fun p -> snd p) n_t_pairs) in
    
    (* store function type in global sym table so that we may use it to unify function invocations across modules *)
    let _ = addsym env ast_fdef.Ast.fname ft in
    
    (* store arguments in local sym table *)
    let env = newscope ast_fdef.Ast.fname env in
    let _ = List.iter (fun (n, t) -> ignore(addsym env n t)) n_t_pairs in
    
    (* type check body of function definition *)
    let sast_body = check_exp env ast_fdef.Ast.fbody in
    
    (* unify return type with that of type of body expression *)
    unify rt (snd sast_body);
    
    {Sast.fname = ast_fdef.Ast.fname;
     Sast.fargs = ast_fdef.Ast.fargs;
     Sast.fbody = sast_body;
     Sast.flocals = env }

and handle_lambda env = function
    (* not a lambda *)
    Ast.Named(fn) -> (fn, Sast.Named(fn))
    
    (* establish definition for lambda *)
  | Ast.Lambda(ast_fdef) ->
        let sast_fdef = check_func_def env ast_fdef in
        (sast_fdef.Sast.fname, Sast.Lambda(sast_fdef));;

(* generate list of sast function definitions from ast function definitions - *)
(* this is the entry point for module *)
let rec check ast_fdefs =
try
   List.map (check_func_def (!genv)) (List.rev ast_fdefs)
with e -> if debug then dump (!genv); raise e
