open Sast

let prelude = [ "#from plog import *";
    "from __future__ import print_function";  (* NOTE: __future__ is only needed for calling Python's print(). *)
    "from collections import defaultdict";
    "def PLOG_append( elem, container ):\n"^
    "\tcontainer.append( elem )";
    "def PLOG_remove( elem, container ):\n"^
    "\tif elem in container:\n"^
    "\t\tcontainer.remove( elem )";
    "def PLOG_length( container ):\n"^
    "\treturn len( container )";
    "def PLOG_print( string, *args ):\n"^
    "\tprint( string.replace(\"%d\",\"{}\").replace(\"%b\",\"{}\").replace(\"%s\",\"{}\").format( *args ), end='' )";
    "class PLOG_NODE():\n"^
    "\tdef __init__(self):\n"^
    "\t\tself.in_edges = set()\n"^
    "\t\tself.out_edges = set()\n"^
    "\t\tself.props = defaultdict( lambda: None )";  (* For == NIL checks. *)
    "class PLOG_EDGE():\n"^
    "\tdef __init__(self, fromnode=None, t=None, tonode=None):\n"^ (* `=None`s were included only because we support declaration of "edges" locally, and they have to be initialized to something (mainly for access to props) *)
    "\t\tself.from_node = fromnode\n"^
    "\t\tself.to_node = tonode\n"^
    "\t\tself.etype = t\n"^
    "\t\tself.props = defaultdict( lambda: None )";
    "class PLOG_GRAPH():\n"^
    "\tdef __init__(self):\n"^
    "\t\tself.nodes = {}\n"^
    "\t\tself.edges = {}\n" ]

let epilogue = [ "if __name__ == '__main__':\n"^
    "\tPUSER_main()" ]

let translate (nodes, graphs, funcs) =
    let addTab s = "\t"^ s
    in

    (* ******************** *)

    let rec string_of_expr = function
        IntLit(i) -> string_of_int i
        | BoolLit(b) -> if b then "True" else "False"
        | StrLit(s) -> "\""^ String.escaped s ^"\""
        | Lst(t, el) -> "["^ String.concat ", " (List.map string_of_expr el) ^"]"
        | Id(s) -> "PUSER_"^ s
        | Unop(o, e) -> (match o with
            Neg -> "(-"^ string_of_expr e ^")"
            | Not -> "(not "^ string_of_expr e ^")")
        | Biop(e1, o, e2) -> "("^ string_of_expr e1 ^" "^ string_of_biop o ^" "^ string_of_expr e2 ^")"
        | Call(s, el) -> (if s = "print" || s = "append" || s = "length" || s = "remove" then
            "PLOG_" else "PUSER_") ^ s ^"( "^ String.concat ", " (List.map string_of_expr el) ^" )"
        | GraphAccess(s, ge, locals) ->
            "PUSER_"^ (match ge with
            NodeId(id) -> if List.exists (fun x -> x = id) locals then id
                else s ^".nodes.get( '"^ id ^"', None )"
            | EdgeId(fn, et, tn) -> s ^".edges.get( (PUSER_"^
                (if List.exists (fun x -> x = fn) locals then fn else s ^".nodes.get('"^ fn ^"', None)") ^
                ", "^ (if List.exists (fun x -> x = et) locals then "PUSER_"^ et else "'"^ et ^"'") ^
                ", PUSER_"^ (if List.exists (fun x -> x = tn) locals then tn else s ^".nodes.get('"^ tn ^"', None)") ^
                "), None )"
            )
        | Property(e, s) -> string_of_expr e ^".props[ '"^ s ^"' ]"
        | Nil(t) -> "None"
        | Inf -> "float('Inf')"
    in

    let string_of_pred (prop, ival) =
        "'"^ prop ^"' in n.props and n.props['"^ prop ^"'] == "^ string_of_int ival
    in

    let get_node_forloop nidpatt npatt locals e eisgraph namednode =
        let patt_pids_length = List.length npatt.pids in

        let nidpatt_index =
            let rec get_nidpatt_index idx = function
                [] -> raise (Failure "Not found!")
                | nid :: [] -> if nid = nidpatt then idx else raise (Failure "Not found!")
                | nid :: etype :: xs -> if nid = nidpatt then idx else get_nidpatt_index (idx+1) xs
            in
            get_nidpatt_index 1 npatt.pids
        in

        let get_node_forloop_body =
            if patt_pids_length = 1 then (
                (* If the "first" node in the pattern is a local var, next line handles that *)
                (if List.exists (fun x -> x = List.nth npatt.pids 0) locals then
                ("n1_nodes = [ n for n in n1_nodes if n == PUSER_"^ List.nth npatt.pids 0 ^" ]")
                else "") ::

                (if List.length npatt.preds = 0 then ""
                else "n1_nodes = [ n for n in n1_nodes if "^
                    (String.concat " and " (List.map string_of_pred npatt.preds))
                    ^" ]")
                ::
                "for n1 in n1_nodes:"
                ::
                [ "\tmatching_nodes.add( n1 )" ]
            ) else (
                let rec funcy i =
                    ("n"^ string_of_int i ^"_nodes = [ e.from_node for e in n"^ string_of_int (i-1) ^
                    ".in_edges if e.etype == '"^ (List.nth npatt.pids (i*2 - 3)) ^"' ]")
                    ::
                    (if List.exists (fun x -> x = List.nth npatt.pids (i*2 - 2)) locals then
                    ("n"^ string_of_int i ^"_nodes = [ n for n in n"^ string_of_int i ^"_nodes if n == PUSER_"^ List.nth npatt.pids (i*2 - 2) ^" ]")
                    else "")
                    ::
                    (if List.length npatt.preds = 0 then ""
                    else "n"^ string_of_int i ^"_nodes = [ n for n in n"^ string_of_int i ^"_nodes if "^
                        String.concat " and " (List.map string_of_pred npatt.preds)
                        ^" ]")
                    ::
                    (if i*2 - 1 = patt_pids_length then
                        (("for n"^ string_of_int i ^" in n"^ string_of_int i ^"_nodes:")
                        ::
                        [ "\tmatching_nodes.add( n"^ string_of_int nidpatt_index ^" )" ])
                    else (("for n"^ string_of_int i ^" in n"^ string_of_int i ^"_nodes:")
                        ::
                        (List.map addTab (funcy (i+1)))) )
                in

                (* If the "first" node in the pattern is a local var, next line handles that *)
                (if List.exists (fun x -> x = List.nth npatt.pids 0) locals then
                    ("n1_nodes = [ n for n in n1_nodes if n == PUSER_"^ List.nth npatt.pids 0 ^" ]")
                else "")
                ::
                (if List.length npatt.preds = 0 then ""
                else "n1_nodes = [ n for n in n1_nodes if "^
                    String.concat " and " (List.map string_of_pred npatt.preds)
                    ^" ]")
                ::
                "for n1 in n1_nodes:"
                ::
                (List.map addTab (funcy 2))
            )
        in
        "matching_nodes = set()\n" ::
        ("n1_nodes = "^ string_of_expr e ^
            (if eisgraph then ".nodes.copy().values()" else "")) ::

        (* String.concat "\n\t" get_node_forloop_body :: *)
        get_node_forloop_body @

        ((match namednode with None -> ""
            (* If we're iterating with a named node, get the subset of the above results which match the type *)
            | Some(name) -> ("matching_nodes = PLOG_GETTNN_"^ name ^"( matching_nodes )\n"))
        (* Now we set ourselves for the for loop block, where we'll actually iterate over the matching nodes *)
        :: ["for PUSER_"^ nidpatt ^" in matching_nodes:"])
    in

    let string_of_node_decl node_decl =
        let nidpatt = node_decl.nidpatt in
        let npatt = node_decl.npatt in

        let patt_pids_length = List.length npatt.pids in (* NOTE: We assume it's odd, based on static sem check *)

        (* This func takes a patt (and nid and idx accumulator) and
        returns the 1-based index of the nid from the `right` of the pattern.
        Assumes List.rev was NOT called on the pids *)
        let nidpatt_index =
            let rec get_nidpatt_index idx = function
                [] -> raise (Failure "Not found!")
                | nid :: [] -> if nid = nidpatt then idx else raise (Failure "Not found!")
                | nid :: etype :: xs -> if nid = nidpatt then idx else get_nidpatt_index (idx+1) xs
            in
            get_nidpatt_index 1 npatt.pids
        in

        let get_node_body =
            if patt_pids_length = 1 then (
                (if List.length npatt.preds = 0 then ""
                else "\tn1_nodes = [ n for n in n1_nodes if "^
                    (String.concat " and " (List.map string_of_pred npatt.preds))
                    ^" ]")
                ::
                "for n1 in n1_nodes:"
                ::
                [ "\tmatching_nodes.add( n1 )" ]
            ) else (
                let rec funcy i =
                    ("\tn"^ string_of_int i ^"_nodes = [ e.from_node for e in n"^ string_of_int (i-1) ^
                    ".in_edges if e.etype == '"^ (List.nth npatt.pids (i*2 - 3)) ^"' ]")
                    ::
                    (if List.length npatt.preds = 0 then ""
                    else "\tn"^ string_of_int i ^"_nodes = [ n for n in n"^ string_of_int i ^"_nodes if "^
                        String.concat " and " (List.map string_of_pred npatt.preds)
                        ^" ]")
                    ::
                    (if i*2 - 1 = patt_pids_length then
                        (("\tfor n"^ string_of_int i ^" in n"^ string_of_int i ^"_nodes:")
                        ::
                        [ "\t\tmatching_nodes.add( n"^ string_of_int nidpatt_index ^" )" ])
                    else (("\tfor n"^ string_of_int i ^" in n"^ string_of_int i ^"_nodes:")
                        ::
                        (List.map addTab (funcy (i+1)))) )
                in

                (if List.length npatt.preds = 0 then ""
                else "\tn1_nodes = [ n for n in n1_nodes if "^
                    String.concat " and " (List.map string_of_pred npatt.preds)
                    ^" ]")
                ::
                "for n1 in n1_nodes:"
                ::
                funcy 2
            )
        in
        "def PLOG_GETTNN_"^ node_decl.nname ^"( n1_nodes ):\n"^
        "\tmatching_nodes = set()\n"^
        String.concat "\n\t" get_node_body ^
        "\n\treturn list( matching_nodes )"
    in

    let string_of_graph_def gname gbody locals =
        let gnodes = "PUSER_"^ gname ^".nodes" in
        let gedges = "PUSER_"^ gname ^".edges" in

        let string_of_nodedec = function
            NodeId(id) -> if List.exists (fun x -> x = id) locals then
                    (("if PUSER_"^ id ^" not in "^ gnodes ^".copy().values():") ::
                    ("\tfint = 1") ::
                    ("\twhile '"^ id ^"'+ str(fint) in "^ gnodes ^":") ::
                    ("\t\tfint += 1") ::
                    ("\tif PUSER_"^ id ^" == None:") ::
                    ("\t\tPUSER_"^ id ^" = PLOG_NODE()") ::
                    ["\t"^ gnodes ^"[ '"^ id ^"'+ str(fint) ] = PUSER_"^ id])
                else (("if '"^ id ^"' not in "^ gnodes ^":") ::
                    ["\t"^ gnodes ^"[ '"^ id ^"' ] = PLOG_NODE()"])
        in

        let string_of_nodeprop sil = function
            (* NOTE: This function assumes that the node exists in the graph. *)
            NodeId(id) -> List.map (fun (s,i) ->
                (if List.exists (fun x -> x = id) locals then ("PUSER_"^id)
                 else (gnodes ^"[ '"^ id ^"' ]")) ^".props[ '" ^ s ^"' ] = "^ string_of_int i) sil
        in

        let string_of_graph_stmt = function
            GraphSet(gel, sil) -> (match gel with
                [EdgeId(fromn,et,ton)] ->
                    let flocal = List.exists (fun x -> x = fromn) locals in
                    let tlocal = List.exists (fun x -> x = ton) locals in
                    let gfromn = if flocal then "PUSER_"^ fromn else gnodes ^"['"^ fromn ^"']" in
                    let get = if List.exists (fun x -> x = et) locals then "PUSER_"^ et
                        else "'"^ et ^"'" in
                    let gton = if tlocal then "PUSER_"^ ton else gnodes ^"['"^ ton ^"']" in
                    let etriple = gfromn ^", "^ get ^", "^ gton in
                    (* Check if both nodes exist (before checking if edge exists). Create if not existing. *)
                    (* NOTE: We have an issue if the nodes are local to the GraphDef but don't exist in the graph.
                        How we resolve this now is: create a new node in the graph with the "name" localname+nextfreeint
                            We also effectively "remember" anything about that node (its props, in/out edges). *)
                    (if flocal then
                        ("if "^ gfromn ^" not in "^ gnodes ^".copy().values():") ::
                        ("\tfint = 1") ::
                        ("\twhile '"^ fromn ^"'+ str(fint) in "^ gnodes ^":") ::
                        ("\t\tfint += 1") ::
                        ("\tif "^ gfromn ^" == None:") ::
                        ("\t\t"^ gfromn ^" = PLOG_NODE()") ::
                        ["\t"^ gnodes ^"[ '"^ fromn ^"'+ str(fint) ] = "^ gfromn]
                    else
                        ("if '"^ fromn ^"' not in "^ gnodes ^":") ::
                        ["\t"^ gfromn ^" = PLOG_NODE()"]
                    ) @
                    (if tlocal then
                        ("if "^ gton ^" not in "^ gnodes ^".copy().values():") ::
                        ("\ttint = 1") ::
                        ("\twhile '"^ ton ^"'+ str(tint) in "^ gnodes ^":") ::
                        ("\t\ttint += 1") ::
                        ("\tif "^ gton ^" == None:") ::
                        ("\t\t"^ gton ^" = PLOG_NODE()") ::
                        ["\t"^ gnodes ^"[ '"^ ton ^"'+ str(tint) ] = "^ gton]
                    else
                        ("if '"^ ton ^"' not in "^ gnodes ^":") ::
                        ["\t"^ gton ^" = PLOG_NODE()"]
                    ) @
                    (* Create the edge based on the two nodes *)
                    ("if ("^ etriple ^") not in "^ gedges ^":") ::
                    ("\t"^ gedges ^"[ ("^ etriple ^") ] = PLOG_EDGE( "^ etriple ^" )") ::
                    (* Update the nodes: Add this new edge to each node accordingly *)
                    ("\t"^ gfromn ^".out_edges.add( "^ gedges ^"[ ("^ etriple ^") ] )") ::
                    ("\t"^ gton ^".in_edges.add( "^ gedges ^"[ ("^ etriple ^") ] )")
                    ::
                    (* Now set any properties for the edge *)
                    List.map (fun (s,i) -> gedges ^"[ ("^ etriple ^") ].props[ '"^ s ^"' ] = "^ string_of_int i) sil
                | ns -> List.concat (List.map string_of_nodedec gel)
                    @
                    List.concat (List.map (string_of_nodeprop sil) ns)
            )
            | GraphDel(gel) -> (match gel with
                [EdgeId(fromn,et,ton)] ->
                    let flocal = List.exists (fun x -> x = fromn) locals in
                    let tlocal = List.exists (fun x -> x = ton) locals in
                    let gfromn = if flocal then "PUSER_"^ fromn else gnodes ^"['"^ fromn ^"']" in
                    let get = if List.exists (fun x -> x = et) locals then "PUSER_"^ et
                        else "'"^ et ^"'" in
                    let gton = if tlocal then "PUSER_"^ ton else gnodes ^"['"^ ton ^"']" in
                    let etriple = gfromn ^", "^ get ^", "^ gton in

                    (* We need to delete all references of the edge (which might exist in nodes' in/out_edges lists). *)
                    (* We assume an invariant: that if a node doesn't exist in a graph, then no edge referencing that node exists in the graph;
                        that is, it's necessary that both nodes (the "from" and "to" nodes) exist for an edge referencing both nodes to exist. *)
                    ("if "^
                        (if flocal then "PUSER_"^ fromn ^" in "^ gnodes ^".copy().values()"
                         else "'"^ fromn ^"' in "^ gnodes)
                     ^" and "^
                        (if tlocal then "PUSER_"^ ton ^" in "^ gnodes ^".copy().values():"
                         else "'"^ ton ^"' in "^ gnodes ^":")) ::
                    ("\tfor node in "^ gnodes ^".copy().values():") ::
                    ("\t\tnode.in_edges.discard( "^ gedges ^".get(("^ etriple ^"), None) )") ::
                    ("\t\tnode.out_edges.discard( "^ gedges ^".get(("^ etriple ^"), None) )") ::
                    ["\t"^ gedges ^".pop( ("^ etriple ^"), None )"]
                | ns -> let nid_to_del = function
                    NodeId(id) ->
                        if List.exists (fun x -> x = id) locals then
                            ("for nkey, val in "^ gnodes ^".copy().items():") ::
                            ("\tif val == PUSER_"^ id ^":") ::
                            (* Delete all edges (NOTE: only in this graph) referencing this node. *)
                            ("\t\tfor ekey, edge in "^ gedges ^".copy().items():") ::
                            ("\t\t\tif edge.from_node == val or edge.to_node == val:") ::

                            ("\t\t\t\tfor node in "^ gnodes ^".copy().values():") ::
                            ("\t\t\t\t\tnode.in_edges.discard( edge )") ::
                            ("\t\t\t\t\tnode.out_edges.discard( edge )") ::
                            ("\t\t\t\t"^ gedges ^".pop( ekey, None )") ::
                            
                            ("\t\tdel "^ gnodes ^"[ nkey ]") ::
                            ["\t\tbreak"] (* NOTE: We assume there's at most 1 such node in a graph. *)
                        else
                            ("if '"^ id ^"' in "^ gnodes ^":") ::
                            (* Delete all edges (NOTE: only in this graph) referencing this node. *)
                            ("\tfor ekey, edge in "^ gedges ^".copy().items():") ::
                            ("\t\tif edge.from_node == "^ gnodes ^"[ '"^ id ^"' ] or edge.to_node == "^ gnodes ^"[ '"^ id ^"' ]:") ::

                            ("\t\t\tfor node in "^ gnodes ^".copy().values():") ::
                            ("\t\t\t\tnode.in_edges.discard( edge )") ::
                            ("\t\t\t\tnode.out_edges.discard( edge )") ::
                            ("\t\t\t"^ gedges ^".pop( ekey, None )") ::

                            ["\tdel "^ gnodes ^"[ '"^ id ^"' ]"]
                    in
                    List.concat (List.map nid_to_del ns)
            )
        in

        List.concat (List.map string_of_graph_stmt gbody)
    in

    let string_of_graph_decl gdecl =
        ("PUSER_"^ gdecl.gname ^" = PLOG_GRAPH()")
        :: string_of_graph_def gdecl.gname gdecl.gbody []
    in

    let rec strings_of_local_decs = function
        [] -> []
        (* It's not necessary to initialize a list if it's explicitly initialized by the user,
           so we can (later) avoid the redundant " = []" in the generated code. *)
        | (TList(t),s) :: xs -> ("PUSER_"^ s ^" = []") :: strings_of_local_decs xs
        | (TNode,s) :: xs -> ("PUSER_"^ s ^" = PLOG_NODE()") :: strings_of_local_decs xs
        | (TEdge,s) :: xs -> ("PUSER_"^ s ^" = PLOG_EDGE()") :: strings_of_local_decs xs
        | _ :: xs -> strings_of_local_decs xs
    in

    let rec string_of_stmt = function
        Assign(e1, e2) -> [string_of_expr e1 ^" = "^ string_of_expr e2]
        (* NOTE: Check for "empty lines", which make a list (e.g. sl/bsl) appear as non-empty. *)
        | Block(tnl, sl) -> let bsl = strings_of_local_decs tnl @ List.concat (List.map string_of_stmt sl) in
            List.map addTab (if List.length bsl = 0 then ["pass"] else bsl)
        | Expr(e) -> string_of_expr e :: []
        | Return(e) -> ("return "^ string_of_expr e) :: []
        | If(e, s, Block([],[])) -> ("if "^ string_of_expr e ^":") :: string_of_stmt s
        | If(e, s1, s2) -> ( ("if "^ string_of_expr e ^":") :: string_of_stmt s1) @
            ("else:" :: string_of_stmt s2)
        | For(t, s, {pids=[];preds=[]}, e, eisgraph, st, locals) -> (match t with
            TEdge -> 
                ("for PUSER_"^ s ^" in "^ string_of_expr e ^
                (if eisgraph then ".edges.copy().values():" else ":")) :: string_of_stmt st
            | TNode -> 
                ("for PUSER_"^ s ^" in "^ string_of_expr e ^
                if eisgraph then ".nodes.copy().values():" else ":") :: string_of_stmt st
            | TNNode(tnns) ->
                ("matching_nodes = PLOG_GETTNN_"^ tnns ^"( "^ string_of_expr e ^
                    (if eisgraph then ".nodes.copy().values()" else "") ^" )") ::
                ("for PUSER_"^ s ^" in matching_nodes:") :: string_of_stmt st
            | _ -> ("for PUSER_"^ s ^" in "^ string_of_expr e ^":") :: string_of_stmt st )
        | For(t, s, p, e, eisgraph, st, locals) -> (match t with
            TEdge -> (* NOTE: We assume static semantic analysis checked there are exactly 3 identifiers in p.pids *)
                "matching_edges = set()" ::
                ("n1_nodes = "^ string_of_expr e ^
                    (if eisgraph then ".nodes.copy().values()" else "")) ::
                (if List.exists (fun x -> x = List.nth p.pids 0) locals then
                    ("n1_nodes = [ n for n in n1_nodes if n == PUSER_"^ List.nth p.pids 0 ^" ]")
                else "") ::
                "for n1 in n1_nodes:" ::
                "\tn2_nodes = [ e.from_node for e in n1.in_edges ]" ::
                (if List.exists (fun x -> x = List.nth p.pids 2) locals then
                    ("\tn2_nodes = [ n for n in n2_nodes if n == PUSER_"^ List.nth p.pids 2 ^" ]")
                else "") ::
                "\tedges = [ e for n2 in n2_nodes for e in n2.out_edges ]" ::
                (if List.length p.preds = 0 then "" else
                    "\tedges = [ n for n in edges if "^ String.concat " and " (List.map string_of_pred p.preds) ^" ]") ::
                "\tfor edge in edges:" ::
                "\t\tmatching_edges.add( edge )" ::
                ("for PUSER_"^ s ^" in matching_edges:") :: string_of_stmt st
            | TNode -> get_node_forloop s p locals e eisgraph None @ string_of_stmt st
            | TNNode(tnns) -> get_node_forloop s p locals e eisgraph (Some tnns) @ string_of_stmt st )
        | While(e, s) -> ("while "^ string_of_expr e ^":") :: string_of_stmt s
        | GraphDef(s, sl, locals) -> string_of_graph_def s sl locals
    in

    let string_of_func fdecl = "def PUSER_"^ fdecl.fname ^"( "^ String.concat ", " (List.map (fun (_,s) -> "PUSER_"^ s) fdecl.formals) ^" ):\n"^
        let fstmt_list = List.concat (List.map string_of_stmt fdecl.fbody) in
        (* flocals are already included in the func's fbody (as a Block). Blocks automatically add tabs. *)
        (*(String.concat "\n\t" (strings_of_local_decs fdecl.flocals)) ^"\n\t"^*)
        (* NOTE: "blank lines" from the fstmt_list will be output. *)
        if List.length fstmt_list = 0 then "\tpass" else String.concat "\n" fstmt_list
    in

    (String.concat "\n" prelude) ^"\n"^
    (String.concat "\n" (List.map string_of_node_decl nodes)) ^"\n"^
    (String.concat "\n" (List.concat (List.map string_of_graph_decl graphs))) ^"\n"^
    (String.concat "\n" (List.map string_of_func funcs)) ^"\n"^
    (String.concat "\n" epilogue) ^"\n"
