(* generate code for java *)

open Ast

let package_del = "package poly;"

let import_decl = ""

let main_fdecl = "public static void main(String[] args) throws Exception\n{\npolyMain();\n}\n"

let type_of_expr global_vars local_vars exp =
  match exp with
    Lvalue(lv) ->
      (match lv with
        Id(s) ->
            (try let var = List.find (fun a -> if a.vname = s then true else false) (global_vars @ local_vars)
                in
                var.vtype
            with Not_found -> UnknownType
            )
      | PolyElmt(s, e) -> Float
      )
  | PolyLiteral(l) -> Poly
  | _ -> UnknownType


let jstring_of_datatype dtype =
  match dtype with
    Int -> "int"
  | Float -> "double"
  | Boolean -> "boolean"
  | Poly -> "PolynomialFunction"
  | String -> "String"
  | UnknownType -> "unknowntype"


let rec jstring_of_expr global_vars local_vars = function
    IntLiteral(l) -> string_of_int l
  | FloatLiteral(l) -> string_of_float l
  | BooleanLiteral(l) -> string_of_bool l
  | StringLiteral(l) -> l
  | PolyLiteral(l) -> "new PolynomialFunction(new double[]{ "^ String.concat ", " (List.map (jstring_of_expr global_vars local_vars) l)^" })"
  | PolyInit(l) -> "new PolynomialFunction(" ^ jstring_of_expr global_vars local_vars l ^ ")"
  | Lvalue(lv) -> jstring_of_lvalue global_vars local_vars lv
  | Binop(e1, o, e2) ->
      (match (type_of_expr global_vars local_vars e1) with
        Poly ->
          (jstring_of_expr global_vars local_vars e1 ^ "." ^
	    (match o with
	      Add -> "add"
	    | Sub -> "subtract"
	    | Mult -> "multiply"
	    | Div -> "divide"
	    | Lshift -> "shiftLeft"
	    | Rshift -> "shiftRight"
	    | Equal -> "equals"
	    | Neq -> "notEqual"
	    | _ -> "operator " ^ string_of_binop o ^ " not implemented for polynomials"
	    ) ^ "(" ^ (jstring_of_expr global_vars local_vars e2) ^ ")"
          )
      | _ ->
          ( "(" ^ jstring_of_expr global_vars local_vars e1 ^ " " ^
            (
            match o with
	      Add -> "+" | Sub -> "-" | Mult -> "*" | Div -> "/" | Lshift -> "<<" | Rshift -> ">>"
            | Equal -> "==" | Neq -> "!="
            | Less -> "<" | Leq -> "<=" | Greater -> ">" | Geq -> ">="
            ) ^ " " ^ jstring_of_expr global_vars local_vars e2 ^ ")"
          )
      )
  | Negate(e) ->
      (match (type_of_expr global_vars local_vars e) with
        Poly ->
          "(" ^ jstring_of_expr global_vars local_vars e ^ ".negate()" ^ ")"
        | _ -> "-" ^ jstring_of_expr global_vars local_vars e
      )
  | Assign(v, e) -> jstring_of_lvalue global_vars local_vars v ^ " = " ^ jstring_of_expr global_vars local_vars e
  | Call(f, el) ->
      (match f with
        "print" -> "System.out.print(" ^ (String.concat "" (List.map (jstring_of_expr global_vars local_vars) el)) ^ ")"
      | "println" -> "System.out.println(" ^ (String.concat "" (List.map (jstring_of_expr global_vars local_vars) el)) ^ ")"
      | "deg" -> (String.concat "" (List.map (jstring_of_expr global_vars local_vars) el)) ^ ".degree()"
      | _ -> f ^ "(" ^ String.concat ", " (List.map (jstring_of_expr global_vars local_vars) el) ^ ")"
      )
  | Noexpr -> ""

and jstring_of_lvalue global_vars local_vars = function
    Id(s) -> s
  | PolyElmt(s, e) -> s ^ ".coefficients[" ^ jstring_of_expr global_vars local_vars e  ^ "]"
  
let rec jstring_of_stmt global_vars local_vars = function
    Block(stmts) ->
      "{\n" ^ String.concat "" (List.map (jstring_of_stmt global_vars local_vars) stmts) ^ "}\n"
  | Expr(expr) -> jstring_of_expr global_vars local_vars expr ^ ";\n";
  | Return(expr) -> "return " ^ jstring_of_expr global_vars local_vars expr ^ ";\n";
  | If(e, s, Block([])) -> "if (" ^ jstring_of_expr global_vars local_vars e ^ ")\n" ^ jstring_of_stmt global_vars local_vars s
  | If(e, s1, s2) ->  "if (" ^ jstring_of_expr global_vars local_vars e ^ ")\n" ^
      jstring_of_stmt global_vars local_vars s1 ^ "else\n" ^ jstring_of_stmt global_vars local_vars s2
  | For(e1, e2, e3, s) ->
      "for (" ^ jstring_of_expr global_vars local_vars e1  ^ " ; " ^ jstring_of_expr global_vars local_vars e2 ^ " ; " ^
      jstring_of_expr global_vars local_vars e3  ^ ") " ^ jstring_of_stmt global_vars local_vars s
  | While(e, s) -> "while (" ^ jstring_of_expr global_vars local_vars e ^ ") " ^ jstring_of_stmt global_vars local_vars s

let jstring_of_vdecl vdecl =
  (jstring_of_datatype vdecl.vtype) ^ " " ^ vdecl.vname ^ ";\n"

let jstring_of_gvdecl gvdecl =
  "public static " ^ jstring_of_vdecl gvdecl

let jstring_of_formal formal =
  jstring_of_datatype formal.vtype ^ " " ^ formal.vname

let jstring_of_fdecl global_vars fdecl =
  let local_vars = (List.map (fun a -> { vname = a.vname; vtype = a.vtype }) fdecl.formals)
                   @ (List.map (fun a -> { vname = a.vname; vtype = a.vtype }) fdecl.locals)
  in
  (match fdecl.fname with
    "main" -> "static " ^ jstring_of_datatype fdecl.rtype ^ " polyMain()"
    | _ -> "static " ^ jstring_of_datatype fdecl.rtype ^ " " ^ fdecl.fname ^
           "(" ^  String.concat ", " (List.map jstring_of_formal fdecl.formals) ^ ")"
   ) ^ " throws Exception" ^
  "\n{\n" ^
  String.concat "" (List.map jstring_of_vdecl fdecl.locals) ^
  String.concat "" (List.map (jstring_of_stmt global_vars local_vars) fdecl.body) ^
  "}\n"
  
let jstring_of_program (vars, funcs) file_name =
  let global_vars = List.map (fun a -> { vname = a.vname; vtype = a.vtype }) vars
  in
  package_del ^ "\n" ^ import_decl ^ "\n\n" ^
  "public class " ^ (String.sub file_name 0 ((String.length file_name) - 5)) ^
  "\n{\n" ^ main_fdecl ^
  String.concat "" (List.map jstring_of_gvdecl vars) ^ "\n" ^
  String.concat "\n" (List.map (jstring_of_fdecl global_vars) funcs) ^
  "\n}"
