open Ast
open Bytecode

module StringMap = Map.Make(String)

type context =
    Global
  | Function
  | Kernel

type channel_type =
    In
  | Out

type kernel_instance_arg =
    ChannelArg of int
  | ExprArg of expr

type kernel_instance = {
    index           : int;
    instance_args   : kernel_instance_arg list;
}

(* Symbol table *)
type symbol_table = {
    (* Map unique index and number of arguments for each functions *)
    function_index   : (int * int) StringMap.t;
    (* Map unique index and list of arguments for each kernel decl *)
    kernel_index     : (int * kernel_arg list) StringMap.t;
    (* Map unique index to each global variable *)
    global_index     : int StringMap.t;
    (* Map unique index to each local variable (kernel or function) *)
    local_index      : int StringMap.t;
    (* When compiling a kernel, map each channel name to index in stack and
     * direction (in or out) *)
    channel_index    : (int * channel_type) StringMap.t;
    (* Helper to keep track of the maximum number of local variables declared
     * in a function or kernel decl for 'Ent' and 'Kent' opcodes *)
    max_locals       : int ref;
    (* Offset local variable indices by this amount *)
    local_offset     : int;
    (* Enum specifying current compilation context (global, kernel, function *)
    context          : context;
    (* Unique index for all defined channels in the program *)
    channels         : int StringMap.t;
    (* Helper map for verifying channel connections - (bool input * bool
     * output) where value is true if that end of the channel is connected *)
    channel_usage    : (bool * bool) StringMap.t;
    (* Number of kernel instances in the program *)
    kernel_instances : int;
}

(* Return a list of pairs, int * 'a, starting at 'n' and increasing by 'stride'
 * for each element of the list *)
let rec enum stride n = function
    [] -> []
  | hd::tl -> (n, hd) :: enum stride (n + stride) tl

let string_map_pairs map pairs =
    List.fold_left (fun m (i, n) -> StringMap.add n i m) map pairs

let string_map_append map item =
    if StringMap.mem item map then
        (raise (Failure ("Item " ^ item ^ " already in map")))
    else
        StringMap.add item (StringMap.cardinal map) map

(* Compile the program `prog` in Ast form to bytecode *)
let compile prog =
    (* These functions are always visible *)
    let builtin_funcs =
        string_map_pairs StringMap.empty [
            ((-1, 1), "print");
            ((-4, 2), "append");
            ((-5, 1), "length");
    ]
    in
    (* These functions are only visible when compiling a kernel body *)
    let kernel_funcs =
        string_map_pairs StringMap.empty [((-2, 1), "read"); ((-3, 2), "write")]
    in
    (* These channels are always available *)
    let builtin_channels =
        string_map_pairs StringMap.empty [(0, "stdin"); (1, "stdout")]
    in
    (* Builtin channels already have one end connected *)
    let builtin_channel_usage =
        string_map_pairs StringMap.empty [
            ((true, false), "stdin");
            ((false, true), "stdout")
        ]
    in

    (* Compile an AST expression. Returns a pair, (bytecode list,
     * symbol_table), since an expression (assignment) could change the
     * symbol table. *)
    let rec compile_expr env = function
        (* Literals are easy, we just push a value on the stack *)
        Literal(l) -> ([Push(value_of_literal l)], env)
        (* Binops compile the left hand side, then the right hand side, and
         * finally output the appropriate opcode *)
      | Binop(lhs, op, rhs) ->
            (* TODO: do we need the symbol_table here in case lhs/rhs is an
             * assignment? *)
            let (lhs_code, _) = compile_expr env lhs
            and (rhs_code, _) = compile_expr env rhs
            in
            (lhs_code @ rhs_code @
             [match op with
                Add -> Add
              | Subtract -> Sub
              | Multiply -> Mul
              | Divide -> Div
              | Modulo -> Mod
              | Equal -> Ceq
              | Neq -> Cne
              | Less -> Clt
              | Leq -> Cle
              | Greater -> Cgt
              | Geq -> Cge
              | And -> And
              | Or -> Or], env)
      | Assignable(a) -> compile_assignable env a
        (* Assignments are tricky. First, like the Ids, we need to determine if
         * the variable being assigned to is a local, a global, or new. In the
         * case of the first two, just store a value at that index. Otherwise,
         * it's a new declaration. If we are in a kernel or function context,
         * it's a local variable and should be added to the local map and the
         * symbol table's max_locals might need to be incremented. In a global
         * context, we need to allocate a new global variable. *)
      | Assign(lhs, rhs) -> compile_assignment env lhs rhs
        (* For function calls, a couple of things could happen. If we are
         * actually calling a function, we check that the function is in the
         * function_index of the symbol_table, check that the number of args
         * matches, compile the code for the args in declaration order, and
         * output the 'Call' bytecode to jump to the beginning of the function
         * body. Since we are still compiling the program at this point, we
         * don't actually know the absolute address of the function, so the
         * 'Call' bytecode gets the index of the function in the symbol table
         * to be patched up later.
         *
         * Kernel functions (read, write) get handled a bit differently in the
         * 'compile_kernel_io_func' function since they have some slight
         * differences. *)
      | Call(f, args) ->
            (* Look up the function in the symbol table *)
            let (func_index, nargs) =
                (try (StringMap.find f env.function_index)
                with Not_found ->
                    raise (Failure ("Undefined function '" ^ f ^ "'")))
            in
            (* Verify the number of args matches. *)
            if nargs = List.length args then
                (* Kernel read/write calls get handled differently *)
                if env.context = Kernel && (StringMap.mem f kernel_funcs) then
                    (compile_kernel_io_func env f (List.rev args), env)
                else
                    (* Compile the arguments in order. *)
                    (* TODO: preserve env? Allow assignments in calls? *)
                    ((List.concat
                        (List.map fst
                            (List.map (compile_expr env) args))) @
                    [Call(func_index)], env)
            else
                raise (Failure (Printf.sprintf
                "Arity mismatch when calling '%s': got %d args, expected %d"
                f (List.length args) nargs))
      | EmptyExpr -> ([], env) (* Do nothing *)
      | DList(exprs) ->
            let exprs_code = List.map (fun e -> fst (compile_expr env e)) exprs
            in
            ((List.concat exprs_code) @ [Mlst(List.length exprs)], env)
    and compile_assignable env = function
        (* Names get looked up in the local variables, then failing that, the
         * globals. Once found, we load the value corresponding index from the
         * frame pointer on to the top of the stack *)
        Id(id) ->
            let load = (try [Ldl (StringMap.find id env.local_index)]
            with Not_found -> try [Ldg (StringMap.find id env.global_index)]
            with Not_found -> raise (Failure ("unknown variable '" ^ id ^ "'")))
            in
            (load, env)
      | ListIndex(list_expr, idx_expr) ->
            let (list_expr_code, env) = compile_expr env list_expr in
            let (idx_code, env) = compile_expr env idx_expr in
            (list_expr_code @ idx_code @ [Ilst], env)
    and compile_assignment env lhs rhs =
        let (rhs_code, env) = compile_expr env rhs in
        match lhs with
        Id(id) ->
            let (store_code, env) =
                let location = (StringMap.mem id env.local_index,
                                StringMap.mem id env.global_index) in
                match location with
                    (true, _) -> (* existing local - get index *)
                        ([Strl(StringMap.find id env.local_index)], env)
                  | (false, true) -> (* global - get global index *)
                        ([Strg(StringMap.find id env.global_index)], env)
                  | (false, false) -> (* new var, update env *)
                        let (map, offset) = match env.context with
                            Global -> (env.global_index, 1) (* 0-index *)
                          | _      -> (env.local_index, env.local_offset)
                        in
                        (* The index of the new variable gets the next highest
                         * integer. Since variable allocation is FIFO, we don't
                         * need to worry about holes in the allocation. *)
                        let new_idx = ((StringMap.cardinal map) + 1 - offset)
                        in
                        env.max_locals := max !(env.max_locals) new_idx;
                        (* Save the new index back to the appropriate slot in
                         * the symbol table. *)
                        let new_map = StringMap.add id new_idx map in
                        match env.context with
                            Global -> ([Strg(new_idx)],
                                        { env with global_index = new_map })
                          | _      -> ([Strl(new_idx)],
                                        { env with local_index = new_map })
            in
            (* Put the RHS on the top of the stack, and store that value. *)
            (rhs_code @ store_code, env)
      | ListIndex(lst, idx) ->
              let (lst_code, _) = compile_expr env lst
              and (idx_code, env) = compile_expr env idx
              in
              ([Halt] @ lst_code @ [Halt] @ idx_code @ rhs_code @ [Alst], env)
    (* Kernel IO functions are a lot like regular functions except that we
     * check to make sure that they are called on the correct type of channel.
     * E.g. it is an error to call "read" on a channel declared as "out" *)
    and compile_kernel_io_func env name args =
        (* This function verifies that the first argument to the function (the
         * channel) is just an identifier matching a channel and not part of an
         * expression. It returns (int * channel_type) *)
        let find_channel env args =
            match (List.hd args) with
                Assignable(Id(id)) ->
                    (try (StringMap.find id env.channel_index)
                    with Not_found ->
                        raise (Failure ("Unknown channel '" ^ id ^ "'")))
              | _ -> raise (Failure ("Cannot use expression as channel name"))
        in
        match name with
            "read" -> let (idx, channel_type) = find_channel env args
                in
                if channel_type = In then
                    [Ldl(idx); Read]
                else
                    raise (Failure ("read() requires an input channel"))
          | "write" -> let (idx, channel_type) = find_channel env args
                in
                if channel_type = Out then
                    [Ldl(idx)] @
                    fst (compile_expr env (List.nth args 1)) @
                    [Write]
                else
                    raise (Failure ("write() requires an output channel"))
          | _ -> raise (Failure ("Unknown error compiling kernel IO"))
    in

    (* This function compiles a statement recursively. It returns a pair of
     * (bytecode list, symbol_table) since the symbol table might change. *)
    let rec compile_stmt env = function
        (* Expressions use the output of the 'compile_expr' function above. We
         * also pop the value off the top of the stack since every expression
         * pushes one unused value. This is to support chaining of calls
         * without much hassle, e.g. 'f(x);' needs to have it's return value
         * popped since it's never used. *)
        ExprStmt(expr) ->
            let (expr_code, env) = compile_expr env expr in
              (expr_code @ [Pop], env)
        (* Returns are only valid in a function context. The 'Ret' value is set
         * to 0 to be set to the number of function arguments later (so we know
         * where to place the return value) *)
      | Return(expr) ->
            if env.context <> Function
            then raise (Failure "Can only return from a function")
            else let (expr_code, env) = compile_expr env expr in
                (expr_code @ [Ret(0)], env)
        (* A block is just a sequence of statements so compile them in order *)
      | Block(stmts) ->
            (fst (List.fold_left stmt_sequence_helper ([], env) stmts), env)
      | If(test, body, else_stmt) ->
            let (test_code, _) = compile_expr env test
            and (body_code, _) = compile_stmt env body
            and (else_body_code, _) = compile_stmt env else_stmt
            in
            (test_code @
            (* Branch to else if cond is false *)
            [Bz(List.length body_code + 2)] @
            body_code @
            (* Branch out of if statement *)
            [Ba(1 + List.length else_body_code)] @
            (* else *)
            else_body_code,
            env)
        (* TODO: rewrite this as while loop in Block()? *)
      | For(init, test, inc, body) ->
            let (init_code, for_env) = compile_stmt env init in
            let (test_code, _) = compile_expr for_env test
            and (inc_code, _) = compile_stmt for_env inc
            and (body_code, _) = compile_stmt for_env body in
            (init_code @
            test_code @
            (* if the test returns false, skip out of the loop *)
            [Bz(2 + List.length body_code + List.length inc_code)] @
            body_code @
            inc_code @
            [Ba(-
                (List.length inc_code + List.length body_code +
                List.length test_code + 1))], env)
      | While(test, body) ->
            let (test_code, _) = compile_expr env test
            and (body_code, _) = compile_stmt env body
            in
            (test_code @
            (* Skip the body if the test is false *)
            [Bz(2 + List.length body_code)] @
            body_code @
            (* Branch back to the test *)
            [Ba(- (List.length body_code + List.length test_code + 1))], env)
        (* A foreach loop can be compiled as a for loop with some
         * hack^Wtrickiness: foreach(id : lst) { body } is equivalent to
         * for(i = 0; i < length(lst); i = i + 1) { id = lst[i]; body }
         * TODO: Currently, the list index value name is hardcoded so we cannot
         * have nested foreach loops without "fun" occurring. *)
      | ForEach(id, lst, body) ->
            let idx = Id("__i") in (* TODO: nested loops needs unique name *)
            compile_stmt env (For(
                ExprStmt(Assign(idx, Literal(Int(0)))),
                Binop(Assignable(idx), Less, Call("length", [lst])),
                ExprStmt(Assign(idx,
                                Binop(Assignable(idx), Add, Literal(Int(1))))),
                Block([
                    ExprStmt(Assign(Id(id), Assignable(ListIndex(lst, Assignable(idx)))));
                    body])))
    (* Small helper function for fold_left when compiling a sequence of
     * statements. Passes the resulting symbol table of each call to
     * compile_stmt to the next and accumulates the bytecodes of each. *)
    and stmt_sequence_helper res stmt =
        let (code, env) = compile_stmt (snd res) stmt in
        (fst res @ code, env)
    in

    (* Compile the body of a function or kernel decl. Return
     * (bcode list * symbol_table) *)
    let compile_body env body =
        List.fold_left stmt_sequence_helper ([], env) body
    in

    (* Compile a function declaration. *)
    let compile_func env f =
        let num_args = List.length f.args in
        (* Construct the local symbol table. Arguments are numbered starting at
         * -2 and decreasing by one in declaration order. We offset the first
         * local variable by the number of arguments. *)
        let func_env = { env with
            local_index =
                string_map_pairs StringMap.empty (enum (-1) (-2) f.args);
            local_offset = num_args;
            context = Function }
        in
        (* Compile the body and get the updated symbol table *)
        let (code, env) = compile_body func_env f.body in
        (* Patch the 'Ret' opcodes with the number of arguments. *)
        let fixed_code = List.map
            (function
                Ret(n) when n = 0 -> Ret(num_args)
              | _ as instr -> instr)
            code
        in
        (* Functions start with an 'Ent' to make room for the locals, and
         * always end in a 'return null;', even if it's unreachable. *)
        [Ent !(env.max_locals)] @ fixed_code @ [Push(Null); Ret(num_args)]
    in

    (* Compile a kernel declaration. *)
    let compile_kernel env k =
        (* Kernels behave a bit differently than functions - since each kernel
         * instance has its own stack, we start the arguments at 0 and go _up_
         * by one, in order. *)
        let arg_indices = enum 1 0 k.kargs in
        (* We need to separate the arguments a bit for some later verification.
         * This preserves their indices, though. *)
        let separate_args (args, channels) = function
            (idx, Input(id)) -> (args, (idx, In, id) :: channels)
          | (idx, Output(id)) -> (args, (idx, Out, id) :: channels)
          | (idx, BasicArg(id)) -> ((idx, id) :: args, channels)
        in
        let (args, channels) = List.fold_left separate_args ([], []) arg_indices
        in
        (* Set up the symbol table for the kernel body. *)
        let kernel_env = { env with
            (* Fold in the kernel IO functions *)
            function_index =
                StringMap.fold StringMap.add env.function_index kernel_funcs;
            (* Initial local variables are the non-channel arguments *)
            local_index =
                string_map_pairs StringMap.empty args;
            (* Register the channels, their stack locations, and their
             * directions *)
            channel_index =
                string_map_pairs
                    StringMap.empty
                    (List.map
                        (fun (idx, ctype, name) -> ((idx, ctype), name))
                        channels);
            (* Channels don't go in the local's map, so we need to offset all
             * local variables by the number of channels to make sure they
             * don't get overwritten *)
            local_offset = -(List.length channels);
            context = Kernel;
            max_locals = ref 0; }
        in
        let (code, env) = compile_body kernel_env k.kbody
        in
        (* Kernel bodies start with a 'Kent' to allocate room for the local
         * variables. Can't use an 'Ent' here because we don't need to store
         * the current fp, etc. *)
        [Kent(!(env.max_locals))] @ code @ [Term]
    in

    (* Compile a single statement in the global context. This includes kernel
     * instance creation. *)
    let compile_global_stmt env stmt =
        (* This function does all the work for instantiating a kernel. *)
        let instantiate_kernel env id args =
            (* Verify a single kernel arg - channel arguments get names of
             * channels, non-channel arguments are normal expressions. *)
            let construct_arg arg arg_def = match (arg, arg_def) with
                (* Channel arg - make sure the channel exists *)
                (Assignable(Id(name)), Input(_))
              | (Assignable(Id(name)), Output(_)) ->
                    let channel = (try (StringMap.find name env.channels)
                        with Not_found -> raise (Failure
                            ("Unknown channel \"" ^ name ^ "\"")))
                    in
                    ChannelArg(channel)
                (* Channel argument, but got expr *)
              | (_, Input(_ as def)) | (_, Output(_ as def)) -> raise
                    (Failure (Printf.sprintf
                        "Cannot use expression \"%s\" as channel argument %s"
                        (string_of_expr arg) def))
                (* Non-channel argument, just pass the expression. *)
              | (_ as expr, BasicArg(_)) -> ExprArg(expr)
            in

            (* This function makes sure that we haven't already connected this
             * end of the channel. If we haven't, store that it is now
             * connected. *)
            let verify_channel_arg env arg arg_def = match (arg, arg_def) with
                (Assignable(Id(name)), Input(_))
              | (Assignable(Id(name)), Output(_)) ->
                    (let usage = StringMap.find name env.channel_usage in
                    match (usage, arg_def) with
                        ((_, false), Input(_)) ->
                            { env with channel_usage =
                                StringMap.add name ((fst usage), true)
                                env.channel_usage }
                      | ((false, _), Output(_)) ->
                            { env with channel_usage =
                                StringMap.add name (true, (snd usage))
                                env.channel_usage }
                      | _ -> raise (Failure
                            ("Channel " ^ name ^ " is already bound")))
              | _ -> env
            in

            (* Make sure that the kernel we are trying to instantiate has
             * already been compiled. If so, get its index and arguments from
             * the symbol table. *)
            let (index, arg_defs) = (try (StringMap.find id env.kernel_index)
                with Not_found -> raise (Failure
                    ("Unknown kernel \"" ^ id ^ "\"")))
            in

            (* First argument pass *)
            let kernel_args = (try (List.map2 construct_arg args arg_defs)
                with Invalid_argument(_) -> raise
                    (Failure "Kernel arity mismatch")) (* TODO: error msg *)
            in

            (* Verify the channel connections *)
            let env = List.fold_left2 verify_channel_arg env args arg_defs
            in

            (* Generate code for the arguments *)
            let gen_arg_code = function
                ChannelArg(chan) -> [Push(Channel(chan))]
              | ExprArg(expr)    -> fst (compile_expr env expr)
            in

            let arg_code = List.concat (List.map gen_arg_code kernel_args)
            in

            (* 'Par' to spawn the kernel instance, preceded by arguments *)
            (arg_code @ [Par(index, List.length kernel_args)],
            { env with kernel_instances = env.kernel_instances + 1 })
        in

        (* Copy the environment, but set the context to global so all variables
         * are global variables, etc. *)
        let global_env = { env with context = Global; }
        in

        (* Check if this is a kernel instantiation or a normal function call *)
        match stmt with
            (* If this is a 'Call' AST node, and the id of the call is a
             * function, we output a normal function call. Otherwise, if the id
             * is a kernel, try to instantiate that kernel.
             * NOTE: the consequence is that if we have a function and a kernel
             * with the same name, you can never instantiate that kernal!
             * TODO: above should never happen *)
            ExprStmt(Call(id, args)) -> (try (compile_stmt global_env stmt)
                with _ -> instantiate_kernel env id (List.rev args))
            (* It's some other statement - just compile it as usual. *)
          | _                        -> compile_stmt global_env stmt
    in

    (* This bit of code assigns a unique integer id to each kernel and function
     * decl, starting at 0 and increasing by one. This will be used later to
     * calculate entry points *)
    let decl_indices =
        enum 1 0 (List.filter (fun (s, _, _) -> (String.length s) > 0)
            (List.map (fun part -> match part with
                FuncDecl(f) -> (f.name, List.length f.args, [])
              | KernelDecl(k) -> (k.kname, List.length k.kargs, k.kargs)
              | _ -> ("", 0, [])) prog))
    in

    (* Get the function indices by checking that the number of kernel args from
     * 'decl_indices' is 0 *)
    let function_indices =
        StringMap.fold StringMap.add builtin_funcs
        (string_map_pairs StringMap.empty
            (List.map (fun (idx, (name, nargs, _)) -> ((idx, nargs), name))
                (List.filter (fun (_, (_, _, kargs)) -> List.length kargs = 0)
                    decl_indices)))
    in

    (* Get the kernel indices by checking that the number of kernel args from
     * 'decl_indices' is greater than 0 *)
    let kernel_indices =
        string_map_pairs StringMap.empty
            (List.map (fun (idx, (name, _, kargs)) -> ((idx, kargs), name))
                (List.filter (fun (_, (_, _, kargs)) -> List.length kargs > 0)
            decl_indices))
    in

    (* Our initial symbol table *)
    let env = { function_index   = function_indices;
                kernel_index     = kernel_indices;
                global_index     = StringMap.empty;
                local_index      = StringMap.empty;
                channel_index    = StringMap.empty;
                max_locals       = ref 0;
                local_offset     = 0;
                context          = Global;
                channels         = builtin_channels;
                channel_usage    = builtin_channel_usage;
                kernel_instances = 0 }
    in

    (*StringMap.iter (fun name (idx, nargs) ->*)
    (*    Printf.printf "%s -> (idx: %d, nargs: %d)\n" name idx nargs) env.function_index;*)

    (* Compile each segment of the program in order. We store kernel and
     * function decl code seperately from global-scope code. *)
    let compile_part ((decls, globals), env) = function
        FuncDecl(f)     -> (((compile_func env f) :: decls, globals), env)
      | KernelDecl(k)   -> (((compile_kernel env k) :: decls, globals), env)
      | Stmt(s)         -> let (body, new_env) = compile_global_stmt env s in
            ((decls, body :: globals), new_env)
      | Channels(c)     ->
            let new_channels =
                List.fold_left string_map_append env.channels c
            and new_channel_usage =
                List.fold_left
                    (fun map name -> StringMap.add name (false, false) map)
                    env.channel_usage c
            in
            ((decls, globals), { env with
                channels = new_channels;
                channel_usage = new_channel_usage
            })
    in

    (* Actually compile the program *)
    let (bodies, env) = List.fold_left compile_part (([], []), env) prog
    in

    let decl_bodies = List.rev (fst bodies)
    in

    let global_stmts = List.rev (snd bodies)
    in

    (* Calculate the actual address of each decl *)
    let decl_offsets = Array.of_list (List.rev (fst (List.fold_left
        (fun (offsets, idx) body ->
            (idx :: offsets, (idx + List.length body)))
        ([], 1) decl_bodies)))
    in

    (*Array.iter (fun o -> Printf.printf "%d\n" o) decl_offsets;*)

    (* Get the total length of function and kernel code so we know where to
     * branch when the program starts. *)
    let decl_bodies_length = List.length (List.concat decl_bodies)
    in

    (* Generate the program text segment.
     * 1) Go back and patch 'Call' and 'Par' bcodes with the actual addresses of
     * the kernels or functions.
     * 2) Concatenate the global code after all of the decls *)
    let text = List.map
        (function
            Call(idx) when idx >= 0 -> Call(decl_offsets.(idx))
          | Par(idx, nargs) when idx >= 0 -> Par(decl_offsets.(idx), nargs)
          | _ as instr -> instr)
        (List.concat (decl_bodies @ global_stmts))
    in

    let num_globals = List.fold_left
        (fun num bcode -> match bcode with
            Strg(i) -> max num i
          | _ -> num)
        0 text
    in

    (* Output the final program - the text is starts with a branch to the
     * beginning of the global code. The last bytecode after the global
     * statements is the special 'Run' to start the parallel processing. *)
    {
        text = Array.of_list ([Ba(decl_bodies_length + 1)] @ text @ [Run]);
        num_globals = num_globals + 1; (* 0-indexed! *)
        num_channels = StringMap.cardinal env.channels;
        num_kernel_instances = env.kernel_instances;
    }
