From 54c32cec350662ef9c49093ae942859c55461955 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 9 Jan 2025 11:10:42 +0100 Subject: [PATCH 1/6] WiP: port unit tests to older bolt versions + unit tests for pool --- src/neo4j/_async/io/_bolt.py | 27 +- src/neo4j/_async/io/_bolt5.py | 47 ++ src/neo4j/_async/io/_pool.py | 84 +- src/neo4j/_sync/io/_bolt.py | 27 +- src/neo4j/_sync/io/_bolt5.py | 47 ++ src/neo4j/_sync/io/_pool.py | 84 +- tests/unit/async_/io/test_class_bolt.py | 12 +- tests/unit/async_/io/test_class_bolt5x8.py | 906 +++++++++++++++++++++ tests/unit/common/work/test_summary.py | 2 + tests/unit/sync/io/test_class_bolt.py | 12 +- tests/unit/sync/io/test_class_bolt5x8.py | 906 +++++++++++++++++++++ 11 files changed, 2116 insertions(+), 38 deletions(-) create mode 100644 tests/unit/async_/io/test_class_bolt5x8.py create mode 100644 tests/unit/sync/io/test_class_bolt5x8.py diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index b3e485d42..9acd37758 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -135,6 +135,8 @@ class AsyncBolt: # results for it. most_recent_qid = None + _address_callback = None + def __init__( self, unresolved_address, @@ -148,8 +150,9 @@ def __init__( notifications_min_severity=None, notifications_disabled_classifications=None, telemetry_disabled=False, + address_callback=None, ): - self.unresolved_address = unresolved_address + self._unresolved_address = unresolved_address self.socket = sock self.local_port = self.socket.getsockname()[1] self.server_info = ServerInfo( @@ -190,6 +193,7 @@ def __init__( self.auth_dict = self._to_auth_dict(auth) self.auth_manager = auth_manager self.telemetry_disabled = telemetry_disabled + self._address_callback = address_callback self.notifications_min_severity = notifications_min_severity self.notifications_disabled_classifications = ( @@ -200,6 +204,15 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @property + def unresolved_address(self): + return self._unresolved_address + + @unresolved_address.setter + def unresolved_address(self, value): + self._unresolved_address = value + self.server_info._address = value + @abc.abstractmethod def _get_server_state_manager(self) -> ServerStateManagerBase: ... @@ -308,6 +321,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x5, AsyncBolt5x6, AsyncBolt5x7, + AsyncBolt5x8, ) handlers = { @@ -325,6 +339,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, AsyncBolt5x7.PROTOCOL_VERSION: AsyncBolt5x7, + AsyncBolt5x8.PROTOCOL_VERSION: AsyncBolt5x8, } if protocol_version is None: @@ -424,6 +439,7 @@ async def open( deadline=None, routing_context=None, pool_config=None, + address_callback=None, ): """ Open a new Bolt connection to a given server address. @@ -433,6 +449,7 @@ async def open( :param deadline: how long to wait for the connection to be established :param routing_context: dict containing routing context :param pool_config: + :param address_callback: :returns: connected AsyncBolt instance @@ -461,7 +478,10 @@ async def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 7): + if protocol_version == (5, 8): + from ._bolt5 import AsyncBolt5x8 + bolt_cls = AsyncBolt5x8 + elif protocol_version == (5, 7): from ._bolt5 import AsyncBolt5x7 bolt_cls = AsyncBolt5x7 elif protocol_version == (5, 6): @@ -542,7 +562,7 @@ async def open( raise connection = bolt_cls( - address, + address._unresolved, s, pool_config.max_connection_lifetime, auth=auth, @@ -552,6 +572,7 @@ async def open( notifications_min_severity=pool_config.notifications_min_severity, notifications_disabled_classifications=pool_config.notifications_disabled_classifications, telemetry_disabled=pool_config.telemetry_disabled, + address_callback=address_callback, ) try: diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 06336193f..1cfbbdf9a 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -24,6 +24,7 @@ from ..._codec.hydration import v2 as hydration_v2 from ..._exceptions import BoltProtocolError from ..._meta import BOLT_AGENT_DICT +from ...addressing import Address from ...api import ( READ_ACCESS, Version, @@ -1225,3 +1226,49 @@ async def _process_message(self, tag, fields): ) return len(details), 1 + + +class AsyncBolt5x8(AsyncBolt5x7): + PROTOCOL_VERSION = Version(5, 8) + + def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + logged_auth_dict = dict(self.auth_dict) + if "credentials" in logged_auth_dict: + logged_auth_dict["credentials"] = "*******" + log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) + self._append( + b"\x6a", + (self.auth_dict,), + response=LogonResponse( + self, "logon", hydration_hooks, on_success=self._logon_success + ), + dehydration_hooks=dehydration_hooks, + ) + + async def _logon_success(self, meta: object) -> None: + if not isinstance(meta, dict): + log.warning( + "[#%04X] _: " + "LOGON expected dictionary metadata, got %r", + self.local_port, + meta, + ) + return + address = meta.get("advertised_address", ...) + if address is ...: + return + if not isinstance(address, str): + log.warning( + "[#%04X] _: " + "LOGON expected string advertised_address, got %r", + self.local_port, + address, + ) + return + address = Address.parse(address, default_port=7687) + if address != self.unresolved_address: + await AsyncUtil.callback(self._address_callback, self, address) + self.unresolved_address = address diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 7a520abe7..047732096 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -133,10 +133,12 @@ def _remove_connection(self, connection): # connection isn't in the pool anymore. with suppress(ValueError): self.connections.get(address, []).remove(connection) + self._log_pool_stats() async def _acquire_from_pool_checked( self, address, health_check, deadline ): + address = address._unresolved while not deadline.expired(): connection = await self._acquire_from_pool(address) if not connection: @@ -165,15 +167,17 @@ async def _acquire_from_pool_checked( return None def _acquire_new_later(self, address, auth, deadline): + unresolved_address = address._unresolved + async def connection_creator(): released_reservation = False try: try: connection = await self.opener( - address, auth or self.pool_config.auth, deadline + self, address, auth or self.pool_config.auth, deadline ) except ServiceUnavailable: - await self.deactivate(address) + await self.deactivate(unresolved_address) raise if auth: # It's unfortunate that we have to create a connection @@ -191,25 +195,31 @@ async def connection_creator(): connection.pool = self connection.in_use = True with self.lock: - self.connections_reservations[address] -= 1 + self.connections_reservations[unresolved_address] -= 1 released_reservation = True - self.connections[address].append(connection) + self.connections[connection.unresolved_address].append( + connection + ) + self._log_pool_stats() return connection finally: if not released_reservation: with self.lock: - self.connections_reservations[address] -= 1 + self.connections_reservations[unresolved_address] -= 1 + self._log_pool_stats() max_pool_size = self.pool_config.max_connection_pool_size infinite_pool_size = max_pool_size < 0 or max_pool_size == float("inf") with self.lock: - connections = self.connections[address] + connections = self.connections[unresolved_address] pool_size = ( - len(connections) + self.connections_reservations[address] + len(connections) + + self.connections_reservations[unresolved_address] ) if infinite_pool_size or pool_size < max_pool_size: # there's room for a new connection - self.connections_reservations[address] += 1 + self.connections_reservations[unresolved_address] += 1 + self._log_pool_stats() return connection_creator return None @@ -497,6 +507,7 @@ async def deactivate(self, address): connections.remove(conn) if not self.connections[address]: del self.connections[address] + self._log_pool_stats() await self._close_connections(closable_connections) @@ -540,10 +551,29 @@ async def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] + self._log_pool_stats() await self._close_connections(connections) except TypeError: pass + def _log_pool_stats(self): + if log.isEnabledFor(5): + with self.lock: + addresses = sorted( + set(self.connections.keys()) + | set(self.connections_reservations.keys()) + ) + stats = { + address: { + "connections": len(self.connections.get(address, ())), + "reservations": self.connections_reservations.get( + address, 0 + ), + } + for address in addresses + } + log.log(5, "[#0000] _: stats %r", stats) + class AsyncBoltPool(AsyncIOPool): is_direct_pool = True @@ -559,7 +589,7 @@ def open(cls, address, *, pool_config, workspace_config): :returns: BoltPool """ - async def opener(addr, auth_manager, deadline): + async def opener(pool_, addr, auth_manager, deadline): return await AsyncBolt.open( addr, auth_manager=auth_manager, @@ -629,13 +659,14 @@ def open( ) routing_context["address"] = str(address) - async def opener(addr, auth_manager, deadline): + async def opener(pool_, addr, auth_manager, deadline): return await AsyncBolt.open( addr, auth_manager=auth_manager, deadline=deadline, routing_context=routing_context, pool_config=pool_config, + address_callback=pool_._move_connection, ) pool = cls(opener, pool_config, workspace_config, address) @@ -855,6 +886,8 @@ async def _update_routing_table_from( ) if callable(database_callback): database_callback(new_database) + + await self.update_connection_pool(database=new_database) return True await self.deactivate(router) return False @@ -943,6 +976,9 @@ async def update_routing_table( raise ServiceUnavailable("Unable to retrieve routing information") async def update_connection_pool(self, *, database): + log.debug( + "[#0000] _: update connection pool, database=%r", database + ) async with self.refresh_lock: routing_tables = [await self.get_or_create_routing_table(database)] for db in self.routing_tables: @@ -952,6 +988,11 @@ async def update_connection_pool(self, *, database): servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address._unresolved not in servers: + log.debug( + "[#0000] _: deactivating address (not used in any " + "routing table): %r", + address, + ) await super().deactivate(address) async def ensure_routing_table_is_fresh( @@ -1013,7 +1054,6 @@ async def ensure_routing_table_is_fresh( acquisition_timeout=acquisition_timeout, database_callback=database_callback, ) - await self.update_connection_pool(database=database) return True @@ -1149,3 +1189,25 @@ async def on_write_failure(self, address, database): if table is not None: table.writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) + + async def _move_connection(self, connection, address): + log.debug( + "[#%04X] _: moving connection from %r to %r", + connection.local_port, + connection.unresolved_address, + address, + ) + with self.lock: + old_pool = self.connections[connection.unresolved_address] + new_pool = self.connections[address] + try: + old_pool.remove(connection) + except ValueError: + log.debug( + "[#%04X] _: abort move (connection not in pool)", + connection.local_port, + ) + return + new_pool.append(connection) + self._log_pool_stats() + self.cond.notify_all() diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 2ad1790a4..4146aa84c 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -135,6 +135,8 @@ class Bolt: # results for it. most_recent_qid = None + _address_callback = None + def __init__( self, unresolved_address, @@ -148,8 +150,9 @@ def __init__( notifications_min_severity=None, notifications_disabled_classifications=None, telemetry_disabled=False, + address_callback=None, ): - self.unresolved_address = unresolved_address + self._unresolved_address = unresolved_address self.socket = sock self.local_port = self.socket.getsockname()[1] self.server_info = ServerInfo( @@ -190,6 +193,7 @@ def __init__( self.auth_dict = self._to_auth_dict(auth) self.auth_manager = auth_manager self.telemetry_disabled = telemetry_disabled + self._address_callback = address_callback self.notifications_min_severity = notifications_min_severity self.notifications_disabled_classifications = ( @@ -200,6 +204,15 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @property + def unresolved_address(self): + return self._unresolved_address + + @unresolved_address.setter + def unresolved_address(self, value): + self._unresolved_address = value + self.server_info._address = value + @abc.abstractmethod def _get_server_state_manager(self) -> ServerStateManagerBase: ... @@ -308,6 +321,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x5, Bolt5x6, Bolt5x7, + Bolt5x8, ) handlers = { @@ -325,6 +339,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x5.PROTOCOL_VERSION: Bolt5x5, Bolt5x6.PROTOCOL_VERSION: Bolt5x6, Bolt5x7.PROTOCOL_VERSION: Bolt5x7, + Bolt5x8.PROTOCOL_VERSION: Bolt5x8, } if protocol_version is None: @@ -424,6 +439,7 @@ def open( deadline=None, routing_context=None, pool_config=None, + address_callback=None, ): """ Open a new Bolt connection to a given server address. @@ -433,6 +449,7 @@ def open( :param deadline: how long to wait for the connection to be established :param routing_context: dict containing routing context :param pool_config: + :param address_callback: :returns: connected Bolt instance @@ -461,7 +478,10 @@ def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 7): + if protocol_version == (5, 8): + from ._bolt5 import Bolt5x8 + bolt_cls = Bolt5x8 + elif protocol_version == (5, 7): from ._bolt5 import Bolt5x7 bolt_cls = Bolt5x7 elif protocol_version == (5, 6): @@ -542,7 +562,7 @@ def open( raise connection = bolt_cls( - address, + address._unresolved, s, pool_config.max_connection_lifetime, auth=auth, @@ -552,6 +572,7 @@ def open( notifications_min_severity=pool_config.notifications_min_severity, notifications_disabled_classifications=pool_config.notifications_disabled_classifications, telemetry_disabled=pool_config.telemetry_disabled, + address_callback=address_callback, ) try: diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 4138a9d5d..9d3296582 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -24,6 +24,7 @@ from ..._codec.hydration import v2 as hydration_v2 from ..._exceptions import BoltProtocolError from ..._meta import BOLT_AGENT_DICT +from ...addressing import Address from ...api import ( READ_ACCESS, Version, @@ -1225,3 +1226,49 @@ def _process_message(self, tag, fields): ) return len(details), 1 + + +class Bolt5x8(Bolt5x7): + PROTOCOL_VERSION = Version(5, 8) + + def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + logged_auth_dict = dict(self.auth_dict) + if "credentials" in logged_auth_dict: + logged_auth_dict["credentials"] = "*******" + log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) + self._append( + b"\x6a", + (self.auth_dict,), + response=LogonResponse( + self, "logon", hydration_hooks, on_success=self._logon_success + ), + dehydration_hooks=dehydration_hooks, + ) + + def _logon_success(self, meta: object) -> None: + if not isinstance(meta, dict): + log.warning( + "[#%04X] _: " + "LOGON expected dictionary metadata, got %r", + self.local_port, + meta, + ) + return + address = meta.get("advertised_address", ...) + if address is ...: + return + if not isinstance(address, str): + log.warning( + "[#%04X] _: " + "LOGON expected string advertised_address, got %r", + self.local_port, + address, + ) + return + address = Address.parse(address, default_port=7687) + if address != self.unresolved_address: + Util.callback(self._address_callback, self, address) + self.unresolved_address = address diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 1570e745c..f3a0a28b2 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -130,10 +130,12 @@ def _remove_connection(self, connection): # connection isn't in the pool anymore. with suppress(ValueError): self.connections.get(address, []).remove(connection) + self._log_pool_stats() def _acquire_from_pool_checked( self, address, health_check, deadline ): + address = address._unresolved while not deadline.expired(): connection = self._acquire_from_pool(address) if not connection: @@ -162,15 +164,17 @@ def _acquire_from_pool_checked( return None def _acquire_new_later(self, address, auth, deadline): + unresolved_address = address._unresolved + def connection_creator(): released_reservation = False try: try: connection = self.opener( - address, auth or self.pool_config.auth, deadline + self, address, auth or self.pool_config.auth, deadline ) except ServiceUnavailable: - self.deactivate(address) + self.deactivate(unresolved_address) raise if auth: # It's unfortunate that we have to create a connection @@ -188,25 +192,31 @@ def connection_creator(): connection.pool = self connection.in_use = True with self.lock: - self.connections_reservations[address] -= 1 + self.connections_reservations[unresolved_address] -= 1 released_reservation = True - self.connections[address].append(connection) + self.connections[connection.unresolved_address].append( + connection + ) + self._log_pool_stats() return connection finally: if not released_reservation: with self.lock: - self.connections_reservations[address] -= 1 + self.connections_reservations[unresolved_address] -= 1 + self._log_pool_stats() max_pool_size = self.pool_config.max_connection_pool_size infinite_pool_size = max_pool_size < 0 or max_pool_size == float("inf") with self.lock: - connections = self.connections[address] + connections = self.connections[unresolved_address] pool_size = ( - len(connections) + self.connections_reservations[address] + len(connections) + + self.connections_reservations[unresolved_address] ) if infinite_pool_size or pool_size < max_pool_size: # there's room for a new connection - self.connections_reservations[address] += 1 + self.connections_reservations[unresolved_address] += 1 + self._log_pool_stats() return connection_creator return None @@ -494,6 +504,7 @@ def deactivate(self, address): connections.remove(conn) if not self.connections[address]: del self.connections[address] + self._log_pool_stats() self._close_connections(closable_connections) @@ -537,10 +548,29 @@ def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] + self._log_pool_stats() self._close_connections(connections) except TypeError: pass + def _log_pool_stats(self): + if log.isEnabledFor(5): + with self.lock: + addresses = sorted( + set(self.connections.keys()) + | set(self.connections_reservations.keys()) + ) + stats = { + address: { + "connections": len(self.connections.get(address, ())), + "reservations": self.connections_reservations.get( + address, 0 + ), + } + for address in addresses + } + log.log(5, "[#0000] _: stats %r", stats) + class BoltPool(IOPool): is_direct_pool = True @@ -556,7 +586,7 @@ def open(cls, address, *, pool_config, workspace_config): :returns: BoltPool """ - def opener(addr, auth_manager, deadline): + def opener(pool_, addr, auth_manager, deadline): return Bolt.open( addr, auth_manager=auth_manager, @@ -626,13 +656,14 @@ def open( ) routing_context["address"] = str(address) - def opener(addr, auth_manager, deadline): + def opener(pool_, addr, auth_manager, deadline): return Bolt.open( addr, auth_manager=auth_manager, deadline=deadline, routing_context=routing_context, pool_config=pool_config, + address_callback=pool_._move_connection, ) pool = cls(opener, pool_config, workspace_config, address) @@ -852,6 +883,8 @@ def _update_routing_table_from( ) if callable(database_callback): database_callback(new_database) + + self.update_connection_pool(database=new_database) return True self.deactivate(router) return False @@ -940,6 +973,9 @@ def update_routing_table( raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self, *, database): + log.debug( + "[#0000] _: update connection pool, database=%r", database + ) with self.refresh_lock: routing_tables = [self.get_or_create_routing_table(database)] for db in self.routing_tables: @@ -949,6 +985,11 @@ def update_connection_pool(self, *, database): servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address._unresolved not in servers: + log.debug( + "[#0000] _: deactivating address (not used in any " + "routing table): %r", + address, + ) super().deactivate(address) def ensure_routing_table_is_fresh( @@ -1010,7 +1051,6 @@ def ensure_routing_table_is_fresh( acquisition_timeout=acquisition_timeout, database_callback=database_callback, ) - self.update_connection_pool(database=database) return True @@ -1146,3 +1186,25 @@ def on_write_failure(self, address, database): if table is not None: table.writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) + + def _move_connection(self, connection, address): + log.debug( + "[#%04X] _: moving connection from %r to %r", + connection.local_port, + connection.unresolved_address, + address, + ) + with self.lock: + old_pool = self.connections[connection.unresolved_address] + new_pool = self.connections[address] + try: + old_pool.remove(connection) + except ValueError: + log.debug( + "[#%04X] _: abort move (connection not in pool)", + connection.local_port, + ) + return + new_pool.append(connection) + self._log_pool_stats() + self.cond.notify_all() diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index b0ddbc968..5fbb49abc 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -69,7 +69,8 @@ def test_class_method_protocol_handlers(): ((5, 5), 1), ((5, 6), 1), ((5, 7), 1), - ((5, 8), 0), + ((5, 8), 1), + ((5, 9), 0), ((6, 0), 0), ], ) @@ -92,7 +93,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -143,6 +144,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 5), "neo4j._async.io._bolt5.AsyncBolt5x5"), ((5, 6), "neo4j._async.io._bolt5.AsyncBolt5x6"), ((5, 7), "neo4j._async.io._bolt5.AsyncBolt5x7"), + ((5, 8), "neo4j._async.io._bolt5.AsyncBolt5x8"), ), ) @mark_async_test @@ -181,7 +183,7 @@ async def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 8), + (5, 9), (6, 0), ), ) @@ -189,7 +191,7 @@ async def test_version_negotiation( async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) address = ("localhost", 7687) diff --git a/tests/unit/async_/io/test_class_bolt5x8.py b/tests/unit/async_/io/test_class_bolt5x8.py new file mode 100644 index 000000000..a07e58c18 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x8.py @@ -0,0 +1,906 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io._bolt5 import AsyncBolt5x8 +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_async_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + _tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_async_test +async def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + await sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + await connection.send_all() + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_async_test +async def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + await sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + await connection.send_all() + with pytest.raises(Neo4jError): + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata + + +@pytest.mark.parametrize( + ("advertised_address", "expected_call"), + ( + (..., None), + (None, Warning), + (1.2, Warning), + ("example.com", neo4j.Address(("example.com", 7687))), + ("example.com:1234", neo4j.Address(("example.com", 1234))), + ), +) +@mark_async_test +async def test_address_callback( + advertised_address, expected_call, fake_socket_pair, caplog +): + cb_calls = [] + + async def cb(connection_, address_): + assert connection_ is connection + assert connection.unresolved_address == address + cb_calls.append(address_) + + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + success_meta = {} + if advertised_address is not ...: + success_meta["advertised_address"] = advertised_address + await sockets.server.send_message(b"\x70", success_meta) + + connection = AsyncBolt5x8(address, sockets.client, 0, address_callback=cb) + + connection.logon() + await connection.send_all() + + if type(expected_call) is type and issubclass(expected_call, Warning): + with caplog.at_level(logging.WARNING): + await connection.fetch_all() + warning_logs = [rec.message for rec in caplog.records] + assert len(warning_logs) == 1 + assert "NON-FATAL PROTOCOL VIOLATION" in warning_logs[0] + assert not cb_calls + return + + await connection.fetch_all() + + if expected_call is None: + assert not cb_calls + return + + assert cb_calls == [expected_call] + assert connection.unresolved_address == expected_call diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 74f2059a3..46c85c539 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -890,6 +890,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 5), "t_first"), ((5, 6), "t_first"), ((5, 7), "t_first"), + ((5, 8), "t_first"), ), ) def test_summary_result_available_after( @@ -927,6 +928,7 @@ def test_summary_result_available_after( ((5, 5), "t_last"), ((5, 6), "t_last"), ((5, 7), "t_last"), + ((5, 8), "t_last"), ), ) def test_summary_result_consumed_after( diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index f3b063037..7c2fc7f0b 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -69,7 +69,8 @@ def test_class_method_protocol_handlers(): ((5, 5), 1), ((5, 6), 1), ((5, 7), 1), - ((5, 8), 0), + ((5, 8), 1), + ((5, 9), 0), ((6, 0), 0), ], ) @@ -92,7 +93,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -143,6 +144,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 5), "neo4j._sync.io._bolt5.Bolt5x5"), ((5, 6), "neo4j._sync.io._bolt5.Bolt5x6"), ((5, 7), "neo4j._sync.io._bolt5.Bolt5x7"), + ((5, 8), "neo4j._sync.io._bolt5.Bolt5x8"), ), ) @mark_sync_test @@ -181,7 +183,7 @@ def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 8), + (5, 9), (6, 0), ), ) @@ -189,7 +191,7 @@ def test_version_negotiation( def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) address = ("localhost", 7687) diff --git a/tests/unit/sync/io/test_class_bolt5x8.py b/tests/unit/sync/io/test_class_bolt5x8.py new file mode 100644 index 000000000..76cdd7a49 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x8.py @@ -0,0 +1,906 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.config import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x8 +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_sync_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, + socket, + PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, + socket, + PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + _tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_sync_test +def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + connection.send_all() + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_sync_test +def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + connection.send_all() + with pytest.raises(Neo4jError): + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata + + +@pytest.mark.parametrize( + ("advertised_address", "expected_call"), + ( + (..., None), + (None, Warning), + (1.2, Warning), + ("example.com", neo4j.Address(("example.com", 7687))), + ("example.com:1234", neo4j.Address(("example.com", 1234))), + ), +) +@mark_sync_test +def test_address_callback( + advertised_address, expected_call, fake_socket_pair, caplog +): + cb_calls = [] + + def cb(connection_, address_): + assert connection_ is connection + assert connection.unresolved_address == address + cb_calls.append(address_) + + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + success_meta = {} + if advertised_address is not ...: + success_meta["advertised_address"] = advertised_address + sockets.server.send_message(b"\x70", success_meta) + + connection = Bolt5x8(address, sockets.client, 0, address_callback=cb) + + connection.logon() + connection.send_all() + + if type(expected_call) is type and issubclass(expected_call, Warning): + with caplog.at_level(logging.WARNING): + connection.fetch_all() + warning_logs = [rec.message for rec in caplog.records] + assert len(warning_logs) == 1 + assert "NON-FATAL PROTOCOL VIOLATION" in warning_logs[0] + assert not cb_calls + return + + connection.fetch_all() + + if expected_call is None: + assert not cb_calls + return + + assert cb_calls == [expected_call] + assert connection.unresolved_address == expected_call From 81e5edbaa7c69f668d23991e2776f80ea8e6d72f Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 14 Jan 2025 11:00:20 +0100 Subject: [PATCH 2/6] TestKit: introduce feature flag for DNS resolver hook --- testkitbackend/test_config.json | 1 + 1 file changed, 1 insertion(+) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index bca7f0caa..3e0ef6381 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -78,6 +78,7 @@ "ConfHint:connection.recv_timeout_seconds": true, + "Backend:DNSResolver": true, "Backend:MockTime": true, "Backend:RTFetch": true, "Backend:RTForceUpdate": true From 268f38551d7da937aa63e73260a4fb4bcbb97978 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 14 Jan 2025 12:03:23 +0100 Subject: [PATCH 3/6] TestKit backend: enable bolt 5.8 feature flag --- testkitbackend/test_config.json | 1 + 1 file changed, 1 insertion(+) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 3e0ef6381..d236ff11e 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -59,6 +59,7 @@ "Feature:Bolt:5.5": true, "Feature:Bolt:5.6": true, "Feature:Bolt:5.7": true, + "Feature:Bolt:5.8": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", From 8db92d4a1aea7825e09d2beb93c622f02e066aba Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 14 Jan 2025 15:59:52 +0100 Subject: [PATCH 4/6] Fix direct driver should ignore advertised address --- src/neo4j/_async/io/_bolt.py | 43 ++++++------- src/neo4j/_async/io/_bolt3.py | 6 +- src/neo4j/_async/io/_bolt4.py | 6 +- src/neo4j/_async/io/_bolt5.py | 18 +++--- src/neo4j/_async/io/_pool.py | 64 +++++++++++-------- src/neo4j/_async/work/result.py | 2 +- src/neo4j/_sync/io/_bolt.py | 43 ++++++------- src/neo4j/_sync/io/_bolt3.py | 6 +- src/neo4j/_sync/io/_bolt4.py | 6 +- src/neo4j/_sync/io/_bolt5.py | 18 +++--- src/neo4j/_sync/io/_pool.py | 64 +++++++++++-------- src/neo4j/_sync/work/result.py | 2 +- tests/unit/async_/fixtures/fake_connection.py | 3 +- tests/unit/async_/io/test_class_bolt5x8.py | 11 ++-- tests/unit/async_/io/test_neo4j_pool.py | 60 ++++++++--------- tests/unit/async_/test_conf.py | 3 - tests/unit/async_/work/test_result.py | 2 +- tests/unit/common/test_conf.py | 3 - tests/unit/sync/fixtures/fake_connection.py | 3 +- tests/unit/sync/io/test_class_bolt5x8.py | 11 ++-- tests/unit/sync/io/test_neo4j_pool.py | 60 ++++++++--------- tests/unit/sync/test_conf.py | 3 - tests/unit/sync/work/test_result.py | 2 +- 23 files changed, 222 insertions(+), 217 deletions(-) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 9acd37758..d5a9f0114 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -61,6 +61,7 @@ if t.TYPE_CHECKING: from ..._api import TelemetryAPI + from ...addressing import Address # Set up logger @@ -135,11 +136,12 @@ class AsyncBolt: # results for it. most_recent_qid = None - _address_callback = None + address_callback = None + advertised_address: Address | None = None def __init__( self, - unresolved_address, + address, sock, max_connection_lifetime, *, @@ -150,14 +152,13 @@ def __init__( notifications_min_severity=None, notifications_disabled_classifications=None, telemetry_disabled=False, - address_callback=None, ): - self._unresolved_address = unresolved_address + self._address = address self.socket = sock self.local_port = self.socket.getsockname()[1] self.server_info = ServerInfo( ResolvedAddress( - sock.getpeername(), host_name=unresolved_address.host + sock.getpeername(), host_name=address._unresolved.host ), self.PROTOCOL_VERSION, ) @@ -193,7 +194,6 @@ def __init__( self.auth_dict = self._to_auth_dict(auth) self.auth_manager = auth_manager self.telemetry_disabled = telemetry_disabled - self._address_callback = address_callback self.notifications_min_severity = notifications_min_severity self.notifications_disabled_classifications = ( @@ -205,13 +205,13 @@ def __del__(self): self.close() @property - def unresolved_address(self): - return self._unresolved_address + def address(self): + return self._address - @unresolved_address.setter - def unresolved_address(self, value): - self._unresolved_address = value - self.server_info._address = value + @address.setter + def address(self, value): + self._address = value + self.server_info._address = value._unresolved @abc.abstractmethod def _get_server_state_manager(self) -> ServerStateManagerBase: ... @@ -439,7 +439,6 @@ async def open( deadline=None, routing_context=None, pool_config=None, - address_callback=None, ): """ Open a new Bolt connection to a given server address. @@ -449,7 +448,6 @@ async def open( :param deadline: how long to wait for the connection to be established :param routing_context: dict containing routing context :param pool_config: - :param address_callback: :returns: connected AsyncBolt instance @@ -562,7 +560,7 @@ async def open( raise connection = bolt_cls( - address._unresolved, + address, s, pool_config.max_connection_lifetime, auth=auth, @@ -572,7 +570,6 @@ async def open( notifications_min_severity=pool_config.notifications_min_severity, notifications_disabled_classifications=pool_config.notifications_disabled_classifications, telemetry_disabled=pool_config.telemetry_disabled, - address_callback=address_callback, ) try: @@ -975,12 +972,12 @@ async def send_all(self): if self.closed(): raise ServiceUnavailable( "Failed to write to closed connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if self.defunct(): raise ServiceUnavailable( "Failed to write to defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) await self._send_all() @@ -998,12 +995,12 @@ async def fetch_message(self): if self._closed: raise ServiceUnavailable( "Failed to read from closed connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if self._defunct: raise ServiceUnavailable( "Failed to read from defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if not self.responses: return 0, 0 @@ -1035,14 +1032,14 @@ async def fetch_all(self): async def _set_defunct_read(self, error=None, silent=False): message = ( "Failed to read from defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) await self._set_defunct(message, error=error, silent=silent) async def _set_defunct_write(self, error=None, silent=False): message = ( "Failed to write data to connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) await self._set_defunct(message, error=error, silent=silent) @@ -1081,7 +1078,7 @@ async def _set_defunct(self, message, error=None, silent=False): # connection again. await self.close() if self.pool and not self._get_server_state_manager().failed(): - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 08e75abb2..556380746 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -579,12 +579,12 @@ async def _process_message(self, tag, fields): await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -595,7 +595,7 @@ async def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 202d55707..70066b3e6 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -494,12 +494,12 @@ async def _process_message(self, tag, fields): await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -511,7 +511,7 @@ async def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 1cfbbdf9a..b1a36be85 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -497,12 +497,12 @@ async def _process_message(self, tag, fields): await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -514,7 +514,7 @@ async def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 @@ -1205,12 +1205,12 @@ async def _process_message(self, tag, fields): await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -1222,7 +1222,7 @@ async def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 @@ -1268,7 +1268,5 @@ async def _logon_success(self, meta: object) -> None: address, ) return - address = Address.parse(address, default_port=7687) - if address != self.unresolved_address: - await AsyncUtil.callback(self._address_callback, self, address) - self.unresolved_address = address + self.advertised_address = Address.parse(address, default_port=7687) + await AsyncUtil.callback(self.address_callback, self) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 047732096..bd695e6dd 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -119,7 +119,7 @@ async def _acquire_from_pool(self, address): return None # no free connection available def _remove_connection(self, connection): - address = connection.unresolved_address + address = connection.address with self.lock: log.debug( "[#%04X] _: remove connection from pool %r %s", @@ -138,7 +138,6 @@ def _remove_connection(self, connection): async def _acquire_from_pool_checked( self, address, health_check, deadline ): - address = address._unresolved while not deadline.expired(): connection = await self._acquire_from_pool(address) if not connection: @@ -167,17 +166,15 @@ async def _acquire_from_pool_checked( return None def _acquire_new_later(self, address, auth, deadline): - unresolved_address = address._unresolved - async def connection_creator(): released_reservation = False try: try: connection = await self.opener( - self, address, auth or self.pool_config.auth, deadline + address, auth or self.pool_config.auth, deadline ) except ServiceUnavailable: - await self.deactivate(unresolved_address) + await self.deactivate(address) raise if auth: # It's unfortunate that we have to create a connection @@ -195,30 +192,27 @@ async def connection_creator(): connection.pool = self connection.in_use = True with self.lock: - self.connections_reservations[unresolved_address] -= 1 + self.connections_reservations[address] -= 1 released_reservation = True - self.connections[connection.unresolved_address].append( - connection - ) + self.connections[address].append(connection) self._log_pool_stats() return connection finally: if not released_reservation: with self.lock: - self.connections_reservations[unresolved_address] -= 1 + self.connections_reservations[address] -= 1 self._log_pool_stats() max_pool_size = self.pool_config.max_connection_pool_size infinite_pool_size = max_pool_size < 0 or max_pool_size == float("inf") with self.lock: - connections = self.connections[unresolved_address] + connections = self.connections[address] pool_size = ( - len(connections) - + self.connections_reservations[unresolved_address] + len(connections) + self.connections_reservations[address] ) if infinite_pool_size or pool_size < max_pool_size: # there's room for a new connection - self.connections_reservations[unresolved_address] += 1 + self.connections_reservations[address] += 1 self._log_pool_stats() return connection_creator return None @@ -357,7 +351,12 @@ async def health_check(connection_, deadline_): f"{deadline.original_timeout!r}s (timeout)" ) log.debug("[#0000] _: trying to hand out new connection") - return await connection_creator() + connection = await connection_creator() + await self._on_new_connection(connection) + return connection + + async def _on_new_connection(self, connection): + return @abc.abstractmethod async def acquire( @@ -519,7 +518,7 @@ async def on_write_failure(self, address, database): async def on_neo4j_error(self, error, connection): assert isinstance(error, Neo4jError) if error._unauthenticates_all_connections(): - address = connection.unresolved_address + address = connection.address log.debug( "[#0000] _: mark all connections to %r as " "unauthenticated", @@ -557,7 +556,8 @@ async def close(self): pass def _log_pool_stats(self): - if log.isEnabledFor(5): + level = logging.DEBUG + if log.isEnabledFor(level): with self.lock: addresses = sorted( set(self.connections.keys()) @@ -572,7 +572,7 @@ def _log_pool_stats(self): } for address in addresses } - log.log(5, "[#0000] _: stats %r", stats) + log.log(level, "[#0000] _: stats %r", stats) class AsyncBoltPool(AsyncIOPool): @@ -589,7 +589,7 @@ def open(cls, address, *, pool_config, workspace_config): :returns: BoltPool """ - async def opener(pool_, addr, auth_manager, deadline): + async def opener(addr, auth_manager, deadline): return await AsyncBolt.open( addr, auth_manager=auth_manager, @@ -659,14 +659,13 @@ def open( ) routing_context["address"] = str(address) - async def opener(pool_, addr, auth_manager, deadline): + async def opener(addr, auth_manager, deadline): return await AsyncBolt.open( addr, auth_manager=auth_manager, deadline=deadline, routing_context=routing_context, pool_config=pool_config, - address_callback=pool_._move_connection, ) pool = cls(opener, pool_config, workspace_config, address) @@ -1086,6 +1085,10 @@ async def _select_address(self, *, access_mode, database): ) return choice(addresses_by_usage[min(addresses_by_usage)]) + async def _on_new_connection(self, connection): + await self._move_connection(connection) + connection.address_callback = self._move_connection + async def acquire( self, access_mode, @@ -1190,16 +1193,22 @@ async def on_write_failure(self, address, database): table.writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) - async def _move_connection(self, connection, address): + async def _move_connection(self, connection): + to_addr = connection.advertised_address + if to_addr is None: + return + from_addr = connection.address + if from_addr == to_addr: + return log.debug( "[#%04X] _: moving connection from %r to %r", connection.local_port, - connection.unresolved_address, - address, + from_addr, + to_addr, ) with self.lock: - old_pool = self.connections[connection.unresolved_address] - new_pool = self.connections[address] + old_pool = self.connections[from_addr] + new_pool = self.connections[to_addr] try: old_pool.remove(connection) except ValueError: @@ -1209,5 +1218,6 @@ async def _move_connection(self, connection, address): ) return new_pool.append(connection) + connection.address = connection.advertised_address self._log_pool_stats() self.cond.notify_all() diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index ac62fa1f9..1a5b00fc3 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -123,7 +123,7 @@ def __init__( self._on_error = on_error self._on_closed = on_closed self._metadata: dict = {} - self._address: Address = self._connection.unresolved_address + self._address: Address = self._connection.address self._keys: tuple[str, ...] = () self._had_record = False self._record_buffer: deque[Record] = deque() diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 4146aa84c..49c28c8f7 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -61,6 +61,7 @@ if t.TYPE_CHECKING: from ..._api import TelemetryAPI + from ...addressing import Address # Set up logger @@ -135,11 +136,12 @@ class Bolt: # results for it. most_recent_qid = None - _address_callback = None + address_callback = None + advertised_address: Address | None = None def __init__( self, - unresolved_address, + address, sock, max_connection_lifetime, *, @@ -150,14 +152,13 @@ def __init__( notifications_min_severity=None, notifications_disabled_classifications=None, telemetry_disabled=False, - address_callback=None, ): - self._unresolved_address = unresolved_address + self._address = address self.socket = sock self.local_port = self.socket.getsockname()[1] self.server_info = ServerInfo( ResolvedAddress( - sock.getpeername(), host_name=unresolved_address.host + sock.getpeername(), host_name=address._unresolved.host ), self.PROTOCOL_VERSION, ) @@ -193,7 +194,6 @@ def __init__( self.auth_dict = self._to_auth_dict(auth) self.auth_manager = auth_manager self.telemetry_disabled = telemetry_disabled - self._address_callback = address_callback self.notifications_min_severity = notifications_min_severity self.notifications_disabled_classifications = ( @@ -205,13 +205,13 @@ def __del__(self): self.close() @property - def unresolved_address(self): - return self._unresolved_address + def address(self): + return self._address - @unresolved_address.setter - def unresolved_address(self, value): - self._unresolved_address = value - self.server_info._address = value + @address.setter + def address(self, value): + self._address = value + self.server_info._address = value._unresolved @abc.abstractmethod def _get_server_state_manager(self) -> ServerStateManagerBase: ... @@ -439,7 +439,6 @@ def open( deadline=None, routing_context=None, pool_config=None, - address_callback=None, ): """ Open a new Bolt connection to a given server address. @@ -449,7 +448,6 @@ def open( :param deadline: how long to wait for the connection to be established :param routing_context: dict containing routing context :param pool_config: - :param address_callback: :returns: connected Bolt instance @@ -562,7 +560,7 @@ def open( raise connection = bolt_cls( - address._unresolved, + address, s, pool_config.max_connection_lifetime, auth=auth, @@ -572,7 +570,6 @@ def open( notifications_min_severity=pool_config.notifications_min_severity, notifications_disabled_classifications=pool_config.notifications_disabled_classifications, telemetry_disabled=pool_config.telemetry_disabled, - address_callback=address_callback, ) try: @@ -975,12 +972,12 @@ def send_all(self): if self.closed(): raise ServiceUnavailable( "Failed to write to closed connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if self.defunct(): raise ServiceUnavailable( "Failed to write to defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) self._send_all() @@ -998,12 +995,12 @@ def fetch_message(self): if self._closed: raise ServiceUnavailable( "Failed to read from closed connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if self._defunct: raise ServiceUnavailable( "Failed to read from defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if not self.responses: return 0, 0 @@ -1035,14 +1032,14 @@ def fetch_all(self): def _set_defunct_read(self, error=None, silent=False): message = ( "Failed to read from defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) self._set_defunct(message, error=error, silent=silent) def _set_defunct_write(self, error=None, silent=False): message = ( "Failed to write data to connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) self._set_defunct(message, error=error, silent=silent) @@ -1081,7 +1078,7 @@ def _set_defunct(self, message, error=None, silent=False): # connection again. self.close() if self.pool and not self._get_server_state_manager().failed(): - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index e3cfd1429..9d67b6e2c 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -579,12 +579,12 @@ def _process_message(self, tag, fields): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -595,7 +595,7 @@ def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 69bb6dd6e..6dc9e7cb9 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -494,12 +494,12 @@ def _process_message(self, tag, fields): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -511,7 +511,7 @@ def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 9d3296582..d6a1a518f 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -497,12 +497,12 @@ def _process_message(self, tag, fields): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -514,7 +514,7 @@ def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 @@ -1205,12 +1205,12 @@ def _process_message(self, tag, fields): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -1222,7 +1222,7 @@ def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 @@ -1268,7 +1268,5 @@ def _logon_success(self, meta: object) -> None: address, ) return - address = Address.parse(address, default_port=7687) - if address != self.unresolved_address: - Util.callback(self._address_callback, self, address) - self.unresolved_address = address + self.advertised_address = Address.parse(address, default_port=7687) + Util.callback(self.address_callback, self) diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index f3a0a28b2..9e851f2f2 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -116,7 +116,7 @@ def _acquire_from_pool(self, address): return None # no free connection available def _remove_connection(self, connection): - address = connection.unresolved_address + address = connection.address with self.lock: log.debug( "[#%04X] _: remove connection from pool %r %s", @@ -135,7 +135,6 @@ def _remove_connection(self, connection): def _acquire_from_pool_checked( self, address, health_check, deadline ): - address = address._unresolved while not deadline.expired(): connection = self._acquire_from_pool(address) if not connection: @@ -164,17 +163,15 @@ def _acquire_from_pool_checked( return None def _acquire_new_later(self, address, auth, deadline): - unresolved_address = address._unresolved - def connection_creator(): released_reservation = False try: try: connection = self.opener( - self, address, auth or self.pool_config.auth, deadline + address, auth or self.pool_config.auth, deadline ) except ServiceUnavailable: - self.deactivate(unresolved_address) + self.deactivate(address) raise if auth: # It's unfortunate that we have to create a connection @@ -192,30 +189,27 @@ def connection_creator(): connection.pool = self connection.in_use = True with self.lock: - self.connections_reservations[unresolved_address] -= 1 + self.connections_reservations[address] -= 1 released_reservation = True - self.connections[connection.unresolved_address].append( - connection - ) + self.connections[address].append(connection) self._log_pool_stats() return connection finally: if not released_reservation: with self.lock: - self.connections_reservations[unresolved_address] -= 1 + self.connections_reservations[address] -= 1 self._log_pool_stats() max_pool_size = self.pool_config.max_connection_pool_size infinite_pool_size = max_pool_size < 0 or max_pool_size == float("inf") with self.lock: - connections = self.connections[unresolved_address] + connections = self.connections[address] pool_size = ( - len(connections) - + self.connections_reservations[unresolved_address] + len(connections) + self.connections_reservations[address] ) if infinite_pool_size or pool_size < max_pool_size: # there's room for a new connection - self.connections_reservations[unresolved_address] += 1 + self.connections_reservations[address] += 1 self._log_pool_stats() return connection_creator return None @@ -354,7 +348,12 @@ def health_check(connection_, deadline_): f"{deadline.original_timeout!r}s (timeout)" ) log.debug("[#0000] _: trying to hand out new connection") - return connection_creator() + connection = connection_creator() + self._on_new_connection(connection) + return connection + + def _on_new_connection(self, connection): + return @abc.abstractmethod def acquire( @@ -516,7 +515,7 @@ def on_write_failure(self, address, database): def on_neo4j_error(self, error, connection): assert isinstance(error, Neo4jError) if error._unauthenticates_all_connections(): - address = connection.unresolved_address + address = connection.address log.debug( "[#0000] _: mark all connections to %r as " "unauthenticated", @@ -554,7 +553,8 @@ def close(self): pass def _log_pool_stats(self): - if log.isEnabledFor(5): + level = logging.DEBUG + if log.isEnabledFor(level): with self.lock: addresses = sorted( set(self.connections.keys()) @@ -569,7 +569,7 @@ def _log_pool_stats(self): } for address in addresses } - log.log(5, "[#0000] _: stats %r", stats) + log.log(level, "[#0000] _: stats %r", stats) class BoltPool(IOPool): @@ -586,7 +586,7 @@ def open(cls, address, *, pool_config, workspace_config): :returns: BoltPool """ - def opener(pool_, addr, auth_manager, deadline): + def opener(addr, auth_manager, deadline): return Bolt.open( addr, auth_manager=auth_manager, @@ -656,14 +656,13 @@ def open( ) routing_context["address"] = str(address) - def opener(pool_, addr, auth_manager, deadline): + def opener(addr, auth_manager, deadline): return Bolt.open( addr, auth_manager=auth_manager, deadline=deadline, routing_context=routing_context, pool_config=pool_config, - address_callback=pool_._move_connection, ) pool = cls(opener, pool_config, workspace_config, address) @@ -1083,6 +1082,10 @@ def _select_address(self, *, access_mode, database): ) return choice(addresses_by_usage[min(addresses_by_usage)]) + def _on_new_connection(self, connection): + self._move_connection(connection) + connection.address_callback = self._move_connection + def acquire( self, access_mode, @@ -1187,16 +1190,22 @@ def on_write_failure(self, address, database): table.writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) - def _move_connection(self, connection, address): + def _move_connection(self, connection): + to_addr = connection.advertised_address + if to_addr is None: + return + from_addr = connection.address + if from_addr == to_addr: + return log.debug( "[#%04X] _: moving connection from %r to %r", connection.local_port, - connection.unresolved_address, - address, + from_addr, + to_addr, ) with self.lock: - old_pool = self.connections[connection.unresolved_address] - new_pool = self.connections[address] + old_pool = self.connections[from_addr] + new_pool = self.connections[to_addr] try: old_pool.remove(connection) except ValueError: @@ -1206,5 +1215,6 @@ def _move_connection(self, connection, address): ) return new_pool.append(connection) + connection.address = connection.advertised_address self._log_pool_stats() self.cond.notify_all() diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index 27164cf84..cf4f1ce07 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -123,7 +123,7 @@ def __init__( self._on_error = on_error self._on_closed = on_closed self._metadata: dict = {} - self._address: Address = self._connection.unresolved_address + self._address: Address = self._connection.address self._keys: tuple[str, ...] = () self._had_record = False self._record_buffer: deque[Record] = deque() diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 9bf967791..65dba3b51 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -55,7 +55,8 @@ def __init__(self, *args, **kwargs): self.attach_mock( mock.AsyncMock(spec=AsyncAuthManager), "auth_manager" ) - self.unresolved_address = next(iter(args), "localhost") + self.address = next(iter(args), "localhost") + self.advertised_address = None self.callbacks = [] diff --git a/tests/unit/async_/io/test_class_bolt5x8.py b/tests/unit/async_/io/test_class_bolt5x8.py index a07e58c18..408e641aa 100644 --- a/tests/unit/async_/io/test_class_bolt5x8.py +++ b/tests/unit/async_/io/test_class_bolt5x8.py @@ -866,10 +866,10 @@ async def test_address_callback( ): cb_calls = [] - async def cb(connection_, address_): + async def cb(connection_): assert connection_ is connection - assert connection.unresolved_address == address - cb_calls.append(address_) + assert connection.address == address + cb_calls.append(connection_.advertised_address) address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair( @@ -882,7 +882,8 @@ async def cb(connection_, address_): success_meta["advertised_address"] = advertised_address await sockets.server.send_message(b"\x70", success_meta) - connection = AsyncBolt5x8(address, sockets.client, 0, address_callback=cb) + connection = AsyncBolt5x8(address, sockets.client, 0) + connection.address_callback = cb connection.logon() await connection.send_all() @@ -903,4 +904,4 @@ async def cb(connection_, address_): return assert cb_calls == [expected_call] - assert connection.unresolved_address == expected_call + assert connection.address == address diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index c0be16adf..d409a4958 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -84,7 +84,7 @@ def routing_side_effect(*args, **kwargs): async def open_(addr, auth, timeout): connection = async_fake_connection_generator() - connection.unresolved_address = addr + connection.address = addr connection.timeout = timeout connection.auth = auth route_mock = mocker.AsyncMock() @@ -188,9 +188,9 @@ async def test_chooses_right_connection_type(opener, type_): ) await pool.release(cx1) if type_ == "r": - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS else: - assert cx1.unresolved_address == WRITER1_ADDRESS + assert cx1.address == WRITER1_ADDRESS @mark_async_test @@ -206,7 +206,7 @@ async def test_reuses_connection(opener): @mark_async_test async def test_closes_stale_connections(opener, break_on_close): async def break_connection(): - await pool.deactivate(cx1.unresolved_address) + await pool.deactivate(cx1.address) if cx_close_mock_side_effect: res = cx_close_mock_side_effect() @@ -218,7 +218,7 @@ async def break_connection(): pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) - assert cx1 in pool.connections[cx1.unresolved_address] + assert cx1 in pool.connections[cx1.address] # simulate connection going stale (e.g. exceeding idle timeout) and then # breaking when the pool tries to close the connection cx1.stale.return_value = True @@ -233,16 +233,16 @@ async def break_connection(): else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.unresolved_address == cx1.unresolved_address - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx2.unresolved_address] + assert cx2.address == cx1.address + assert cx1 not in pool.connections[cx1.address] + assert cx2 in pool.connections[cx2.address] @mark_async_test async def test_does_not_close_stale_connections_in_use(opener): pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - assert cx1 in pool.connections[cx1.unresolved_address] + assert cx1 in pool.connections[cx1.address] # simulate connection going stale (e.g. exceeding idle timeout) while being # in use cx1.stale.return_value = True @@ -250,9 +250,9 @@ async def test_does_not_close_stale_connections_in_use(opener): await pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.unresolved_address == cx1.unresolved_address - assert cx1 in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx2.unresolved_address] + assert cx2.address == cx1.address + assert cx1 in pool.connections[cx1.address] + assert cx2 in pool.connections[cx2.address] await pool.release(cx1) # now that cx1 is back in the pool and still stale, @@ -263,9 +263,9 @@ async def test_does_not_close_stale_connections_in_use(opener): await pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.unresolved_address == cx1.unresolved_address - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx3 in pool.connections[cx2.unresolved_address] + assert cx3.address == cx1.address + assert cx1 not in pool.connections[cx1.address] + assert cx3 in pool.connections[cx2.address] @mark_async_test @@ -314,7 +314,7 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( cx1 = await pool._acquire( READER1_ADDRESS, None, Deadline(30), liveness_timeout ) - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.reset.assert_not_called() @@ -330,7 +330,7 @@ async def test_acquire_performs_liveness_check_on_existing_connection( ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -367,7 +367,7 @@ def liveness_side_effect(*args, **kwargs): ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -384,11 +384,11 @@ def liveness_side_effect(*args, **kwargs): READER1_ADDRESS, None, Deadline(30), liveness_timeout ) assert cx1 is not cx2 - assert cx1.unresolved_address == cx2.unresolved_address + assert cx1.address == cx2.address cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_not_called() - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx1.unresolved_address] + assert cx1 not in pool.connections[cx1.address] + assert cx2 in pool.connections[cx1.address] @pytest.mark.parametrize( @@ -412,8 +412,8 @@ def liveness_side_effect(*args, **kwargs): ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS - assert cx2.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS + assert cx2.address == READER1_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -439,8 +439,8 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_awaited_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) cx3.reset.assert_awaited_once() - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx3 in pool.connections[cx1.unresolved_address] + assert cx1 not in pool.connections[cx1.address] + assert cx3 in pool.connections[cx1.address] @mark_async_test @@ -701,7 +701,7 @@ def get_readers(database): opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS await pool.release(cx1) cx1.close.assert_not_called() @@ -712,7 +712,7 @@ def get_readers(database): readers["db1"] = [str(READER2_ADDRESS)] cx2 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) - assert cx2.unresolved_address == READER2_ADDRESS + assert cx2.address == READER2_ADDRESS cx1.close.assert_awaited_once() assert len(pool.connections[READER1_ADDRESS]) == 0 @@ -740,14 +740,14 @@ def get_readers(database): ) cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) await pool.release(cx1) - assert cx1.unresolved_address in {READER1_ADDRESS, READER2_ADDRESS} + assert cx1.address in {READER1_ADDRESS, READER2_ADDRESS} reader1_connection_count = len(pool.connections[READER1_ADDRESS]) reader2_connection_count = len(pool.connections[READER2_ADDRESS]) assert reader1_connection_count + reader2_connection_count == 1 cx2 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) await pool.release(cx2) - assert cx2.unresolved_address == READER1_ADDRESS + assert cx2.address == READER1_ADDRESS cx1.close.assert_not_called() cx2.close.assert_not_called() assert len(pool.connections[READER1_ADDRESS]) == 1 @@ -759,7 +759,7 @@ def get_readers(database): cx3 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) await pool.release(cx3) - assert cx3.unresolved_address == READER3_ADDRESS + assert cx3.address == READER3_ADDRESS cx1.close.assert_not_called() cx2.close.assert_not_called() diff --git a/tests/unit/async_/test_conf.py b/tests/unit/async_/test_conf.py index 1902a5d65..c9d20c6ec 100644 --- a/tests/unit/async_/test_conf.py +++ b/tests/unit/async_/test_conf.py @@ -36,7 +36,6 @@ AsyncClientCertificateProviders, ClientCertificate, ) -from neo4j.debug import watch from neo4j.exceptions import ConfigurationError from ..._async_compat import mark_async_test @@ -45,8 +44,6 @@ # python -m pytest tests/unit/test_conf.py -s -v -watch("neo4j") - test_pool_config = { "connection_timeout": 30.0, "keep_alive": True, diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 1291d95ee..8f8248833 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -149,7 +149,7 @@ def __init__( self.run_meta = run_meta self.summary_meta = summary_meta AsyncConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) - self.unresolved_address = None + self.address = None self._new_hydration_scope_called = False async def send_all(self): diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 93b2c2436..63baf227b 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -24,12 +24,9 @@ READ_ACCESS, WRITE_ACCESS, ) -from neo4j.debug import watch from neo4j.exceptions import ConfigurationError -watch("neo4j") - test_session_config = { "connection_acquisition_timeout": 60.0, "max_transaction_retry_time": 30.0, diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index 8785badb6..fcb3a3708 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -55,7 +55,8 @@ def __init__(self, *args, **kwargs): self.attach_mock( mock.MagicMock(spec=AuthManager), "auth_manager" ) - self.unresolved_address = next(iter(args), "localhost") + self.address = next(iter(args), "localhost") + self.advertised_address = None self.callbacks = [] diff --git a/tests/unit/sync/io/test_class_bolt5x8.py b/tests/unit/sync/io/test_class_bolt5x8.py index 76cdd7a49..09b83ab4a 100644 --- a/tests/unit/sync/io/test_class_bolt5x8.py +++ b/tests/unit/sync/io/test_class_bolt5x8.py @@ -866,10 +866,10 @@ def test_address_callback( ): cb_calls = [] - def cb(connection_, address_): + def cb(connection_): assert connection_ is connection - assert connection.unresolved_address == address - cb_calls.append(address_) + assert connection.address == address + cb_calls.append(connection_.advertised_address) address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair( @@ -882,7 +882,8 @@ def cb(connection_, address_): success_meta["advertised_address"] = advertised_address sockets.server.send_message(b"\x70", success_meta) - connection = Bolt5x8(address, sockets.client, 0, address_callback=cb) + connection = Bolt5x8(address, sockets.client, 0) + connection.address_callback = cb connection.logon() connection.send_all() @@ -903,4 +904,4 @@ def cb(connection_, address_): return assert cb_calls == [expected_call] - assert connection.unresolved_address == expected_call + assert connection.address == address diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 89b4d16b3..fbd19ce69 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -84,7 +84,7 @@ def routing_side_effect(*args, **kwargs): def open_(addr, auth, timeout): connection = fake_connection_generator() - connection.unresolved_address = addr + connection.address = addr connection.timeout = timeout connection.auth = auth route_mock = mocker.MagicMock() @@ -188,9 +188,9 @@ def test_chooses_right_connection_type(opener, type_): ) pool.release(cx1) if type_ == "r": - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS else: - assert cx1.unresolved_address == WRITER1_ADDRESS + assert cx1.address == WRITER1_ADDRESS @mark_sync_test @@ -206,7 +206,7 @@ def test_reuses_connection(opener): @mark_sync_test def test_closes_stale_connections(opener, break_on_close): def break_connection(): - pool.deactivate(cx1.unresolved_address) + pool.deactivate(cx1.address) if cx_close_mock_side_effect: res = cx_close_mock_side_effect() @@ -218,7 +218,7 @@ def break_connection(): pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) - assert cx1 in pool.connections[cx1.unresolved_address] + assert cx1 in pool.connections[cx1.address] # simulate connection going stale (e.g. exceeding idle timeout) and then # breaking when the pool tries to close the connection cx1.stale.return_value = True @@ -233,16 +233,16 @@ def break_connection(): else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.unresolved_address == cx1.unresolved_address - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx2.unresolved_address] + assert cx2.address == cx1.address + assert cx1 not in pool.connections[cx1.address] + assert cx2 in pool.connections[cx2.address] @mark_sync_test def test_does_not_close_stale_connections_in_use(opener): pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - assert cx1 in pool.connections[cx1.unresolved_address] + assert cx1 in pool.connections[cx1.address] # simulate connection going stale (e.g. exceeding idle timeout) while being # in use cx1.stale.return_value = True @@ -250,9 +250,9 @@ def test_does_not_close_stale_connections_in_use(opener): pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.unresolved_address == cx1.unresolved_address - assert cx1 in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx2.unresolved_address] + assert cx2.address == cx1.address + assert cx1 in pool.connections[cx1.address] + assert cx2 in pool.connections[cx2.address] pool.release(cx1) # now that cx1 is back in the pool and still stale, @@ -263,9 +263,9 @@ def test_does_not_close_stale_connections_in_use(opener): pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.unresolved_address == cx1.unresolved_address - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx3 in pool.connections[cx2.unresolved_address] + assert cx3.address == cx1.address + assert cx1 not in pool.connections[cx1.address] + assert cx3 in pool.connections[cx2.address] @mark_sync_test @@ -314,7 +314,7 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( cx1 = pool._acquire( READER1_ADDRESS, None, Deadline(30), liveness_timeout ) - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.reset.assert_not_called() @@ -330,7 +330,7 @@ def test_acquire_performs_liveness_check_on_existing_connection( ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -367,7 +367,7 @@ def liveness_side_effect(*args, **kwargs): ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -384,11 +384,11 @@ def liveness_side_effect(*args, **kwargs): READER1_ADDRESS, None, Deadline(30), liveness_timeout ) assert cx1 is not cx2 - assert cx1.unresolved_address == cx2.unresolved_address + assert cx1.address == cx2.address cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_not_called() - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx1.unresolved_address] + assert cx1 not in pool.connections[cx1.address] + assert cx2 in pool.connections[cx1.address] @pytest.mark.parametrize( @@ -412,8 +412,8 @@ def liveness_side_effect(*args, **kwargs): ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS - assert cx2.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS + assert cx2.address == READER1_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -439,8 +439,8 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_called_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) cx3.reset.assert_called_once() - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx3 in pool.connections[cx1.unresolved_address] + assert cx1 not in pool.connections[cx1.address] + assert cx3 in pool.connections[cx1.address] @mark_sync_test @@ -701,7 +701,7 @@ def get_readers(database): opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS pool.release(cx1) cx1.close.assert_not_called() @@ -712,7 +712,7 @@ def get_readers(database): readers["db1"] = [str(READER2_ADDRESS)] cx2 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) - assert cx2.unresolved_address == READER2_ADDRESS + assert cx2.address == READER2_ADDRESS cx1.close.assert_called_once() assert len(pool.connections[READER1_ADDRESS]) == 0 @@ -740,14 +740,14 @@ def get_readers(database): ) cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) pool.release(cx1) - assert cx1.unresolved_address in {READER1_ADDRESS, READER2_ADDRESS} + assert cx1.address in {READER1_ADDRESS, READER2_ADDRESS} reader1_connection_count = len(pool.connections[READER1_ADDRESS]) reader2_connection_count = len(pool.connections[READER2_ADDRESS]) assert reader1_connection_count + reader2_connection_count == 1 cx2 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) pool.release(cx2) - assert cx2.unresolved_address == READER1_ADDRESS + assert cx2.address == READER1_ADDRESS cx1.close.assert_not_called() cx2.close.assert_not_called() assert len(pool.connections[READER1_ADDRESS]) == 1 @@ -759,7 +759,7 @@ def get_readers(database): cx3 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) pool.release(cx3) - assert cx3.unresolved_address == READER3_ADDRESS + assert cx3.address == READER3_ADDRESS cx1.close.assert_not_called() cx2.close.assert_not_called() diff --git a/tests/unit/sync/test_conf.py b/tests/unit/sync/test_conf.py index 65481811b..bc589a38a 100644 --- a/tests/unit/sync/test_conf.py +++ b/tests/unit/sync/test_conf.py @@ -36,7 +36,6 @@ ClientCertificate, ClientCertificateProviders, ) -from neo4j.debug import watch from neo4j.exceptions import ConfigurationError from ..._async_compat import mark_sync_test @@ -45,8 +44,6 @@ # python -m pytest tests/unit/test_conf.py -s -v -watch("neo4j") - test_pool_config = { "connection_timeout": 30.0, "keep_alive": True, diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 623d50143..c9a1bb48a 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -149,7 +149,7 @@ def __init__( self.run_meta = run_meta self.summary_meta = summary_meta ConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) - self.unresolved_address = None + self.address = None self._new_hydration_scope_called = False def send_all(self): From b8ccf60353545e1662a3e79ccde432d40473a5b9 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 14 Jan 2025 17:29:54 +0100 Subject: [PATCH 5/6] Unit tests for advertised address optimization --- tests/unit/async_/fixtures/fake_connection.py | 1 + tests/unit/async_/io/test_direct.py | 51 +++++++++++++++++++ tests/unit/async_/io/test_neo4j_pool.py | 40 ++++++++++++++- tests/unit/sync/fixtures/fake_connection.py | 1 + tests/unit/sync/io/test_direct.py | 51 +++++++++++++++++++ tests/unit/sync/io/test_neo4j_pool.py | 40 ++++++++++++++- 6 files changed, 180 insertions(+), 4 deletions(-) diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 65dba3b51..56ac79bdc 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -57,6 +57,7 @@ def __init__(self, *args, **kwargs): ) self.address = next(iter(args), "localhost") self.advertised_address = None + self.address_callback = None self.callbacks = [] diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 80014266a..77950813c 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -25,6 +25,7 @@ WorkspaceConfig, ) from neo4j._deadline import Deadline +from neo4j.addressing import Address from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ( ClientError, @@ -37,6 +38,8 @@ class AsyncFakeBoltPool(AsyncIOPool): is_direct_pool = False + __on_open = None + def __init__(self, connection_gen, address, *, auth=None, **config): self.buffered_connection_mocks = [] config["auth"] = static_auth(None) @@ -54,6 +57,8 @@ async def opener(addr, auth, timeout): else: mock = connection_gen() mock.address = addr + if self.__on_open is not None: + self.__on_open(mock) return mock super().__init__(opener, self.pool_config, self.workspace_config) @@ -273,3 +278,49 @@ async def test_liveness_check( cx1.reset.reset_mock() await pool.release(cx1) cx1.reset.assert_not_called() + + +@pytest.fixture +async def simple_pool_factory(async_fake_connection_generator): + pools = [] + + def factory(**config): + pool_ = AsyncFakeBoltPool( + async_fake_connection_generator, + ("127.0.0.1", 7687), + **config, + ) + pools.append(pool_) + return pool_ + + yield factory + + for pool in pools: + await pool.close() + + +async def test_configures_no_address_cb_on_connection(simple_pool_factory): + pool = simple_pool_factory() + cx = await pool.acquire("r", Deadline(3), "test_db", None, None, None) + + assert cx.address_callback is None + + +async def test_does_not_move_connection_to_advertised_address_after_open( + simple_pool_factory, +): + advertised_address = Address(("example.com", 1234)) + + def on_open(connection): + assert connection.address != advertised_address # sanity check + connection.advertised_address = advertised_address + + pool = simple_pool_factory() + pool._AsyncFakeBoltPool__on_open = on_open + cx = await pool.acquire("r", Deadline(3), "test_db", None, None, None) + + # assert has been moved + assert cx.address == pool.address + assert len(pool.connections[pool.address]) == 1 + assert len(pool.connections[advertised_address]) == 0 + assert cx in pool.connections[pool.address] diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index d409a4958..a1f20252c 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -33,7 +33,10 @@ WorkspaceConfig, ) from neo4j._deadline import Deadline -from neo4j.addressing import ResolvedAddress +from neo4j.addressing import ( + Address, + ResolvedAddress, +) from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ( Neo4jError, @@ -55,7 +58,7 @@ @pytest.fixture def custom_routing_opener(async_fake_connection_generator, mocker): - def make_opener(failures=None, get_readers=None): + def make_opener(failures=None, get_readers=None, on_open=None): def routing_side_effect(*args, **kwargs): nonlocal failures res = next(failures, None) @@ -92,6 +95,10 @@ async def open_(addr, auth, timeout): route_mock.side_effect = routing_side_effect connection.attach_mock(route_mock, "route") opener_.connections.append(connection) + + if on_open is not None: + on_open(connection) + return connection failures = iter(failures or []) @@ -767,3 +774,32 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count assert len(pool.connections[READER3_ADDRESS]) == 1 + + +@mark_async_test +async def test_configures_address_cb_on_connection(opener): + pool = _simple_pool(opener) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + + assert cx.address_callback == pool._move_connection + + +@mark_async_test +async def test_moves_connection_to_advertised_address_after_open( + custom_routing_opener, +): + advertised_address = Address(("example.com", 1234)) + + def on_open(connection): + assert connection.address != advertised_address # sanity check + connection.advertised_address = advertised_address + + opener = custom_routing_opener(on_open=on_open) + pool = _simple_pool(opener) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + + # assert has been moved + assert cx.address == advertised_address + assert len(pool.connections[READER1_ADDRESS]) == 0 + assert len(pool.connections[advertised_address]) == 1 + assert cx in pool.connections[advertised_address] diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index fcb3a3708..f66694ece 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -57,6 +57,7 @@ def __init__(self, *args, **kwargs): ) self.address = next(iter(args), "localhost") self.advertised_address = None + self.address_callback = None self.callbacks = [] diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index a899ae499..933efa0b1 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -25,6 +25,7 @@ from neo4j._sync.config import PoolConfig from neo4j._sync.io import Bolt from neo4j._sync.io._pool import IOPool +from neo4j.addressing import Address from neo4j.auth_management import AuthManagers from neo4j.exceptions import ( ClientError, @@ -37,6 +38,8 @@ class FakeBoltPool(IOPool): is_direct_pool = False + __on_open = None + def __init__(self, connection_gen, address, *, auth=None, **config): self.buffered_connection_mocks = [] config["auth"] = static_auth(None) @@ -54,6 +57,8 @@ def opener(addr, auth, timeout): else: mock = connection_gen() mock.address = addr + if self.__on_open is not None: + self.__on_open(mock) return mock super().__init__(opener, self.pool_config, self.workspace_config) @@ -273,3 +278,49 @@ def test_liveness_check( cx1.reset.reset_mock() pool.release(cx1) cx1.reset.assert_not_called() + + +@pytest.fixture +def simple_pool_factory(fake_connection_generator): + pools = [] + + def factory(**config): + pool_ = FakeBoltPool( + fake_connection_generator, + ("127.0.0.1", 7687), + **config, + ) + pools.append(pool_) + return pool_ + + yield factory + + for pool in pools: + pool.close() + + +def test_configures_no_address_cb_on_connection(simple_pool_factory): + pool = simple_pool_factory() + cx = pool.acquire("r", Deadline(3), "test_db", None, None, None) + + assert cx.address_callback is None + + +def test_does_not_move_connection_to_advertised_address_after_open( + simple_pool_factory, +): + advertised_address = Address(("example.com", 1234)) + + def on_open(connection): + assert connection.address != advertised_address # sanity check + connection.advertised_address = advertised_address + + pool = simple_pool_factory() + pool._FakeBoltPool__on_open = on_open + cx = pool.acquire("r", Deadline(3), "test_db", None, None, None) + + # assert has been moved + assert cx.address == pool.address + assert len(pool.connections[pool.address]) == 1 + assert len(pool.connections[advertised_address]) == 0 + assert cx in pool.connections[pool.address] diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index fbd19ce69..e7bb9475f 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -33,7 +33,10 @@ Bolt, Neo4jPool, ) -from neo4j.addressing import ResolvedAddress +from neo4j.addressing import ( + Address, + ResolvedAddress, +) from neo4j.auth_management import AuthManagers from neo4j.exceptions import ( Neo4jError, @@ -55,7 +58,7 @@ @pytest.fixture def custom_routing_opener(fake_connection_generator, mocker): - def make_opener(failures=None, get_readers=None): + def make_opener(failures=None, get_readers=None, on_open=None): def routing_side_effect(*args, **kwargs): nonlocal failures res = next(failures, None) @@ -92,6 +95,10 @@ def open_(addr, auth, timeout): route_mock.side_effect = routing_side_effect connection.attach_mock(route_mock, "route") opener_.connections.append(connection) + + if on_open is not None: + on_open(connection) + return connection failures = iter(failures or []) @@ -767,3 +774,32 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count assert len(pool.connections[READER3_ADDRESS]) == 1 + + +@mark_sync_test +def test_configures_address_cb_on_connection(opener): + pool = _simple_pool(opener) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + + assert cx.address_callback == pool._move_connection + + +@mark_sync_test +def test_moves_connection_to_advertised_address_after_open( + custom_routing_opener, +): + advertised_address = Address(("example.com", 1234)) + + def on_open(connection): + assert connection.address != advertised_address # sanity check + connection.advertised_address = advertised_address + + opener = custom_routing_opener(on_open=on_open) + pool = _simple_pool(opener) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + + # assert has been moved + assert cx.address == advertised_address + assert len(pool.connections[READER1_ADDRESS]) == 0 + assert len(pool.connections[advertised_address]) == 1 + assert cx in pool.connections[advertised_address] From 80e3cfe453f66edc950bf7a77fc2f87208b68324 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 15 Jan 2025 11:51:02 +0100 Subject: [PATCH 6/6] TMP! point CI to corresponding TestKit PR branch --- testkit/testkit.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testkit/testkit.json b/testkit/testkit.json index 931900356..307f10913 100644 --- a/testkit/testkit.json +++ b/testkit/testkit.json @@ -1,6 +1,6 @@ { "testkit": { "uri": "https://github.com/neo4j-drivers/testkit.git", - "ref": "5.0" + "ref": "advertised-address" } }