Skip to content

WIP: FastPath API and example notebook #672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,307 changes: 1,307 additions & 0 deletions examples/fastpath-medical-dataset.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion graphdatascience/graph/base_graph_proc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ def writeRelationship(
).squeeze()

@multimethod
def removeNodeProperties(self) -> None: ...
def removeNodeProperties(self) -> None:
...

@removeNodeProperties.register
@graph_type_check
Expand Down
45 changes: 45 additions & 0 deletions graphdatascience/graph_data_science.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import pathlib
import sys
from typing import Any, Dict, Optional, Tuple, Type, Union

import rsa
from neo4j import Driver
from pandas import DataFrame

from .call_builder import IndirectCallBuilder
from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
from .error.uncallable_namespace import UncallableNamespace
from .model.fastpath_runner import FastPathRunner
from .query_runner.arrow_query_runner import ArrowQueryRunner
from .query_runner.neo4j_query_runner import Neo4jQueryRunner
from .query_runner.query_runner import QueryRunner
Expand All @@ -16,6 +20,7 @@


class GraphDataScience(DirectEndpoints, UncallableNamespace):

"""
Primary API class for the Neo4j Graph Data Science Python Client.
Always bind this object to a variable called `gds`.
Expand Down Expand Up @@ -81,8 +86,31 @@ def __init__(
None if arrow is True else arrow,
)

if auth is not None:
with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
pub_key = rsa.PublicKey.load_pkcs1(f.read())
self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()

self._compute_cluster_ip = None

super().__init__(self._query_runner, "gds", self._server_version)

def set_compute_cluster_ip(self, ip: str) -> None:
self._compute_cluster_ip = ip

@staticmethod
def _path(package: str, resource: str) -> pathlib.Path:
if sys.version_info >= (3, 9):
from importlib.resources import files

# files() returns a Traversable, but usages require a Path object
return pathlib.Path(str(files(package) / resource))
else:
from importlib.resources import path

# we dont want to use a context manager here, so we need to call __enter__ manually
return path(package, resource).__enter__()

@property
def graph(self) -> GraphProcRunner:
return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
Expand All @@ -95,6 +123,23 @@ def alpha(self) -> AlphaEndpoints:
def beta(self) -> BetaEndpoints:
return BetaEndpoints(self._query_runner, "gds.beta", self._server_version)

@property
def fastpath(self) -> FastPathRunner:
if not isinstance(self._query_runner, ArrowQueryRunner):
raise ValueError("Running FastPath requires GDS with the Arrow server enabled")
if self._compute_cluster_ip is None:
raise ValueError(
"You must set a valid computer cluster ip with the method `set_compute_cluster_ip` to use this feature"
)
return FastPathRunner(
self._query_runner,
"gds.fastpath",
self._server_version,
self._compute_cluster_ip,
self._encrypted_db_password,
self._query_runner.uri,
)

def __getattr__(self, attr: str) -> IndirectCallBuilder:
return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version)

Expand Down
117 changes: 117 additions & 0 deletions graphdatascience/model/fastpath_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
import os
import time
from typing import Any, Dict, Optional

import pyarrow as pa
import pyarrow.flight
import requests
from pandas import Series

from ..error.client_only_endpoint import client_only_endpoint
from ..error.illegal_attr_checker import IllegalAttrChecker
from ..error.uncallable_namespace import UncallableNamespace
from ..graph.graph_object import Graph
from ..query_runner.query_runner import QueryRunner
from ..server_version.compatible_with import compatible_with
from ..server_version.server_version import ServerVersion

logging.basicConfig(level=logging.INFO)


class FastPathRunner(UncallableNamespace, IllegalAttrChecker):
def __init__(
self,
query_runner: QueryRunner,
namespace: str,
server_version: ServerVersion,
compute_cluster_ip: str,
encrypted_db_password: str,
arrow_uri: str,
):
self._query_runner = query_runner
self._namespace = namespace
self._server_version = server_version
self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005"
self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815"
self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080"
self._encrypted_db_password = encrypted_db_password
self._arrow_uri = arrow_uri

@compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
@client_only_endpoint("gds.fastpath")
def mutate(
self,
G: Graph,
graph_filter: Optional[Dict[str, Any]] = None,
mlflow_experiment_name: Optional[str] = None,
**algo_config: Any,
) -> Series:
if graph_filter is None:
# Take full graph if no filter provided
node_filter = G.node_properties().to_dict()
rel_filter = G.relationship_properties().to_dict()
graph_filter = {"node_filter": node_filter, "rel_filter": rel_filter}

graph_config = {"name": G.name()}
graph_config.update(graph_filter)

config = {
"user_name": "DUMMY_USER",
"task": "FASTPATH",
"task_config": {
"graph_config": graph_config,
"task_config": algo_config,
"stream_node_results": True,
},
"encrypted_db_password": self._encrypted_db_password,
"graph_arrow_uri": self._arrow_uri,
}

if mlflow_experiment_name is not None:
config["task_config"]["mlflow"] = {
"config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name}
}

job_id = self._start_job(config)

self._wait_for_job(job_id)

return Series({"status": "finished"})

#return self._stream_results(job_id)

def _start_job(self, config: Dict[str, Any]) -> str:
res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config)
res.raise_for_status()
job_id = res.json()["job_id"]
logging.info(f"Job with ID '{job_id}' started")

return job_id

def _wait_for_job(self, job_id: str) -> None:
while True:
time.sleep(1)

res = requests.get(f"{self._compute_cluster_web_uri}/api/machine-learning/status/{job_id}")

res_json = res.json()
if res_json["job_status"] == "exited":
logging.info("FastPath job completed!")
return
elif res_json["job_status"] == "failed":
error = f"FastPath job failed with errors:{os.linesep}{os.linesep.join(res_json['errors'])}"
if res.status_code == 400:
raise ValueError(error)
else:
raise RuntimeError(error)

# def _stream_results(self, job_id: str) -> DataFrame:
# client = pa.flight.connect(self._compute_cluster_arrow_uri)

# upload_descriptor = pa.flight.FlightDescriptor.for_path(f"{job_id}.nodes")
# flight = client.get_flight_info(upload_descriptor)
# reader = client.do_get(flight.endpoints[0].ticket)
# read_table = reader.read_all()

# return read_table.to_pandas()
1 change: 1 addition & 0 deletions graphdatascience/query_runner/arrow_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
self._fallback_query_runner = fallback_query_runner
self._server_version = server_version
self._arrow_endpoint_version = arrow_endpoint_version
self.uri = uri

host, port_string = uri.split(":")

Expand Down
Empty file.
4 changes: 4 additions & 0 deletions graphdatascience/resources/field-testing/pub.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-----BEGIN RSA PUBLIC KEY-----
MEgCQQDNfbk2/PGneqZO6Vx9VbPe6ZnQJ/F5kOOW07jGDU34NFfUI06Nw0HmwT2h
c9s3nZTUUlAVi/aUCl3b4NcB8vThAgMBAAE=
-----END RSA PUBLIC KEY-----
1 change: 0 additions & 1 deletion graphdatascience/session/aura_api_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def from_json(cls, json: dict[str, Any]) -> TenantDetails:

# datetime.fromisoformat only works with Python version > 3.9
class TimeParser:

@staticmethod
def fromisoformat(date: str) -> datetime:
if sys.version_info >= (3, 11):
Expand Down
1 change: 0 additions & 1 deletion graphdatascience/tests/unit/test_aura_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_create_session(requests_mock: Mocker) -> None:


def test_list_session(requests_mock: Mocker) -> None:

api = AuraApi(client_id="", client_secret="", tenant_id="some-tenant")

mock_auth_token(requests_mock)
Expand Down
1 change: 1 addition & 0 deletions requirements/base/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ multimethod >= 1.0, < 2.0
neo4j >= 4.4.2, < 6.0
pandas >= 1.0, < 3.0
pyarrow >= 11.0, < 16.0
rsa >= 4.0, < 5.0
textdistance >= 4.0, < 5.0
tqdm >= 4.0, < 5.0
typing-extensions >= 4.0, < 5.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
url="https://neo4j.com/product/graph-data-science/",
classifiers=classifiers,
packages=setuptools.find_packages(),
package_data={"graphdatascience": ["py.typed", "resources/**/*.gzip"]},
package_data={"graphdatascience": ["py.typed", "resources/**/*.gzip", "resources/field-testing/pub.pem"]},
project_urls=project_urls,
python_requires=">=3.8",
install_requires=reqs,
Expand Down