(****************************************************************************
 *
 * File: jvm.ml
 *
 * Purpose: convert the bytecode into jvm instructions that can be compiled
 * into a java class using jasmin.
 *
 * Also, the created class file is add to the runtime class files and a 
 * single jar file is produced that can be run.
 *
 *)

(* we use strings for constants so that any transformations are performed by
 * the jasmin assembler, this is mainly a concern for float constants
 *)

type jcode = 
  | Dup 
  | Pop
  | Swap
  | Ldc_s of string   
  | Ldc_i of string 
  | Ldc_f of string
  | Bipush of string
  
  | Iload of int
  | Istore of int
  | Isub
  | Iadd
  | Idiv
  | Imul

  | Fload of int
  | Fstore of int
  | Fsub
  | Fadd
  | Fdiv
  | Fmul
  | Fcmp 

  | Astore of int
  | Aload of int

  | GetStatic of string * string
  | PutStatic of string * string

  | Invokespecial of string
  | Invokevirtual of string
  | Invokestatic of string

  | Anewarray of string
  | Aastore 

  | Ifeq of int
  | Ifne of int
  | Ifgt of int
  | Ifge of int
  | Iflt of int
  | Ifle of int

  | If_icmpeq of int
  | If_icmpne of int
  | If_icmplt of int
  | If_icmple of int
  | If_icmpgt of int
  | If_icmpge of int
  | Ifnonnull of int

  | Bmethod of string * int * int
  | Emethod of string
  | Returni 
  | Returnf 
  | Returna 
  | Return 

  | Comment of string
  | Label of int
  | Goto of int

  | I2c
  | I2b
  | F2i
  | Nop

let to_jvm_type = function
  | Ast.Char    -> "C"  
  | Ast.Float   -> "F"
  | Ast.Int     -> "I"
  | Ast.String  -> "Ljava/lang/String;"
  | Ast.Bool    -> "I"
  | Ast.Map     -> "Ljava/util/Map;"
  | Ast.Void    -> "V"
  | Ast.StringArray -> "[Ljava/lang/String;"
  | Ast.FileIn  -> "LRtl/FileIn;"
  | Ast.FileOut -> "LRtl/FileOut;"
  
let to_param_str = function 
  | Ast.Void -> "" (* void parameter is an empty () *)
  | _ as typ -> to_jvm_type typ

let to_bmethod name params rtn stack locals =
  let param_names = String.concat "" (List.map to_param_str params) in 
  let signature = Printf.sprintf ".method public static %s(%s)%s" name param_names (to_jvm_type rtn) in
  Bmethod(signature, stack, locals)

(****************************************************************************
 * Make the proper jasmine function signature to make a call
 *)
let make_signature name params rtn =
  let param_names = String.concat "" (List.map to_param_str params) in 
  Printf.sprintf "%s(%s)%s" name param_names (to_jvm_type rtn)

let to_return_type = function
  | Ast.Bool | Ast.Int | Ast.Char -> Returni
  | Ast.Float -> Returnf
  | Ast.Map | Ast.String | Ast.StringArray | Ast.FileIn | Ast.FileOut -> Returna
  | Ast.Void -> Return


let to_jcode ir_code = 
  let module BC = Bytecode in
  (*
  let to_return_code = function
    | Ast.Bool | Ast.Int | Ast.Char -> Returni
    | Ast.Float -> Returnf
    | Ast.Void -> Return
    | _ -> Returna
  in
  *)

  let make_math_op op icode fcode = function 
    | Ast.Int -> icode
    | Ast.Float -> fcode
    | _ as typ -> 
      begin 
        Error.internal_error ("cannot perform arithmetic on objects of type '" ^ (Ast.string_of_type typ) ^ "' with '" ^ op ^ "'");
        [Nop]
      end
  in

  let to_op_code typ = function
    | Ast.Add -> 
        make_math_op "+" [Iadd] [Fadd] typ
    | Ast.Sub -> 
        make_math_op "-" [Isub] [Fsub] typ
    | Ast.Mult -> 
        make_math_op "*" [Imul] [Fmul] typ
    | Ast.Div ->  
        make_math_op "/" [Idiv] [Fdiv] typ

    (* not supported yet *)
    | Ast.LAnd ->
        make_math_op "&&" [Nop] [Nop] typ
    | Ast.LOr -> 
        make_math_op "||" [Nop] [Nop] typ

    (* sanity check *)
    | Ast.Equal | Ast.Neq | Ast.Less | Ast.Leq | Ast.Greater | Ast.Geq | Ast.NotNull -> 
      begin
        Error.internal_error "attempt to match comparison operators with mathemcatical operators";
        [Nop]
      end
  in

  let make_cmp_op op icode fcode scode = function 
    | Ast.Int | Ast.Char | Ast.Bool -> icode
    | Ast.Float -> fcode
    | Ast.String as typ -> 
      if List.length scode > 0 then 
        scode
      else 
        begin
          Error.internal_error ("cannot compare objects of type '" ^ (Ast.string_of_type typ) ^ "' with '" ^ op ^ "'");
          [Nop]
        end
    | _ as typ -> 
      begin 
        Error.internal_error ("cannot compare objects of type '" ^ (Ast.string_of_type typ) ^ "' with '" ^ op ^ "'");
        [Nop]
      end
  in

  (* lbl = label number when comparision is false *)
  let to_cmp_code typ lbl = function
    | Ast.Equal -> 
        make_cmp_op "==" [If_icmpne(lbl)] [Fcmp; Ifne(lbl)] 
        [Invokestatic("Rtl/equals(Ljava/lang/String;Ljava/lang/String;)Z"); Ifne(lbl)] typ
    | Ast.Neq -> 
        make_cmp_op "!=" [If_icmpeq(lbl)] [Fcmp; Ifeq(lbl)] 
        [Invokestatic("Rtl/equals(Ljava/lang/String;Ljava/lang/String;)Z"); Ifeq(lbl)] typ
    | Ast.Less -> 
        make_cmp_op "<" [If_icmpge(lbl)] [Fcmp; Ifge(lbl)] [] typ
    | Ast.Leq ->  
        make_cmp_op "<=" [If_icmpgt(lbl)] [Fcmp; Ifgt(lbl)] [] typ
    | Ast.Greater ->
        make_cmp_op "<=" [If_icmple(lbl)] [Fcmp; Ifle(lbl)] [] typ
    | Ast.Geq -> 
        make_cmp_op "<=" [If_icmplt(lbl)] [Fcmp; Iflt(lbl)] [] typ
    | Ast.NotNull -> 
        [Ifnonnull(lbl)] 
    | Ast.Add | Ast.Sub | Ast.Mult | Ast.Div | Ast.LAnd | Ast.LOr -> 
      begin
        Error.internal_error "attempt to match comparison operators with mathemcatical operators";
        [Nop]
      end
  in

  let to_load i = function
    | Ast.Int -> [Iload(i)]
    | Ast.Char -> [Iload(i)]
    | Ast.Float -> [Fload(i)]
    | Ast.String -> [Aload(i)]
    | Ast.StringArray -> [Aload(i)]
    | Ast.FileIn -> [Aload(i)]
    | Ast.FileOut -> [Aload(i)]
    | Ast.Map -> [Aload(i)]
    | Ast.Void | Ast.Bool -> [Nop]
  in

  let to_store i = function
    | Ast.Int -> [Istore(i)]
    | Ast.Char -> [Istore(i)]
    | Ast.Float -> [Fstore(i)]
    | Ast.String -> [Astore(i)]
    | Ast.StringArray -> [Astore(i)]
    | Ast.FileIn -> [Astore(i)]
    | Ast.FileOut -> [Astore(i)]
    | Ast.Map -> [Astore(i)]
    | Ast.Void | Ast.Bool -> [Nop]
  in

  let to_get_field name = function
    | Ast.Int -> [GetStatic(name,"I")]
    | Ast.Char -> [GetStatic(name,"C")]
    | Ast.Float -> [GetStatic(name,"F")]
    | Ast.String -> [GetStatic(name,"Ljava/lang/String;")]
    | Ast.StringArray -> [GetStatic(name,"[Ljava/lang/String;")]
    | Ast.FileIn -> [GetStatic(name,"LRtlFileIn;")]
    | Ast.FileOut -> [GetStatic(name,"LRtlFileOut;")]
    | Ast.Map -> [GetStatic(name,"Ljava/util/Map;")]
    | Ast.Void | Ast.Bool -> [Nop]
  in
  
  let to_put_field name = function
    | Ast.Int -> [PutStatic(name,"I")]
    | Ast.Char -> [PutStatic(name,"C")]
    | Ast.Float -> [PutStatic(name,"F")]
    | Ast.String -> [PutStatic(name,"Ljava/lang/String;")]
    | Ast.StringArray -> [PutStatic(name,"[Ljava/lang/String;")]
    | Ast.FileIn -> [PutStatic(name,"LRtlFileIn;")]
    | Ast.FileOut -> [PutStatic(name,"LRtlFileOut;")]
    | Ast.Map -> [PutStatic(name,"Ljava/util/Map;")]
    | Ast.Void | Ast.Bool -> [Nop]
  in

  match ir_code with
    | BC.Dup -> [Dup]
    | BC.Pop -> [Pop]
    | BC.Swap -> [Swap]
    | BC.Pushb("true") -> [Bipush("1")]
    | BC.Pushb(_) -> [Bipush("0")]
    | BC.Pushc(s) -> [Bipush(string_of_int (int_of_char (String.get s 0)))]
    | BC.Pushf(s) -> [Ldc_f(s)]
    | BC.Pushs(s) -> [Ldc_s(s)]
    | BC.Pushi(s) -> [Ldc_i(s)]
    | BC.Binop(op, typ) -> to_op_code typ op
    | BC.Cmp(op, typ, lbl) -> List.rev(to_cmp_code typ lbl op)
    | BC.Load(i,typ) -> to_load i typ
    | BC.Store(i,typ) -> to_store i typ
    | BC.GetGlobal(n,typ) -> to_get_field n typ
    | BC.PutGlobal(n,typ) -> to_put_field n typ
    | BC.Call(s,_)  -> [Invokestatic(s)]
    | BC.CallV(s,_)  -> [Invokevirtual(s)]
    | BC.Bfunc(n,p,r,s,l) -> [to_bmethod n p r s l] 
    | BC.Efunc(s) -> [Emethod(s)]
    | BC.Return(typ) -> [to_return_type typ]
    | BC.Beq(i) -> [Ifeq(i)]
    | BC.Bne(i) -> [Ifne(i)]
    | BC.Jump(i) ->    [Goto(i)] 
    | BC.Label(i) -> [Label(i)]
    | BC.Halt   ->     [Nop]
    | BC.Nop   ->     [Nop]
    | BC.Comment(s) -> [Comment(s)]

let to_jvm_code (ir, globals) = 
  let fixup_calls ir = 
    let module BC = Bytecode in
    let fix_call = function
      | BC.Call(s,true) -> begin
          let var = Symbol.find_function s in
          BC.Call(make_signature var.Symbol.fid var.Symbol.args var.Symbol.rtn, false)
        end
      | BC.CallV(s,true) -> begin
          let var = Symbol.find_function s in
          BC.CallV(make_signature var.Symbol.fid var.Symbol.args var.Symbol.rtn, false)
        end
      | _ as instr -> instr
    in
    List.map fix_call ir
  in
  let rcode = List.map (fun elt -> to_jcode elt) (fixup_calls ir) in
  let code = List.rev rcode in

  (List.rev (List.fold_left (fun elt acc -> elt @ acc) [] code), globals)

(*
 * Empty strings must be represented by ""
 *)
let fix_empty_str = function
  | "" -> "\"\""
  | _ as s -> s
  

(*
 * Convert the IR bytecode into equvialent jvm byte code, a single
 * byte code might require more than one jvm byte codes
 *)
let jcode_to_string cname = function
  | Dup -> "dup"
  | Pop -> "pop"
  | Swap -> "swap"
  | Ldc_s(s) -> "ldc " ^ fix_empty_str s   
  | Ldc_i(s) -> "ldc " ^ fix_empty_str s
  | Ldc_f(s) -> "ldc " ^ fix_empty_str s
  | Bipush(s) -> "bipush " ^ fix_empty_str s 
  | Iload(i) -> "iload " ^ string_of_int i
  | Istore(i) -> "istore " ^ string_of_int i
  | Isub -> "isub" 
  | Iadd -> "iadd" 
  | Idiv -> "idiv" 
  | Imul -> "imul" 
  | Fload(i) -> "fload " ^ string_of_int i
  | Fstore(i) -> "fstore " ^ string_of_int i
  | Fsub -> "fsub"
  | Fadd -> "fadd"
  | Fdiv -> "fdiv"
  | Fmul -> "fmul"
  | Fcmp -> "fcmp"
  | Astore(i) -> "astore " ^ string_of_int i
  | Aload(i) -> "aload " ^ string_of_int i
  | GetStatic(n,t) -> "getstatic " ^ cname ^ "/" ^ n ^ " " ^ t
  | PutStatic(n,t) -> "putstatic " ^ cname ^ "/" ^ n ^ " " ^ t
  | Invokespecial(s) -> "invokespecial " ^ s
  | Invokevirtual(s) -> "invokevirtual " ^ s
  | Invokestatic(s) -> "invokestatic " ^ s
  | Anewarray(s) -> "anewarray"
  | Aastore -> "aastore"
  | Ifeq(i) -> "ifeq Label" ^ string_of_int i
  | Ifne(i) -> "ifne Label" ^ string_of_int i
  | Ifgt(i) -> "ifgt Label" ^ string_of_int i
  | Ifge(i) -> "ifge Label" ^ string_of_int i
  | Iflt(i) -> "iflt Label" ^ string_of_int i
  | Ifle(i) -> "ifle Label" ^ string_of_int i
  | If_icmpeq(i) -> "if_icmpeq Label" ^ string_of_int i
  | If_icmpne(i) -> "if_icmpne Label" ^ string_of_int i
  | If_icmplt(i) -> "if_icmplt Label" ^ string_of_int i
  | If_icmple(i) -> "if_icmple Label" ^ string_of_int i
  | If_icmpgt(i) -> "if_icmpgt Label" ^ string_of_int i
  | If_icmpge(i) -> "if_icmpge Label" ^ string_of_int i
  | Ifnonnull(i) -> "ifnonnull Label" ^ string_of_int i
  | Bmethod(s,stack,locals) -> s ^ (Printf.sprintf "\n\t.limit stack %d\n\t.limit locals %d" stack locals)
  | Emethod(s) -> ".end method\n"
  | Returni -> "ireturn"
  | Returnf -> "freturn"
  | Returna -> "areturn"
  | Return -> "return"
  | Comment(s) -> "; " ^ s
  | Label(i) -> "Label" ^ string_of_int i ^ ":"
  | Goto(i) -> "goto Label" ^ string_of_int i
  | I2c -> "i2c"
  | I2b -> "i2b"
  | F2i -> "f2i"
  | Nop -> ""



(****************************************************************************
 *  Write the code to a file 
 *)
let assemble filename (code, globals) =
  let bname = Filename.basename filename in
  let name = Filename.chop_extension bname in
  let gen_header name = 
    let fmt : ('a, 'b, 'c ) format = 
      ".source %s.j\n" ^^
      ".class public %s\n" ^^
      ".super java/lang/Object\n\n" in
    Printf.sprintf fmt name name
  in
  let header = gen_header name in
  let to_string code =
    List.map (fun elt ->  
      let prefix = 
        match elt with 
          | Label(_) | Bmethod(_,_,_) | Emethod(_) -> ""
          | _ -> "\t" in
      prefix ^ (jcode_to_string name elt)) code
  in
  let body = to_string code |> String.concat "\n" in
  let _init_ =
      ".method public <init>()V\n" ^
      "\taload_0\n" ^
      "\tinvokenonvirtual java/lang/Object/<init>()V\n" ^
      "\treturn\n" ^
      ".end method\n\n" in
  let generate_field (name, var) = 
    Printf.sprintf ".field private static %s %s" name (to_jvm_type var.Symbol.vtype)
  in
  let fields = String.concat "\n" (List.map generate_field globals) in 

  try 
    let oname = (Filename.chop_extension filename) ^ ".j" in
    let () = if Sys.file_exists oname then Sys.remove oname else () in 
    let oc = open_out_gen [Open_creat; Open_text; Open_wronly] 0o640 oname in

    Printf.fprintf oc "%s%s\n\n%s\n%s\n" header fields _init_ body;

    close_out oc;
    true 
  with Failure(s) -> 
    Error.report s;
    false

(****************************************************************************
 * Combine the runtime files into a single jar with the program
 *)
let link filename flag =
  if flag = false then
    false 
  else begin
    let path = Filename.dirname filename in
    let bname = Filename.basename filename in
    let name = Filename.chop_extension bname in
    let rtl_path = try Sys.getenv "RTL" with Not_found -> "./" in
    let rt_classes = ["Rtl.class"; "RtlException.class"; "RtlInFile.class"; "RtlOutFile.class"] in 
    let rt_files = List.map (fun elt -> rtl_path ^ "/" ^ elt) rt_classes in
    let files = String.concat " " rt_files in
    let cmd = Printf.sprintf "jar cfe %s.jar %s %s.class %s" name name name files in 
    (*let () = print_endline cmd in*) 
    let rcode = Sys.command cmd in

    (* clean up class file *)
    let () = Sys.remove (name ^ ".class") in 

    if rcode != 0 then
      begin
        Error.report ("linking failed, error " ^ string_of_int rcode);
        Error.report ("command [" ^ cmd ^ "]");
        false
      end  
    else 
      begin
        (* move jar to same directory as source *)
        Sys.rename (name ^ ".jar") (path ^ "/" ^ name ^ ".jar");
        true
      end
  end


(****************************************************************************
 * Call jasmin on the assembly file
 *)
let compile filename flag =
  if flag = false then
    false
  else 
    begin
      let iname = (Filename.chop_extension filename) ^ ".j" in
      let cmd = Printf.sprintf "jasmin %s" iname in
      (*let () = print_endline cmd in*) 
      let rcode = Sys.command cmd in

      if rcode = 0 then
        link filename true 
      else begin 
        Error.report ("assembling failed, error " ^ string_of_int rcode);
        false
      end
    end

