(* MatCV Semantic Checker *)

open Ast

(*module ReservedWords = Set.Make(String)*)

let keywords = ["row"; "col"; "ele"; "pixel"; "var"; "const"; "if"; "else"; "for"; "break"; "continue"; "exit"; "while"; "return"; "function"; "true"; "false"]

let builtInFunctions = ["print"; "main"]

let code = ref ([])

let getListSize lst = List.fold_left (fun acc _ -> acc + 1) 0 lst

let generateTypeForAnnotation() =
let rec genHelp = function
| 'z'::tail -> 'a' :: List.rev (genHelp (List.rev tail))
| c::tail -> (Char.chr ((Char.code c) + 1)) :: tail
| [] -> ['a']
in 
let updatedCode = genHelp !code
in code := updatedCode;
Annotation(String.concat "" (List.map Char.escaped updatedCode));;


let typeOfAexpr = function
| ALiteral(_, t) -> t
| ABoolLit(_, t) -> t
| AId(_, t) -> t
| AMatPlus (_, _, _, _, t)|AMatMinus (_, _, _, _, t) -> t
| AUnboundedAccessRead( _, _, _, t) -> t
| AUnboundedAccessWrite(_, _, _, _, t) -> t
| AMatAccess(_, _, _, _, t) -> t
| ABinaryOp(_, _, _, t) -> t
| AUnop(_, _, t) -> t
| ACall(_, _, _, t) -> t
| ANoexpr(t) -> t


let typeOfAvarDecl = function
| ANodecl(t) -> t
| AMatrix(_, _, _, _, _, t) -> t
| AExprAssign(_, _, _, t) -> t
| ADimAssign(_, _, _, _, t) -> t
| AMatElementAssign(_, _, _, _, _, t) -> t



let typeOfStatement = function
| ABlock(_, t) -> t
| AExpr(_, t) -> t
| AVarDecl(_, t) -> t
| AReturn(_, _, t) -> t
| AFor(_, _, _, _, t) -> t
| AWhile(_, _, t) -> t
| AIf(_, _, _, t) -> t
| AExit(t) -> t
| ABreak(t) -> t
| AForEachLoop(_, _, _, _, _, _, t) -> t
| AContinue(t) -> t

(* Old Code *)
let getVariableDeclFromStatement statements = 
    let rec helper acc = function
    | [] -> acc
    | VarDecl(s) :: t -> helper (VarDecl(s)::acc) t
    | _ :: t -> helper acc t
    in helper [] statements


(* Handle errors *)
let printError message =
    print_string("\nError: " ^ message); exit 1

let printWarning message =
    print_string ("\nWarning: " ^ message)

let printReservedError name =  printError ("Name: " ^ name ^ " is reserved.")

 let printDuplicateFunctionError typ m =
     match typ with
            | Keyword -> printReservedError m
            | _ -> printError ("Multiple definitions of function: " ^ m)

let printUndefinedVariableError m = printError ("Undefined variable: " ^ m)
let printInvalidDimensionsError m = printError ("Invalid dimensions were specified for Matrix: " ^ m)

let printTypeMismatchError id t1 t2 = printError ("Type Mismatch: Cannot assign type " ^ (string_of_builtInType t2) ^ " to " ^ id ^ ". Previously had type: " ^ (string_of_builtInType t1) )


let rec annotateExpression globalSymbolTable localSymbolTable = function
  | Literal(l) -> ALiteral(l, Int)
  | BoolLit(b) -> ABoolLit(b, Bool)
  | Id(id) -> let idType = if Hashtbl.mem localSymbolTable id
                           then Hashtbl.find localSymbolTable id
                           else if Hashtbl.mem globalSymbolTable id 
                           then Hashtbl.find globalSymbolTable id 
                           else let _ = printUndefinedVariableError id in Void in
                           AId(id, idType)
  | MatPlus(id1, id2) ->
                           let idType1 = if Hashtbl.mem localSymbolTable id1
                           then Hashtbl.find localSymbolTable id1
                           else if Hashtbl.mem globalSymbolTable id1 
                           then Hashtbl.find globalSymbolTable id1
                           else let _ = printUndefinedVariableError id1 in Void in

                           let idType2 = if Hashtbl.mem localSymbolTable id2
                           then Hashtbl.find localSymbolTable id2
                           else if Hashtbl.mem globalSymbolTable id2
                           then Hashtbl.find globalSymbolTable id2
                           else let _ = printUndefinedVariableError id2 in Void in
                           AMatPlus(id1, idType1, id2, idType2, generateTypeForAnnotation())

  | MatMinus(id1, id2) ->
                           let idType1 = if Hashtbl.mem localSymbolTable id1
                           then Hashtbl.find localSymbolTable id1
                           else if Hashtbl.mem globalSymbolTable id1 
                           then Hashtbl.find globalSymbolTable id1
                           else let _ = printUndefinedVariableError id1 in Void in

                           let idType2 = if Hashtbl.mem localSymbolTable id2
                           then Hashtbl.find localSymbolTable id2
                           else if Hashtbl.mem globalSymbolTable id2
                           then Hashtbl.find globalSymbolTable id2
                           else let _ = printUndefinedVariableError id2 in Void in
                           AMatMinus(id1, idType1, id2, idType2, generateTypeForAnnotation())

  | UnboundedAccessRead(id, expr) -> let idType = if Hashtbl.mem localSymbolTable id
                           then Hashtbl.find localSymbolTable id
                           else if Hashtbl.mem globalSymbolTable id 
                           then Hashtbl.find globalSymbolTable id 
                           else let _ = printUndefinedVariableError id in Void in
                           let aexpr = annotateExpression globalSymbolTable localSymbolTable expr in
                           AUnboundedAccessRead(id, idType, aexpr, generateTypeForAnnotation())

  | UnboundedAccessWrite(id, expr1, expr2) -> let idType = if Hashtbl.mem localSymbolTable id
                           then Hashtbl.find localSymbolTable id
                           else if Hashtbl.mem globalSymbolTable id 
                           then Hashtbl.find globalSymbolTable id 
                           else let _ = printUndefinedVariableError id in Void in
                           let aexpr1 = annotateExpression globalSymbolTable localSymbolTable expr1 in
                           let aexpr2 = annotateExpression globalSymbolTable localSymbolTable expr2 in
                           AUnboundedAccessWrite(id, idType, aexpr1, aexpr2, generateTypeForAnnotation())

  | MatAccess(id, exprList) -> let idType = 
                               if Hashtbl.mem localSymbolTable id then 
                               Hashtbl.find localSymbolTable id 
                               else if Hashtbl.mem globalSymbolTable id then 
                               Hashtbl.find globalSymbolTable id 
                               else let _ = printUndefinedVariableError id in Void in
                               let aExprList = List.map (fun expr -> annotateExpression globalSymbolTable localSymbolTable expr) exprList in
                               let nDimensions = getListSize exprList in
                               AMatAccess(id, idType, aExprList, nDimensions, generateTypeForAnnotation())

  | BinaryOp(expr1, op, expr2) -> let aexpr1 = annotateExpression globalSymbolTable localSymbolTable expr1 in
                                  let aexpr2 = annotateExpression globalSymbolTable localSymbolTable expr2 in
                                  ABinaryOp(aexpr1, op, aexpr2, generateTypeForAnnotation())

  | Unop(uop, expr) -> let aexpr = annotateExpression globalSymbolTable localSymbolTable expr in
                                  AUnop(uop, aexpr, generateTypeForAnnotation())

  | Call(id, exprList) -> let idType = 
                          if Hashtbl.mem localSymbolTable id then 
                          Hashtbl.find localSymbolTable id 
                          else if Hashtbl.mem globalSymbolTable id then 
                          Hashtbl.find globalSymbolTable id 
                          else let _ = printUndefinedVariableError id in Void in
                          let aExprList = List.map (fun expr -> annotateExpression globalSymbolTable localSymbolTable expr) exprList in
                          ACall(id, idType, aExprList, generateTypeForAnnotation())
                          
  | Noexpr -> ANoexpr(Void)



let rec annotateVarDecl globalSymbolTable localSymbolTable = function
      | Nodecl -> ANodecl(Void)
      | Matrix(id, exprListList) -> let idType = 
                                if Hashtbl.mem localSymbolTable id then 
                                Hashtbl.find localSymbolTable id 
                                else if Hashtbl.mem globalSymbolTable id then 
                                Hashtbl.find globalSymbolTable id 
                                else let idt = generateTypeForAnnotation() in
                                (* Add a generated type to the local symbol table *)
                                let _ = Hashtbl.add localSymbolTable id idt
                                in idt
                                in
                                let _ = if idType = Keyword then printReservedError id
                                in
                                let nRows = getListSize exprListList 
                                in
                                let nCols = if nRows <> 0 then
                                    getListSize (List.hd exprListList)
                                    else 0
                                in
                                (* Check whether all rows have equal number of elements *)
                                let _ = List.iter (fun exprList -> if (getListSize exprList) <> nCols then 
                                    printInvalidDimensionsError id) exprListList
                                in
                                (* Annotate each element *)
                                let aExprListList = List.map 
                                (fun exprList -> List.map (fun expr -> 
                                                           annotateExpression globalSymbolTable localSymbolTable expr) 
                                                           exprList) exprListList in
                                (* Store the row and column count with this matrix *)
                                (* Will help in code generation *)
                                AMatrix(id, idType, aExprListList, nRows, nCols, Void)

    | ExprAssign(id, expr) -> let idType =
                                if Hashtbl.mem localSymbolTable id then 
                                Hashtbl.find localSymbolTable id 
                                else if Hashtbl.mem globalSymbolTable id then 
                                Hashtbl.find globalSymbolTable id 
                                else let idt = generateTypeForAnnotation() in
                                (* Add a generated type to the local symbol table *)
                                let _ = Hashtbl.add localSymbolTable id idt
                                in idt
                                in
                                let _ = if idType = Keyword then printReservedError id
                                in
                                let aExpr = annotateExpression globalSymbolTable localSymbolTable expr
                                in
                                AExprAssign(id, idType, aExpr, Void)

    | DimAssign(id, exprList) -> let idType = 
                                if Hashtbl.mem localSymbolTable id then 
                                Hashtbl.find localSymbolTable id 
                                else if Hashtbl.mem globalSymbolTable id then 
                                Hashtbl.find globalSymbolTable id 
                                else let idt = generateTypeForAnnotation() in
                                (* Add a generated type to the local symbol table *)
                                let _ = Hashtbl.add localSymbolTable id idt
                                in idt
                                in
                                let _ = if idType = Keyword then printReservedError id
                                in
                                let aExprList = List.map (fun expr -> 
                                                           annotateExpression globalSymbolTable localSymbolTable expr) 
                                                           exprList
                                in
                                let nDimensions = getListSize exprList in
                                ADimAssign(id, idType, aExprList, nDimensions, Void)

    | MatElementAssign(id, exprList, expr) -> let idType = 
                                if Hashtbl.mem localSymbolTable id then 
                                Hashtbl.find localSymbolTable id 
                                else if Hashtbl.mem globalSymbolTable id then 
                                Hashtbl.find globalSymbolTable id 
                                else
                                let _= printUndefinedVariableError id in
                                Empty
                                in
                                let _ = if idType = Keyword then printReservedError id
                                in
                                let aExprList = List.map (fun expr -> 
                                                           annotateExpression globalSymbolTable localSymbolTable expr) 
                                                           exprList
                                in
                                let aExpr = annotateExpression globalSymbolTable localSymbolTable expr
                                in
                                let nDimensions = getListSize exprList in
                                AMatElementAssign(id, idType, aExprList, aExpr, nDimensions, Void)


let mergeSymbolTables globalSymbolTable localSymbolTable = let mergedSymbolTable = Hashtbl.create 100 in
          let _ = Hashtbl.iter (fun key value -> Hashtbl.add mergedSymbolTable key value) localSymbolTable
          in 
          let _ = Hashtbl.iter (fun key value -> if not (Hashtbl.mem mergedSymbolTable key) then Hashtbl.add mergedSymbolTable key value) globalSymbolTable 
          in mergedSymbolTable




let rec annotateStatement globalSymbolTable localSymbolTable ?isControlFlowAllowed:(isCFA = false) inFunction = function
    | Block (statementList) -> let newGlobalSymbolTable = mergeSymbolTables globalSymbolTable localSymbolTable in
                               let newLocalSymbolTable = Hashtbl.create 100 in
                               let aStatementList = List.map (fun statement -> annotateStatement newGlobalSymbolTable newLocalSymbolTable inFunction statement ~isControlFlowAllowed:isCFA) statementList
                               in
                               ABlock (aStatementList, Void)
    | Expr (expr) -> let aExpr = annotateExpression globalSymbolTable localSymbolTable expr in
                     AExpr(aExpr, Void)
    | VarDecl (varDecl) -> let aVarDecl = annotateVarDecl globalSymbolTable localSymbolTable varDecl in
                           AVarDecl(aVarDecl, Void)
    | Return (expr) -> let _ = if inFunction = "main" then printError "Cannot use return outside functions." in let aExpr = annotateExpression globalSymbolTable localSymbolTable expr in
    let funcType = Hashtbl.find globalSymbolTable inFunction
    in (match funcType with
    | FuncSignature(returnTypeSig, formalTypeList) -> AReturn (returnTypeSig, aExpr, Void)
    | _ -> let _ = printError "Invalid use of return statement." in AReturn (Void, aExpr, Void)
    )
                      
    | For (varDecl1, expr, varDecl2, statement) -> let aVarDecl1 = annotateVarDecl globalSymbolTable localSymbolTable varDecl1 in
                                                   let aVarDecl2 = annotateVarDecl globalSymbolTable localSymbolTable varDecl2 in
                                                   let aExpr = annotateExpression globalSymbolTable localSymbolTable expr in
                                                   let aStatement = annotateStatement globalSymbolTable localSymbolTable ~isControlFlowAllowed:true inFunction statement  in
                                                   AFor (aVarDecl1, aExpr, aVarDecl2, aStatement, Void)
    | While (expr, statement) -> let aExpr = annotateExpression globalSymbolTable localSymbolTable expr in
                                 let aStatement = annotateStatement globalSymbolTable localSymbolTable ~isControlFlowAllowed:true inFunction statement in
                                 AWhile (aExpr, aStatement, Void)
    | If (expr, statement1, statement2) -> 
            let newGlobalSymbolTable = mergeSymbolTables globalSymbolTable localSymbolTable in
            let newLocalSymbolTable = Hashtbl.create 100 in
            let aExpr = annotateExpression newGlobalSymbolTable newLocalSymbolTable expr in
                                           let aStatement1 = annotateStatement newGlobalSymbolTable newLocalSymbolTable ~isControlFlowAllowed:isCFA inFunction statement1 in
                                           let aStatement2 = annotateStatement newGlobalSymbolTable newLocalSymbolTable ~isControlFlowAllowed:isCFA inFunction statement2 in
                                           AIf(aExpr, aStatement1, aStatement2, Void)
    | Exit -> AExit(Void)
    | Break -> let _ = if not isCFA then printError "Invalid use of break." in ABreak(Void)
    | ForEachLoop (id, objName, statement, loopType) ->  
            let newGlobalSymbolTable = mergeSymbolTables globalSymbolTable localSymbolTable in
            let newLocalSymbolTable = Hashtbl.create 100 in
            let idType = 
                        if Hashtbl.mem newLocalSymbolTable id then 
                        Hashtbl.find newLocalSymbolTable id 
                        else if Hashtbl.mem newGlobalSymbolTable id then 
                        Hashtbl.find newGlobalSymbolTable id 
                        else let idt = generateTypeForAnnotation() in
                        (* Add a generated type to the local symbol table *)
                        let _ = Hashtbl.add newLocalSymbolTable id idt
                        in idt
                        in
                        let objType = 
                        if Hashtbl.mem newLocalSymbolTable objName then
                        Hashtbl.find newLocalSymbolTable objName
                        else if Hashtbl.mem newGlobalSymbolTable objName then
                        Hashtbl.find newGlobalSymbolTable objName
                        else
                        let _= printUndefinedVariableError objName in
                        Empty
                        in
                        let aStatement = annotateStatement newGlobalSymbolTable newLocalSymbolTable ~isControlFlowAllowed:true inFunction statement
                        in
                        AForEachLoop (id, idType, objName, objType, aStatement, loopType, Void)
    | Continue -> let _ = if not isCFA then printError "Invalid use of continue." in AContinue(Void)


let rec collectExpr = function
  |  ALiteral(_) | ABoolLit(_) | AId(_) | ANoexpr(_) -> []
  (* If someone accesses a variable like matrix, it means that id's
   * type should be Mat and each expression should evaluate to Int and
   * this expression returns an Int *)
  | AUnboundedAccessRead(id, idType, aexpr, exprType) ->  let constraints = [(exprType, Int)]
  in let exprConstr = collectExpr aexpr in [(typeOfAexpr aexpr, Int)] @ constraints @ exprConstr

  | AUnboundedAccessWrite(id, idType, aexpr1, aexpr2, exprType) ->   let constraints = [(exprType, Int)]
  in let exprConstr1 = collectExpr aexpr1 
  in let exprConstr2 = collectExpr aexpr2 in [(typeOfAexpr aexpr1, Int);(typeOfAexpr aexpr2, Int)] @ constraints @ exprConstr1 @ exprConstr2
  
  |AMatPlus(id1, idType1, id2, idType2, exprType) -> [(exprType, idType1); (idType1, idType2)]

  |AMatMinus(id1, idType1, id2, idType2, exprType) -> [(exprType, idType1); (idType1, idType2)]

  | AMatAccess(id, idType, aExprList, nDim, exprType) ->
          let constraints = [(idType, Mat(nDim)); (exprType, Int)] (* Not supporting a matrix of functions for now *)
          in let exprConstraints = List.fold_left (fun constraintAcc expr -> let exprConstr = collectExpr expr in (typeOfAexpr expr, Int) :: exprConstr @ constraintAcc) [] aExprList
          in constraints @ exprConstraints
  (* Now in case of binary operators, the result can be bool if the
   * operators are comparison operators etc. *)
  | ABinaryOp(aexpr1, op, aexpr2, exprType) ->
          let t1 = typeOfAexpr aexpr1 in let t2 = typeOfAexpr aexpr2
          in
          let constraints = match op with
          | Equal | Neq | Less | Leq | Greater | Geq | And | Or ->
                  [(t1, t2); (exprType, Bool)]
          | Add | Sub | Mul | Div | Exp | Mod -> [(t1, Int); (t2, Int); (exprType, Int)]
          in
          constraints @ (collectExpr aexpr1) @ (collectExpr aexpr2)
  
  | AUnop(uop, aexpr, exprType) ->
          let t = typeOfAexpr aexpr in
          let constraints = match uop with
          | Neg -> [(t, Int);(exprType, Int)]
          | Not -> [(exprType, Bool)]
          in
          constraints @ (collectExpr aexpr)

          (* TODO: Add more constraints using function definition: *)
  | ACall(id, idType, aExprList, exprType) ->
          match idType with
          | FuncSignature(returnTypeSig, formalTypeList) ->
          let exprConstraints = List.fold_left (fun constraintAcc expr -> let exprConstr = collectExpr expr in exprConstr @ constraintAcc) [] aExprList
          in let size1 = getListSize formalTypeList in let size2 = getListSize aExprList in
          let _ = if size1 != size2 then printError ("Function: " ^ id ^ " called with: " ^ string_of_int size2 ^ " arguments. While the function expects: " ^ string_of_int size1 ^ " arguments.")
          in let formalConstraints = List.map2 (fun typ1 aExpr -> (typ1, typeOfAexpr(aExpr))) formalTypeList aExprList  
          in let retConstraint = [(returnTypeSig, exprType)]
          in retConstraint @ formalConstraints @ exprConstraints
          | Keyword when id = "print" -> []
          | _ -> let _ = printError "Invalid use of function call." in [] 
          


let rec collectVarDecl = function
    | ANodecl(_) -> []

    | AMatrix(id, idType, aExprListList, nRows, nCols, _) ->
            let constraints = [(idType, Mat(2))]
            in
            let aExprListListConstraints = List.fold_left (fun constraintList aExprList ->
                List.fold_left (fun constrList aexpr -> (typeOfAexpr aexpr, Int) ::(collectExpr aexpr) @ constrList) constraintList aExprList
            ) [] aExprListList
            in constraints @ aExprListListConstraints

    | AExprAssign(id, idType, aExpr, _) ->
            let constraints = [(idType, typeOfAexpr aExpr)]
            in
            (collectExpr aExpr) @ constraints

    | ADimAssign(id, idType, aExprList, nDimensions,  _) ->
            let constraints = [(idType, Mat(nDimensions))]
            in
            let aExprListConstraints = List.fold_left (fun constrList expr -> (typeOfAexpr expr, Int) ::(collectExpr expr) @ constrList) [] aExprList
            in constraints @ aExprListConstraints
    | AMatElementAssign(id, idType, aExprList, aExpr, nDimensions, _) ->
            let constraints = [(idType, Mat(nDimensions)); (typeOfAexpr aExpr, Int)]
            in
            let aExprListConstraints = List.fold_left (fun constrList expr -> (typeOfAexpr expr, Int) ::(collectExpr expr) @ constrList) [] aExprList
            in constraints @ aExprListConstraints @ (collectExpr aExpr)

            

            (* All statements have type Void *)
let rec collectStatement = function
    | AContinue(_) | ABreak(_) | AExit(_) -> []

    | ABlock (aStatementList, _) -> List.fold_left (fun constraintAcc astatement -> (collectStatement astatement) @ constraintAcc) [] aStatementList

    | AExpr(aExpr, _) -> collectExpr aExpr

    | AVarDecl(aVarDecl, _) -> collectVarDecl aVarDecl 

    (* TODO: Relate return type to the annotated function *)
    | AReturn (retType, aExpr, _) -> [(retType, typeOfAexpr(aExpr))] @ (collectExpr aExpr)

    | AFor (aVarDecl1, aExpr, aVarDecl2, aStatement, _) -> let constLst1 = collectVarDecl aVarDecl1
                                                           in
                                                           let constLst2 = collectVarDecl aVarDecl2
                                                           in
                                                           let constLst3 = collectExpr aExpr
                                                           in
                                                           let constLst4 = collectStatement aStatement
                                                           in
                                                           (typeOfAexpr aExpr, Bool) :: constLst1 @ constLst2 @ constLst3 @ constLst4

    | AIf(aExpr, aStatement1, aStatement2, _) -> let constraints = (typeOfAexpr aExpr, Bool) :: (collectExpr aExpr)
                                                 in
                                                 let constLst1 = collectStatement aStatement1
                                                 in
                                                 let constLst2 = collectStatement aStatement2
                                                 in
                                                 constraints @ constLst1 @ constLst2

    | AWhile (aExpr, aStatement, _) -> let constraints = (typeOfAexpr aExpr, Bool) :: (collectExpr aExpr)
                                       in
                                       let constLst = collectStatement aStatement
                                       in constraints @ constLst
   
    | AForEachLoop (id, idType, objName, objType, aStatement, loopType, _) -> let constraints = match loopType with
                                        | Row -> [(idType, Mat(2)); (objType, Mat(3))]
                                        | Ele -> [(idType, Int); (objType, Mat(3))]
                                        | Pixel -> [(idType, Mat(1)); (objType, Mat(3))]
                                        in 
                                        constraints @ (collectStatement aStatement)


let rec substitute t1 t2 t = 
    match t with
    | Void | Int | Bool | Func -> t
    | Mat(nDim) -> Mat(nDim)
    | Annotation(s) ->  if t1 = t then t2 else t
    | FuncSignature(_) -> Func
    | Keyword -> Keyword
    | _ -> printError "Unknown type error."

let apply substitutionList typ =
        List.fold_right (fun (t1, t2) t -> substitute t1 t2 t) substitutionList typ


let rec unifyOne s t =
  if s = t then (*let _ = print_string ("Unify one s = t:" ^ string_of_builtInType s ^ "," ^ string_of_builtInType t ^ "\n") in*) []
  else
      match (s, t) with
      | Annotation(x), Annotation(y)  -> [Annotation(x), Annotation(y)]
      | Annotation(x), y | y, Annotation(x) -> [(Annotation(x), y)]
      | x , y -> let _ = printError ("Mismatched types:" ^ string_of_builtInType x ^ "," ^ string_of_builtInType y ^ "\n") in []
and unify = function
  | [] -> []
  | (x, y) :: t ->
      let t2 = unify t in
      let t1 = unifyOne (apply t2 x) (apply t2 y) in
      t1 @ t2

let collectStatementList astatements = let constraints = List.fold_left (fun constraintList astatement -> (constraintList @ (collectStatement astatement))) [] astatements in constraints


let annotateFunctionHelper globalSymbolTable localSymbolTable func = 
    let funcType = Hashtbl.find globalSymbolTable func.fname
    in match funcType with
    | FuncSignature(returnTypeSig, formalTypeList) ->
            let aFormals = List.map2 (fun id typ -> let _ =  (if Hashtbl.mem localSymbolTable id then printError ("Two or more formals have same name: " ^ id ^ " in function: " ^ func.fname)) in let _ = Hashtbl.add localSymbolTable id typ in (id,typ)) func.formals formalTypeList
            in
    {
        afname = (func.fname, funcType);
        aformals = aFormals; 
        abody = List.map (fun statement -> annotateStatement globalSymbolTable localSymbolTable func.fname statement) func.body;
        retType = returnTypeSig;
    }
    | _ -> let _ = printError "Incorrect use of function: " ^ func.fname in
    (* Record shown below is useless. It is here to allow printError to work *)
    {
        afname = (func.fname, Void);
        aformals = List.map (fun id -> let _ = Hashtbl.add localSymbolTable id Void in (id,Void)) func.formals;
        abody = List.map (fun statement -> annotateStatement globalSymbolTable localSymbolTable func.fname statement) func.body;
        retType = Void;
    }



let annotateFunction globalSymbolTable func =
    let localSymbolTable = Hashtbl.create 100
    in
    annotateFunctionHelper globalSymbolTable localSymbolTable func



let collectFunction func = collectStatementList func.abody
let collectFunctionList functions = let constraints = List.fold_left (fun constraintList func -> (constraintList @ (collectFunction func))) [] functions in constraints


let rec applyExpression unifiedConstraints = function
  | ALiteral(_) as x -> x | ABoolLit(_) as x -> x | ANoexpr(_) as x -> x
  | AId(id, idType) -> AId(id, (apply unifiedConstraints idType)) 
  | AUnboundedAccessRead(id, idType, aexpr, exprType) -> AUnboundedAccessRead(id, (apply unifiedConstraints idType), (applyExpression unifiedConstraints aexpr), (apply unifiedConstraints exprType))
  | AUnboundedAccessWrite(id, idType, aexpr1, aexpr2, exprType) -> AUnboundedAccessWrite(id, (apply unifiedConstraints idType), (applyExpression unifiedConstraints aexpr1),(applyExpression unifiedConstraints aexpr2), (apply unifiedConstraints exprType))

  |AMatPlus(id1, idType1, id2, idType2, exprType) -> AMatPlus(id1, (apply unifiedConstraints idType1), id2, (apply unifiedConstraints idType2), (apply unifiedConstraints exprType))

  |AMatMinus(id1, idType1, id2, idType2, exprType) -> AMatMinus(id1, (apply unifiedConstraints idType1), id2, (apply unifiedConstraints idType2), (apply unifiedConstraints exprType))

  | AMatAccess(id, idType, aExprList, nDim, exprType) ->
          let resolvedExprList = List.map (fun aexpr -> applyExpression unifiedConstraints aexpr) aExprList
          in
          AMatAccess(id, (apply unifiedConstraints idType), resolvedExprList, nDim, (apply unifiedConstraints exprType))
  | ABinaryOp(aexpr1, op, aexpr2, exprType) ->
          ABinaryOp((applyExpression unifiedConstraints aexpr1), op, (applyExpression unifiedConstraints aexpr2), (apply unifiedConstraints exprType))

  | AUnop(uop, aexpr, exprType) -> AUnop(uop, applyExpression unifiedConstraints aexpr, apply unifiedConstraints exprType)
  | ACall(id, idType, aExprList, exprType) -> 
          let resolvedExprList = List.map (fun aexpr -> applyExpression unifiedConstraints aexpr) aExprList
          in
          ACall(id, (apply unifiedConstraints idType), resolvedExprList, (apply unifiedConstraints exprType))



let rec applyVarDecl unifiedConstraints = function
    | ANodecl(_) as x -> x
    | AMatrix(id, idType, aExprListList, nRows, nCols, varDeclType) ->
          let resolveExpr = List.map (fun aexpr -> applyExpression unifiedConstraints aexpr)
          in 
          let resolvedExprListList = List.map (fun aexprList -> resolveExpr aexprList) aExprListList
          in
          AMatrix(id, (apply unifiedConstraints idType), resolvedExprListList, nRows, nCols, (apply unifiedConstraints varDeclType)) 
    | AExprAssign(id, idType, aExpr, varDeclType) ->
          AExprAssign(id, (apply unifiedConstraints idType), (applyExpression unifiedConstraints aExpr), (apply unifiedConstraints varDeclType)) 
    | ADimAssign(id, idType, aExprList, nDimensions, varDeclType) ->
          let resolvedExprList = List.map (fun aexpr -> applyExpression unifiedConstraints aexpr) aExprList
          in
          ADimAssign(id, (apply unifiedConstraints idType), resolvedExprList, nDimensions, (apply unifiedConstraints varDeclType)) 
    | AMatElementAssign(id, idType, aExprList, aExpr, nDimensions, varDeclType) ->
          let resolvedExprList = List.map (fun aexpr -> applyExpression unifiedConstraints aexpr) aExprList
          in
          AMatElementAssign(id, (apply unifiedConstraints idType), resolvedExprList, (applyExpression unifiedConstraints aExpr), nDimensions, (apply unifiedConstraints varDeclType)) 


let rec applyStatement unifiedConstraints = function
    | AContinue(_) as x -> x | ABreak(_) as x -> x | AExit(_) as x -> x

    | ABlock (aStatementList, statementType) -> 
          let resolvedStatementList = List.map (fun astatement -> applyStatement unifiedConstraints astatement) aStatementList
          in
          ABlock(resolvedStatementList, (apply unifiedConstraints statementType))

    | AExpr(aExpr, statementType) -> 
          AExpr((applyExpression unifiedConstraints aExpr), (apply unifiedConstraints statementType))

    | AVarDecl(aVarDecl, statementType) -> 
          AVarDecl((applyVarDecl unifiedConstraints aVarDecl), (apply unifiedConstraints statementType))

    | AReturn(returnType, aExpr, statementType) -> 
          AReturn((apply unifiedConstraints returnType), (applyExpression unifiedConstraints aExpr), (apply unifiedConstraints statementType))

    | AFor(aVarDecl1, aExpr, aVarDecl2, aStatement, statementType) -> 
          AFor((applyVarDecl unifiedConstraints aVarDecl1), (applyExpression unifiedConstraints aExpr), (applyVarDecl unifiedConstraints aVarDecl2), (applyStatement unifiedConstraints aStatement), (apply unifiedConstraints statementType))


    | AIf(aExpr, aStatement1, aStatement2, statementType) -> 
          AIf((applyExpression unifiedConstraints aExpr), (applyStatement unifiedConstraints aStatement1), (applyStatement unifiedConstraints aStatement2), (apply unifiedConstraints statementType))

    | AWhile(aExpr, aStatement, statementType) ->
          AWhile((applyExpression unifiedConstraints aExpr), (applyStatement unifiedConstraints aStatement), (apply unifiedConstraints statementType))

    | AForEachLoop(id, idType, objName, objType, aStatement, loopType, statementType) -> 
          AForEachLoop(id, (apply unifiedConstraints idType), objName, (apply unifiedConstraints objType), (applyStatement unifiedConstraints aStatement), loopType, (apply unifiedConstraints statementType))


let rec applyStatementList unifiedConstraints aStatementList = List.map (fun astatement -> applyStatement unifiedConstraints astatement) aStatementList 



let applyFunction unifiedConstraints func =
    let (fname, _) = func.afname in
    {
        afname = (fname, Func);
        aformals = List.map (fun (id, typ) -> (id, (apply unifiedConstraints typ))) func.aformals;
        abody = applyStatementList unifiedConstraints func.abody;
        retType = apply unifiedConstraints func.retType(*let rType = (apply unifiedConstraints func.retType) in match rType with
                  | Annotation(_) -> Void
                  | x -> x*)
                  ;
    }

let applyFunctions unifiedConstraints functions =
    List.map (fun func -> applyFunction unifiedConstraints func) functions


    (* Check program semantics *)
let check_semantics (gstatements, functions) = 
    let globalSymbolTable = Hashtbl.create 100 in
    let _ =
          List.iter (fun ele -> Hashtbl.add globalSymbolTable ele Keyword) (keywords @ builtInFunctions)
    in
    (* Check for duplicate functions *)
    let _ = List.iter (fun ele -> 
            if Hashtbl.mem globalSymbolTable ele.fname then 
            printDuplicateFunctionError (Hashtbl.find globalSymbolTable ele.fname) ele.fname  
            else let formalTypes = List.map (fun _ -> generateTypeForAnnotation()) ele.formals
            in 
            Hashtbl.add globalSymbolTable ele.fname (FuncSignature(generateTypeForAnnotation(), formalTypes))) functions
    in
    let localSymbolTable = Hashtbl.create 100 in
    let agstatements = List.map (fun statement -> annotateStatement globalSymbolTable localSymbolTable "main" statement) gstatements
    (* Overwrite globalSymbolTable with localSymbolTable *)
    in let globalSymbolTable = mergeSymbolTables globalSymbolTable localSymbolTable in
    let gconstraints = collectStatementList agstatements
    in
    let afunctions = List.map (fun func -> annotateFunction globalSymbolTable func) functions
    in
    let fconstraints = collectFunctionList afunctions
    in
    let constraints = gconstraints @ fconstraints
    in
    let unifiedConstraints = unify constraints
    in
    let resolvedGStatements = applyStatementList unifiedConstraints agstatements
    in
    let resolvedFunctions = applyFunctions unifiedConstraints afunctions
    (*in
    let _ = print_string(Ast.string_of_program(resolvedGStatements, resolvedFunctions)) *)
    in
    resolvedGStatements, resolvedFunctions
