Skip to content

Commit 2909b97

Browse files
authored
Merge pull request #922 from DarthMax/arrow_wcc_endpoints_integration_tests
Wcc Endpoints + integration tests
2 parents 713cc5c + ced69c0 commit 2909b97

File tree

16 files changed

+507
-21
lines changed

16 files changed

+507
-21
lines changed

graphdatascience/arrow_client/v2/job_client.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,19 @@ def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]
4545
return deserialize_single(res)
4646

4747
@staticmethod
48-
def stream_results(client: AuthenticatedArrowClient, job_id: str) -> DataFrame:
49-
encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8")
48+
def stream_results(client: AuthenticatedArrowClient, graph_name: str, job_id: str) -> DataFrame:
49+
payload = {
50+
"graphName": graph_name,
51+
"jobId": job_id,
52+
}
5053

51-
res = client.do_action_with_retry("v2/results.stream", encoded_config)
54+
res = client.do_action_with_retry("v2/results.stream", json.dumps(payload).encode("utf-8"))
5255
export_job_id = JobIdConfig(**deserialize_single(res)).job_id
5356

54-
payload = {
55-
"name": export_job_id,
56-
"version": 1,
57-
}
57+
stream_payload = {"version": "v2", "name": export_job_id, "body": {}}
5858

59-
ticket = Ticket(json.dumps(payload).encode("utf-8"))
60-
with client.get_stream(ticket) as get:
61-
arrow_table = get.read_all()
59+
ticket = Ticket(json.dumps(stream_payload).encode("utf-8"))
6260

61+
get = client.get_stream(ticket)
62+
arrow_table = get.read_all()
6363
return arrow_table.to_pandas(types_mapper=ArrowDtype) # type: ignore

graphdatascience/procedure_surface/api/wcc_endpoints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def write(
246246
@abstractmethod
247247
def estimate(
248248
self,
249-
graph_name: Optional[str] = None,
249+
G: Optional[Graph] = None,
250250
projection_config: Optional[dict[str, Any]] = None,
251251
) -> EstimationResult:
252252
"""
@@ -259,8 +259,8 @@ def estimate(
259259
260260
Parameters
261261
----------
262-
graph_name : Optional[str], optional
263-
Name of the graph to be used in the estimation
262+
G : Optional[Graph], optional
263+
The graph to be used in the estimation
264264
projection_config : Optional[dict[str, Any]], optional
265265
Configuration dictionary for the projection.
266266
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import json
2+
from typing import Any, List, Optional
3+
4+
from pandas import DataFrame
5+
6+
from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
7+
from ...arrow_client.v2.data_mapper_utils import deserialize_single
8+
from ...arrow_client.v2.job_client import JobClient
9+
from ...arrow_client.v2.mutation_client import MutationClient
10+
from ...arrow_client.v2.write_back_client import WriteBackClient
11+
from ...graph.graph_object import Graph
12+
from ..api.estimation_result import EstimationResult
13+
from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult
14+
from ..utils.config_converter import ConfigConverter
15+
16+
WCC_ENDPOINT = "v2/community.wcc"
17+
18+
19+
class WccArrowEndpoints(WccEndpoints):
20+
def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[WriteBackClient] = None):
21+
self._arrow_client = arrow_client
22+
self._write_back_client = write_back_client
23+
24+
def mutate(
25+
self,
26+
G: Graph,
27+
mutate_property: str,
28+
threshold: Optional[float] = None,
29+
relationship_types: Optional[List[str]] = None,
30+
node_labels: Optional[List[str]] = None,
31+
sudo: Optional[bool] = None,
32+
log_progress: Optional[bool] = None,
33+
username: Optional[str] = None,
34+
concurrency: Optional[int] = None,
35+
job_id: Optional[str] = None,
36+
seed_property: Optional[str] = None,
37+
consecutive_ids: Optional[bool] = None,
38+
relationship_weight_property: Optional[str] = None,
39+
) -> WccMutateResult:
40+
config = ConfigConverter.convert_to_gds_config(
41+
graph_name=G.name(),
42+
concurrency=concurrency,
43+
consecutive_ids=consecutive_ids,
44+
job_id=job_id,
45+
log_progress=log_progress,
46+
node_labels=node_labels,
47+
relationship_types=relationship_types,
48+
relationship_weight_property=relationship_weight_property,
49+
seed_property=seed_property,
50+
sudo=sudo,
51+
threshold=threshold,
52+
)
53+
54+
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
55+
56+
mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property)
57+
computation_result = JobClient.get_summary(self._arrow_client, job_id)
58+
59+
computation_result["nodePropertiesWritten"] = mutate_result.node_properties_written
60+
computation_result["mutateMillis"] = 0
61+
62+
return WccMutateResult(**computation_result)
63+
64+
def stats(
65+
self,
66+
G: Graph,
67+
threshold: Optional[float] = None,
68+
relationship_types: Optional[List[str]] = None,
69+
node_labels: Optional[List[str]] = None,
70+
sudo: Optional[bool] = None,
71+
log_progress: Optional[bool] = None,
72+
username: Optional[str] = None,
73+
concurrency: Optional[int] = None,
74+
job_id: Optional[str] = None,
75+
seed_property: Optional[str] = None,
76+
consecutive_ids: Optional[bool] = None,
77+
relationship_weight_property: Optional[str] = None,
78+
) -> WccStatsResult:
79+
config = ConfigConverter.convert_to_gds_config(
80+
graph_name=G.name(),
81+
concurrency=concurrency,
82+
consecutive_ids=consecutive_ids,
83+
job_id=job_id,
84+
log_progress=log_progress,
85+
node_labels=node_labels,
86+
relationship_types=relationship_types,
87+
relationship_weight_property=relationship_weight_property,
88+
seed_property=seed_property,
89+
sudo=sudo,
90+
threshold=threshold,
91+
)
92+
93+
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
94+
computation_result = JobClient.get_summary(self._arrow_client, job_id)
95+
96+
return WccStatsResult(**computation_result)
97+
98+
def stream(
99+
self,
100+
G: Graph,
101+
min_component_size: Optional[int] = None,
102+
threshold: Optional[float] = None,
103+
relationship_types: Optional[List[str]] = None,
104+
node_labels: Optional[List[str]] = None,
105+
sudo: Optional[bool] = None,
106+
log_progress: Optional[bool] = None,
107+
username: Optional[str] = None,
108+
concurrency: Optional[int] = None,
109+
job_id: Optional[str] = None,
110+
seed_property: Optional[str] = None,
111+
consecutive_ids: Optional[bool] = None,
112+
relationship_weight_property: Optional[str] = None,
113+
) -> DataFrame:
114+
config = ConfigConverter.convert_to_gds_config(
115+
graph_name=G.name(),
116+
concurrency=concurrency,
117+
consecutive_ids=consecutive_ids,
118+
job_id=job_id,
119+
log_progress=log_progress,
120+
min_component_size=min_component_size,
121+
node_labels=node_labels,
122+
relationship_types=relationship_types,
123+
relationship_weight_property=relationship_weight_property,
124+
seed_property=seed_property,
125+
sudo=sudo,
126+
threshold=threshold,
127+
)
128+
129+
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
130+
return JobClient.stream_results(self._arrow_client, G.name(), job_id)
131+
132+
def write(
133+
self,
134+
G: Graph,
135+
write_property: str,
136+
min_component_size: Optional[int] = None,
137+
threshold: Optional[float] = None,
138+
relationship_types: Optional[List[str]] = None,
139+
node_labels: Optional[List[str]] = None,
140+
sudo: Optional[bool] = None,
141+
log_progress: Optional[bool] = None,
142+
username: Optional[str] = None,
143+
concurrency: Optional[int] = None,
144+
job_id: Optional[str] = None,
145+
seed_property: Optional[str] = None,
146+
consecutive_ids: Optional[bool] = None,
147+
relationship_weight_property: Optional[str] = None,
148+
write_concurrency: Optional[int] = None,
149+
) -> WccWriteResult:
150+
config = ConfigConverter.convert_to_gds_config(
151+
graph_name=G.name(),
152+
concurrency=concurrency,
153+
consecutive_ids=consecutive_ids,
154+
job_id=job_id,
155+
log_progress=log_progress,
156+
min_component_size=min_component_size,
157+
node_labels=node_labels,
158+
relationship_types=relationship_types,
159+
relationship_weight_property=relationship_weight_property,
160+
seed_property=seed_property,
161+
sudo=sudo,
162+
threshold=threshold,
163+
)
164+
165+
job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
166+
computation_result = JobClient.get_summary(self._arrow_client, job_id)
167+
168+
if self._write_back_client is None:
169+
raise Exception("Write back client is not initialized")
170+
171+
write_millis = self._write_back_client.write(
172+
G.name(), job_id, write_concurrency if write_concurrency is not None else concurrency
173+
)
174+
175+
computation_result["writeMillis"] = write_millis
176+
177+
return WccWriteResult(**computation_result)
178+
179+
def estimate(
180+
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
181+
) -> EstimationResult:
182+
if G is not None:
183+
payload = {"graphName": G.name()}
184+
elif projection_config is not None:
185+
payload = projection_config
186+
else:
187+
raise ValueError("Either graph_name or projection_config must be provided.")
188+
189+
res = self._arrow_client.do_action_with_retry("v2/community.wcc.estimate", json.dumps(payload).encode("utf-8"))
190+
191+
return EstimationResult(**deserialize_single(res))

graphdatascience/procedure_surface/cypher/wcc_cypher_endpoints.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import OrderedDict
12
from typing import Any, List, Optional, Union
23

34
from pandas import DataFrame
@@ -177,18 +178,20 @@ def write(
177178
return WccWriteResult(**result.to_dict())
178179

179180
def estimate(
180-
self, graph_name: Optional[str] = None, projection_config: Optional[dict[str, Any]] = None
181+
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
181182
) -> EstimationResult:
182-
config: Union[str, dict[str, Any]] = {}
183+
config: Union[dict[str, Any]] = OrderedDict()
183184

184-
if graph_name is not None:
185-
config = graph_name
185+
if G is not None:
186+
config["graphNameOrConfiguration"] = G.name()
186187
elif projection_config is not None:
187-
config = projection_config
188+
config["graphNameOrConfiguration"] = projection_config
188189
else:
189190
raise ValueError("Either graph_name or projection_config must be provided.")
190191

191-
params = CallParameters(config=config)
192+
config["algoConfig"] = {}
193+
194+
params = CallParameters(**config)
192195

193196
result = self._query_runner.call_procedure(endpoint="gds.wcc.stats.estimate", params=params).squeeze()
194197

graphdatascience/tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ def pytest_addoption(parser: Any) -> None:
1616
parser.addoption(
1717
"--include-cloud-architecture", action="store_true", help="include tests resuiring a cloud architecture setup"
1818
)
19+
parser.addoption("--include-integration-v2", action="store_true", help="include integration tests for v2")

graphdatascience/tests/integrationV2/__init__.py

Whitespace-only changes.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import Any
2+
3+
import pytest
4+
5+
6+
def pytest_collection_modifyitems(config: Any, items: Any) -> None:
7+
if not config.getoption("--include-integration-v2"):
8+
skip_v2 = pytest.mark.skip(reason="need --include-integration-v2 option to run")
9+
for item in items:
10+
item.add_marker(skip_v2)

graphdatascience/tests/integrationV2/procedure_surface/__init__.py

Whitespace-only changes.

graphdatascience/tests/integrationV2/procedure_surface/arrow/__init__.py

Whitespace-only changes.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
import tempfile
3+
from typing import Generator
4+
5+
import pytest
6+
from testcontainers.core.container import DockerContainer
7+
from testcontainers.core.waiting_utils import wait_for_logs
8+
9+
from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
10+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
11+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
12+
13+
14+
@pytest.fixture(scope="session")
15+
def password_file() -> Generator[str, None, None]:
16+
"""Create a temporary file and return its path."""
17+
temp_dir = tempfile.mkdtemp()
18+
temp_file_path = os.path.join(temp_dir, "password")
19+
20+
with open(temp_file_path, "w") as f:
21+
f.write("password")
22+
23+
yield temp_dir
24+
25+
# Clean up the file and directory
26+
os.unlink(temp_file_path)
27+
os.rmdir(temp_dir)
28+
29+
30+
@pytest.fixture(scope="session")
31+
def session_container(password_file: str) -> Generator[DockerContainer, None, None]:
32+
session_image = os.getenv("GDS_SESSION_IMAGE")
33+
34+
if session_image is None:
35+
raise ValueError("GDS_SESSION_IMAGE environment variable is not set")
36+
37+
session_container = (
38+
DockerContainer(
39+
image=session_image,
40+
)
41+
.with_env("ALLOW_LIST", "DEFAULT")
42+
.with_env("DNS_NAME", "gds-session")
43+
.with_env("PAGE_CACHE_SIZE", "100M")
44+
.with_exposed_ports(8491)
45+
.with_network_aliases(["gds-session"])
46+
.with_volume_mapping(password_file, "/passwords")
47+
)
48+
49+
with session_container as session_container:
50+
wait_for_logs(session_container, "Running GDS tasks: 0")
51+
yield session_container
52+
stdout, stderr = session_container.get_logs()
53+
print(stdout)
54+
55+
56+
@pytest.fixture
57+
def arrow_client(session_container: DockerContainer) -> AuthenticatedArrowClient:
58+
"""Create an authenticated Arrow client connected to the session container."""
59+
host = session_container.get_container_host_ip()
60+
port = session_container.get_exposed_port(8491)
61+
62+
return AuthenticatedArrowClient.create(
63+
arrow_info=ArrowInfo(f"{host}:{port}", True, True, ["v1", "v2"]),
64+
auth=UsernamePasswordAuthentication("neo4j", "password"),
65+
encrypted=False,
66+
)

0 commit comments

Comments
 (0)