open Bytecode

type state = {
    stack: value array;
    fp       : int;
    sp       : int;
    pc       : int;
    finished : bool;
}

let stack_size = 1024

let run_program prog debug =
    let globals = Array.make prog.num_globals Null
    and text = prog.text
    and channels = Array.init prog.num_channels (fun _ -> Queue.create ())
    and kernel_states =
        Array.init prog.num_kernel_instances
            (fun _ -> {
                stack = Array.make stack_size Null;
                fp = 0; sp = 0; pc = 0; finished = false;
            })
    and next_state = ref 0
    and global_stack = Array.make stack_size Null
    and got_eof = ref false
    and stdin_lexbuf = Lexing.from_channel stdin
    in

    let read_token_from_stdin () =
        (*print_string "input> "; flush stdout;*)
        match Stdin_scanner.token stdin_lexbuf with
            Stdin_scanner.EOF -> raise End_of_file
          | Stdin_scanner.Value(_ as v) -> v
    in

    let ptr_of_value = function
        Int(i) -> i
      | _      -> raise (Failure ("Bad stack pointer"))
    in

    let dump_stack stack fp sp =
        print_endline "STACK";
        Array.iteri (fun i v -> if (i < sp) && (i >= fp) then
            Printf.printf "%3d: %s\n" i (string_of_value v)) stack;
        print_endline "END STACK"
    in

    let list_assign lst idx value =
        if idx >= List.length lst
        then raise
            (Failure (Printf.sprintf "List index %d is out of range" idx))
        else
            List.mapi (fun i v -> if i == idx then value else v) lst
    in

    let read_channel stack sp = match stack.(sp - 1) with
        Channel(0) -> (try (
            let token = read_token_from_stdin () in
            (*print_endline ("GOT TOKEN: \"" ^ string_of_value token ^ "\"");*)
            stack.(sp - 1) <- token; true)
            with End_of_file -> got_eof := true; false)
      | Channel(1) -> raise (Failure "Cannot read from stdout")
      | Channel(i) -> let channel = channels.(i) in
            (if Queue.is_empty channel then
                false
            else
                (stack.(sp - 1) <- Queue.pop channel; true))
      | _ as v -> raise (Failure (Printf.sprintf
            "Cannot read from non-channel \"%s\""
            (string_of_value v)))
    in

    let write_channel stack sp = match stack.(sp - 2) with
        Channel(0) -> raise (Failure "Cannot write to stdin")
      | Channel(1) -> print_endline (string_of_value stack.(sp - 1))
      | Channel(i) -> Queue.push stack.(sp - 1) channels.(i)
      | _          -> raise (Failure "Cannot write to non-channel")
    in

    (* Normalize types for arithmetic *)
    let normalize_arith_types lhs rhs = match (lhs, rhs) with
        (Int(_), Int(_))        -> (lhs, rhs)
      | (Int(l), Float(_))      -> (Float(float_of_int l), rhs)
      | (Int(_), Bool(r))       -> (lhs, Int(if r then 1 else 0))
      | (Float(_), Float(_))    -> (lhs, rhs)
      | (Float(_), Int(r))      -> (lhs, Float(float_of_int r))
      | (Float(_), Bool(r))     -> (lhs, Float(if r then 1.0 else 0.0))
      | (Bool(l), Int(_))       -> (Int(if l then 1 else 0), rhs)
      | (Bool(l), Float(_))     -> (Float(if l then 1.0 else 0.0), rhs)
      | (Bool(_), Bool(_))      -> (lhs, rhs)
      | (String(_), _)          -> (lhs, String(string_of_value rhs))
      | (_, String(_))          -> (String(string_of_value lhs), rhs)
      | _                       -> raise (Failure
            (Printf.sprintf "Unknown conversion: lhs = %s, rhs = %s\n"
                            (string_of_value lhs) (string_of_value rhs)))
    in

    let handle_binop stack op lhs rhs =
        let raise_binop_failure op (lhs, rhs) = raise
            (Failure (Printf.sprintf "Invalid binop %s for %s, %s"
                (string_of_bcode op)
                (string_of_value lhs)
                (string_of_value rhs)))
        in
        let normalized_args = normalize_arith_types lhs rhs
        in
        match op with
            Add -> (match normalized_args with
                (Int(l), Int(r))       -> Int(l + r)
              | (Float(l), Float(r))   -> Float(l +. r)
              | (String(l), String(r)) -> String(l ^ r)
              | _                   -> raise_binop_failure op normalized_args)
          | Sub -> (match normalized_args with
                (Int(l), Int(r))     -> Int(l - r)
              | (Float(l), Float(r)) -> Float(l -. r)
              | _                    -> raise_binop_failure op normalized_args)
          | Mul -> (match normalized_args with
                (Int(l), Int(r))     -> Int(l * r)
              | (Float(l), Float(r)) -> Float(l *. r)
              | _                    -> raise_binop_failure op normalized_args)
          | Div -> (match normalized_args with
                (Int(l), Int(r))     -> Int(l / r)
              | (Float(l), Float(r)) -> Float(l /. r)
              | _                    -> raise_binop_failure op normalized_args)
          | Mod -> (match normalized_args with
                (Int(l), Int(r))     -> Int(l mod r)
              | (Float(l), Float(r)) ->
                    Int((int_of_float l) mod (int_of_float r))
              | _                    -> raise_binop_failure op normalized_args)
          | Ceq -> Bool((fst normalized_args) = (snd normalized_args))
          | Cne -> Bool((fst normalized_args) <> (snd normalized_args))
          | Clt -> Bool(match normalized_args with
                (Int(l), Int(r))     -> l < r
              | (Float(l), Float(r)) -> l < r
              | _                    -> raise_binop_failure op normalized_args)
          | Cle -> Bool(match normalized_args with
                (Int(l), Int(r))     -> l <= r
              | (Float(l), Float(r)) -> l <= r
              | _                    -> raise_binop_failure op normalized_args)
          | Cgt -> Bool(match normalized_args with
                (Int(l), Int(r))     -> l > r
              | (Float(l), Float(r)) -> l > r
              | _                    -> raise_binop_failure op normalized_args)
          | Cge -> Bool(match normalized_args with
                (Int(l), Int(r))     -> l >= r
              | (Float(l), Float(r)) -> l >= r
              | _                    -> raise_binop_failure op normalized_args)
          | _ -> raise (Failure ("Unknown binop: " ^ string_of_bcode op))

    in

    let rec exec stack fp sp pc =
        if debug then (
            dump_stack stack 0 sp;
            Printf.printf "%d %d %d %s\n" fp sp pc (string_of_bcode text.(pc));
            flush stdout);
        match text.(pc) with
        Halt        -> raise (Failure "Halt")
      | Push(value) -> stack.(sp) <- value; exec stack fp (sp + 1) (pc + 1)
      | Pop         ->                      exec stack fp (sp - 1) (pc + 1)
      | Add | Sub | Mul | Div | Mod
      | Ceq | Cne | Clt | Cle | Cgt | Cge as op ->
            let lhs = stack.(sp - 2) and rhs = stack.(sp - 1) in
            stack.(sp - 2) <- handle_binop stack op lhs rhs;
            exec stack fp (sp - 1) (pc + 1)
      | And         -> stack.(sp - 2) <-
            Bool((bool_of_value stack.(sp - 2) &&
                 (bool_of_value stack.(sp - 1))));
            exec stack fp (sp - 1) (pc + 1)
      | Or          -> stack.(sp - 2) <-
            Bool((bool_of_value stack.(sp - 2) ||
                 (bool_of_value stack.(sp - 1))));
            exec stack fp (sp - 1) (pc + 1)
      | Strg(idx)   -> globals.(idx) <- stack.(sp - 1);
                       exec stack fp sp (pc + 1)
      | Ldg(idx)    -> stack.(sp) <- globals.(idx);
                       exec stack fp (sp + 1) (pc + 1)
      | Strl(idx)   -> stack.(fp + idx) <- stack.(sp - 1);
                       exec stack fp sp (pc + 1)
      | Ldl(idx)    -> stack.(sp) <- stack.(fp + idx);
                       exec stack fp (sp + 1) (pc + 1)
      | Call(-1)    -> print_endline (string_of_value stack.(sp - 1));
                       stack.(sp - 1) <- Null;
                       exec stack fp sp (pc + 1)
      | Call(-4)    ->
            (match stack.(sp - 1) with
                VList(l) -> stack.(sp - 2) <- VList(l @ [stack.(sp - 2)])
              | _        -> raise (Failure "Can only append() to list"));
            exec stack fp (sp - 1) (pc + 1)
      | Call(-5)    ->
            (match stack.(sp - 1) with
                VList(l) -> stack.(sp - 1) <- Int(List.length l)
              | _        -> raise (Failure "Can only call length() on a list"));
            exec stack fp sp (pc + 1)
      | Call(idx)   -> stack.(sp) <- Int(pc + 1); exec stack fp (sp + 1) idx
      | Ret(nargs)  -> let ret_fp = ptr_of_value (stack.(fp))
                       and ret_pc = ptr_of_value (stack.(fp - 1)) in
                       stack.(fp - 1 - nargs) <- stack.(sp - 1);
                       exec stack ret_fp (fp - nargs) ret_pc
      | Bz(idx)     -> exec stack fp (sp - 1)
                       (if not (bool_of_value stack.(sp - 1)) then
                            pc + idx
                        else
                            pc + 1)
      | Bnz(idx)    -> exec stack fp (sp - 1)
                       (if bool_of_value stack.(sp - 1) then
                           pc + idx
                       else
                           pc + 1)
      | Ba(idx)     -> exec stack fp sp (pc + idx)
      | Mlst(n)     ->
            stack.(sp - n) <- VList(Array.to_list (Array.sub stack (sp - n) n));
            exec stack fp (sp - n + 1) (pc + 1)
      | Ilst        ->
            (match stack.(sp - 2) with
                VList(l) -> (match stack.(sp - 1) with
                    Int(i) -> stack.(sp - 2) <- List.nth l i
                  | _      -> raise (Failure "Can only index list with int"))
              | _        -> raise (Failure "Cannot index non-list"));
            exec stack fp (sp - 1) (pc + 1)
      | Alst        ->
            (match stack.(sp - 3) with
                VList(l) -> (match stack.(sp - 2) with
                    Int(i) ->
                        stack.(sp - 3) <- VList(list_assign l i stack.(sp - 1))
                  | _ -> raise (Failure "Can only index list with int"))
              | _ -> raise (Failure "Cannot index non-list"));
            exec stack fp (sp - 2) (pc + 1)
      | Read        ->
            if (read_channel stack sp) then
                exec stack fp sp (pc + 1)
            else
                (fp, sp, pc,
                    (if stack.(sp - 1) = Channel(0) then !got_eof else false))
      | Write       -> write_channel stack sp; exec stack fp (sp - 1) (pc + 1)
      | Ent(num)    -> stack.(sp) <- Int(fp);
                       exec stack sp (sp + num + 1) (pc + 1)
      | Par(idx, n) ->
            let state = kernel_states.(!next_state) in
            Array.blit stack (sp - n) state.stack 0 n;
            kernel_states.(!next_state) <- { state with sp = n; pc = idx };
            next_state := (!next_state) + 1;
            exec stack fp (sp - n) (pc + 1)
      | Kent(n)     -> exec stack fp (sp + n) (pc + 1)
      | Term        -> (fp, sp, pc, true)
      | Run         -> (*if debug then Array.iteri
            (fun n s -> Printf.printf "state %d: %d %d %d\n" n s.fp s.sp s.pc)
            kernel_states; *)
            run_kernels ();
            (fp, sp, (pc + 1), true)
    and run_kernels () =
        if (Array.length kernel_states) = 0 ||
            Array.fold_left (fun f s -> f || s.finished) false kernel_states
        then (if debug then print_endline "program finished")
        else (Array.iteri (fun i s ->
            let (fp, sp, pc, finished) = exec s.stack s.fp s.sp s.pc in
                kernel_states.(i) <-
                    { s with finished = finished; fp = fp; sp = sp; pc = pc})
            kernel_states;
            run_kernels ())
    in

    exec global_stack 0 0 0

