diff --git a/docs/source/conf.py b/docs/source/conf.py index f0a8c41e..765ea786 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,7 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath(os.path.join("..", ".."))) +sys.path.insert(0, os.path.abspath(os.path.join("..", "..", "src"))) from neo4j import __version__ as project_version @@ -345,6 +345,9 @@ def setup(app): intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "dateutil": ("https://dateutil.readthedocs.io/en/stable/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), + "pyarrow": ("https://arrow.apache.org/docs/", None), } autodoc_default_options = { diff --git a/docs/source/index.rst b/docs/source/index.rst index 6480c91f..b59c6edf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,8 +9,8 @@ Bolt protocol versions supported: .. # [bolt-version-bump] search tag when changing bolt version support +* Bolt 6.0 * Bolt 5.0 - 5.8 -* Bolt 4.4 See https://7687.org/bolt-compatibility/ for what Neo4j DBMS versions support which Bolt versions. See https://neo4j.com/developer/kb/neo4j-supported-versions/ for a driver-server compatibility matrix. @@ -36,6 +36,8 @@ Topics + :ref:`temporal-data-types` ++ :ref:`vector-data-types` + + :ref:`breaking-changes` @@ -47,6 +49,7 @@ Topics async_api.rst types/spatial.rst types/temporal.rst + types/vector.rst breaking_changes.rst diff --git a/docs/source/types/vector.rst b/docs/source/types/vector.rst new file mode 100644 index 00000000..7673e6fb --- /dev/null +++ b/docs/source/types/vector.rst @@ -0,0 +1,18 @@ +.. _vector-data-types: + +***************** +Vector Data Types +***************** + +.. autoclass:: neo4j.vector.Vector + :members: + + +.. autoclass:: neo4j.vector.VectorEndian + :show-inheritance: + :members: + + +.. autoclass:: neo4j.vector.VectorDType + :show-inheritance: + :members: diff --git a/pyproject.toml b/pyproject.toml index c02b395c..3bb5055d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,12 +54,12 @@ Forum = "https://community.neo4j.com/c/drivers-stacks/python/" Discord = "https://discord.com/invite/neo4j" [project.optional-dependencies] -numpy = ["numpy >= 1.7.0, < 3.0.0"] +numpy = ["numpy >= 1.21.2, < 3.0.0"] pandas = [ "pandas >= 1.1.0, < 3.0.0", - "numpy >= 1.7.0, < 3.0.0", + "numpy >= 1.21.2, < 3.0.0", ] -pyarrow = ["pyarrow >= 1.0.0"] +pyarrow = ["pyarrow >= 6.0.0, < 21.0.0"] [build-system] @@ -207,8 +207,9 @@ asyncio_default_fixture_loop_scope="function" [[tool.mypy.overrides]] module = [ "pandas.*", - "neo4j._codec.packstream._rust", - "neo4j._codec.packstream._rust.*", + "pyarrow.*", + "neo4j._rust", + "neo4j._rust.*", ] ignore_missing_imports = true diff --git a/src/neo4j/_async/io/__init__.py b/src/neo4j/_async/io/__init__.py index 7c950834..7409d3d5 100644 --- a/src/neo4j/_async/io/__init__.py +++ b/src/neo4j/_async/io/__init__.py @@ -36,6 +36,7 @@ _bolt3, _bolt4, _bolt5, + _bolt6, ) from ._bolt import AsyncBolt from ._common import ConnectionErrorHandler diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index f2e6aa55..e79dae5a 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -286,7 +286,6 @@ def __init_subclass__(cls: type[t.Self], **kwargs: t.Any) -> None: cls.protocol_handlers[protocol_version] = cls super().__init_subclass__(**kwargs) - # [bolt-version-bump] search tag when changing bolt version support @classmethod def get_handshake(cls) -> bytes: """ diff --git a/src/neo4j/_async/io/_bolt6.py b/src/neo4j/_async/io/_bolt6.py new file mode 100644 index 00000000..311c954e --- /dev/null +++ b/src/neo4j/_async/io/_bolt6.py @@ -0,0 +1,615 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from logging import getLogger +from ssl import SSLSocket + +from ... import _typing as t +from ..._api import TelemetryAPI +from ..._async_compat.util import AsyncUtil +from ..._codec.hydration import v3 as hydration_v3 +from ..._exceptions import BoltProtocolError +from ..._io import BoltProtocolVersion +from ..._meta import BOLT_AGENT_DICT +from ...api import READ_ACCESS +from ...exceptions import ( + DatabaseUnavailable, + ForbiddenOnReadOnlyDatabase, + Neo4jError, + NotALeader, + ServiceUnavailable, +) +from ._bolt import ( + AsyncBolt, + ClientStateManagerBase, + ServerStateManagerBase, + tx_timeout_as_ms, +) +from ._bolt5 import ( + BoltStates5x1, + ClientStateManager5x1, + ServerStateManager5x1, +) +from ._common import ( + CommitResponse, + InitResponse, + LogonResponse, + ResetResponse, + Response, +) + + +log = getLogger("neo4j.io") + + +class AsyncBolt6x0(AsyncBolt): + """Protocol handler for Bolt 6.0.""" + + PROTOCOL_VERSION = BoltProtocolVersion(6, 0) + + HYDRATION_HANDLER_CLS = hydration_v3.HydrationHandler + + supports_multiple_results = True + + supports_multiple_databases = True + + supports_re_auth = True + + supports_notification_filtering = True + + bolt_states: t.Any = BoltStates5x1 + + DEFAULT_ERROR_DIAGNOSTIC_RECORD = DEFAULT_STATUS_DIAGNOSTIC_RECORD = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager5x1( + BoltStates5x1.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager5x1( + BoltStates5x1.CONNECTED, on_change=self._on_client_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug( + "[#%04X] _: server state: %s > %s", + self.local_port, + old_state.name, + new_state.name, + ) + + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + + def _on_client_state_change(self, old_state, new_state): + log.debug( + "[#%04X] _: client state: %s > %s", + self.local_port, + old_state.name, + new_state.name, + ) + + def _get_client_state_manager(self) -> ClientStateManagerBase: + return self._client_state_manager + + @property + def ssr_enabled(self) -> bool: + return self.connection_hints.get("ssr.enabled", False) is True + + @property + def is_reset(self): + # We can't be sure of the server's state if there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + if self.responses: + return self.responses[-1] and self.responses[-1].message == "reset" + return self._server_state_manager.state == self.bolt_states.READY + + @property + def encrypted(self): + return isinstance(self.socket, SSLSocket) + + @property + def der_encoded_server_certificate(self): + return self.socket.getpeercert(binary_form=True) + + def get_base_headers(self): + headers = {"user_agent": self.user_agent} + if self.routing_context is not None: + headers["routing"] = self.routing_context + if self.notifications_min_severity is not None: + headers["notifications_minimum_severity"] = ( + self.notifications_min_severity + ) + if self.notifications_disabled_classifications is not None: + headers["notifications_disabled_classifications"] = ( + self.notifications_disabled_classifications + ) + headers["bolt_agent"] = BOLT_AGENT_DICT + return headers + + async def hello(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + + def on_success(metadata): + self.connection_hints.update(metadata.pop("hints", {})) + self.server_info.update(metadata) + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ + "connection.recv_timeout_seconds" + ] + if isinstance(recv_timeout, int) and recv_timeout > 0: + self.socket.settimeout(recv_timeout) + else: + log.info( + "[#%04X] _: Server supplied an " + "invalid value for " + "connection.recv_timeout_seconds (%r). Make sure " + "the server and network is set up correctly.", + self.local_port, + recv_timeout, + ) + + extra = self.get_base_headers() + log.debug("[#%04X] C: HELLO %r", self.local_port, extra) + self._append( + b"\x01", + (extra,), + response=InitResponse( + self, "hello", hydration_hooks, on_success=on_success + ), + dehydration_hooks=dehydration_hooks, + ) + + self.logon(dehydration_hooks, hydration_hooks) + await self.send_all() + await self.fetch_all() + + def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + logged_auth_dict = dict(self.auth_dict) + if "credentials" in logged_auth_dict: + logged_auth_dict["credentials"] = "*******" + log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) + self._append( + b"\x6a", + (self.auth_dict,), + response=LogonResponse(self, "logon", hydration_hooks), + dehydration_hooks=dehydration_hooks, + ) + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: LOGOFF", self.local_port) + self._append( + b"\x6b", + response=LogonResponse(self, "logoff", hydration_hooks), + dehydration_hooks=dehydration_hooks, + ) + + def telemetry( + self, + api: TelemetryAPI, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ) -> None: + if self.telemetry_disabled or not self.connection_hints.get( + "telemetry.enabled", False + ): + return + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + api_raw = int(api) + log.debug( + "[#%04X] C: TELEMETRY %i # (%r)", self.local_port, api_raw, api + ) + self._append( + b"\x54", + (api_raw,), + Response(self, "telemetry", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + async def route( + self, + database=None, + imp_user=None, + bookmarks=None, + dehydration_hooks=None, + hydration_hooks=None, + ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + routing_context = self.routing_context or {} + db_context = {} + if database is not None: + db_context.update(db=database) + if imp_user is not None: + db_context.update(imp_user=imp_user) + log.debug( + "[#%04X] C: ROUTE %r %r %r", + self.local_port, + routing_context, + bookmarks, + db_context, + ) + metadata = {} + bookmarks = [] if bookmarks is None else list(bookmarks) + self._append( + b"\x66", + (routing_context, bookmarks, db_context), + response=Response( + self, "route", hydration_hooks, on_success=metadata.update + ), + dehydration_hooks=dehydration_hooks, + ) + await self.send_all() + await self.fetch_all() + return [metadata.get("rt")] + + def run( + self, + query, + parameters=None, + mode=None, + bookmarks=None, + metadata=None, + timeout=None, + db=None, + imp_user=None, + notifications_min_severity=None, + notifications_disabled_classifications=None, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + if not parameters: + parameters = {} + extra = {} + if mode in {READ_ACCESS, "r"}: + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if ( + self._client_state_manager.state + != self.bolt_states.TX_READY_OR_TX_STREAMING + ): + self.last_database = db + if imp_user: + extra["imp_user"] = imp_user + if notifications_min_severity is not None: + extra["notifications_minimum_severity"] = ( + notifications_min_severity + ) + if notifications_disabled_classifications is not None: + extra["notifications_disabled_classifications"] = ( + notifications_disabled_classifications + ) + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError( + "Bookmarks must be provided as iterable" + ) from None + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError( + "Metadata must be coercible to a dict" + ) from None + if timeout is not None: + extra["tx_timeout"] = tx_timeout_as_ms(timeout) + fields = (query, parameters, extra) + log.debug( + "[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)) + ) + self._append( + b"\x10", + fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def discard( + self, + n=-1, + qid=-1, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ): + handlers["on_success"] = self._make_enrich_statuses_handler( + wrapped_handler=handlers.get("on_success") + ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) + self._append( + b"\x2f", + (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def pull( + self, + n=-1, + qid=-1, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ): + handlers["on_success"] = self._make_enrich_statuses_handler( + wrapped_handler=handlers.get("on_success") + ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: PULL %r", self.local_port, extra) + self._append( + b"\x3f", + (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def begin( + self, + mode=None, + bookmarks=None, + metadata=None, + timeout=None, + db=None, + imp_user=None, + notifications_min_severity=None, + notifications_disabled_classifications=None, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + extra = {} + if mode in {READ_ACCESS, "r"}: + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + self.last_database = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError( + "Bookmarks must be provided as iterable" + ) from None + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError( + "Metadata must be coercible to a dict" + ) from None + if timeout is not None: + extra["tx_timeout"] = tx_timeout_as_ms(timeout) + if notifications_min_severity is not None: + extra["notifications_minimum_severity"] = ( + notifications_min_severity + ) + if notifications_disabled_classifications is not None: + extra["notifications_disabled_classifications"] = ( + notifications_disabled_classifications + ) + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append( + b"\x11", + (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: COMMIT", self.local_port) + self._append( + b"\x12", + (), + CommitResponse(self, "commit", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def rollback( + self, dehydration_hooks=None, hydration_hooks=None, **handlers + ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: ROLLBACK", self.local_port) + self._append( + b"\x13", + (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + async def reset(self, dehydration_hooks=None, hydration_hooks=None): + """ + Reset the connection. + + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: RESET", self.local_port) + response = ResetResponse(self, "reset", hydration_hooks) + self._append( + b"\x0f", response=response, dehydration_hooks=dehydration_hooks + ) + await self.send_all() + await self.fetch_all() + + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: GOODBYE", self.local_port) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) + + async def _process_message(self, tag, fields): + """Process at most one message from the server, if available. + + :returns: 2-tuple of number of detail messages and number of summary + messages fetched + """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + + if details: + # Do not log any data + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) + await self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug( + "[#%04X] S: SUCCESS %r", self.local_port, summary_metadata + ) + self._server_state_manager.transition( + response.message, summary_metadata + ) + await response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7e": + log.debug("[#%04X] S: IGNORED", self.local_port) + await response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7f": + log.debug( + "[#%04X] S: FAILURE %r", self.local_port, summary_metadata + ) + self._server_state_manager.state = self.bolt_states.FAILED + self._enrich_error_diagnostic_record(summary_metadata) + try: + await response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + await self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + await self.pool.on_write_failure( + address=self.unresolved_address, + database=self.last_database, + ) + raise + except Neo4jError as e: + if self.pool: + await self.pool.on_neo4j_error(e, self) + raise + else: + sig_int = ord(summary_signature) + raise BoltProtocolError( + f"Unexpected response message with signature {sig_int:02X}", + self.unresolved_address, + ) + + return len(details), 1 + + def _enrich_error_diagnostic_record(self, metadata): + if not isinstance(metadata, dict): + return + diag_record = metadata.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info( + "[#%04X] _: Server supplied an " + "invalid error diagnostic record (%r).", + self.local_port, + diag_record, + ) + else: + for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + self._enrich_error_diagnostic_record(metadata.get("cause")) + + def _make_enrich_statuses_handler(self, wrapped_handler=None): + async def handler(metadata): + def enrich(metadata_): + if not isinstance(metadata_, dict): + return + statuses = metadata_.get("statuses") + if not isinstance(statuses, list): + return + for status in statuses: + if not isinstance(status, dict): + continue + diag_record = status.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info( + "[#%04X] _: Server supplied an " + "invalid status diagnostic record (%r).", + self.local_port, + diag_record, + ) + continue + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + + enrich(metadata) + await AsyncUtil.callback(wrapped_handler, metadata) + + return handler diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index fe12b1ad..062a13c3 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -66,6 +66,11 @@ from ...graph import Graph +if False: + # Ugly work-around to make sphinx understand `@_t.overload` + import typing as t # type: ignore[no-redef] + + notification_log = getLogger("neo4j.notifications") diff --git a/src/neo4j/_codec/hydration/_common.py b/src/neo4j/_codec/hydration/_common.py index 633357aa..5c0f6858 100644 --- a/src/neo4j/_codec/hydration/_common.py +++ b/src/neo4j/_codec/hydration/_common.py @@ -46,7 +46,7 @@ def get_transformer(self, item): transformer = self.exact_types.get(type_) if transformer is not None: return transformer - transformer = next( + return next( ( f for super_type, f in self.subtypes.items() @@ -54,9 +54,6 @@ def get_transformer(self, item): ), None, ) - if transformer is not None: - return transformer - return None class BrokenHydrationObject: diff --git a/src/neo4j/_codec/hydration/v1/hydration_handler.py b/src/neo4j/_codec/hydration/v1/hydration_handler.py index 9863a676..fd4a2dad 100644 --- a/src/neo4j/_codec/hydration/v1/hydration_handler.py +++ b/src/neo4j/_codec/hydration/v1/hydration_handler.py @@ -41,6 +41,7 @@ Duration, Time, ) +from ....vector import Vector from .._common import ( GraphHydrator, HydrationScope, @@ -49,6 +50,7 @@ from . import ( spatial, temporal, + vector, ) @@ -177,6 +179,7 @@ def __init__(self): datetime: temporal.dehydrate_datetime, Duration: temporal.dehydrate_duration, timedelta: temporal.dehydrate_timedelta, + Vector: vector.dehydrate_vector, } ) if np is not None: diff --git a/src/neo4j/_codec/hydration/v1/vector.py b/src/neo4j/_codec/hydration/v1/vector.py new file mode 100644 index 00000000..a580e8ff --- /dev/null +++ b/src/neo4j/_codec/hydration/v1/vector.py @@ -0,0 +1,25 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...._io import BoltProtocolVersion +from ....exceptions import ConfigurationError + + +def dehydrate_vector(_): + raise ConfigurationError( + "Vector types require at least Bolt " + f"Protocol {BoltProtocolVersion(6, 0)}." + ) diff --git a/src/neo4j/_codec/hydration/v2/hydration_handler.py b/src/neo4j/_codec/hydration/v2/hydration_handler.py index 4ea46def..17097643 100644 --- a/src/neo4j/_codec/hydration/v2/hydration_handler.py +++ b/src/neo4j/_codec/hydration/v2/hydration_handler.py @@ -36,11 +36,13 @@ Duration, Time, ) +from ....vector import Vector from .._common import HydrationScope from .._interface import HydrationHandlerABC from ..v1 import ( spatial, temporal as temporal_v1, + vector, ) from ..v1.hydration_handler import _GraphHydrator from . import temporal as temporal_v2 @@ -75,6 +77,7 @@ def __init__(self): datetime: temporal_v2.dehydrate_datetime, Duration: temporal_v1.dehydrate_duration, timedelta: temporal_v1.dehydrate_timedelta, + Vector: vector.dehydrate_vector, } ) if np is not None: diff --git a/src/neo4j/_codec/hydration/v3/__init__.py b/src/neo4j/_codec/hydration/v3/__init__.py new file mode 100644 index 00000000..e0307550 --- /dev/null +++ b/src/neo4j/_codec/hydration/v3/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hydration_handler import HydrationHandler + + +__all__ = [ + "HydrationHandler", +] diff --git a/src/neo4j/_codec/hydration/v3/hydration_handler.py b/src/neo4j/_codec/hydration/v3/hydration_handler.py new file mode 100644 index 00000000..169acecb --- /dev/null +++ b/src/neo4j/_codec/hydration/v3/hydration_handler.py @@ -0,0 +1,102 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import ( + date, + datetime, + time, + timedelta, +) + +from ...._optional_deps import ( + np, + pd, +) +from ....spatial import ( + CartesianPoint, + Point, + WGS84Point, +) +from ....time import ( + Date, + DateTime, + Duration, + Time, +) +from ....vector import Vector +from .._common import HydrationScope +from .._interface import HydrationHandlerABC +from ..v1 import ( + spatial, + temporal as temporal_v1, +) +from ..v1.hydration_handler import _GraphHydrator +from ..v2 import temporal as temporal_v2 +from . import vector + + +class HydrationHandler(HydrationHandlerABC): # type: ignore[no-redef] + def __init__(self): + super().__init__() + self._created_scope = False + self.struct_hydration_functions = { + **self.struct_hydration_functions, + b"X": spatial.hydrate_point, + b"Y": spatial.hydrate_point, + b"D": temporal_v1.hydrate_date, + b"T": temporal_v1.hydrate_time, # time zone offset + b"t": temporal_v1.hydrate_time, # no time zone + b"I": temporal_v2.hydrate_datetime, # time zone offset + b"i": temporal_v2.hydrate_datetime, # time zone name + b"d": temporal_v2.hydrate_datetime, # no time zone + b"E": temporal_v1.hydrate_duration, + b"V": vector.hydrate_vector, + } + self.dehydration_hooks.update( + exact_types={ + Point: spatial.dehydrate_point, + CartesianPoint: spatial.dehydrate_point, + WGS84Point: spatial.dehydrate_point, + Date: temporal_v1.dehydrate_date, + date: temporal_v1.dehydrate_date, + Time: temporal_v1.dehydrate_time, + time: temporal_v1.dehydrate_time, + DateTime: temporal_v2.dehydrate_datetime, + datetime: temporal_v2.dehydrate_datetime, + Duration: temporal_v1.dehydrate_duration, + timedelta: temporal_v1.dehydrate_timedelta, + Vector: vector.dehydrate_vector, + } + ) + if np is not None: + self.dehydration_hooks.update( + exact_types={ + np.datetime64: temporal_v1.dehydrate_np_datetime, + np.timedelta64: temporal_v1.dehydrate_np_timedelta, + } + ) + if pd is not None: + self.dehydration_hooks.update( + exact_types={ + pd.Timestamp: temporal_v2.dehydrate_pandas_datetime, + pd.Timedelta: temporal_v1.dehydrate_pandas_timedelta, + type(pd.NaT): lambda _: None, + } + ) + + def new_hydration_scope(self): + self._created_scope = True + return HydrationScope(self, _GraphHydrator()) diff --git a/src/neo4j/_codec/hydration/v3/vector.py b/src/neo4j/_codec/hydration/v3/vector.py new file mode 100644 index 00000000..338ae3fe --- /dev/null +++ b/src/neo4j/_codec/hydration/v3/vector.py @@ -0,0 +1,56 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ....vector import ( + Vector, + VectorDType, +) +from ...packstream import Structure + + +_DTYPE_LOOKUP = { + b"\xc8": VectorDType.I8, + b"\xc9": VectorDType.I16, + b"\xca": VectorDType.I32, + b"\xcb": VectorDType.I64, + b"\xc6": VectorDType.F32, + b"\xc1": VectorDType.F64, +} + +_TYP_LOOKUP = {v: k for k, v in _DTYPE_LOOKUP.items()} + + +def hydrate_vector(typ: bytes, data: bytes) -> Vector: + """ + Hydrator for `Vector` values. + + :param typ: bytes marking the inner type of the vector + :param data: big-endian bytes representing the vector data + :returns: Vector + """ + dtype = _DTYPE_LOOKUP[typ] + return Vector(data, dtype) + + +def dehydrate_vector(value: Vector) -> Structure: + """ + Dehydrator for `Vector` values. + + :param value: + :type value: Vector + :returns: + """ + return Structure(b"V", _TYP_LOOKUP[value.dtype], value.raw()) diff --git a/src/neo4j/_codec/packstream/_common.py b/src/neo4j/_codec/packstream/_common.py index 38241ebd..4bab66ec 100644 --- a/src/neo4j/_codec/packstream/_common.py +++ b/src/neo4j/_codec/packstream/_common.py @@ -15,7 +15,7 @@ try: - from ._rust import Structure + from ..._rust.codec.packstream import Structure RUST_AVAILABLE = True except ImportError: diff --git a/src/neo4j/_codec/packstream/v1/__init__.py b/src/neo4j/_codec/packstream/v1/__init__.py index 0e71a3fa..ad85e2a8 100644 --- a/src/neo4j/_codec/packstream/v1/__init__.py +++ b/src/neo4j/_codec/packstream/v1/__init__.py @@ -36,7 +36,7 @@ try: - from .._rust.v1 import ( + from ...._rust.codec.packstream.v1 import ( pack as _rust_pack, unpack as _rust_unpack, ) diff --git a/src/neo4j/_optional_deps/__init__.py b/src/neo4j/_optional_deps/__init__.py index f16f0e59..2002bf65 100644 --- a/src/neo4j/_optional_deps/__init__.py +++ b/src/neo4j/_optional_deps/__init__.py @@ -31,8 +31,10 @@ with suppress(ImportError): import pandas as pd # type: ignore[no-redef] +pa: t.Any = None -__all__ = [ - "np", - "pd", -] +with suppress(ImportError): + import pyarrow as pa # type: ignore[no-redef] + + +__all__ = ["np", "pa", "pd"] diff --git a/src/neo4j/_sync/io/__init__.py b/src/neo4j/_sync/io/__init__.py index 775fd504..62e6dfee 100644 --- a/src/neo4j/_sync/io/__init__.py +++ b/src/neo4j/_sync/io/__init__.py @@ -36,6 +36,7 @@ _bolt3, _bolt4, _bolt5, + _bolt6, ) from ._bolt import Bolt from ._common import ConnectionErrorHandler diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 7b5b5a86..9a146baa 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -286,7 +286,6 @@ def __init_subclass__(cls: type[t.Self], **kwargs: t.Any) -> None: cls.protocol_handlers[protocol_version] = cls super().__init_subclass__(**kwargs) - # [bolt-version-bump] search tag when changing bolt version support @classmethod def get_handshake(cls) -> bytes: """ diff --git a/src/neo4j/_sync/io/_bolt6.py b/src/neo4j/_sync/io/_bolt6.py new file mode 100644 index 00000000..dc7c213b --- /dev/null +++ b/src/neo4j/_sync/io/_bolt6.py @@ -0,0 +1,615 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from logging import getLogger +from ssl import SSLSocket + +from ... import _typing as t +from ..._api import TelemetryAPI +from ..._async_compat.util import Util +from ..._codec.hydration import v3 as hydration_v3 +from ..._exceptions import BoltProtocolError +from ..._io import BoltProtocolVersion +from ..._meta import BOLT_AGENT_DICT +from ...api import READ_ACCESS +from ...exceptions import ( + DatabaseUnavailable, + ForbiddenOnReadOnlyDatabase, + Neo4jError, + NotALeader, + ServiceUnavailable, +) +from ._bolt import ( + Bolt, + ClientStateManagerBase, + ServerStateManagerBase, + tx_timeout_as_ms, +) +from ._bolt5 import ( + BoltStates5x1, + ClientStateManager5x1, + ServerStateManager5x1, +) +from ._common import ( + CommitResponse, + InitResponse, + LogonResponse, + ResetResponse, + Response, +) + + +log = getLogger("neo4j.io") + + +class Bolt6x0(Bolt): + """Protocol handler for Bolt 6.0.""" + + PROTOCOL_VERSION = BoltProtocolVersion(6, 0) + + HYDRATION_HANDLER_CLS = hydration_v3.HydrationHandler + + supports_multiple_results = True + + supports_multiple_databases = True + + supports_re_auth = True + + supports_notification_filtering = True + + bolt_states: t.Any = BoltStates5x1 + + DEFAULT_ERROR_DIAGNOSTIC_RECORD = DEFAULT_STATUS_DIAGNOSTIC_RECORD = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager5x1( + BoltStates5x1.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager5x1( + BoltStates5x1.CONNECTED, on_change=self._on_client_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug( + "[#%04X] _: server state: %s > %s", + self.local_port, + old_state.name, + new_state.name, + ) + + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + + def _on_client_state_change(self, old_state, new_state): + log.debug( + "[#%04X] _: client state: %s > %s", + self.local_port, + old_state.name, + new_state.name, + ) + + def _get_client_state_manager(self) -> ClientStateManagerBase: + return self._client_state_manager + + @property + def ssr_enabled(self) -> bool: + return self.connection_hints.get("ssr.enabled", False) is True + + @property + def is_reset(self): + # We can't be sure of the server's state if there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + if self.responses: + return self.responses[-1] and self.responses[-1].message == "reset" + return self._server_state_manager.state == self.bolt_states.READY + + @property + def encrypted(self): + return isinstance(self.socket, SSLSocket) + + @property + def der_encoded_server_certificate(self): + return self.socket.getpeercert(binary_form=True) + + def get_base_headers(self): + headers = {"user_agent": self.user_agent} + if self.routing_context is not None: + headers["routing"] = self.routing_context + if self.notifications_min_severity is not None: + headers["notifications_minimum_severity"] = ( + self.notifications_min_severity + ) + if self.notifications_disabled_classifications is not None: + headers["notifications_disabled_classifications"] = ( + self.notifications_disabled_classifications + ) + headers["bolt_agent"] = BOLT_AGENT_DICT + return headers + + def hello(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + + def on_success(metadata): + self.connection_hints.update(metadata.pop("hints", {})) + self.server_info.update(metadata) + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ + "connection.recv_timeout_seconds" + ] + if isinstance(recv_timeout, int) and recv_timeout > 0: + self.socket.settimeout(recv_timeout) + else: + log.info( + "[#%04X] _: Server supplied an " + "invalid value for " + "connection.recv_timeout_seconds (%r). Make sure " + "the server and network is set up correctly.", + self.local_port, + recv_timeout, + ) + + extra = self.get_base_headers() + log.debug("[#%04X] C: HELLO %r", self.local_port, extra) + self._append( + b"\x01", + (extra,), + response=InitResponse( + self, "hello", hydration_hooks, on_success=on_success + ), + dehydration_hooks=dehydration_hooks, + ) + + self.logon(dehydration_hooks, hydration_hooks) + self.send_all() + self.fetch_all() + + def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + logged_auth_dict = dict(self.auth_dict) + if "credentials" in logged_auth_dict: + logged_auth_dict["credentials"] = "*******" + log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) + self._append( + b"\x6a", + (self.auth_dict,), + response=LogonResponse(self, "logon", hydration_hooks), + dehydration_hooks=dehydration_hooks, + ) + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: LOGOFF", self.local_port) + self._append( + b"\x6b", + response=LogonResponse(self, "logoff", hydration_hooks), + dehydration_hooks=dehydration_hooks, + ) + + def telemetry( + self, + api: TelemetryAPI, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ) -> None: + if self.telemetry_disabled or not self.connection_hints.get( + "telemetry.enabled", False + ): + return + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + api_raw = int(api) + log.debug( + "[#%04X] C: TELEMETRY %i # (%r)", self.local_port, api_raw, api + ) + self._append( + b"\x54", + (api_raw,), + Response(self, "telemetry", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def route( + self, + database=None, + imp_user=None, + bookmarks=None, + dehydration_hooks=None, + hydration_hooks=None, + ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + routing_context = self.routing_context or {} + db_context = {} + if database is not None: + db_context.update(db=database) + if imp_user is not None: + db_context.update(imp_user=imp_user) + log.debug( + "[#%04X] C: ROUTE %r %r %r", + self.local_port, + routing_context, + bookmarks, + db_context, + ) + metadata = {} + bookmarks = [] if bookmarks is None else list(bookmarks) + self._append( + b"\x66", + (routing_context, bookmarks, db_context), + response=Response( + self, "route", hydration_hooks, on_success=metadata.update + ), + dehydration_hooks=dehydration_hooks, + ) + self.send_all() + self.fetch_all() + return [metadata.get("rt")] + + def run( + self, + query, + parameters=None, + mode=None, + bookmarks=None, + metadata=None, + timeout=None, + db=None, + imp_user=None, + notifications_min_severity=None, + notifications_disabled_classifications=None, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + if not parameters: + parameters = {} + extra = {} + if mode in {READ_ACCESS, "r"}: + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if ( + self._client_state_manager.state + != self.bolt_states.TX_READY_OR_TX_STREAMING + ): + self.last_database = db + if imp_user: + extra["imp_user"] = imp_user + if notifications_min_severity is not None: + extra["notifications_minimum_severity"] = ( + notifications_min_severity + ) + if notifications_disabled_classifications is not None: + extra["notifications_disabled_classifications"] = ( + notifications_disabled_classifications + ) + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError( + "Bookmarks must be provided as iterable" + ) from None + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError( + "Metadata must be coercible to a dict" + ) from None + if timeout is not None: + extra["tx_timeout"] = tx_timeout_as_ms(timeout) + fields = (query, parameters, extra) + log.debug( + "[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)) + ) + self._append( + b"\x10", + fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def discard( + self, + n=-1, + qid=-1, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ): + handlers["on_success"] = self._make_enrich_statuses_handler( + wrapped_handler=handlers.get("on_success") + ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) + self._append( + b"\x2f", + (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def pull( + self, + n=-1, + qid=-1, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ): + handlers["on_success"] = self._make_enrich_statuses_handler( + wrapped_handler=handlers.get("on_success") + ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: PULL %r", self.local_port, extra) + self._append( + b"\x3f", + (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def begin( + self, + mode=None, + bookmarks=None, + metadata=None, + timeout=None, + db=None, + imp_user=None, + notifications_min_severity=None, + notifications_disabled_classifications=None, + dehydration_hooks=None, + hydration_hooks=None, + **handlers, + ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + extra = {} + if mode in {READ_ACCESS, "r"}: + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + self.last_database = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError( + "Bookmarks must be provided as iterable" + ) from None + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError( + "Metadata must be coercible to a dict" + ) from None + if timeout is not None: + extra["tx_timeout"] = tx_timeout_as_ms(timeout) + if notifications_min_severity is not None: + extra["notifications_minimum_severity"] = ( + notifications_min_severity + ) + if notifications_disabled_classifications is not None: + extra["notifications_disabled_classifications"] = ( + notifications_disabled_classifications + ) + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append( + b"\x11", + (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: COMMIT", self.local_port) + self._append( + b"\x12", + (), + CommitResponse(self, "commit", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def rollback( + self, dehydration_hooks=None, hydration_hooks=None, **handlers + ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: ROLLBACK", self.local_port) + self._append( + b"\x13", + (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks, + ) + + def reset(self, dehydration_hooks=None, hydration_hooks=None): + """ + Reset the connection. + + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: RESET", self.local_port) + response = ResetResponse(self, "reset", hydration_hooks) + self._append( + b"\x0f", response=response, dehydration_hooks=dehydration_hooks + ) + self.send_all() + self.fetch_all() + + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + log.debug("[#%04X] C: GOODBYE", self.local_port) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) + + def _process_message(self, tag, fields): + """Process at most one message from the server, if available. + + :returns: 2-tuple of number of detail messages and number of summary + messages fetched + """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + + if details: + # Do not log any data + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) + self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug( + "[#%04X] S: SUCCESS %r", self.local_port, summary_metadata + ) + self._server_state_manager.transition( + response.message, summary_metadata + ) + response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7e": + log.debug("[#%04X] S: IGNORED", self.local_port) + response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7f": + log.debug( + "[#%04X] S: FAILURE %r", self.local_port, summary_metadata + ) + self._server_state_manager.state = self.bolt_states.FAILED + self._enrich_error_diagnostic_record(summary_metadata) + try: + response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + self.pool.on_write_failure( + address=self.unresolved_address, + database=self.last_database, + ) + raise + except Neo4jError as e: + if self.pool: + self.pool.on_neo4j_error(e, self) + raise + else: + sig_int = ord(summary_signature) + raise BoltProtocolError( + f"Unexpected response message with signature {sig_int:02X}", + self.unresolved_address, + ) + + return len(details), 1 + + def _enrich_error_diagnostic_record(self, metadata): + if not isinstance(metadata, dict): + return + diag_record = metadata.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info( + "[#%04X] _: Server supplied an " + "invalid error diagnostic record (%r).", + self.local_port, + diag_record, + ) + else: + for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + self._enrich_error_diagnostic_record(metadata.get("cause")) + + def _make_enrich_statuses_handler(self, wrapped_handler=None): + def handler(metadata): + def enrich(metadata_): + if not isinstance(metadata_, dict): + return + statuses = metadata_.get("statuses") + if not isinstance(statuses, list): + return + for status in statuses: + if not isinstance(status, dict): + continue + diag_record = status.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info( + "[#%04X] _: Server supplied an " + "invalid status diagnostic record (%r).", + self.local_port, + diag_record, + ) + continue + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + + enrich(metadata) + Util.callback(wrapped_handler, metadata) + + return handler diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index fbb3938a..1b8522ca 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -66,6 +66,11 @@ from ...graph import Graph +if False: + # Ugly work-around to make sphinx understand `@_t.overload` + import typing as t # type: ignore[no-redef] + + notification_log = getLogger("neo4j.notifications") diff --git a/src/neo4j/_typing.py b/src/neo4j/_typing.py index d09cb60c..284c4ccf 100644 --- a/src/neo4j/_typing.py +++ b/src/neo4j/_typing.py @@ -14,7 +14,7 @@ # limitations under the License. -from __future__ import annotations +from __future__ import annotations as _ from collections.abc import ( AsyncIterator, @@ -32,6 +32,7 @@ Set, ValuesView, ) +from importlib.util import find_spec as _find_spec from typing import ( Any, cast, @@ -89,12 +90,13 @@ "overload", ) -if TYPE_CHECKING: - from typing_extensions import NotRequired # Python 3.11+ # noqa: TC004 - from typing_extensions import Self # Python 3.11 # noqa: TC004 - from typing_extensions import ( # Python 3.11 # noqa: TC004 Python - LiteralString, - ) + +_te_available = _find_spec("typing_extensions") is not None + +if TYPE_CHECKING or _te_available: + from typing_extensions import LiteralString # Python 3.11 + from typing_extensions import NotRequired # Python 3.11+ + from typing_extensions import Self # Python 3.11 __all__ = ( # noqa: PLE0604 false positive *__all__, diff --git a/src/neo4j/time/__init__.py b/src/neo4j/time/__init__.py index 2c7afa93..8c39b719 100644 --- a/src/neo4j/time/__init__.py +++ b/src/neo4j/time/__init__.py @@ -59,6 +59,11 @@ from .._warnings import deprecated as _deprecated +if False: + # Ugly work-around to make sphinx understand `@_t.overload` + import typing as _t # type: ignore[no-redef] + + __all__ = [ "MAX_INT64", "MAX_YEAR", diff --git a/src/neo4j/vector.py b/src/neo4j/vector.py new file mode 100644 index 00000000..ceb08ae6 --- /dev/null +++ b/src/neo4j/vector.py @@ -0,0 +1,1256 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Vectors. + +https://trello.com/c/2xcLszsC/1164-python-vector-types-design-investigation +""" + +from __future__ import annotations as _ + +import abc as _abc +import struct as _struct +import sys as _sys +from enum import Enum as _Enum + +from . import _typing as _t +from ._optional_deps import ( + np as _np, + pa as _pa, +) + + +if False: + # Ugly work-around to make sphinx understand `@_t.overload` + import typing as _t # type: ignore[no-redef] + + +try: + from ._rust import vector as _vec_rust + from ._rust.vector import swap_endian as _swap_endian_unchecked_rust +except ImportError: + _swap_endian_unchecked_rust = None + _vec_rust = None + +if _t.TYPE_CHECKING: + import numpy # type: ignore[import] + import pyarrow # type: ignore[import] + + +__all__ = [ + "Vector", + "VectorDType", + "VectorEndian", +] + + +class Vector: + r""" + A class representing a Neo4j vector. + + The constructor accepts various types of data to create a vector. + Depending on ``data``'s type, further arguments may be required/allowed. + Examples of valid invocations are:: + + Vector([1, 2, 3], "i8") + Vector(b"\x00\x01\x00\x02", VectorDType.I16) + Vector(b"\x01\x00\x02\x00", "i16", byteorder="little") + Vector(numpy.array([1, 2, 3])) + Vector(pyarrow.array([1, 2, 3])) + + Internally, a vector is stored as a contiguous block of memory + (:class:`bytes`), containing homogeneous values encoded in big-endian + order. Support for this feature requires a DBMS supporting Bolt version + 6.0 or later. This corresponds to Neo4j 2025.05 or later. + + TODO: check and update final server version above! + + :param data: + The data from which the vector will be constructed. + The constructor accepts the following types: + + * ``Iterable[float]``, ``Iterable[int]`` (but not ``bytes`` or + ``bytearray``): + Use an iterable of floats or an iterable of ints to construct the + vector from native Python values. + The ``dtype`` parameter is required. + * ``bytes``, ``bytearray``: Use raw bytes to construct the vector. + The ``dtype`` parameter is required and ``byteorder`` is optional. + * ``numpy.ndarray``: Use a numpy array to construct the vector. + No further parameters are accepted. + * ``pyarrow.Array``: Use a pyarrow array to construct the vector. + No further parameters are accepted. + :param dtype: The type of the vector. + See :attr:`.dtype` for currently supported inner data types. + + This parameter is required if ``data`` is of type :class:`bytes`, + :class:`bytearray`, ``Iterable[float]``, or ``Iterable[int]``. + Otherwise, it must be omitted. + :param byteorder: The endianness of the input data (default: ``"big"``). + If ``"little"`` is given, ``neo4j-rust-ext`` or ``numpy`` is used to + speed up the internal byte flipping (if either package is installed). + Use :data:`sys.byteorder` if you want to use the system's native + endianness. + + This parameter is optional if ``data`` is of type ``bytes`` or + ``bytearray``. Otherwise, it must be omitted. + + :raises ValueError: + Depending on the type of ``data``: + * ``Iterable[float]``, ``Iterable[int]`` (excluding byte types): + * If the dtype is not supported. + * ``bytes``, ``bytearray``: + * If the dtype is not supported or data's size is not a + multiple of dtype's size. + * If byteorder is not one of ``"big"`` or ``"little"``. + * ``numpy.ndarray``: + * If the dtype is not supported. + * If the array is not one-dimensional. + * ``pyarrow.Array``: + * If the array's type is not supported. + * If the array contains null values. + :raises TypeError: + Depending on the type of ``data``: + * ``Iterable[float]``, ``Iterable[int]``(excluding byte types): + * If data's elements don't match the expected type depending on + dtype. + :raises OverflowError: + Depending on the type of ``data``: + * ``Iterable[float]``, ``Iterable[int]``(excluding byte types): + * If the value is out of range for the given type. + + .. versionadded: 6.0 + """ + + __slots__ = ("__weakref__", "_inner") + + _inner: _InnerVector + + @_t.overload + def __init__( + self, + data: _t.Iterable[float], + dtype: _T_VectorDTypeFloat, + /, + ) -> None: ... + + @_t.overload + def __init__( + self, + data: _t.Iterable[int], + dtype: _T_VectorDTypeInt, + /, + ) -> None: ... + + @_t.overload + def __init__( + self, + data: bytes | bytearray, + dtype: _T_VectorDType, + /, + *, + byteorder: _T_VectorEndian = "big", + ) -> None: ... + + @_t.overload + def __init__(self, data: numpy.ndarray, /) -> None: ... + + @_t.overload + def __init__(self, data: pyarrow.Array, /) -> None: ... + + def __init__(self, data, *args, **kwargs) -> None: + if isinstance(data, (bytes, bytearray)): + self._set_bytes(bytes(data), *args, **kwargs) + elif _np is not None and isinstance(data, _np.ndarray): + self._set_numpy(data, *args, **kwargs) + elif _pa is not None and isinstance(data, _pa.Array): + self._set_pyarrow(data, *args, **kwargs) + else: + self._set_native(data, *args, **kwargs) + + def raw(self, /, *, byteorder: _T_VectorEndian = "big") -> bytes: + """ + Get the raw bytes of the vector. + + The data is a continuous block of memory, containing an array of the + vector's data type. The data is stored in big-endian order. Pass + another byte-order to this method to get the converted data. + + :param byteorder: The endianness the data should be returned in. + If the data's byte-order needs flipping, this method tries to use + ``neo4j-rust-ext`` or ``numpy``, if installed, to speed up the + process. Use :data:`sys.byteorder` if you want to use the system's + native endianness. + + :returns: The raw bytes of the vector. + + :raises ValueError: + If byteorder is not one of ``"big"`` or ``"little"``. + """ + match byteorder: + case "big": + return self._inner.data + case "little": + return self._inner.data_le + case _: + raise ValueError( + f"Invalid byteorder: {byteorder!r}. " + "Must be 'big' or 'little'." + ) + + def set_raw( + self, + data: bytes, + /, + *, + byteorder: _T_VectorEndian = "big", + ) -> None: + """ + Set the raw bytes of the vector. + + :param data: The new raw bytes of the vector. + :param byteorder: The endianness of ``data``. + The data will always be stored in big-endian order. If passed-in + byte-order needs flipping, this method tries to use + ``neo4j-rust-ext`` or ``numpy``, if installed, to speed up the + process. Use :data:`sys.byteorder` if you want to use the system's + native endianness. + + :raises ValueError: + * If data's size is not a multiple of dtype's size. + * If byteorder is not one of ``"big"`` or ``"little"``. + :raises TypeError: If the data is not of type bytes. + """ + match byteorder: + case "big": + self._inner.data = data + case "little": + self._inner.data_le = data + case _: + raise ValueError( + f"Invalid byteorder: {byteorder!r}. " + "Must be 'big' or 'little'." + ) + + @property + def dtype(self) -> VectorDType: + """ + Get the type of the vector. + + :returns: The type of the vector. + """ + return self._inner.dtype + + def __len__(self) -> int: + """ + Get the number of elements in the vector. + + :returns: The number of elements in the vector. + """ + return len(self._inner) + + def __str__(self) -> str: + return str(self._inner) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.raw()!r}, {self.dtype!r})" + + @classmethod + def from_bytes( + cls, + data: bytes, + dtype: _T_VectorDType, + /, + *, + byteorder: _T_VectorEndian = "big", + ) -> _t.Self: + """ + Create a Vector instance from raw bytes. + + :param data: The raw bytes to create the vector from. + :param dtype: The type of the vector. + See also :attr:`.dtype`. + :param byteorder: The endianness of the data. + If ``"little"``, the bytes in data will be flipped to big-endian. + If installed, ``neo4j-rust-ext`` or ``numpy`` will be used to speed + up the byte flipping. Use :data:`sys.byteorder` if you want to use + the system's native endianness. + + :raises ValueError: + * If data's size is not a multiple of dtype's size. + * If byteorder is not one of ``"big"`` or ``"little"``. + :raises TypeError: If the data is not of type bytes. + """ + obj = cls.__new__(cls) + obj._set_bytes(data, dtype, byteorder=byteorder) + return obj + + def _set_bytes( + self, + data: bytes, + dtype: _T_VectorDType, + /, + *, + byteorder: _T_VectorEndian = "big", + ) -> None: + self._inner = _get_type(dtype)(data, byteorder=byteorder) + + @classmethod + @_t.overload + def from_native( + cls, data: _t.Iterable[float], dtype: _T_VectorDTypeFloat, / + ) -> _t.Self: ... + + @classmethod + @_t.overload + def from_native( + cls, data: _t.Iterable[int], dtype: _T_VectorDTypeInt, / + ) -> _t.Self: ... + + @classmethod + def from_native( + cls, + data: _t.Iterable[float] | _t.Iterable[int], + dtype: _T_VectorDType, + /, + ) -> _t.Self: + """ + Create a Vector instance from an iterable of values. + + :param data: The list, tuple, or other iterable of values to create the + vector from. + :param dtype: The type of the vector. + See also :attr:`.dtype`. + + ``data`` must contain values that match the expected type given by + ``dtype``: + + * ``dtype == "f32"``: :class:`float` + * ``dtype == "f64"``: :class:`float` + * ``dtype == "i8"``: :class:`int` + * ``dtype == "i16"``: :class:`int` + * ``dtype == "i32"``: :class:`int` + * ``dtype == "i64"``: :class:`int` + + :raises ValueError: If the dtype is not supported. + :raises TypeError: If data's elements don't match the expected type + depending on dtype. + :raises OverflowError: If the value is out of range for the given type. + """ + obj = cls.__new__(cls) + obj._set_native(data, dtype) + return obj + + def _set_native( + self, + data: _t.Iterable[float] | _t.Iterable[int], + dtype: _T_VectorDType, + /, + ) -> None: + self._inner = _get_type(dtype).from_native(data) + + def to_native(self) -> list[object]: + """ + Convert the vector to a native Python list. + + The type of the elements in the list depends on the dtype of the + vector. See :meth:`Vector.from_native` for details. + + :returns: A list of values representing the vector. + """ + return self._inner.to_native() + + @classmethod + def from_numpy(cls, data: numpy.ndarray, /) -> _t.Self: + """ + Create a Vector instance from a numpy array. + + :param data: The numpy array to create the vector from. + The array must be one-dimensional and have a dtype that is + supported by Neo4j vectors: ``float64``, ``float32``, + ``int64``, ``int32``, ``int16``, or ``int8``. + See also :attr:`.dtype`. + + :raises ValueError: + * If the dtype is not supported. + * If the array is not one-dimensional. + :raises ImportError: If numpy is not installed. + + :returns: A Vector instance constructed from the numpy array. + """ + obj = cls.__new__(cls) + obj._set_numpy(data) + return obj + + def to_numpy(self) -> numpy.ndarray: + """ + Convert the vector to a numpy array. + + The array's dtype depends on the dtype of the vector. However, it will + always be in big-endian order. + + :returns: A numpy array representing the vector. + + :raises ImportError: If numpy is not installed. + """ + return self._inner.to_numpy() + + def _set_numpy(self, data: numpy.ndarray, /) -> None: + if data.ndim != 1: + raise ValueError("Data must be one-dimensional") + type_: type[_InnerVector] + match data.dtype.name: + case "float64": + type_ = _VecF64 + case "float32": + type_ = _VecF32 + case "int64": + type_ = _VecI64 + case "int32": + type_ = _VecI32 + case "int16": + type_ = _VecI16 + case "int8": + type_ = _VecI8 + case _: + raise ValueError(f"Unsupported numpy dtype: {data.dtype.name}") + self._inner = type_.from_numpy(data) + + @classmethod + def from_pyarrow(cls, data: pyarrow.Array, /) -> _t.Self: + """ + Create a Vector instance from a pyarrow array. + + :param data: The pyarrow array to create the vector from. + The array must have a type that is supported by Neo4j. + See also :attr:`.dtype`. + + PyArrow stores data in little endian. Therefore, the byte-order needs + to be swapped. If ``neo4j-rust-ext`` or ``numpy`` is installed, it will + be used to speed up the byte flipping. + + :raises ValueError: + * If the array's type is not supported. + * If the array contains null values. + :raises ImportError: If pyarrow is not installed. + + :returns: A Vector instance constructed from the pyarrow array. + """ + obj = cls.__new__(cls) + obj._set_pyarrow(data) + return obj + + def to_pyarrow(self) -> pyarrow.Array: + """ + Convert the vector to a pyarrow array. + + :returns: A pyarrow array representing the vector. + + :raises ImportError: If pyarrow is not installed. + """ + return self._inner.to_pyarrow() + + def _set_pyarrow(self, data: pyarrow.Array, /) -> None: + import pyarrow + + type_: type[_InnerVector] + if data.type == pyarrow.float64(): + type_ = _VecF64 + elif data.type == pyarrow.float32(): + type_ = _VecF32 + elif data.type == pyarrow.int64(): + type_ = _VecI64 + elif data.type == pyarrow.int32(): + type_ = _VecI32 + elif data.type == pyarrow.int16(): + type_ = _VecI16 + elif data.type == pyarrow.int8(): + type_ = _VecI8 + else: + raise ValueError(f"Unsupported pyarrow dtype: {data.type}") + inner = type_.from_pyarrow(data) + self._inner = inner + + # TODO: consider conversion to/from + # * tensorflow + # * pandas + # * polars + + +class VectorEndian(str, _Enum): + """ + Data endianness (i.e., byte order) of the elements in a :class:`Vector`. + + Inherits from :class:`str` and :class:`enum.Enum`. + Every driver API accepting a :class:`.VectorEndian` value will also accept + a string:: + + >>> VectorEndian.BIG == "big" + True + >>> VectorEndian.LITTLE == "little" + True + + .. seealso:: :attr:`Vector.raw` + + .. versionadded:: 6.0 + """ + + BIG = "big" + LITTLE = "little" + + +_T_VectorEndian = VectorEndian | _t.Literal["big", "little"] + + +class VectorDType(str, _Enum): + """ + The data type of the elements in a :class:`Vector`. + + Currently supported types are: + + * ``f32``: 32-bit floating point number (single) + * ``f64``: 64-bit floating point number (double) + * ``i8``: 8-bit integer + * ``i16``: 16-bit integer + * ``i32``: 32-bit integer + * ``i64``: 64-bit integer + + Inherits from :class:`str` and :class:`enum.Enum`. + Every driver API accepting a :class:`.VectorDType` value will also accept + a string:: + + >>> VectorDType.F32 == "f32" + True + >>> VectorDType.I8 == "i8" + True + + .. seealso:: :attr:`Vector.dtype` + + .. versionadded:: 6.0 + """ + + F32 = "f32" + F64 = "f64" + I8 = "i8" + I16 = "i16" + I32 = "i32" + I64 = "i64" + + +_T_VectorDType = ( + VectorDType | _t.Literal["f32", "f64", "i8", "i16", "i32", "i64"] +) +_T_VectorDTypeInt = _t.Literal[ + VectorDType.I8, + VectorDType.I16, + VectorDType.I32, + VectorDType.I64, + "i8", + "i16", + "i32", + "i64", +] +_T_VectorDTypeFloat = _t.Literal[ + VectorDType.F32, VectorDType.F64, "f32", "f64" +] + + +def _swap_endian(type_size: int, data: bytes, /) -> bytes: + """Swap from big endian to little endian.""" + if type_size == 1: + return data + if type_size not in {2, 4, 8}: + raise ValueError(f"Unsupported type size: {type_size}") + if len(data) % type_size != 0: + raise ValueError( + f"Data length {len(data)} is not a multiple of {type_size}" + ) + return _swap_endian_unchecked(type_size, data) + + +def _swap_endian_unchecked_np(type_size: int, data: bytes, /) -> bytes: + match type_size: + case 2: + dtype = _np.dtype(" bytes: + return bytes( + byte + for i in range(0, len(data), type_size) + for byte in data[i : i + type_size][::-1] + ) + + +if _swap_endian_unchecked_rust is not None: + _swap_endian_unchecked = _swap_endian_unchecked_rust +elif _np is not None: + _swap_endian_unchecked = _swap_endian_unchecked_np +else: + _swap_endian_unchecked = _swap_endian_unchecked_py + + +def _get_type(dtype: _T_VectorDType, /) -> type[_InnerVector]: + if isinstance(dtype, str): + if dtype not in VectorDType.__members__.values(): + raise ValueError(f"Unsupported vector type: {dtype!r}.") + dtype = VectorDType(dtype) + if not isinstance(dtype, VectorDType): + raise TypeError(f"Expected a VectorDType or str, got {type(dtype)}.") + if dtype not in _TYPES: + raise ValueError(f"Unsupported vector type: {dtype!r}.") + return _TYPES[dtype] + + +_TYPES: dict[VectorDType, type[_InnerVector]] = {} + + +class _InnerVector(_abc.ABC): + __slots__ = ("_data", "_data_le") + + dtype: _t.ClassVar[VectorDType] + size: _t.ClassVar[int] + _data: bytes + _data_le: bytes | None + + def __init__( + self, data: bytes, /, *, byteorder: _T_VectorEndian = "big" + ) -> None: + super().__init__() + if self.__class__ == _InnerVector: + raise TypeError("Cannot instantiate abstract class InnerVector") + match byteorder: + case "big": + self.data = data + self._data_le = None + case "little": + self.data = _swap_endian(self.size, data) + self._data_le = data + case _: + raise ValueError( + f"Invalid byteorder: {byteorder!r}. " + "Must be 'big' or 'little'." + ) + + @property + def data(self) -> bytes: + return self._data + + @data.setter + def data(self, data: bytes, /) -> None: + if not isinstance(data, bytes): + raise TypeError("Data must be of type bytes") + if not len(data) % self.size == 0: + raise ValueError( + f"Data length {len(data)} is not a multiple of {self.size}" + ) + self._data = data + + @property + def data_le(self) -> bytes: + if self._data_le is None: + self._data_le = _swap_endian(self.size, self.data) + return self._data_le + + @data_le.setter + def data_le(self, data: bytes, /) -> None: + self.data = _swap_endian(self.size, data) + self._data_le = data + + def __init_subclass__(cls) -> None: + super().__init_subclass__() + dtype = getattr(cls, "dtype", None) + if not isinstance(dtype, VectorDType): + raise TypeError( + f"Class {cls.__name__} must have a VectorDType attribute" + "'dtype'" + ) + if not isinstance(getattr(cls, "size", None), int): + raise TypeError( + f"Class {cls.__name__} must have a str attribute 'size'" + ) + if cls.size not in {1, 2, 4, 8}: + # Either change the sub-type's size if it was a typo or add support + # for the new size in the swap_endian function. + raise ValueError( + f"Class {cls.__name__} has an unhandled size {cls.size}" + ) + if dtype in _TYPES: + raise ValueError( + f"Class {cls.__name__} has a duplicate type '{dtype}'" + ) + _TYPES[dtype] = cls + + def __len__(self) -> int: + return len(self.data) // self.size + + def __str__(self) -> str: + size = len(self) + return f"Vec[{self.dtype}; {size}]" + + def __repr__(self) -> str: + cls_name = self.__class__.__name__ + return f"{cls_name}({self.data!r})" + + @classmethod + @_abc.abstractmethod + def from_native(cls, data: _t.Iterable[object], /) -> _t.Self: ... + + @_abc.abstractmethod + def to_native(self) -> list[object]: ... + + @classmethod + def from_numpy(cls, data: numpy.ndarray, /) -> _t.Self: + if data.dtype.byteorder == "<" or ( + data.dtype.byteorder == "=" and _sys.byteorder == "little" + ): + data = data.byteswap() + return cls(data.tobytes()) + + @_abc.abstractmethod + def to_numpy(self) -> numpy.ndarray: ... + + @classmethod + def from_pyarrow(cls, data: pyarrow.Array, /) -> _t.Self: + width = data.type.byte_width + assert cls.size == width + if _pa.compute.count(data, mode="only_null").as_py(): + raise ValueError("PyArrow array must not contain any null values.") + _, buffer = data.buffers() + buffer = buffer[ + data.offset * width : (data.offset + len(data)) * width + ] + return cls(bytes(buffer), byteorder=_sys.byteorder) + + @_abc.abstractmethod + def to_pyarrow(self) -> pyarrow.Array: ... + + +class _VecF64(_InnerVector): + __slots__ = () + + dtype = VectorDType.F64 + size = 8 + + @classmethod + def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self: + return cls(_vec_rust.vec_f64_from_native(data)) + + @classmethod + def _from_native_np(cls, data: _t.Iterable[object], /) -> _t.Self: + data = tuple(data) + non_float = tuple(item for item in data if not isinstance(item, float)) + if non_float: + raise TypeError( + f"Cannot build f64 vector from {type(non_float[0]).__name__}, " + "expected float." + ) + return cls(_np.fromiter(data, dtype=_np.dtype(">f8")).tobytes()) + + @classmethod + def _from_native_py(cls, data: _t.Iterable[object], /) -> _t.Self: + bytes_ = bytearray() + for item in data: + if not isinstance(item, float): + raise TypeError( + f"Cannot build f64 vector from {type(item).__name__}, " + "expected float." + ) + bytes_.extend(_struct.pack(">d", item)) + return cls(bytes(bytes_)) + + if _vec_rust is not None: + from_native = _from_native_rust + elif _np is not None: + from_native = _from_native_np + else: + from_native = _from_native_py + + def _to_native_rust(self) -> list[object]: + return _vec_rust.vec_f64_to_native(self.data) + + def _to_native_np(self) -> list[object]: + return _np.frombuffer(self.data, dtype=_np.dtype(">f8")).tolist() + + def _to_native_py(self) -> list[object]: + return [ + _struct.unpack(">d", self.data[i : i + self.size])[0] + for i in range(0, len(self.data), self.size) + ] + + if _vec_rust is not None: + to_native = _to_native_rust + elif _np is not None: + to_native = _to_native_np + else: + to_native = _to_native_py + + def to_numpy(self) -> numpy.ndarray: + import numpy + + return numpy.frombuffer(self.data, dtype=numpy.dtype(">f8")) + + def to_pyarrow(self) -> pyarrow.Array: + import pyarrow + + buffer = pyarrow.py_buffer(self.data_le) + return pyarrow.Array.from_buffers( + pyarrow.float64(), len(self), [None, buffer], 0 + ) + + +class _VecF32(_InnerVector): + __slots__ = () + + dtype = VectorDType.F32 + size = 4 + + @classmethod + def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self: + return cls(_vec_rust.vec_f32_from_native(data)) + + @classmethod + def _from_native_np(cls, data: _t.Iterable[object], /) -> _t.Self: + data = tuple(data) + non_float = tuple(item for item in data if not isinstance(item, float)) + if non_float: + raise TypeError( + f"Cannot build f32 vector from {type(non_float[0]).__name__}, " + "expected float." + ) + return cls(_np.fromiter(data, dtype=_np.dtype(">f4")).tobytes()) + + @classmethod + def _from_native_py(cls, data: _t.Iterable[object], /) -> _t.Self: + bytes_ = bytearray() + for item in data: + if not isinstance(item, float): + raise TypeError( + f"Cannot build f32 vector from {type(item).__name__}, " + "expected float." + ) + bytes_.extend(_struct.pack(">f", item)) + return cls(bytes(bytes_)) + + if _vec_rust is not None: + from_native = _from_native_rust + elif _np is not None: + from_native = _from_native_np + else: + from_native = _from_native_py + + def _to_native_rust(self) -> list[object]: + return _vec_rust.vec_f32_to_native(self.data) + + def _to_native_np(self) -> list[object]: + return _np.frombuffer(self.data, dtype=_np.dtype(">f4")).tolist() + + def _to_native_py(self) -> list[object]: + return [ + _struct.unpack(">f", self.data[i : i + self.size])[0] + for i in range(0, len(self.data), self.size) + ] + + if _vec_rust is not None: + to_native = _to_native_rust + elif _np is not None: + to_native = _to_native_np + else: + to_native = _to_native_py + + def to_numpy(self) -> numpy.ndarray: + import numpy + + return numpy.frombuffer(self.data, dtype=numpy.dtype(">f4")) + + def to_pyarrow(self) -> pyarrow.Array: + import pyarrow + + buffer = pyarrow.py_buffer(self.data_le) + return pyarrow.Array.from_buffers( + pyarrow.float32(), len(self), [None, buffer], 0 + ) + + +_I64_MIN = -9_223_372_036_854_775_808 +_I64_MAX = 9_223_372_036_854_775_807 + + +class _VecI64(_InnerVector): + __slots__ = () + + dtype = VectorDType.I64 + size = 8 + + @classmethod + def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self: + return cls(_vec_rust.vec_i64_from_native(data)) + + @classmethod + def _from_native_np(cls, data: _t.Iterable[object], /) -> _t.Self: + data = tuple(data) + non_int = tuple(item for item in data if not isinstance(item, int)) + if non_int: + raise TypeError( + f"Cannot build i64 vector from {type(non_int[0]).__name__}, " + "expected int." + ) + data = _t.cast(tuple[int, ...], data) + overflow_int = tuple( + item for item in data if not _I64_MIN <= item <= _I64_MAX + ) + if overflow_int: + raise OverflowError( + f"Value {overflow_int[0]} is out of range for i64: " + f"[-{_I64_MIN}, {_I64_MAX}]" + ) + return cls(_np.fromiter(data, dtype=_np.dtype(">i8")).tobytes()) + + @classmethod + def _from_native_py(cls, data: _t.Iterable[object], /) -> _t.Self: + bytes_ = bytearray() + for item in data: + if not isinstance(item, int): + raise TypeError( + f"Cannot build i64 vector from {type(item).__name__}, " + "expected int." + ) + if not _I64_MIN <= item <= _I64_MAX: + raise OverflowError( + f"Value {item} is out of range for i64: " + f"[-{_I64_MIN}, {_I64_MAX}]" + ) + bytes_.extend(_struct.pack(">q", item)) + return cls(bytes(bytes_)) + + if _vec_rust is not None: + from_native = _from_native_rust + elif _np is not None: + from_native = _from_native_np + else: + from_native = _from_native_py + + def _to_native_rust(self) -> list[object]: + return _vec_rust.vec_i64_to_native(self.data) + + def _to_native_np(self) -> list[object]: + return _np.frombuffer(self.data, dtype=_np.dtype(">i8")).tolist() + + def _to_native_py(self) -> list[object]: + return [ + _struct.unpack(">q", self.data[i : i + self.size])[0] + for i in range(0, len(self.data), self.size) + ] + + if _vec_rust is not None: + to_native = _to_native_rust + elif _np is not None: + to_native = _to_native_np + else: + to_native = _to_native_py + + def to_numpy(self) -> numpy.ndarray: + import numpy + + return numpy.frombuffer(self.data, dtype=numpy.dtype(">i8")) + + def to_pyarrow(self) -> pyarrow.Array: + import pyarrow + + buffer = pyarrow.py_buffer(self.data_le) + return pyarrow.Array.from_buffers( + pyarrow.int64(), len(self), [None, buffer], 0 + ) + + +_I32_MIN = -2_147_483_648 +_I32_MAX = 2_147_483_647 + + +class _VecI32(_InnerVector): + __slots__ = () + + dtype = VectorDType.I32 + size = 4 + + @classmethod + def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self: + return cls(_vec_rust.vec_i32_from_native(data)) + + @classmethod + def _from_native_np(cls, data: _t.Iterable[object], /) -> _t.Self: + data = tuple(data) + non_int = tuple(item for item in data if not isinstance(item, int)) + if non_int: + raise TypeError( + f"Cannot build i32 vector from {type(non_int[0]).__name__}, " + "expected int." + ) + data = _t.cast(tuple[int, ...], data) + overflow_int = tuple( + item for item in data if not _I32_MIN <= item <= _I32_MAX + ) + if overflow_int: + raise OverflowError( + f"Value {overflow_int[0]} is out of range for i32: " + f"[-{_I32_MIN}, {_I32_MAX}]" + ) + return cls(_np.fromiter(data, dtype=_np.dtype(">i4")).tobytes()) + + @classmethod + def _from_native_py(cls, data: _t.Iterable[object], /) -> _t.Self: + bytes_ = bytearray() + for item in data: + if not isinstance(item, int): + raise TypeError( + f"Cannot build i32 vector from {type(item).__name__}, " + "expected int." + ) + if not _I32_MIN <= item <= _I32_MAX: + raise OverflowError( + f"Value {item} is out of range for i32: " + f"[-{_I32_MIN}, {_I32_MAX}]" + ) + bytes_.extend(_struct.pack(">i", item)) + return cls(bytes(bytes_)) + + if _vec_rust is not None: + from_native = _from_native_rust + elif _np is not None: + from_native = _from_native_np + else: + from_native = _from_native_py + + def _to_native_rust(self) -> list[object]: + return _vec_rust.vec_i32_to_native(self.data) + + def _to_native_np(self) -> list[object]: + return _np.frombuffer(self.data, dtype=_np.dtype(">i4")).tolist() + + def _to_native_py(self) -> list[object]: + return [ + _struct.unpack(">i", self.data[i : i + self.size])[0] + for i in range(0, len(self.data), self.size) + ] + + if _vec_rust is not None: + to_native = _to_native_rust + elif _np is not None: + to_native = _to_native_np + else: + to_native = _to_native_py + + def to_numpy(self) -> numpy.ndarray: + import numpy + + return numpy.frombuffer(self.data, dtype=numpy.dtype(">i4")) + + def to_pyarrow(self) -> pyarrow.Array: + import pyarrow + + buffer = pyarrow.py_buffer(self.data_le) + return pyarrow.Array.from_buffers( + pyarrow.int32(), len(self), [None, buffer], 0 + ) + + +_I16_MIN = -32_768 +_I16_MAX = 32_767 + + +class _VecI16(_InnerVector): + __slots__ = () + + dtype = VectorDType.I16 + size = 2 + + @classmethod + def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self: + return cls(_vec_rust.vec_i16_from_native(data)) + + @classmethod + def _from_native_np(cls, data: _t.Iterable[object], /) -> _t.Self: + data = tuple(data) + non_int = tuple(item for item in data if not isinstance(item, int)) + if non_int: + raise TypeError( + f"Cannot build i16 vector from {type(non_int[0]).__name__}, " + "expected int." + ) + data = _t.cast(tuple[int, ...], data) + overflow_int = tuple( + item for item in data if not _I16_MIN <= item <= _I16_MAX + ) + if overflow_int: + raise OverflowError( + f"Value {overflow_int[0]} is out of range for i16: " + f"[-{_I16_MIN}, {_I16_MAX}]" + ) + return cls(_np.fromiter(data, dtype=_np.dtype(">i2")).tobytes()) + + @classmethod + def _from_native_py(cls, data: _t.Iterable[object], /) -> _t.Self: + bytes_ = bytearray() + for item in data: + if not isinstance(item, int): + raise TypeError( + f"Cannot build i16 vector from {type(item).__name__}, " + "expected int." + ) + if not _I16_MIN <= item <= _I16_MAX: + raise OverflowError( + f"Value {item} is out of range for i16: " + f"[-{_I16_MIN}, {_I16_MAX}]" + ) + bytes_.extend(_struct.pack(">h", item)) + return cls(bytes(bytes_)) + + if _vec_rust is not None: + from_native = _from_native_rust + elif _np is not None: + from_native = _from_native_np + else: + from_native = _from_native_py + + def _to_native_rust(self) -> list[object]: + return _vec_rust.vec_i16_to_native(self.data) + + def _to_native_np(self) -> list[object]: + return _np.frombuffer(self.data, dtype=_np.dtype(">i2")).tolist() + + def _to_native_py(self) -> list[object]: + return [ + _struct.unpack(">h", self.data[i : i + self.size])[0] + for i in range(0, len(self.data), self.size) + ] + + if _vec_rust is not None: + to_native = _to_native_rust + elif _np is not None: + to_native = _to_native_np + else: + to_native = _to_native_py + + def to_numpy(self) -> numpy.ndarray: + import numpy + + return numpy.frombuffer(self.data, dtype=numpy.dtype(">i2")) + + def to_pyarrow(self) -> pyarrow.Array: + import pyarrow + + buffer = pyarrow.py_buffer(self.data_le) + return pyarrow.Array.from_buffers( + pyarrow.int16(), len(self), [None, buffer], 0 + ) + + +_I8_MIN = -128 +_I8_MAX = 127 + + +class _VecI8(_InnerVector): + __slots__ = () + + dtype = VectorDType.I8 + size = 1 + + @classmethod + def _from_native_rust(cls, data: _t.Iterable[object], /) -> _t.Self: + return cls(_vec_rust.vec_i8_from_native(data)) + + @classmethod + def _from_native_np(cls, data: _t.Iterable[object], /) -> _t.Self: + data = tuple(data) + non_int = tuple(item for item in data if not isinstance(item, int)) + if non_int: + raise TypeError( + f"Cannot build i8 vector from {type(non_int[0]).__name__}, " + "expected int." + ) + data = _t.cast(tuple[int, ...], data) + overflow_int = tuple( + item for item in data if not _I8_MIN <= item <= _I8_MAX + ) + if overflow_int: + raise OverflowError( + f"Value {overflow_int[0]} is out of range for i8: " + f"[-{_I8_MIN}, {_I8_MAX}]" + ) + return cls(_np.fromiter(data, dtype=_np.dtype(">i1")).tobytes()) + + @classmethod + def _from_native_py(cls, data: _t.Iterable[object], /) -> _t.Self: + bytes_ = bytearray() + for item in data: + if not isinstance(item, int): + raise TypeError( + f"Cannot build i8 vector from {type(item).__name__}, " + "expected int." + ) + if not _I8_MIN <= item <= _I8_MAX: + raise OverflowError( + f"Value {item} is out of range for i8: " + f"[-{_I8_MIN}, {_I8_MAX}]" + ) + bytes_.extend(_struct.pack(">b", item)) + return cls(bytes(bytes_)) + + if _vec_rust is not None: + from_native = _from_native_rust + elif _np is not None: + from_native = _from_native_np + else: + from_native = _from_native_py + + def _to_native_rust(self) -> list[object]: + return _vec_rust.vec_i8_to_native(self.data) + + def _to_native_np(self) -> list[object]: + return _np.frombuffer(self.data, dtype=_np.dtype(">i1")).tolist() + + def _to_native_py(self) -> list[object]: + return [ + _struct.unpack(">b", self.data[i : i + self.size])[0] + for i in range(0, len(self.data), self.size) + ] + + if _vec_rust is not None: + to_native = _to_native_rust + elif _np is not None: + to_native = _to_native_np + else: + to_native = _to_native_py + + def to_numpy(self) -> numpy.ndarray: + import numpy + + return numpy.frombuffer(self.data, dtype=numpy.dtype(">i1")) + + def to_pyarrow(self) -> pyarrow.Array: + import pyarrow + + buffer = pyarrow.py_buffer(self.data_le) + return pyarrow.Array.from_buffers( + pyarrow.int8(), len(self), [None, buffer], 0 + ) diff --git a/testkitbackend/fromtestkit.py b/testkitbackend/fromtestkit.py index 438e08f4..e860275a 100644 --- a/testkitbackend/fromtestkit.py +++ b/testkitbackend/fromtestkit.py @@ -37,6 +37,10 @@ Duration, Time, ) +from neo4j.vector import ( + Vector, + VectorDType, +) from ._preview_imports import NotificationDisabledClassification @@ -185,6 +189,11 @@ def to_param(m): seconds=data["seconds"], nanoseconds=data["nanoseconds"], ) + if name == "CypherVector": + return Vector( + VectorDType(data["dtype"]), + bytes([int(byte, 16) for byte in data["data"].split()]), + ) raise ValueError("Unknown param type " + name) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 0c784bc7..1cc6e098 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -49,6 +49,7 @@ "Feature:API:Summary:GqlStatusObjects": true, "Feature:API:Type.Spatial": true, "Feature:API:Type.Temporal": true, + "Feature:API:Type.Vector": true, "Feature:Auth:Bearer": true, "Feature:Auth:Custom": true, "Feature:Auth:Kerberos": true, @@ -68,6 +69,7 @@ "Feature:Bolt:5.6": true, "Feature:Bolt:5.7": true, "Feature:Bolt:5.8": true, + "Feature:Bolt:6.0": true, "Feature:Bolt:HandshakeManifestV1": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 4f4805d2..8b9acdd2 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -39,6 +39,7 @@ Duration, Time, ) +from neo4j.vector import Vector from neo4j.warnings import PreviewWarning from ._warning_check import warning_check @@ -297,6 +298,14 @@ def to(name, val): "nanoseconds": v.nanoseconds, }, } + if isinstance(v, Vector): + return { + "name": "CypherVector", + "data": { + "dtype": v.dtype, + "data": " ".join(f"{byte:02x}" for byte in v.raw()), + }, + } raise ValueError("Unhandled type:" + str(type(v))) diff --git a/tests/unit/async_/io/test__bolt_socket.py b/tests/unit/async_/io/test__bolt_socket.py index 954c0b1b..10718a50 100644 --- a/tests/unit/async_/io/test__bolt_socket.py +++ b/tests/unit/async_/io/test__bolt_socket.py @@ -42,7 +42,6 @@ def _deque_popleft_n(d: deque[_T], n: int) -> list[_T]: DEADLINE = Deadline(float("inf")) -# [bolt-version-bump] search tag when changing bolt version support @mark_async_test @pytest.mark.parametrize("log_level", (1, logging.DEBUG, logging.CRITICAL)) async def test_handshake(async_bolt_socket_factory, caplog, log_level): @@ -68,7 +67,7 @@ async def test_handshake_manifest_v1( caplog, log_level, ): - chosen_version = (5, 8) + chosen_version = (6, 0) expected_feature_bits = b"\x00" # varint(0) caplog.set_level(log_level) diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 45f2af1e..b50b1809 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -49,6 +49,7 @@ def test_class_method_protocol_handlers(): (3, 0), (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), + (6, 0), } # fmt: on @@ -81,7 +82,8 @@ def test_class_method_protocol_handlers(): ((5, 7), 1), ((5, 8), 1), ((5, 9), 0), - ((6, 0), 0), + ((6, 0), 1), + ((6, 1), 0), ], ) def test_class_method_protocol_handlers_with_protocol_version( @@ -91,7 +93,6 @@ def test_class_method_protocol_handlers_with_protocol_version( assert (test_input in protocol_handlers) == expected -# [bolt-version-bump] search tag when changing bolt version support def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( @@ -153,6 +154,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 6), "neo4j._async.io._bolt5.AsyncBolt5x6"), ((5, 7), "neo4j._async.io._bolt5.AsyncBolt5x7"), ((5, 8), "neo4j._async.io._bolt5.AsyncBolt5x8"), + ((6, 0), "neo4j._async.io._bolt6.AsyncBolt6x0"), ), ) @mark_async_test @@ -193,14 +195,15 @@ async def test_version_negotiation( (4, 0), (4, 1), (5, 9), - (6, 0), + (6, 1), ), ) @mark_async_test async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8', " + "'6.0')" ) address = ("localhost", 7687) diff --git a/tests/unit/async_/io/test_class_bolt6x0.py b/tests/unit/async_/io/test_class_bolt6x0.py new file mode 100644 index 00000000..ad6777cc --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt6x0.py @@ -0,0 +1,872 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io._bolt6 import AsyncBolt6x0 +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_async_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt6x0( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt6x0( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt6x0( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt6x0( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.connection_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt6x0( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt6x0( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt6x0.UNPACKER_CLS) + connection = AsyncBolt6x0( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt6x0( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt6x0( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt6x0( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt6x0(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + _tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + connection = AsyncBolt6x0(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_async_test +async def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + connection = AsyncBolt6x0(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + await sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + await connection.send_all() + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_async_test +async def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + connection = AsyncBolt6x0(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + await sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + await connection.send_all() + with pytest.raises(Neo4jError): + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt6x0.PACKER_CLS, + unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt6x0( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is bool(ssr_hint) diff --git a/tests/unit/common/codec/hydration/v1/_base.py b/tests/unit/common/codec/hydration/_base.py similarity index 100% rename from tests/unit/common/codec/hydration/v1/_base.py rename to tests/unit/common/codec/hydration/_base.py diff --git a/tests/unit/common/codec/hydration/v1/test_graph_hydration.py b/tests/unit/common/codec/hydration/v1/test_graph_hydration.py index c78ff9dd..f51ce10c 100644 --- a/tests/unit/common/codec/hydration/v1/test_graph_hydration.py +++ b/tests/unit/common/codec/hydration/v1/test_graph_hydration.py @@ -23,7 +23,7 @@ Relationship, ) -from ._base import HydrationHandlerTestBase +from .._base import HydrationHandlerTestBase class TestGraphHydration(HydrationHandlerTestBase): diff --git a/tests/unit/common/codec/hydration/v1/test_hydration_handler.py b/tests/unit/common/codec/hydration/v1/test_hydration_handler.py index 4a927695..4cd0fe0a 100644 --- a/tests/unit/common/codec/hydration/v1/test_hydration_handler.py +++ b/tests/unit/common/codec/hydration/v1/test_hydration_handler.py @@ -43,8 +43,9 @@ Duration, Time, ) +from neo4j.vector import Vector -from ._base import HydrationHandlerTestBase +from .._base import HydrationHandlerTestBase class TestHydrationHandler(HydrationHandlerTestBase): @@ -85,6 +86,7 @@ def test_scope_dehydration_keys(self, hydration_scope): pd.Timestamp, pd.Timedelta, type(pd.NaT), + Vector, } assert not hooks.subtypes diff --git a/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py b/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py index 94e58147..0a8fc38c 100644 --- a/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py +++ b/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py @@ -24,7 +24,7 @@ WGS84Point, ) -from ._base import HydrationHandlerTestBase +from .._base import HydrationHandlerTestBase class TestSpatialDehydration(HydrationHandlerTestBase): diff --git a/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py b/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py index 5791ab8f..9ebdf953 100644 --- a/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py +++ b/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py @@ -24,7 +24,7 @@ WGS84Point, ) -from ._base import HydrationHandlerTestBase +from .._base import HydrationHandlerTestBase class TestSpatialHydration(HydrationHandlerTestBase): diff --git a/tests/unit/common/codec/hydration/v1/test_temporal_dehydration.py b/tests/unit/common/codec/hydration/v1/test_temporal_dehydration.py index 90fcb4e8..b27be6b0 100644 --- a/tests/unit/common/codec/hydration/v1/test_temporal_dehydration.py +++ b/tests/unit/common/codec/hydration/v1/test_temporal_dehydration.py @@ -34,7 +34,7 @@ Time, ) -from ._base import HydrationHandlerTestBase +from .._base import HydrationHandlerTestBase class TestTimeDehydration(HydrationHandlerTestBase): diff --git a/tests/unit/common/codec/hydration/v1/test_temporal_hydration.py b/tests/unit/common/codec/hydration/v1/test_temporal_hydration.py index 8b9667c7..2c59dc1e 100644 --- a/tests/unit/common/codec/hydration/v1/test_temporal_hydration.py +++ b/tests/unit/common/codec/hydration/v1/test_temporal_hydration.py @@ -27,7 +27,7 @@ Time, ) -from ._base import HydrationHandlerTestBase +from .._base import HydrationHandlerTestBase class TestTemporalHydration(HydrationHandlerTestBase): diff --git a/tests/unit/common/codec/hydration/v1/test_unknown_hydration.py b/tests/unit/common/codec/hydration/v1/test_unknown_hydration.py index 36323bc7..9b28a109 100644 --- a/tests/unit/common/codec/hydration/v1/test_unknown_hydration.py +++ b/tests/unit/common/codec/hydration/v1/test_unknown_hydration.py @@ -20,7 +20,7 @@ from neo4j._codec.hydration.v1 import HydrationHandler from neo4j._codec.packstream import Structure -from ._base import HydrationHandlerTestBase +from .._base import HydrationHandlerTestBase class TestUnknownHydration(HydrationHandlerTestBase): diff --git a/tests/unit/common/codec/hydration/v1/test_vector_dehydration.py b/tests/unit/common/codec/hydration/v1/test_vector_dehydration.py new file mode 100644 index 00000000..d91c3b80 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_vector_dehydration.py @@ -0,0 +1,44 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j.exceptions import ConfigurationError +from neo4j.vector import Vector + +from .._base import HydrationHandlerTestBase + + +class TestVectorDehydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + @pytest.fixture + def transformer(self, hydration_scope): + def transformer(value): + transformer_ = hydration_scope.dehydration_hooks.get_transformer( + value + ) + assert callable(transformer_) + return transformer_(value) + + return transformer + + def test_vector(self, transformer): + with pytest.raises(ConfigurationError, match="Vector"): + transformer(Vector.from_native("f64", [1.0])) diff --git a/tests/unit/common/codec/hydration/v1/test_vector_hydration.py b/tests/unit/common/codec/hydration/v1/test_vector_hydration.py new file mode 100644 index 00000000..3037a9b3 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_vector_hydration.py @@ -0,0 +1,39 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from struct import pack + +import pytest + +from neo4j._codec.hydration import BrokenHydrationObject +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure + +from .._base import HydrationHandlerTestBase + + +class TestVectorHydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_vector_structure_tag(self, hydration_scope): + struct = Structure(b"V", bytes(0xC1), pack(">f", 1.0)) + res = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(res, BrokenHydrationObject) + error = res.error + assert isinstance(error, ValueError) + assert repr(b"V") in str(error) diff --git a/tests/unit/common/codec/hydration/v2/test_hydration_handler.py b/tests/unit/common/codec/hydration/v2/test_hydration_handler.py index 3751d69f..a0112df1 100644 --- a/tests/unit/common/codec/hydration/v2/test_hydration_handler.py +++ b/tests/unit/common/codec/hydration/v2/test_hydration_handler.py @@ -19,11 +19,11 @@ from neo4j._codec.hydration.v2 import HydrationHandler from ..v1.test_hydration_handler import ( - TestHydrationHandler as TestHydrationHandlerV1, + TestHydrationHandler as _TestHydrationHandler, ) -class TestHydrationHandler(TestHydrationHandlerV1): +class TestHydrationHandler(_TestHydrationHandler): @pytest.fixture def hydration_handler(self): return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py b/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py index ace4d896..03ca83cb 100644 --- a/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py +++ b/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py @@ -19,11 +19,11 @@ from neo4j._codec.hydration.v2 import HydrationHandler from ..v1.test_spacial_dehydration import ( - TestSpatialDehydration as _TestSpatialDehydrationV1, + TestSpatialDehydration as _TestSpatialDehydration, ) -class TestSpatialDehydration(_TestSpatialDehydrationV1): +class TestSpatialDehydration(_TestSpatialDehydration): @pytest.fixture def hydration_handler(self): return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py b/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py index 207d35ab..a5555dea 100644 --- a/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py +++ b/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py @@ -19,11 +19,11 @@ from neo4j._codec.hydration.v2 import HydrationHandler from ..v1.test_spacial_hydration import ( - TestSpatialHydration as _TestSpatialHydrationV1, + TestSpatialHydration as _TestSpatialHydration, ) -class TestSpatialHydration(_TestSpatialHydrationV1): +class TestSpatialHydration(_TestSpatialHydration): @pytest.fixture def hydration_handler(self): return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_temporal_dehydration.py b/tests/unit/common/codec/hydration/v2/test_temporal_dehydration.py index ca83ea7e..52e8e956 100644 --- a/tests/unit/common/codec/hydration/v2/test_temporal_dehydration.py +++ b/tests/unit/common/codec/hydration/v2/test_temporal_dehydration.py @@ -25,11 +25,11 @@ from neo4j.time import DateTime from ..v1.test_temporal_dehydration import ( - TestTimeDehydration as _TestTemporalDehydrationV1, + TestTimeDehydration as _TestTemporalDehydration, ) -class TestTimeDehydration(_TestTemporalDehydrationV1): +class TestTimeDehydration(_TestTemporalDehydration): @pytest.fixture def hydration_handler(self): return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_temporal_hydration.py b/tests/unit/common/codec/hydration/v2/test_temporal_hydration.py index 150f6a02..8298b4f0 100644 --- a/tests/unit/common/codec/hydration/v2/test_temporal_hydration.py +++ b/tests/unit/common/codec/hydration/v2/test_temporal_hydration.py @@ -23,11 +23,11 @@ from neo4j.time import DateTime from ..v1.test_temporal_hydration import ( - TestTemporalHydration as _TestTemporalHydrationV1, + TestTemporalHydration as _TestTemporalHydration, ) -class TestTemporalHydration(_TestTemporalHydrationV1): +class TestTemporalHydration(_TestTemporalHydration): @pytest.fixture def hydration_handler(self): return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_vector_dehydration.py b/tests/unit/common/codec/hydration/v2/test_vector_dehydration.py new file mode 100644 index 00000000..b2913491 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_vector_dehydration.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler + +from ..v1.test_vector_dehydration import ( + TestVectorDehydration as _TestVectorDehydration, +) + + +class TestVectorDehydration(_TestVectorDehydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_vector_hydration.py b/tests/unit/common/codec/hydration/v2/test_vector_hydration.py new file mode 100644 index 00000000..334d145b --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_vector_hydration.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v2 import HydrationHandler + +from ..v1.test_vector_hydration import ( + TestVectorHydration as _TestVectorHydration, +) + + +class TestVectorHydration(_TestVectorHydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v3/__init__.py b/tests/unit/common/codec/hydration/v3/__init__.py new file mode 100644 index 00000000..3f968099 --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/hydration/v3/test_graph_hydration.py b/tests/unit/common/codec/hydration/v3/test_graph_hydration.py new file mode 100644 index 00000000..6966ff3c --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/test_graph_hydration.py @@ -0,0 +1,27 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v3 import HydrationHandler + +from ..v2.test_graph_hydration import TestGraphHydration as _TestGraphHydration + + +class TestGraphHydration(_TestGraphHydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v3/test_hydration_handler.py b/tests/unit/common/codec/hydration/v3/test_hydration_handler.py new file mode 100644 index 00000000..1752adcc --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/test_hydration_handler.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v3 import HydrationHandler + +from ..v2.test_hydration_handler import ( + TestHydrationHandler as _TestHydrationHandler, +) + + +class TestHydrationHandler(_TestHydrationHandler): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v3/test_spacial_dehydration.py b/tests/unit/common/codec/hydration/v3/test_spacial_dehydration.py new file mode 100644 index 00000000..b45dd90d --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/test_spacial_dehydration.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v3 import HydrationHandler + +from ..v2.test_spacial_dehydration import ( + TestSpatialDehydration as _TestSpatialDehydration, +) + + +class TestSpatialDehydration(_TestSpatialDehydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v3/test_spacial_hydration.py b/tests/unit/common/codec/hydration/v3/test_spacial_hydration.py new file mode 100644 index 00000000..2b63e0ba --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/test_spacial_hydration.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v3 import HydrationHandler + +from ..v2.test_spacial_hydration import ( + TestSpatialHydration as _TestSpatialHydration, +) + + +class TestSpatialHydration(_TestSpatialHydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v3/test_temporal_dehydration.py b/tests/unit/common/codec/hydration/v3/test_temporal_dehydration.py new file mode 100644 index 00000000..0eb8e5a4 --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/test_temporal_dehydration.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v3 import HydrationHandler + +from ..v2.test_temporal_dehydration import ( + TestTimeDehydration as _TestTemporalDehydration, +) + + +class TestTimeDehydration(_TestTemporalDehydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v3/test_temporal_hydration.py b/tests/unit/common/codec/hydration/v3/test_temporal_hydration.py new file mode 100644 index 00000000..1d191768 --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/test_temporal_hydration.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v3 import HydrationHandler + +from ..v2.test_temporal_hydration import ( + TestTemporalHydration as _TestTemporalHydration, +) + + +class TestTemporalHydration(_TestTemporalHydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v3/test_unknown_hydration.py b/tests/unit/common/codec/hydration/v3/test_unknown_hydration.py new file mode 100644 index 00000000..be681bcf --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/test_unknown_hydration.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v3 import HydrationHandler + +from ..v2.test_unknown_hydration import ( + TestUnknownHydration as _TestUnknownHydration, +) + + +class TestUnknownHydration(_TestUnknownHydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v3/test_vector_dehydration.py b/tests/unit/common/codec/hydration/v3/test_vector_dehydration.py new file mode 100644 index 00000000..becadafa --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/test_vector_dehydration.py @@ -0,0 +1,74 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v3 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.vector import Vector + +from .._base import HydrationHandlerTestBase + + +class TestVectorDehydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + @pytest.fixture + def transformer(self, hydration_scope): + def transformer(value): + transformer_ = hydration_scope.dehydration_hooks.get_transformer( + value + ) + assert callable(transformer_) + return transformer_(value) + + return transformer + + @pytest.fixture + def assert_transforms(self, transformer): + def assert_(value, expected): + struct = transformer(value) + assert struct == expected + + return assert_ + + @pytest.mark.parametrize( + ("dtype", "marker", "data"), + ( + *( + (dtype, marker, data) + for (dtype, marker) in ( + ("i8", b"\xc8"), + ("i16", b"\xc9"), + ("i32", b"\xca"), + ("i64", b"\xcb"), + ("f32", b"\xc6"), + ("f64", b"\xc1"), + ) + for data in (b"", bytes(range(128))) + ), + ("i8", b"\xc8", bytes(range(1))), + ("i16", b"\xc9", bytes(range(2))), + ("i32", b"\xca", bytes(range(4))), + ("i64", b"\xcb", bytes(range(8))), + ("f32", b"\xc6", bytes(range(4))), + ("f64", b"\xc1", bytes(range(8))), + ), + ) + def test_vector(self, assert_transforms, dtype, marker, data): + assert_transforms(Vector(data, dtype), Structure(b"V", marker, data)) diff --git a/tests/unit/common/codec/hydration/v3/test_vector_hydration.py b/tests/unit/common/codec/hydration/v3/test_vector_hydration.py new file mode 100644 index 00000000..10790ed8 --- /dev/null +++ b/tests/unit/common/codec/hydration/v3/test_vector_hydration.py @@ -0,0 +1,59 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v3 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.vector import Vector + +from .._base import HydrationHandlerTestBase + + +class TestVectorHydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + @pytest.mark.parametrize( + ("dtype", "marker", "data"), + ( + *( + (dtype, marker, data) + for (dtype, marker) in ( + ("i8", b"\xc8"), + ("i16", b"\xc9"), + ("i32", b"\xca"), + ("i64", b"\xcb"), + ("f32", b"\xc6"), + ("f64", b"\xc1"), + ) + for data in (b"", bytes(range(128))) + ), + ("i8", b"\xc8", bytes(range(1))), + ("i16", b"\xc9", bytes(range(2))), + ("i32", b"\xca", bytes(range(4))), + ("i64", b"\xcb", bytes(range(8))), + ("f32", b"\xc6", bytes(range(4))), + ("f64", b"\xc1", bytes(range(8))), + ), + ) + def test_vector(self, hydration_scope, dtype, marker, data): + struct = Structure(b"V", marker, data) + vector = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(vector, Vector) + assert vector.dtype == dtype + assert vector.raw() == data diff --git a/tests/unit/common/test_import_neo4j.py b/tests/unit/common/test_import_neo4j.py index e83985a2..f06cd9d0 100644 --- a/tests/unit/common/test_import_neo4j.py +++ b/tests/unit/common/test_import_neo4j.py @@ -153,6 +153,7 @@ def test_import_star(): ("auth_management", None), ("debug", None), ("exceptions", None), + ("vector", None), ("warnings", None), ) diff --git a/tests/unit/common/vector/__init__.py b/tests/unit/common/vector/__init__.py new file mode 100644 index 00000000..3f968099 --- /dev/null +++ b/tests/unit/common/vector/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/vector/test_import_vector.py b/tests/unit/common/vector/test_import_vector.py new file mode 100644 index 00000000..8be94f13 --- /dev/null +++ b/tests/unit/common/vector/test_import_vector.py @@ -0,0 +1,73 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib + +import pytest + + +MODULE_PATH = "neo4j.vector" +VECTOR_ATTRIBUTES = ( + # (name, warning) + ("Vector", None), + ("VectorDType", None), + ("VectorEndian", None), +) + + +def _get_module(): + module = importlib.__import__(MODULE_PATH) + for submodule in MODULE_PATH.split(".")[1:]: + module = getattr(module, submodule) + return module + + +@pytest.mark.parametrize(("name", "warning"), VECTOR_ATTRIBUTES) +def test_attribute_import(name, warning): + module = _get_module() + if warning: + with pytest.warns(warning): + getattr(module, name) + else: + getattr(module, name) + + +@pytest.mark.parametrize(("name", "warning"), VECTOR_ATTRIBUTES) +def test_attribute_from_import(name, warning): + if warning: + with pytest.warns(warning): + importlib.__import__(MODULE_PATH, fromlist=(name,)) + else: + importlib.__import__(MODULE_PATH, fromlist=(name,)) + + +def test_all(): + module = _get_module() + + assert sorted(module.__all__) == sorted([i[0] for i in VECTOR_ATTRIBUTES]) + + +def test_dir(): + module = _get_module() + + dir_attrs = (attr for attr in dir(module) if not attr.startswith("_")) + assert sorted(dir_attrs) == sorted([i[0] for i in VECTOR_ATTRIBUTES]) + + +def test_import_star(): + # ignore PT029: purposefully capturing all warnings to then apply further + # checks on them + importlib.__import__(MODULE_PATH, fromlist=("*",)) diff --git a/tests/unit/common/vector/test_vector.py b/tests/unit/common/vector/test_vector.py new file mode 100644 index 00000000..0612bc83 --- /dev/null +++ b/tests/unit/common/vector/test_vector.py @@ -0,0 +1,856 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import math +import random +import struct +import sys +import timeit +import typing as t + +import pytest + +from neo4j._optional_deps import ( + np, + pa, +) +from neo4j.vector import ( + _swap_endian, + Vector, +) + + +if t.TYPE_CHECKING: + import numpy + import pyarrow + + +def _max_value_be_bytes(size: t.Literal[1, 2, 4, 8], count: int = 1) -> bytes: + def generator(count_: int) -> t.Iterable[int]: + pack_format = { + 1: ">b", + 2: ">h", + 4: ">i", + 8: ">q", + }[size] + if count_ <= 0: + return + yield from struct.pack(pack_format, 0) + count_ -= 1 + i = 0 + min_value = -(2 ** (size * 8 - 1)) + max_value = 2 ** (size * 8 - 1) - 1 + while True: + if count_ <= 0: + return + yield from struct.pack(pack_format, min_value + i) + count_ -= 1 + if count_ == 0: + return + yield from struct.pack(pack_format, max_value - i) + count_ -= 1 + i += 1 + i %= 2 ** (size * 8) + + return bytes(generator(count)) + + +def _random_value_be_bytes( + size: t.Literal[1, 2, 4, 8], count: int = 1 +) -> bytes: + def generator(count_: int) -> t.Iterable[int]: + pack_format = { + 1: ">B", + 2: ">H", + 4: ">I", + 8: ">Q", + }[size] + while count_ > 0: + yield from struct.pack( + pack_format, random.randint(0, 2 ** (size * 8) - 1) + ) + count_ -= 1 + + return bytes(generator(count)) + + +def _get_type_size(dtype: str) -> t.Literal[1, 2, 4, 8]: + lookup: dict[str, t.Literal[1, 2, 4, 8]] = { + "i8": 1, + "i16": 2, + "i32": 4, + "i64": 8, + "f32": 4, + "f64": 8, + } + return lookup[dtype] + + +def _normalize_float_bytes(dtype: str, data: bytes) -> bytes: + if dtype not in {"f32", "f64"}: + raise ValueError(f"Invalid dtype {dtype}") + type_size = _get_type_size(dtype) + pack_format = _dtype_to_pack_format(dtype) + chunks = (data[i : i + type_size] for i in range(0, len(data), type_size)) + return bytes( + b + for chunk in chunks + for b in struct.pack(pack_format, struct.unpack(pack_format, chunk)[0]) + ) + + +def _dtype_to_pack_format(dtype: str) -> str: + return { + "i8": ">b", + "i16": ">h", + "i32": ">i", + "i64": ">q", + "f32": ">f", + "f64": ">d", + }[dtype] + + +def _mock_mask_extensions(mocker, used_ext): + from neo4j.vector import ( + _swap_endian_unchecked_np, + _swap_endian_unchecked_py, + _swap_endian_unchecked_rust, + _VecF32, + _VecF64, + _VecI8, + _VecI16, + _VecI32, + _VecI64, + ) + + vec_types = (_VecF64, _VecF32, _VecI64, _VecI32, _VecI16, _VecI8) + match used_ext: + case "numpy": + if _swap_endian_unchecked_np is None: + pytest.skip("numpy not installed") + mocker.patch( + "neo4j.vector._swap_endian_unchecked", + new=_swap_endian_unchecked_np, + ) + for vec_type in vec_types: + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.from_native", + new=vec_type._from_native_np, + ) + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.to_native", + new=vec_type._to_native_np, + ) + case "rust": + if _swap_endian_unchecked_rust is None: + pytest.skip("rust extensions are not installed") + mocker.patch( + "neo4j.vector._swap_endian_unchecked", + new=_swap_endian_unchecked_rust, + ) + for vec_type in vec_types: + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.from_native", + new=vec_type._from_native_rust, + ) + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.to_native", + new=vec_type._to_native_rust, + ) + case "python": + mocker.patch( + "neo4j.vector._swap_endian_unchecked", + new=_swap_endian_unchecked_py, + ) + for vec_type in vec_types: + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.from_native", + new=vec_type._from_native_py, + ) + mocker.patch( + f"neo4j.vector.{vec_type.__name__}.to_native", + new=vec_type._to_native_py, + ) + case _: + raise ValueError(f"Invalid ext value {used_ext}") + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def _test_bench_swap_endian(mocker, ext): + data = bytes(i % 256 for i in range(10_000)) + _mock_mask_extensions(mocker, ext) + print(timeit.timeit(lambda: _swap_endian(2, data), number=1_000)) # noqa: T201 + print(timeit.timeit(lambda: _swap_endian(4, data), number=1_000)) # noqa: T201 + print(timeit.timeit(lambda: _swap_endian(8, data), number=1_000)) # noqa: T201 + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +def _test_bench_from_native(mocker, ext, dtype): + print(f"Testing {ext} for {dtype}") # noqa: T201 + data_raw = bytes(i % 256 for i in range(8 * 1_000)) + data = Vector.from_bytes(data_raw, dtype).to_native() + _mock_mask_extensions(mocker, ext) + + def work(data, dtype): + Vector.from_native(data, dtype) + + print(timeit.timeit(lambda: work(iter(data), dtype), number=1_000)) # noqa: T201 + print(timeit.timeit(lambda: work(data, dtype), number=1_000)) # noqa: T201 + + print() # noqa: T201 + data_raw = bytes(i % 256 for i in range(8 * 1)) + data = Vector.from_bytes(data_raw, dtype).to_native() + print(timeit.timeit(lambda: work(iter(data), dtype), number=100_000)) # noqa: T201 + print(timeit.timeit(lambda: work(data, dtype), number=100_000)) # noqa: T201 + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +def _test_bench_to_native(mocker, ext, dtype): + print(f"Testing {ext} for {dtype}") # noqa: T201 + data = Vector.from_bytes(bytes(i % 256 for i in range(8 * 1_000)), dtype) + _mock_mask_extensions(mocker, ext) + + print(timeit.timeit(data.to_native, number=1_000)) # noqa: T201 + print(timeit.timeit(data.to_native, number=1_000)) # noqa: T201 + + print() # noqa: T201 + data = Vector.from_bytes(bytes(i % 256 for i in range(8 * 1)), dtype) + print(timeit.timeit(data.to_native, number=100_000)) # noqa: T201 + print(timeit.timeit(data.to_native, number=100_000)) # noqa: T201 + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_swap_endian(mocker, ext): + data = bytes(range(1, 17)) + _mock_mask_extensions(mocker, ext) + res = _swap_endian(2, data) + assert isinstance(res, bytes) + assert res == bytes( + (2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15) + ) + res = _swap_endian(4, data) + assert isinstance(res, bytes) + assert res == bytes( + (4, 3, 2, 1, 8, 7, 6, 5, 12, 11, 10, 9, 16, 15, 14, 13) + ) + res = _swap_endian(8, data) + assert isinstance(res, bytes) + assert res == bytes( + (8, 7, 6, 5, 4, 3, 2, 1, 16, 15, 14, 13, 12, 11, 10, 9) + ) + + +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +@pytest.mark.parametrize("type_size", (-1, 0, 3, 5, 7, 9, 16, 32)) +def test_swap_endian_unhandled_size(mocker, ext, type_size): + data = bytes(i % 256 for i in range(1, abs(type_size) * 4)) + _mock_mask_extensions(mocker, ext) + + with pytest.raises(ValueError, match=str(type_size)): + _swap_endian(type_size, data) + + +@pytest.mark.parametrize( + ("dtype", "data"), + ( + ("i8", b""), + ("i8", b"\x01"), + ("i8", b"\x01\x02\x03\x04"), + ("i8", _max_value_be_bytes(1, 4096)), + ("i16", b""), + ("i16", b"\x00\x01"), + ("i16", b"\x00\x01\x00\x02"), + ("i16", _max_value_be_bytes(2, 4096)), + ("i32", b""), + ("i32", b"\x00\x00\x00\x01"), + ("i32", b"\x00\x00\x00\x01\x00\x00\x00\x02"), + ("i32", _max_value_be_bytes(4, 4096)), + ("i64", b""), + ("i64", b"\x00\x00\x00\x00\x00\x00\x00\x01"), + ( + "i64", + ( + b"\x00\x00\x00\x00\x00\x00\x00\x01" + b"\x00\x00\x00\x00\x00\x00\x00\x02" + ), + ), + ("i64", _max_value_be_bytes(8, 4096)), + ("f32", b""), + ("f32", _random_value_be_bytes(4, 4096)), + ("f64", b""), + ("f64", _random_value_be_bytes(8, 4096)), + ), +) +@pytest.mark.parametrize("input_endian", (None, "big", "little")) +@pytest.mark.parametrize("as_bytearray", (False, True)) +def test_raw_data( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + data: bytes, + input_endian: t.Literal["big", "little"] | None, + as_bytearray: bool, +) -> None: + swapped_data = _swap_endian(_get_type_size(dtype), data) + if input_endian is None: + input_data = bytearray(data) if as_bytearray else data + v = Vector(input_data, dtype) + elif input_endian == "big": + input_data = bytearray(data) if as_bytearray else data + v = Vector(input_data, dtype, byteorder=input_endian) + elif input_endian == "little": + input_data = bytearray(swapped_data) if as_bytearray else swapped_data + v = Vector(input_data, dtype, byteorder=input_endian) + else: + raise ValueError(f"Invalid input_endian {input_endian}") + assert v.dtype == dtype + assert v.raw() == data + assert v.raw(byteorder="big") == data + assert v.raw(byteorder="little") == swapped_data + + +def nan_equals(a: list[object], b: list[object]) -> bool: + if len(a) != len(b): + return False + for i in range(len(a)): + ai = a[i] + bi = b[i] + if ai != bi and not ( + isinstance(ai, float) + and isinstance(bi, float) + and math.isnan(ai) + and math.isnan(bi) + ): + return False + return True + + +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_from_native_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + repeat: int, + size: int, + ext: str, + mocker: t.Any, +) -> None: + _mock_mask_extensions(mocker, ext) + type_size = _get_type_size(dtype) + for _ in range(repeat): + data = _random_value_be_bytes(type_size, size) + values = [ + struct.unpack( + _dtype_to_pack_format(dtype), data[i : i + type_size] + )[0] + for i in range(0, len(data), type_size) + ] + v = Vector.from_native(values, dtype) + expected_raw = data + if dtype.startswith("f"): + expected_raw = _normalize_float_bytes(dtype, data) + assert v.raw() == expected_raw + + +SPECIAL_VALUES = ( + # (dtype, value, packed_bytes_be) + # i8 + ("i8", -128, b"\x80"), + ("i8", 0, b"\x00"), + ("i8", 127, b"\x7f"), + # i16 + ("i16", -32768, b"\x80\x00"), + ("i16", 0, b"\x00\x00"), + ("i16", 32767, b"\x7f\xff"), + # i32 + ("i32", -2147483648, b"\x80\x00\x00\x00"), + ("i32", 0, b"\x00\x00\x00\x00"), + ("i32", 2147483647, b"\x7f\xff\xff\xff"), + # i64 + ("i64", -9223372036854775808, b"\x80\x00\x00\x00\x00\x00\x00\x00"), + ("i64", 0, b"\x00\x00\x00\x00\x00\x00\x00\x00"), + ("i64", 9223372036854775807, b"\x7f\xff\xff\xff\xff\xff\xff\xff"), + # f32 + # NaN + ("f32", float("nan"), b"\x7f\xc0\x00\x00"), + ("f32", float("-nan"), b"\xff\xc0\x00\x00"), + ( + "f32", + struct.unpack(">f", b"\x7f\xc0\x00\x11")[0], + b"\x7f\xc0\x00\x11", + ), + ( + "f32", + struct.unpack(">f", b"\x7f\x80\x00\x01")[0], + # Python < 3.14 does not properly preserver all NaN payload + # when calling struct.pack + _normalize_float_bytes("f32", b"\x7f\x80\x00\x01"), + ), + # ±inf + ("f32", float("inf"), b"\x7f\x80\x00\x00"), + ("f32", float("-inf"), b"\xff\x80\x00\x00"), + # ±0.0 + ("f32", 0.0, b"\x00\x00\x00\x00"), + ("f32", -0.0, b"\x80\x00\x00\x00"), + # smallest normal + ( + "f32", + struct.unpack(">f", b"\x00\x80\x00\x00")[0], + b"\x00\x80\x00\x00", + ), + ( + "f32", + struct.unpack(">f", b"\x80\x80\x00\x00")[0], + b"\x80\x80\x00\x00", + ), + # subnormal + ( + "f32", + struct.unpack(">f", b"\x00\x00\x00\x01")[0], + b"\x00\x00\x00\x01", + ), + ( + "f32", + struct.unpack(">f", b"\x80\x00\x00\x01")[0], + b"\x80\x00\x00\x01", + ), + # largest normal + ( + "f32", + struct.unpack(">f", b"\x7f\x7f\xff\xff")[0], + b"\x7f\x7f\xff\xff", + ), + ( + "f32", + struct.unpack(">f", b"\xff\x7f\xff\xff")[0], + b"\xff\x7f\xff\xff", + ), + # f64 + # NaN + ("f64", float("nan"), b"\x7f\xf8\x00\x00\x00\x00\x00\x00"), + ("f64", float("-nan"), b"\xff\xf8\x00\x00\x00\x00\x00\x00"), + ( + "f64", + struct.unpack(">d", b"\x7f\xf8\x00\x00\x00\x00\x00\x11")[0], + b"\x7f\xf8\x00\x00\x00\x00\x00\x11", + ), + ( + "f64", + struct.unpack(">d", b"\x7f\xf0\x00\x01\x00\x00\x00\x01")[0], + b"\x7f\xf0\x00\x01\x00\x00\x00\x01", + ), + # ±inf + ("f64", float("inf"), b"\x7f\xf0\x00\x00\x00\x00\x00\x00"), + ("f64", float("-inf"), b"\xff\xf0\x00\x00\x00\x00\x00\x00"), + # ±0.0 + ("f64", 0.0, b"\x00\x00\x00\x00\x00\x00\x00\x00"), + ("f64", -0.0, b"\x80\x00\x00\x00\x00\x00\x00\x00"), + # smallest normal + ( + "f64", + struct.unpack(">d", b"\x00\x10\x00\x00\x00\x00\x00\x00")[0], + b"\x00\x10\x00\x00\x00\x00\x00\x00", + ), + ( + "f64", + struct.unpack(">d", b"\x80\x10\x00\x00\x00\x00\x00\x00")[0], + b"\x80\x10\x00\x00\x00\x00\x00\x00", + ), + # subnormal + ( + "f64", + struct.unpack(">d", b"\x00\x00\x00\x00\x00\x00\x00\x01")[0], + b"\x00\x00\x00\x00\x00\x00\x00\x01", + ), + ( + "f64", + struct.unpack(">d", b"\x80\x00\x00\x00\x00\x00\x00\x01")[0], + b"\x80\x00\x00\x00\x00\x00\x00\x01", + ), + # largest normal + ( + "f64", + struct.unpack(">d", b"\x7f\xef\xff\xff\xff\xff\xff\xff")[0], + b"\x7f\xef\xff\xff\xff\xff\xff\xff", + ), + ( + "f64", + struct.unpack(">d", b"\xff\xef\xff\xff\xff\xff\xff\xff")[0], + b"\xff\xef\xff\xff\xff\xff\xff\xff", + ), +) + + +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_from_native_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + data_be: bytes, + ext: str, + mocker: t.Any, +) -> None: + _mock_mask_extensions(mocker, ext) + if dtype in {"f32", "f64"}: + assert isinstance(value, float) + dtype_f = t.cast(t.Literal["f32", "f64"], dtype) + v = Vector.from_native([value], dtype_f) + elif dtype in {"i8", "i16", "i32", "i64"}: + assert isinstance(value, int) + dtype_i = t.cast(t.Literal["i8", "i16", "i32", "i64"], dtype) + v = Vector.from_native([value], dtype_i) + else: + raise ValueError(f"Invalid dtype {dtype}") + assert v.raw() == data_be + + +@pytest.mark.parametrize( + ("dtype", "value"), + ( + ("i8", "1"), + ("i8", None), + ("i8", 1.0), + ("i16", "1"), + ("i16", None), + ("i16", 1.0), + ("i32", "1"), + ("i32", None), + ("i32", 1.0), + ("i64", "1"), + ("i64", None), + ("i64", 1.0), + ("f32", "1.0"), + ("f32", None), + ("f32", 1), + ("f64", "1.0"), + ("f64", None), + ("f64", 1), + ), +) +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_from_native_wrong_type( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + ext: str, + mocker: t.Any, +) -> None: + _mock_mask_extensions(mocker, ext) + with pytest.raises(TypeError) as exc: + Vector.from_native([value], dtype) # type: ignore + + assert dtype in str(exc.value) + assert str(type(value).__name__) in str(exc.value) + + +@pytest.mark.parametrize( + ("dtype", "value"), + ( + ("i8", -129), + ("i8", 128), + ("i16", -32769), + ("i16", 32768), + ("i32", -2147483649), + ("i32", 2147483648), + ("i64", -9223372036854775809), + ("i64", 9223372036854775808), + ), +) +@pytest.mark.parametrize("ext", ("numpy", "rust", "python")) +def test_from_native_overflow( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + ext: str, + mocker: t.Any, +) -> None: + _mock_mask_extensions(mocker, ext) + with pytest.raises(OverflowError) as exc: + Vector.from_native([value], dtype) # type: ignore + + assert dtype in str(exc.value) + + +def _vector_from_data( + data: bytes, + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, +) -> Vector: + match endian: + case None: + return Vector(data, dtype) + case "big": + return Vector(data, dtype, byteorder=endian) + case "little": + type_size = _get_type_size(dtype) + data_le = _swap_endian(type_size, data) + return Vector(data_le, dtype, byteorder=endian) + case _: + raise ValueError(f"Invalid endian {endian}") + + +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", None)) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_to_native_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + for _ in range(repeat): + data = _random_value_be_bytes(type_size, size) + expected = [ + struct.unpack( + _dtype_to_pack_format(dtype), data[i : i + type_size] + )[0] + for i in range(0, len(data), type_size) + ] + v = _vector_from_data(data, dtype, endian) + assert nan_equals(v.to_native(), expected) + + +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +def test_to_native_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + data_be: bytes, +) -> None: + type_size = _get_type_size(dtype) + pack_format = _dtype_to_pack_format(dtype) + expected = [ + struct.unpack(pack_format, data_be[i : i + type_size])[0] + for i in range(0, len(data_be), type_size) + ] + v = Vector(data_be, dtype) + assert nan_equals(v.to_native(), expected) + + +def _get_numpy_dtype(dtype: str) -> str: + return { + "i8": "i1", + "i16": "i2", + "i32": "i4", + "i64": "i8", + "f32": "f4", + "f64": "f8", + }[dtype] + + +def _get_numpy_array( + data_be: bytes, dtype: str, endian: t.Literal["big", "little", "native"] +) -> numpy.ndarray: + np_type = _get_numpy_dtype(dtype) + type_size = _get_type_size(dtype) + data_in = data_be + match endian: + case "big": + data_in = data_be + np_type = f">{np_type}" + case "little": + data_in = _swap_endian(type_size, data_be) + np_type = f"<{np_type}" + case "native": + if sys.byteorder == "little": + data_in = _swap_endian(type_size, data_be) + np_type = f"={np_type}" + return np.frombuffer(data_in, dtype=np_type) + + +@pytest.mark.skipif(np is None, reason="numpy not installed") +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", "native")) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_from_numpy_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little", "native"], + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + for _ in range(repeat): + data_be = _random_value_be_bytes(type_size, size) + array = _get_numpy_array(data_be, dtype, endian) + v = Vector.from_numpy(array) + assert v.dtype == dtype + assert v.raw() == data_be + assert nan_equals(array.tolist(), v.to_native()) + + +@pytest.mark.skipif(np is None, reason="numpy not installed") +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +@pytest.mark.parametrize("endian", ("big", "little", "native")) +def test_from_numpy_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little", "native"], + value: object, + data_be: bytes, +) -> None: + array = _get_numpy_array(data_be, dtype, endian) + v = Vector.from_numpy(array) + assert v.dtype == dtype + assert v.raw() == data_be + assert nan_equals(array.tolist(), v.to_native()) + + +@pytest.mark.skipif(np is None, reason="numpy not installed") +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", None)) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_to_numpy_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + np_type = _get_numpy_dtype(dtype) + for _ in range(repeat): + data = _random_value_be_bytes(type_size, size) + v = _vector_from_data(data, dtype, endian) + array = v.to_numpy() + assert array.dtype == np.dtype(f">{np_type}") + assert array.size == len(data) // type_size + assert array.tobytes() == data + assert nan_equals(array.tolist(), v.to_native()) + + +@pytest.mark.skipif(np is None, reason="numpy not installed") +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +@pytest.mark.parametrize("endian", ("big", "little", None)) +def test_to_numpy_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + value: object, + data_be: bytes, +) -> None: + np_type = _get_numpy_dtype(dtype) + v = _vector_from_data(data_be, dtype, endian) + array = v.to_numpy() + assert array.dtype == np.dtype(f">{np_type}") + assert array.size == 1 + assert array.tobytes() == data_be + assert nan_equals(array.tolist(), v.to_native()) + + +def _get_pyarrow_dtype(dtype: str) -> pyarrow.DataType: + return { + "i8": pa.int8(), + "i16": pa.int16(), + "i32": pa.int32(), + "i64": pa.int64(), + "f32": pa.float32(), + "f64": pa.float64(), + }[dtype] + + +def _get_pyarrow_array(data_be: bytes, dtype: str) -> pyarrow.Array: + type_size = _get_type_size(dtype) + length = len(data_be) // type_size + data_in = data_be + if sys.byteorder == "little": + data_in = _swap_endian(type_size, data_be) + pa_type = _get_pyarrow_dtype(dtype) + buffers = [None, pa.py_buffer(data_in)] + return pa.Array.from_buffers(pa_type, length, buffers, 0) + + +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", "native")) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_from_pyarrow_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little", "native"], + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + for _ in range(repeat): + data_be = _random_value_be_bytes(type_size, size) + array = _get_pyarrow_array(data_be, dtype) + v = Vector.from_pyarrow(array) + assert v.dtype == dtype + assert v.raw() == data_be + assert nan_equals(array.to_pylist(), v.to_native()) + + +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +def test_from_pyarrow_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + value: object, + data_be: bytes, +) -> None: + array = _get_pyarrow_array(data_be, dtype) + v = Vector.from_pyarrow(array) + assert v.dtype == dtype + assert v.raw() == data_be + assert nan_equals(array.to_pylist(), v.to_native()) + + +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +@pytest.mark.parametrize("dtype", ("i8", "i16", "i32", "i64", "f32", "f64")) +@pytest.mark.parametrize("endian", ("big", "little", None)) +@pytest.mark.parametrize(("repeat", "size"), ((10_000, 1), (1, 10_000))) +def test_to_pyarrow_random( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + repeat: int, + size: int, +) -> None: + type_size = _get_type_size(dtype) + pa_type = _get_pyarrow_dtype(dtype) + for _ in range(repeat): + data_be = _random_value_be_bytes(type_size, size) + data_ne = data_be + if sys.byteorder == "little": + data_ne = _swap_endian(type_size, data_be) + v = _vector_from_data(data_be, dtype, endian) + array = v.to_pyarrow() + assert array.type == pa_type + assert pa.compute.count(array, mode="only_null").as_py() == 0 + buffers = array.buffers() + assert len(buffers) == 2 + assert buffers[0] is None + assert buffers[1].to_pybytes() == data_ne + assert nan_equals(array.tolist(), v.to_native()) + + +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +@pytest.mark.parametrize(("dtype", "value", "data_be"), SPECIAL_VALUES) +@pytest.mark.parametrize("endian", ("big", "little", None)) +def test_to_pyarrow_special_values( + dtype: t.Literal["i8", "i16", "i32", "i64", "f32", "f64"], + endian: t.Literal["big", "little"] | None, + value: object, + data_be: bytes, +) -> None: + type_size = _get_type_size(dtype) + data_ne = data_be + if sys.byteorder == "little": + data_ne = _swap_endian(type_size, data_be) + pa_type = _get_pyarrow_dtype(dtype) + v = _vector_from_data(data_be, dtype, endian) + array = v.to_pyarrow() + assert array.type == pa_type + assert pa.compute.count(array, mode="only_null").as_py() == 0 + buffers = array.buffers() + assert len(buffers) == 2 + assert buffers[0] is None + assert buffers[1].to_pybytes() == data_ne + assert nan_equals(array.tolist(), v.to_native()) diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 36fcc208..893aed0c 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -895,6 +895,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 6), "t_first"), ((5, 7), "t_first"), ((5, 8), "t_first"), + ((6, 0), "t_first"), ), ) def test_summary_result_available_after( @@ -933,6 +934,7 @@ def test_summary_result_available_after( ((5, 6), "t_last"), ((5, 7), "t_last"), ((5, 8), "t_last"), + ((6, 0), "t_last"), ), ) def test_summary_result_consumed_after( diff --git a/tests/unit/sync/io/test__bolt_socket.py b/tests/unit/sync/io/test__bolt_socket.py index fb1a293d..127c222b 100644 --- a/tests/unit/sync/io/test__bolt_socket.py +++ b/tests/unit/sync/io/test__bolt_socket.py @@ -42,7 +42,6 @@ def _deque_popleft_n(d: deque[_T], n: int) -> list[_T]: DEADLINE = Deadline(float("inf")) -# [bolt-version-bump] search tag when changing bolt version support @mark_sync_test @pytest.mark.parametrize("log_level", (1, logging.DEBUG, logging.CRITICAL)) def test_handshake(bolt_socket_factory, caplog, log_level): @@ -68,7 +67,7 @@ def test_handshake_manifest_v1( caplog, log_level, ): - chosen_version = (5, 8) + chosen_version = (6, 0) expected_feature_bits = b"\x00" # varint(0) caplog.set_level(log_level) diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index 6f5f1a66..e87eebb9 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -49,6 +49,7 @@ def test_class_method_protocol_handlers(): (3, 0), (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), + (6, 0), } # fmt: on @@ -81,7 +82,8 @@ def test_class_method_protocol_handlers(): ((5, 7), 1), ((5, 8), 1), ((5, 9), 0), - ((6, 0), 0), + ((6, 0), 1), + ((6, 1), 0), ], ) def test_class_method_protocol_handlers_with_protocol_version( @@ -91,7 +93,6 @@ def test_class_method_protocol_handlers_with_protocol_version( assert (test_input in protocol_handlers) == expected -# [bolt-version-bump] search tag when changing bolt version support def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( @@ -153,6 +154,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 6), "neo4j._sync.io._bolt5.Bolt5x6"), ((5, 7), "neo4j._sync.io._bolt5.Bolt5x7"), ((5, 8), "neo4j._sync.io._bolt5.Bolt5x8"), + ((6, 0), "neo4j._sync.io._bolt6.Bolt6x0"), ), ) @mark_sync_test @@ -193,14 +195,15 @@ def test_version_negotiation( (4, 0), (4, 1), (5, 9), - (6, 0), + (6, 1), ), ) @mark_sync_test def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8', " + "'6.0')" ) address = ("localhost", 7687) diff --git a/tests/unit/sync/io/test_class_bolt6x0.py b/tests/unit/sync/io/test_class_bolt6x0.py new file mode 100644 index 00000000..eb8b3867 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt6x0.py @@ -0,0 +1,872 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.config import PoolConfig +from neo4j._sync.io._bolt6 import Bolt6x0 +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_sync_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt6x0( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt6x0( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt6x0( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt6x0( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, + socket, + PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.connection_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt6x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt6x0( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt6x0.UNPACKER_CLS) + connection = Bolt6x0( + address, + socket, + PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt6x0( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt6x0( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt6x0( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt6x0(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + _tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + connection = Bolt6x0(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_sync_test +def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + connection = Bolt6x0(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + connection.send_all() + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_sync_test +def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + connection = Bolt6x0(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + connection.send_all() + with pytest.raises(Neo4jError): + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt6x0.PACKER_CLS, + unpacker_cls=Bolt6x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt6x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is bool(ssr_hint)