From f9ec389b22a008a6c803cf31c27e3e7508b0be50 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 26 Jun 2025 16:08:49 +0200 Subject: [PATCH 1/4] Unify `connection_acquisition_timeout` behavior The config option `connection_acquisition_timeout` now spans anything that's required to acquire a working connection from the pool. This includes * Potentially fetching a routing table This entails acquiring a connection in itself. * Bolt, TLS, TCP handshaking * Authentication * Any other required IO (e.g., DNS lookups) * Waiting for room in the pool * possibly more Previously, the timeout wold be restarted for auxiliary connection acquisitions like those for fetching a routing table. --- docs/source/api.rst | 5 +++++ src/neo4j/_async/io/__init__.py | 2 ++ src/neo4j/_async/io/_pool.py | 15 +++++++++++++-- src/neo4j/_async/work/workspace.py | 20 +++++++++++++++----- src/neo4j/_sync/io/__init__.py | 2 ++ src/neo4j/_sync/io/_pool.py | 15 +++++++++++++-- src/neo4j/_sync/work/workspace.py | 20 +++++++++++++++----- testkitbackend/test_config.json | 8 +------- 8 files changed, 66 insertions(+), 21 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 23608e6a9..0941399bb 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -430,6 +430,11 @@ it should be chosen larger than :ref:`connection-timeout-ref`. :Type: ``float`` :Default: ``60.0`` +.. versionadded:: 6.0 + The setting now entails *anything* required to acquire a connection. + This includes potential fetching of routing tables which in itself requires acquiring a connection. + Previously, the timeout wold be restarted for such auxiliary connection acquisitions. + .. _connection-timeout-ref: diff --git a/src/neo4j/_async/io/__init__.py b/src/neo4j/_async/io/__init__.py index 7c950834e..f7c926f0d 100644 --- a/src/neo4j/_async/io/__init__.py +++ b/src/neo4j/_async/io/__init__.py @@ -28,6 +28,7 @@ "AsyncBoltPool", "AsyncNeo4jPool", "ConnectionErrorHandler", + "acquisition_timeout_to_deadline", ] @@ -40,6 +41,7 @@ from ._bolt import AsyncBolt from ._common import ConnectionErrorHandler from ._pool import ( + acquisition_timeout_to_deadline, AcquisitionAuth, AcquisitionDatabase, AsyncBoltPool, diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 58eeafdaf..935e979f9 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -260,6 +260,8 @@ async def connection_creator(): with self.lock: self.connections_reservations[address] -= 1 + if deadline.expired(): + return None max_pool_size = self.pool_config.max_connection_pool_size infinite_pool_size = max_pool_size < 0 or max_pool_size == float("inf") with self.lock: @@ -969,7 +971,9 @@ async def update_routing_table( :raise neo4j.exceptions.ServiceUnavailable: """ - _check_acquisition_timeout(acquisition_timeout) + acquisition_timeout = acquisition_timeout_to_deadline( + acquisition_timeout + ) async with self.refresh_lock: routing_table = await self.get_routing_table(database) if routing_table is not None: @@ -1147,7 +1151,7 @@ async def acquire( database_callback=None, ): access_mode = check_access_mode(access_mode) - _check_acquisition_timeout(timeout) + timeout = acquisition_timeout_to_deadline(timeout) target_database = database.name @@ -1242,6 +1246,13 @@ async def on_write_failure(self, address, database): log.debug("[#0000] _: table=%r", self.routing_tables) +def acquisition_timeout_to_deadline(timeout: object) -> Deadline: + if isinstance(timeout, Deadline): + return timeout + _check_acquisition_timeout(timeout) + return Deadline(timeout) + + def _check_acquisition_timeout(timeout: object) -> None: if not isinstance(timeout, (int, float)): raise TypeError( diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index dae2fbbc3..75960d226 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -30,12 +30,14 @@ ) from .._debug import AsyncNonConcurrentMethodChecker from ..io import ( + acquisition_timeout_to_deadline, AcquisitionAuth, AcquisitionDatabase, ) if t.TYPE_CHECKING: + from ..._deadline import Deadline from ...api import _TAuth from ...auth_management import ( AsyncAuthManager, @@ -159,13 +161,19 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: await self._connection.fetch_all() await self._disconnect() + acquisition_deadline = acquisition_timeout_to_deadline( + acquisition_timeout + ) + ssr_enabled = self._pool.ssr_enabled target_db = await self._get_routing_target_database( - acquire_auth, ssr_enabled=ssr_enabled + acquire_auth, + ssr_enabled=ssr_enabled, + acquisition_deadline=acquisition_deadline, ) acquire_kwargs_ = { "access_mode": access_mode, - "timeout": acquisition_timeout, + "timeout": acquisition_deadline, "database": target_db, "bookmarks": await self._get_bookmarks(), "auth": acquire_auth, @@ -188,7 +196,9 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: ) await self._disconnect() target_db = await self._get_routing_target_database( - acquire_auth, ssr_enabled=False + acquire_auth, + ssr_enabled=False, + acquisition_deadline=acquisition_deadline, ) acquire_kwargs_["database"] = target_db self._connection = await self._pool.acquire(**acquire_kwargs_) @@ -198,6 +208,7 @@ async def _get_routing_target_database( self, acquire_auth: AcquisitionAuth, ssr_enabled: bool, + acquisition_deadline: Deadline, ) -> AcquisitionDatabase: if ( self._pinned_database @@ -232,14 +243,13 @@ async def _get_routing_target_database( ) return AcquisitionDatabase(cached_db, guessed=True) - acquisition_timeout = self._config.connection_acquisition_timeout log.debug("[#0000] _: resolve home database") await self._pool.update_routing_table( database=self._config.database, imp_user=self._config.impersonated_user, bookmarks=await self._get_bookmarks(), auth=acquire_auth, - acquisition_timeout=acquisition_timeout, + acquisition_timeout=acquisition_deadline, database_callback=self._make_db_resolution_callback(), ) return AcquisitionDatabase(self._config.database) diff --git a/src/neo4j/_sync/io/__init__.py b/src/neo4j/_sync/io/__init__.py index 775fd504d..097890e61 100644 --- a/src/neo4j/_sync/io/__init__.py +++ b/src/neo4j/_sync/io/__init__.py @@ -28,6 +28,7 @@ "BoltPool", "Neo4jPool", "ConnectionErrorHandler", + "acquisition_timeout_to_deadline", ] @@ -40,6 +41,7 @@ from ._bolt import Bolt from ._common import ConnectionErrorHandler from ._pool import ( + acquisition_timeout_to_deadline, AcquisitionAuth, AcquisitionDatabase, BoltPool, diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 1819ad08c..8deb06c6b 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -257,6 +257,8 @@ def connection_creator(): with self.lock: self.connections_reservations[address] -= 1 + if deadline.expired(): + return None max_pool_size = self.pool_config.max_connection_pool_size infinite_pool_size = max_pool_size < 0 or max_pool_size == float("inf") with self.lock: @@ -966,7 +968,9 @@ def update_routing_table( :raise neo4j.exceptions.ServiceUnavailable: """ - _check_acquisition_timeout(acquisition_timeout) + acquisition_timeout = acquisition_timeout_to_deadline( + acquisition_timeout + ) with self.refresh_lock: routing_table = self.get_routing_table(database) if routing_table is not None: @@ -1144,7 +1148,7 @@ def acquire( database_callback=None, ): access_mode = check_access_mode(access_mode) - _check_acquisition_timeout(timeout) + timeout = acquisition_timeout_to_deadline(timeout) target_database = database.name @@ -1239,6 +1243,13 @@ def on_write_failure(self, address, database): log.debug("[#0000] _: table=%r", self.routing_tables) +def acquisition_timeout_to_deadline(timeout: object) -> Deadline: + if isinstance(timeout, Deadline): + return timeout + _check_acquisition_timeout(timeout) + return Deadline(timeout) + + def _check_acquisition_timeout(timeout: object) -> None: if not isinstance(timeout, (int, float)): raise TypeError( diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 1be5a744b..a13b53c75 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -30,12 +30,14 @@ ) from .._debug import NonConcurrentMethodChecker from ..io import ( + acquisition_timeout_to_deadline, AcquisitionAuth, AcquisitionDatabase, ) if t.TYPE_CHECKING: + from ..._deadline import Deadline from ...api import _TAuth from ...auth_management import AuthManager from ..home_db_cache import ( @@ -156,13 +158,19 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: self._connection.fetch_all() self._disconnect() + acquisition_deadline = acquisition_timeout_to_deadline( + acquisition_timeout + ) + ssr_enabled = self._pool.ssr_enabled target_db = self._get_routing_target_database( - acquire_auth, ssr_enabled=ssr_enabled + acquire_auth, + ssr_enabled=ssr_enabled, + acquisition_deadline=acquisition_deadline, ) acquire_kwargs_ = { "access_mode": access_mode, - "timeout": acquisition_timeout, + "timeout": acquisition_deadline, "database": target_db, "bookmarks": self._get_bookmarks(), "auth": acquire_auth, @@ -185,7 +193,9 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: ) self._disconnect() target_db = self._get_routing_target_database( - acquire_auth, ssr_enabled=False + acquire_auth, + ssr_enabled=False, + acquisition_deadline=acquisition_deadline, ) acquire_kwargs_["database"] = target_db self._connection = self._pool.acquire(**acquire_kwargs_) @@ -195,6 +205,7 @@ def _get_routing_target_database( self, acquire_auth: AcquisitionAuth, ssr_enabled: bool, + acquisition_deadline: Deadline, ) -> AcquisitionDatabase: if ( self._pinned_database @@ -229,14 +240,13 @@ def _get_routing_target_database( ) return AcquisitionDatabase(cached_db, guessed=True) - acquisition_timeout = self._config.connection_acquisition_timeout log.debug("[#0000] _: resolve home database") self._pool.update_routing_table( database=self._config.database, imp_user=self._config.impersonated_user, bookmarks=self._get_bookmarks(), auth=acquire_auth, - acquisition_timeout=acquisition_timeout, + acquisition_timeout=acquisition_deadline, database_callback=self._make_db_resolution_callback(), ) return AcquisitionDatabase(self._config.database) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 0c784bc78..3777f4e29 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -15,13 +15,7 @@ "'neo4j.datatypes.test_temporal_types.TestDataTypes.test_date_time_cypher_created_tz_id'": "test_subtest_skips.tz_id", "stub\\.routing\\.test_routing_v[0-9x]+\\.RoutingV[0-9x]+\\.test_should_drop_connections_failing_liveness_check": - "Liveness check error handling is not (yet) unified: https://github.com/neo-technology/drivers-adr/pull/83", - "'stub.homedb.test_homedb.TestHomeDbMixedCluster.test_connection_acquisition_timeout_during_fallback'": - "TODO: 6.0 - pending unification: connection acquisition timeout should count towards the total time spent waiting for a connection (including routing, home db resolution, ...)", - "'stub.driver_parameters.test_connection_acquisition_timeout_ms.TestConnectionAcquisitionTimeoutMs.test_does_encompass_router_route_response'": - "TODO: 6.0 - pending unification: connection acquisition timeout should count towards the total time spent waiting for a connection (including routing, home db resolution, ...)", - "'stub.driver_parameters.test_connection_acquisition_timeout_ms.TestConnectionAcquisitionTimeoutMs.test_router_handshake_shares_acquisition_timeout'": - "TODO: 6.0 - pending unification: connection acquisition timeout should count towards the total time spent waiting for a connection (including routing, home db resolution, ...)" + "Liveness check error handling is not (yet) unified: https://github.com/neo-technology/drivers-adr/pull/83" }, "features": { "Feature:API:BookmarkManager": true, From 27c8b8572d5e52c3defe67d254af0e71e8350152 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 15 Jul 2025 13:02:15 +0200 Subject: [PATCH 2/4] fixup! Unify `connection_acquisition_timeout` behavior --- src/neo4j/_async/io/_pool.py | 5 +- src/neo4j/_sync/io/_pool.py | 5 +- tests/unit/async_/io/test_bolt_pool.py | 156 ++++++++++++++++++++++++ tests/unit/async_/io/test_direct.py | 2 +- tests/unit/async_/io/test_neo4j_pool.py | 36 ++++++ tests/unit/mixed/io/test_direct.py | 4 +- tests/unit/sync/io/test_direct.py | 2 +- tests/unit/sync/io/test_neo4j_pool.py | 36 ++++++ 8 files changed, 236 insertions(+), 10 deletions(-) create mode 100644 tests/unit/async_/io/test_bolt_pool.py diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 935e979f9..fabdbe702 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -661,7 +661,7 @@ async def acquire( timeout, database, bookmarks, - auth: AcquisitionAuth, + auth: AcquisitionAuth | None, liveness_check_timeout, unprepared=False, database_callback=None, @@ -669,14 +669,13 @@ async def acquire( # The access_mode and database is not needed for a direct connection, # it's just there for consistency. access_mode = check_access_mode(access_mode) - _check_acquisition_timeout(timeout) + deadline = acquisition_timeout_to_deadline(timeout) log.debug( "[#0000] _: acquire direct connection, " "access_mode=%r, database=%r", access_mode, database, ) - deadline = Deadline.from_timeout_or_deadline(timeout) return await self._acquire( self.address, auth, deadline, liveness_check_timeout, unprepared ) diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 8deb06c6b..d3946ff3e 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -658,7 +658,7 @@ def acquire( timeout, database, bookmarks, - auth: AcquisitionAuth, + auth: AcquisitionAuth | None, liveness_check_timeout, unprepared=False, database_callback=None, @@ -666,14 +666,13 @@ def acquire( # The access_mode and database is not needed for a direct connection, # it's just there for consistency. access_mode = check_access_mode(access_mode) - _check_acquisition_timeout(timeout) + deadline = acquisition_timeout_to_deadline(timeout) log.debug( "[#0000] _: acquire direct connection, " "access_mode=%r, database=%r", access_mode, database, ) - deadline = Deadline.from_timeout_or_deadline(timeout) return self._acquire( self.address, auth, deadline, liveness_check_timeout, unprepared ) diff --git a/tests/unit/async_/io/test_bolt_pool.py b/tests/unit/async_/io/test_bolt_pool.py new file mode 100644 index 000000000..bb90d0be9 --- /dev/null +++ b/tests/unit/async_/io/test_bolt_pool.py @@ -0,0 +1,156 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import contextlib + +import pytest + +from neo4j import READ_ACCESS +from neo4j._addressing import ResolvedAddress +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io import ( + AcquisitionDatabase, + AsyncBoltPool, +) +from neo4j._conf import WorkspaceConfig +from neo4j.auth_management import AsyncAuthManagers +from neo4j.exceptions import ConnectionAcquisitionTimeoutError + +from ...._async_compat import mark_async_test + + +SERVER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") + + +def make_home_db_resolve(home_db): + def _home_db_resolve(db): + return db or home_db + + return _home_db_resolve + + +_default_db_resolve = make_home_db_resolve("neo4j") + + +@pytest.fixture +def custom_opener(async_fake_connection_generator, mocker): + def make_opener( + failures=None, + db_resolve=_default_db_resolve, + on_open=None, + ): + def routing_side_effect(*args, **kwargs): + nonlocal failures + opener_.route_requests.append(kwargs.get("database")) + res = next(failures, None) + if res is None: + routers = readers = writers = [str(SERVER1_ADDRESS)] + rt = { + "ttl": 1000, + "servers": [ + {"addresses": routers, "role": "ROUTE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": writers, "role": "WRITE"}, + ], + } + db = db_resolve(kwargs.get("database")) + if db is not ...: + rt["db"] = db + return [rt] + raise res + + async def open_(addr, auth, timeout): + connection = async_fake_connection_generator() + connection.unresolved_address = addr + connection.timeout = timeout + connection.auth = auth + route_mock = mocker.AsyncMock() + + route_mock.side_effect = routing_side_effect + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + + if callable(on_open): + on_open(connection) + + return connection + + failures = iter(failures or []) + opener_ = mocker.AsyncMock() + opener_.connections = [] + opener_.route_requests = [] + opener_.side_effect = open_ + return opener_ + + return make_opener + + +@pytest.fixture +def opener(custom_opener): + return custom_opener() + + +def _pool_config(): + pool_config = AsyncPoolConfig() + pool_config.auth = _auth_manager(("user", "pass")) + return pool_config + + +def _auth_manager(auth): + return AsyncAuthManagers.static(auth) + + +def _simple_pool(opener) -> AsyncBoltPool: + return AsyncBoltPool( + opener, _pool_config(), WorkspaceConfig(), SERVER1_ADDRESS + ) + + +TEST_DB1 = AcquisitionDatabase("test_db1") + + +@pytest.mark.parametrize( + ("timeout", "expected_error"), + ( + (1, None), + (2 ^ 128, None), + (0.000000001, None), + (float("inf"), None), + (-1, ValueError), + (0, ValueError), + (float("-inf"), ValueError), + (float("NaN"), ValueError), + (float("-NaN"), ValueError), + ("1", TypeError), + (None, TypeError), + ([1], TypeError), + ), +) +@mark_async_test +async def test_invalid_acquisition_timeouts(opener, timeout, expected_error): + pool = _simple_pool(opener) + + async def call(): + with contextlib.suppress(ConnectionAcquisitionTimeoutError): + await pool.acquire( + READ_ACCESS, timeout, TEST_DB1, None, None, None + ) + + if expected_error is None: + await call() + else: + with pytest.raises(expected_error): + await call() diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index e242e3c16..83a5abf0f 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -204,7 +204,7 @@ async def test_pool_max_conn_pool_size(async_fake_connection_generator): async_fake_connection_generator, (), max_connection_pool_size=1 ) as pool: address = neo4j.Address(("127.0.0.1", 7687)) - await pool._acquire(address, None, Deadline(0), None) + await pool._acquire(address, None, Deadline(float("inf")), None) assert pool.in_use_connection_count(address) == 1 with pytest.raises(ConnectionAcquisitionTimeoutError): await pool._acquire(address, None, Deadline(0), None) diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index 1517d5d6e..a758ac3c3 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -14,6 +14,7 @@ # limitations under the License. +import contextlib import inspect import pytest @@ -37,6 +38,7 @@ from neo4j._deadline import Deadline from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ( + ConnectionAcquisitionTimeoutError, Neo4jError, ServiceUnavailable, SessionExpired, @@ -914,3 +916,37 @@ def on_open(connection): await pool.release(cx) assert pool.ssr_enabled + + +@pytest.mark.parametrize( + ("timeout", "expected_error"), + ( + (1, None), + (2 ^ 128, None), + (0.000000001, None), + (float("inf"), None), + (-1, ValueError), + (0, ValueError), + (float("-inf"), ValueError), + (float("NaN"), ValueError), + (float("-NaN"), ValueError), + ("1", TypeError), + (None, TypeError), + ([1], TypeError), + ), +) +@mark_async_test +async def test_invalid_acquisition_timeouts(opener, timeout, expected_error): + pool = _simple_pool(opener) + + async def call(): + with contextlib.suppress(ConnectionAcquisitionTimeoutError): + await pool.acquire( + READ_ACCESS, timeout, TEST_DB1, None, None, None + ) + + if expected_error is None: + await call() + else: + with pytest.raises(expected_error): + await call() diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index 5b42976f0..f74d05884 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -141,7 +141,7 @@ def test_full_pool_re_auth(self, fake_connection_generator, mocker): def acquire1(pool_): nonlocal cx1 - cx = pool_._acquire(address, acquire_auth1, Deadline(0), None) + cx = pool_._acquire(address, acquire_auth1, Deadline(1), None) acquire1_event.set() cx1 = cx while True: @@ -258,7 +258,7 @@ async def test_full_pool_re_auth_async( async def acquire1(pool_): nonlocal cx1 cx = await pool_._acquire( - address, acquire_auth1, Deadline(0), None + address, acquire_auth1, Deadline(1), None ) cx1 = cx while len(pool_.cond._waiters) == 0: diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index e0a0636b1..6c8934f4d 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -204,7 +204,7 @@ def test_pool_max_conn_pool_size(fake_connection_generator): fake_connection_generator, (), max_connection_pool_size=1 ) as pool: address = neo4j.Address(("127.0.0.1", 7687)) - pool._acquire(address, None, Deadline(0), None) + pool._acquire(address, None, Deadline(float("inf")), None) assert pool.in_use_connection_count(address) == 1 with pytest.raises(ConnectionAcquisitionTimeoutError): pool._acquire(address, None, Deadline(0), None) diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 7ee1f4a79..5228f8967 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -14,6 +14,7 @@ # limitations under the License. +import contextlib import inspect import pytest @@ -37,6 +38,7 @@ ) from neo4j.auth_management import AuthManagers from neo4j.exceptions import ( + ConnectionAcquisitionTimeoutError, Neo4jError, ServiceUnavailable, SessionExpired, @@ -914,3 +916,37 @@ def on_open(connection): pool.release(cx) assert pool.ssr_enabled + + +@pytest.mark.parametrize( + ("timeout", "expected_error"), + ( + (1, None), + (2 ^ 128, None), + (0.000000001, None), + (float("inf"), None), + (-1, ValueError), + (0, ValueError), + (float("-inf"), ValueError), + (float("NaN"), ValueError), + (float("-NaN"), ValueError), + ("1", TypeError), + (None, TypeError), + ([1], TypeError), + ), +) +@mark_sync_test +def test_invalid_acquisition_timeouts(opener, timeout, expected_error): + pool = _simple_pool(opener) + + def call(): + with contextlib.suppress(ConnectionAcquisitionTimeoutError): + pool.acquire( + READ_ACCESS, timeout, TEST_DB1, None, None, None + ) + + if expected_error is None: + call() + else: + with pytest.raises(expected_error): + call() From 2146420a7dc7328702a5cdad50eb7e2b86ae7810 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 15 Jul 2025 13:31:52 +0200 Subject: [PATCH 3/4] fixup! Unify `connection_acquisition_timeout` behavior --- tests/unit/sync/io/test_bolt_pool.py | 156 +++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 tests/unit/sync/io/test_bolt_pool.py diff --git a/tests/unit/sync/io/test_bolt_pool.py b/tests/unit/sync/io/test_bolt_pool.py new file mode 100644 index 000000000..d207b83d2 --- /dev/null +++ b/tests/unit/sync/io/test_bolt_pool.py @@ -0,0 +1,156 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import contextlib + +import pytest + +from neo4j import READ_ACCESS +from neo4j._addressing import ResolvedAddress +from neo4j._conf import WorkspaceConfig +from neo4j._sync.config import PoolConfig +from neo4j._sync.io import ( + AcquisitionDatabase, + BoltPool, +) +from neo4j.auth_management import AuthManagers +from neo4j.exceptions import ConnectionAcquisitionTimeoutError + +from ...._async_compat import mark_sync_test + + +SERVER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") + + +def make_home_db_resolve(home_db): + def _home_db_resolve(db): + return db or home_db + + return _home_db_resolve + + +_default_db_resolve = make_home_db_resolve("neo4j") + + +@pytest.fixture +def custom_opener(fake_connection_generator, mocker): + def make_opener( + failures=None, + db_resolve=_default_db_resolve, + on_open=None, + ): + def routing_side_effect(*args, **kwargs): + nonlocal failures + opener_.route_requests.append(kwargs.get("database")) + res = next(failures, None) + if res is None: + routers = readers = writers = [str(SERVER1_ADDRESS)] + rt = { + "ttl": 1000, + "servers": [ + {"addresses": routers, "role": "ROUTE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": writers, "role": "WRITE"}, + ], + } + db = db_resolve(kwargs.get("database")) + if db is not ...: + rt["db"] = db + return [rt] + raise res + + def open_(addr, auth, timeout): + connection = fake_connection_generator() + connection.unresolved_address = addr + connection.timeout = timeout + connection.auth = auth + route_mock = mocker.MagicMock() + + route_mock.side_effect = routing_side_effect + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + + if callable(on_open): + on_open(connection) + + return connection + + failures = iter(failures or []) + opener_ = mocker.MagicMock() + opener_.connections = [] + opener_.route_requests = [] + opener_.side_effect = open_ + return opener_ + + return make_opener + + +@pytest.fixture +def opener(custom_opener): + return custom_opener() + + +def _pool_config(): + pool_config = PoolConfig() + pool_config.auth = _auth_manager(("user", "pass")) + return pool_config + + +def _auth_manager(auth): + return AuthManagers.static(auth) + + +def _simple_pool(opener) -> BoltPool: + return BoltPool( + opener, _pool_config(), WorkspaceConfig(), SERVER1_ADDRESS + ) + + +TEST_DB1 = AcquisitionDatabase("test_db1") + + +@pytest.mark.parametrize( + ("timeout", "expected_error"), + ( + (1, None), + (2 ^ 128, None), + (0.000000001, None), + (float("inf"), None), + (-1, ValueError), + (0, ValueError), + (float("-inf"), ValueError), + (float("NaN"), ValueError), + (float("-NaN"), ValueError), + ("1", TypeError), + (None, TypeError), + ([1], TypeError), + ), +) +@mark_sync_test +def test_invalid_acquisition_timeouts(opener, timeout, expected_error): + pool = _simple_pool(opener) + + def call(): + with contextlib.suppress(ConnectionAcquisitionTimeoutError): + pool.acquire( + READ_ACCESS, timeout, TEST_DB1, None, None, None + ) + + if expected_error is None: + call() + else: + with pytest.raises(expected_error): + call() From e5883dec499d1e7c6b6c450a4747ef9faec6645d Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 15 Jul 2025 14:57:37 +0200 Subject: [PATCH 4/4] fixup! Unify `connection_acquisition_timeout` behavior --- CHANGELOG.md | 5 ++++- docs/source/api.rst | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dbfee64be..32bd57ba6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,12 +39,15 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. (instead of internal `neo4j._exceptions.BoltHandshakeError`). `UnsupportedServerProduct` is now a subclass of `ServiceUnavailable` (instead of `Exception` directly). - `connection_acquisition_timeout` configuration option - - `ValueError` on invalid values (instead of `ClientError`) + - Raise `ValueError` on invalid values (instead of `ClientError`). - Consistently restrict the value to be strictly positive - New `ConnectionAcquisitionTimeoutError` (subclass of `DriverError`) instead of `ClientError` (subclass of `Neo4jError`) the timeout is exceeded. - This improves the differentiation between `DriverError` for client-side errors and `Neo4jError` for server-side errors. + - The option now spans *anything* required to acquire a connection. + This includes potential fetching of routing tables which in itself requires acquiring a connection. + Previously, the timeout would be restarted for such auxiliary connection acquisitions. - `TypeError` instead of `ValueError` when passing a `Query` object to `Transaction.run`. - `TransactionError` (subclass of `DriverError`) instead of `ClientError` (subclass of `Neo4jError`) when calling `session.run()` while an explicit transaction is active on that session. diff --git a/docs/source/api.rst b/docs/source/api.rst index 71de8815f..9622e2afa 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -433,7 +433,7 @@ it should be chosen larger than :ref:`connection-timeout-ref`. .. versionadded:: 6.0 The setting now entails *anything* required to acquire a connection. This includes potential fetching of routing tables which in itself requires acquiring a connection. - Previously, the timeout wold be restarted for such auxiliary connection acquisitions. + Previously, the timeout would be restarted for such auxiliary connection acquisitions. .. _connection-timeout-ref: