From e8316853e60686e6b9073060ff704860d605cfc2 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 14 Aug 2025 10:35:09 -0700 Subject: [PATCH 1/5] PYTHON-5506 Prototype adaptive token bucket retry --- pymongo/asynchronous/collection.py | 1 + pymongo/asynchronous/database.py | 1 + pymongo/asynchronous/helpers.py | 86 ++++++++++++++++++++++++-- pymongo/asynchronous/mongo_client.py | 17 +++-- pymongo/synchronous/collection.py | 1 + pymongo/synchronous/database.py | 1 + pymongo/synchronous/helpers.py | 86 ++++++++++++++++++++++++-- pymongo/synchronous/mongo_client.py | 17 +++-- test/asynchronous/test_backpressure.py | 22 +++++++ test/test_backpressure.py | 22 +++++++ 10 files changed, 234 insertions(+), 20 deletions(-) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index dead0ed4dc..6ff62e9fe3 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -253,6 +253,7 @@ def __init__( unicode_decode_error_handler="replace", document_class=dict ) self._timeout = database.client.options.timeout + self._retry_policy = database.client._retry_policy if create or kwargs: if _IS_SYNC: diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index f3b35a0dcb..8abc7059d0 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -136,6 +136,7 @@ def __init__( self._name = name self._client: AsyncMongoClient[_DocumentType] = client self._timeout = client.options.timeout + self._retry_policy = client._retry_policy @property def client(self) -> AsyncMongoClient[_DocumentType]: diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 49d5ec604e..7a3b4d5ec7 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -34,6 +34,7 @@ PyMongoError, ) from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE +from pymongo.lock import _async_create_lock _IS_SYNC = False @@ -78,7 +79,10 @@ async def inner(*args: Any, **kwargs: Any) -> Any: _MAX_RETRIES = 3 _BACKOFF_INITIAL = 0.05 _BACKOFF_MAX = 10 -_TIME = time +# DRIVERS-3240 will determine these defaults. +DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 +DEFAULT_RETRY_TOKEN_RETURN = 0.1 +_TIME = time # Added so synchro script doesn't remove the time import. async def _backoff( @@ -89,23 +93,95 @@ async def _backoff( await asyncio.sleep(backoff) +class _TokenBucket: + """A token bucket implementation for rate limiting.""" + + def __init__( + self, + capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY, + return_rate: float = DEFAULT_RETRY_TOKEN_RETURN, + ): + self.lock = _async_create_lock() + self.capacity = capacity + # DRIVERS-3240 will determine how full the bucket should start. + self.tokens = capacity + self.return_rate = return_rate + + async def consume(self) -> bool: + """Consume a token from the bucket if available.""" + async with self.lock: + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + async def deposit(self, retry: bool = False) -> None: + """Deposit a token back into the bucket.""" + retry_token = 1 if retry else 0 + async with self.lock: + self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate) + + +class _RetryPolicy: + """A retry limiter that performs exponential backoff with jitter. + + Retry attempts are limited by a token bucket to prevent overwhelming the server during + a prolonged outage or high load. + """ + + def __init__( + self, + token_bucket: _TokenBucket, + attempts: int = _MAX_RETRIES, + backoff_initial: float = _BACKOFF_INITIAL, + backoff_max: float = _BACKOFF_MAX, + ): + self.token_bucket = token_bucket + self.attempts = attempts + self.backoff_initial = backoff_initial + self.backoff_max = backoff_max + + async def record_success(self, retry: bool): + """Record a successful operation.""" + await self.token_bucket.deposit(retry) + + async def backoff(self, attempt: int) -> None: + """Return the backoff duration for the given .""" + await _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + + async def should_retry(self, attempt: int) -> bool: + """Return if we have budget to retry and how long to backoff.""" + # TODO: Check CSOT deadline here. + if attempt > self.attempts: + return False + # Check token bucket last since we only want to consume a token if we actually retry. + if not await self.token_bucket.consume(): + # DRIVERS-3246 Improve diagnostics when this case happens. + # We could add info to the exception and log. + return False + return True + + def _retry_overload(func: F) -> F: @functools.wraps(func) - async def inner(*args: Any, **kwargs: Any) -> Any: + async def inner(self: Any, *args: Any, **kwargs: Any) -> Any: + retry_policy = self._retry_policy attempt = 0 while True: try: - return await func(*args, **kwargs) + res = await func(self, *args, **kwargs) + await retry_policy.record_success(retry=attempt > 0) + return res except PyMongoError as exc: if not exc.has_error_label("Retryable"): raise attempt += 1 - if attempt > _MAX_RETRIES: + if not await retry_policy.should_retry(attempt): raise # Implement exponential backoff on retry. if exc.has_error_label("SystemOverloaded"): - await _backoff(attempt) + await retry_policy.backoff(attempt) continue return cast(F, inner) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index ae6e819334..214159b35f 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -67,7 +67,11 @@ from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload +from pymongo.asynchronous.helpers import ( + _retry_overload, + _RetryPolicy, + _TokenBucket, +) from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.client_options import ClientOptions @@ -774,6 +778,7 @@ def __init__( self._timeout: float | None = None self._topology_settings: TopologySettings = None # type: ignore[assignment] self._event_listeners: _EventListeners | None = None + self._retry_policy = _RetryPolicy(_TokenBucket()) # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -2740,7 +2745,7 @@ def __init__( self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2775,7 +2780,9 @@ async def run(self) -> T: while True: self._check_last_error(check_csot=True) try: - return await self._read() if self._is_read else await self._write() + res = await self._read() if self._is_read else await self._write() + await self._retry_policy.record_success(self._attempt_number > 0) + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2846,13 +2853,13 @@ async def run(self) -> T: self._always_retryable = always_retryable if always_retryable: - if self._attempt_number > _MAX_RETRIES: + if not await self._retry_policy.should_retry(self._attempt_number): if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise if overloaded: - await _backoff(self._attempt_number) + await self._retry_policy.backoff(self._attempt_number) def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 3df867f7bc..324139d40a 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -256,6 +256,7 @@ def __init__( unicode_decode_error_handler="replace", document_class=dict ) self._timeout = database.client.options.timeout + self._retry_policy = database.client._retry_policy if create or kwargs: if _IS_SYNC: diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index d8b9ae6a10..62f8f69067 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -136,6 +136,7 @@ def __init__( self._name = name self._client: MongoClient[_DocumentType] = client self._timeout = client.options.timeout + self._retry_policy = client._retry_policy @property def client(self) -> MongoClient[_DocumentType]: diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 889382b19c..8462479f93 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -34,6 +34,7 @@ PyMongoError, ) from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE +from pymongo.lock import _create_lock _IS_SYNC = True @@ -78,7 +79,10 @@ def inner(*args: Any, **kwargs: Any) -> Any: _MAX_RETRIES = 3 _BACKOFF_INITIAL = 0.05 _BACKOFF_MAX = 10 -_TIME = time +# DRIVERS-3240 will determine these defaults. +DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 +DEFAULT_RETRY_TOKEN_RETURN = 0.1 +_TIME = time # Added so synchro script doesn't remove the time import. def _backoff( @@ -89,23 +93,95 @@ def _backoff( time.sleep(backoff) +class _TokenBucket: + """A token bucket implementation for rate limiting.""" + + def __init__( + self, + capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY, + return_rate: float = DEFAULT_RETRY_TOKEN_RETURN, + ): + self.lock = _create_lock() + self.capacity = capacity + # DRIVERS-3240 will determine how full the bucket should start. + self.tokens = capacity + self.return_rate = return_rate + + def consume(self) -> bool: + """Consume a token from the bucket if available.""" + with self.lock: + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + def deposit(self, retry: bool = False) -> None: + """Deposit a token back into the bucket.""" + retry_token = 1 if retry else 0 + with self.lock: + self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate) + + +class _RetryPolicy: + """A retry limiter that performs exponential backoff with jitter. + + Retry attempts are limited by a token bucket to prevent overwhelming the server during + a prolonged outage or high load. + """ + + def __init__( + self, + token_bucket: _TokenBucket, + attempts: int = _MAX_RETRIES, + backoff_initial: float = _BACKOFF_INITIAL, + backoff_max: float = _BACKOFF_MAX, + ): + self.token_bucket = token_bucket + self.attempts = attempts + self.backoff_initial = backoff_initial + self.backoff_max = backoff_max + + def record_success(self, retry: bool): + """Record a successful operation.""" + self.token_bucket.deposit(retry) + + def backoff(self, attempt: int) -> None: + """Return the backoff duration for the given .""" + _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + + def should_retry(self, attempt: int) -> bool: + """Return if we have budget to retry and how long to backoff.""" + # TODO: Check CSOT deadline here. + if attempt > self.attempts: + return False + # Check token bucket last since we only want to consume a token if we actually retry. + if not self.token_bucket.consume(): + # DRIVERS-3246 Improve diagnostics when this case happens. + # We could add info to the exception and log. + return False + return True + + def _retry_overload(func: F) -> F: @functools.wraps(func) - def inner(*args: Any, **kwargs: Any) -> Any: + def inner(self: Any, *args: Any, **kwargs: Any) -> Any: + retry_policy = self._retry_policy attempt = 0 while True: try: - return func(*args, **kwargs) + res = func(self, *args, **kwargs) + retry_policy.record_success(retry=attempt > 0) + return res except PyMongoError as exc: if not exc.has_error_label("Retryable"): raise attempt += 1 - if attempt > _MAX_RETRIES: + if not retry_policy.should_retry(attempt): raise # Implement exponential backoff on retry. if exc.has_error_label("SystemOverloaded"): - _backoff(attempt) + retry_policy.backoff(attempt) continue return cast(F, inner) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index dcd8c50cca..92c392154c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -110,7 +110,11 @@ from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload +from pymongo.synchronous.helpers import ( + _retry_overload, + _RetryPolicy, + _TokenBucket, +) from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -774,6 +778,7 @@ def __init__( self._timeout: float | None = None self._topology_settings: TopologySettings = None # type: ignore[assignment] self._event_listeners: _EventListeners | None = None + self._retry_policy = _RetryPolicy(_TokenBucket()) # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -2730,7 +2735,7 @@ def __init__( self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2765,7 +2770,9 @@ def run(self) -> T: while True: self._check_last_error(check_csot=True) try: - return self._read() if self._is_read else self._write() + res = self._read() if self._is_read else self._write() + self._retry_policy.record_success(self._attempt_number > 0) + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2836,13 +2843,13 @@ def run(self) -> T: self._always_retryable = always_retryable if always_retryable: - if self._attempt_number > _MAX_RETRIES: + if not self._retry_policy.should_retry(self._attempt_number): if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise if overloaded: - _backoff(self._attempt_number) + self._retry_policy.backoff(self._attempt_number) def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" diff --git a/test/asynchronous/test_backpressure.py b/test/asynchronous/test_backpressure.py index a9a6fb56f5..5e96bcb451 100644 --- a/test/asynchronous/test_backpressure.py +++ b/test/asynchronous/test_backpressure.py @@ -150,6 +150,28 @@ async def test_retry_overload_error_getMore(self): self.assertIn("Retryable", str(error.exception)) + @async_client_context.require_failCommand_appName + async def test_limit_retry_command(self): + client = await self.async_rs_or_single_client() + client._retry_policy.token_bucket.tokens = 1 + db = client.pymongo_test + await db.t.insert_one({"x": 1}) + + # Ensure command is retried once overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": 1} + async with self.fail_point(fail_many): + await db.command("find", "t") + + # Ensure command stops retrying when there are no tokens left. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": 2} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + if __name__ == "__main__": unittest.main() diff --git a/test/test_backpressure.py b/test/test_backpressure.py index 324dd6f15a..94ef8b0191 100644 --- a/test/test_backpressure.py +++ b/test/test_backpressure.py @@ -150,6 +150,28 @@ def test_retry_overload_error_getMore(self): self.assertIn("Retryable", str(error.exception)) + @client_context.require_failCommand_appName + def test_limit_retry_command(self): + client = self.rs_or_single_client() + client._retry_policy.token_bucket.tokens = 1 + db = client.pymongo_test + db.t.insert_one({"x": 1}) + + # Ensure command is retried once overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": 1} + with self.fail_point(fail_many): + db.command("find", "t") + + # Ensure command stops retrying when there are no tokens left. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": 2} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + if __name__ == "__main__": unittest.main() From 83798d11c4530a10b63734f80b010e7241369e39 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 21 Aug 2025 11:55:45 -0700 Subject: [PATCH 2/5] PYTHON-5506 Fix typing --- pymongo/asynchronous/helpers.py | 2 +- pymongo/synchronous/helpers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 7a3b4d5ec7..438a9db79a 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -141,7 +141,7 @@ def __init__( self.backoff_initial = backoff_initial self.backoff_max = backoff_max - async def record_success(self, retry: bool): + async def record_success(self, retry: bool) -> None: """Record a successful operation.""" await self.token_bucket.deposit(retry) diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 8462479f93..25f04b37ab 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -141,7 +141,7 @@ def __init__( self.backoff_initial = backoff_initial self.backoff_max = backoff_max - def record_success(self, retry: bool): + def record_success(self, retry: bool) -> None: """Record a successful operation.""" self.token_bucket.deposit(retry) From 6cdf7cbbaa5819605358f4dd90a52dce30062769 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 21 Aug 2025 13:31:26 -0700 Subject: [PATCH 3/5] PYTHON-5506 Add unittest for _RetryPolicy/_TokenBucket --- test/asynchronous/test_backpressure.py | 46 ++++++++++++++++++++++++-- test/test_backpressure.py | 44 ++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/test/asynchronous/test_backpressure.py b/test/asynchronous/test_backpressure.py index 5e96bcb451..de7a5795be 100644 --- a/test/asynchronous/test_backpressure.py +++ b/test/asynchronous/test_backpressure.py @@ -19,9 +19,15 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest - -from pymongo.asynchronous.helpers import _MAX_RETRIES +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + async_client_context, + unittest, +) + +from pymongo.asynchronous import helpers +from pymongo.asynchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket from pymongo.errors import PyMongoError _IS_SYNC = False @@ -173,5 +179,39 @@ async def test_limit_retry_command(self): self.assertIn("Retryable", str(error.exception)) +class TestRetryPolicy(AsyncPyMongoTestCase): + async def test_retry_policy(self): + capacity = 10 + retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity)) + self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES) + self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) + self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) + for i in range(1, helpers._MAX_RETRIES + 1): + self.assertTrue(await retry_policy.should_retry(i)) + self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1)) + for i in range(capacity - helpers._MAX_RETRIES): + self.assertTrue(await retry_policy.should_retry(1)) + # No tokens left, should not retry. + self.assertFalse(await retry_policy.should_retry(1)) + self.assertEqual(retry_policy.token_bucket.tokens, 0) + + # record_success should generate tokens. + for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)): + await retry_policy.record_success(retry=False) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) + for i in range(2): + self.assertTrue(await retry_policy.should_retry(1)) + self.assertFalse(await retry_policy.should_retry(1)) + + # Recording a successful retry should return 1 additional token. + await retry_policy.record_success(retry=True) + self.assertAlmostEqual( + retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN + ) + self.assertTrue(await retry_policy.should_retry(1)) + self.assertFalse(await retry_policy.should_retry(1)) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_backpressure.py b/test/test_backpressure.py index 94ef8b0191..3403542ed0 100644 --- a/test/test_backpressure.py +++ b/test/test_backpressure.py @@ -19,10 +19,16 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + client_context, + unittest, +) from pymongo.errors import PyMongoError -from pymongo.synchronous.helpers import _MAX_RETRIES +from pymongo.synchronous import helpers +from pymongo.synchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket _IS_SYNC = True @@ -173,5 +179,39 @@ def test_limit_retry_command(self): self.assertIn("Retryable", str(error.exception)) +class TestRetryPolicy(PyMongoTestCase): + def test_retry_policy(self): + capacity = 10 + retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity)) + self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES) + self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) + self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) + for i in range(1, helpers._MAX_RETRIES + 1): + self.assertTrue(retry_policy.should_retry(i)) + self.assertFalse(retry_policy.should_retry(helpers._MAX_RETRIES + 1)) + for i in range(capacity - helpers._MAX_RETRIES): + self.assertTrue(retry_policy.should_retry(1)) + # No tokens left, should not retry. + self.assertFalse(retry_policy.should_retry(1)) + self.assertEqual(retry_policy.token_bucket.tokens, 0) + + # record_success should generate tokens. + for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)): + retry_policy.record_success(retry=False) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) + for i in range(2): + self.assertTrue(retry_policy.should_retry(1)) + self.assertFalse(retry_policy.should_retry(1)) + + # Recording a successful retry should return 1 additional token. + retry_policy.record_success(retry=True) + self.assertAlmostEqual( + retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN + ) + self.assertTrue(retry_policy.should_retry(1)) + self.assertFalse(retry_policy.should_retry(1)) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + + if __name__ == "__main__": unittest.main() From 752d8c0f9f113fa7a6351ab4b85c5f1ed2477443 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 21 Aug 2025 16:55:45 -0700 Subject: [PATCH 4/5] PYTHON-5506 Check CSOT deadline before consuming a token --- pymongo/asynchronous/helpers.py | 30 ++++++++++++++++---------- pymongo/asynchronous/mongo_client.py | 8 +++++-- pymongo/synchronous/helpers.py | 28 +++++++++++++++--------- pymongo/synchronous/mongo_client.py | 8 +++++-- test/asynchronous/test_backpressure.py | 29 ++++++++++++++++++------- test/test_backpressure.py | 29 ++++++++++++++++++------- 6 files changed, 91 insertions(+), 41 deletions(-) diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 438a9db79a..477c8f902b 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -29,6 +29,7 @@ cast, ) +from pymongo import _csot from pymongo.errors import ( OperationFailure, PyMongoError, @@ -85,12 +86,11 @@ async def inner(*args: Any, **kwargs: Any) -> Any: _TIME = time # Added so synchro script doesn't remove the time import. -async def _backoff( +def _backoff( attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX -) -> None: +) -> float: jitter = random.random() # noqa: S311 - backoff = jitter * min(initial_delay * (2**attempt), max_delay) - await asyncio.sleep(backoff) + return jitter * min(initial_delay * (2**attempt), max_delay) class _TokenBucket: @@ -145,15 +145,20 @@ async def record_success(self, retry: bool) -> None: """Record a successful operation.""" await self.token_bucket.deposit(retry) - async def backoff(self, attempt: int) -> None: + def backoff(self, attempt: int) -> float: """Return the backoff duration for the given .""" - await _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) - async def should_retry(self, attempt: int) -> bool: + async def should_retry(self, attempt: int, delay: float) -> bool: """Return if we have budget to retry and how long to backoff.""" - # TODO: Check CSOT deadline here. if attempt > self.attempts: return False + + # If the delay would exceed the deadline, bail early before consuming a token. + if _csot.get_timeout(): + if time.monotonic() + delay > _csot.get_deadline(): + return False + # Check token bucket last since we only want to consume a token if we actually retry. if not await self.token_bucket.consume(): # DRIVERS-3246 Improve diagnostics when this case happens. @@ -176,12 +181,15 @@ async def inner(self: Any, *args: Any, **kwargs: Any) -> Any: if not exc.has_error_label("Retryable"): raise attempt += 1 - if not await retry_policy.should_retry(attempt): + delay = 0 + if exc.has_error_label("SystemOverloaded"): + delay = retry_policy.backoff(attempt) + if not await retry_policy.should_retry(attempt, delay): raise # Implement exponential backoff on retry. - if exc.has_error_label("SystemOverloaded"): - await retry_policy.backoff(attempt) + if delay: + await asyncio.sleep(delay) continue return cast(F, inner) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 214159b35f..47408f26db 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -35,6 +35,7 @@ import asyncio import contextlib import os +import time import warnings import weakref from collections import defaultdict @@ -174,6 +175,8 @@ UpdateMany, ] +_TIME = time # Added so synchro script doesn't remove the time import. + class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): HOST = "localhost" @@ -2853,13 +2856,14 @@ async def run(self) -> T: self._always_retryable = always_retryable if always_retryable: - if not await self._retry_policy.should_retry(self._attempt_number): + delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0 + if not await self._retry_policy.should_retry(self._attempt_number, delay): if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise if overloaded: - await self._retry_policy.backoff(self._attempt_number) + await asyncio.sleep(delay) def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 25f04b37ab..52bbbbbd6b 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -29,6 +29,7 @@ cast, ) +from pymongo import _csot from pymongo.errors import ( OperationFailure, PyMongoError, @@ -87,10 +88,9 @@ def inner(*args: Any, **kwargs: Any) -> Any: def _backoff( attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX -) -> None: +) -> float: jitter = random.random() # noqa: S311 - backoff = jitter * min(initial_delay * (2**attempt), max_delay) - time.sleep(backoff) + return jitter * min(initial_delay * (2**attempt), max_delay) class _TokenBucket: @@ -145,15 +145,20 @@ def record_success(self, retry: bool) -> None: """Record a successful operation.""" self.token_bucket.deposit(retry) - def backoff(self, attempt: int) -> None: + def backoff(self, attempt: int) -> float: """Return the backoff duration for the given .""" - _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) - def should_retry(self, attempt: int) -> bool: + def should_retry(self, attempt: int, delay: float) -> bool: """Return if we have budget to retry and how long to backoff.""" - # TODO: Check CSOT deadline here. if attempt > self.attempts: return False + + # If the delay would exceed the deadline, bail early before consuming a token. + if _csot.get_timeout(): + if time.monotonic() + delay > _csot.get_deadline(): + return False + # Check token bucket last since we only want to consume a token if we actually retry. if not self.token_bucket.consume(): # DRIVERS-3246 Improve diagnostics when this case happens. @@ -176,12 +181,15 @@ def inner(self: Any, *args: Any, **kwargs: Any) -> Any: if not exc.has_error_label("Retryable"): raise attempt += 1 - if not retry_policy.should_retry(attempt): + delay = 0 + if exc.has_error_label("SystemOverloaded"): + delay = retry_policy.backoff(attempt) + if not retry_policy.should_retry(attempt, delay): raise # Implement exponential backoff on retry. - if exc.has_error_label("SystemOverloaded"): - retry_policy.backoff(attempt) + if delay: + time.sleep(delay) continue return cast(F, inner) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 92c392154c..b4116375c8 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -35,6 +35,7 @@ import asyncio import contextlib import os +import time import warnings import weakref from collections import defaultdict @@ -171,6 +172,8 @@ UpdateMany, ] +_TIME = time # Added so synchro script doesn't remove the time import. + class MongoClient(common.BaseObject, Generic[_DocumentType]): HOST = "localhost" @@ -2843,13 +2846,14 @@ def run(self) -> T: self._always_retryable = always_retryable if always_retryable: - if not self._retry_policy.should_retry(self._attempt_number): + delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0 + if not self._retry_policy.should_retry(self._attempt_number, delay): if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise if overloaded: - self._retry_policy.backoff(self._attempt_number) + time.sleep(delay) def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" diff --git a/test/asynchronous/test_backpressure.py b/test/asynchronous/test_backpressure.py index de7a5795be..598236dbfe 100644 --- a/test/asynchronous/test_backpressure.py +++ b/test/asynchronous/test_backpressure.py @@ -15,8 +15,11 @@ """Test Client Backpressure spec.""" from __future__ import annotations +import asyncio import sys +import pymongo + sys.path[0:0] = [""] from test.asynchronous import ( @@ -187,12 +190,12 @@ async def test_retry_policy(self): self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) for i in range(1, helpers._MAX_RETRIES + 1): - self.assertTrue(await retry_policy.should_retry(i)) - self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1)) + self.assertTrue(await retry_policy.should_retry(i, 0)) + self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0)) for i in range(capacity - helpers._MAX_RETRIES): - self.assertTrue(await retry_policy.should_retry(1)) + self.assertTrue(await retry_policy.should_retry(1, 0)) # No tokens left, should not retry. - self.assertFalse(await retry_policy.should_retry(1)) + self.assertFalse(await retry_policy.should_retry(1, 0)) self.assertEqual(retry_policy.token_bucket.tokens, 0) # record_success should generate tokens. @@ -200,18 +203,28 @@ async def test_retry_policy(self): await retry_policy.record_success(retry=False) self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) for i in range(2): - self.assertTrue(await retry_policy.should_retry(1)) - self.assertFalse(await retry_policy.should_retry(1)) + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertFalse(await retry_policy.should_retry(1, 0)) # Recording a successful retry should return 1 additional token. await retry_policy.record_success(retry=True) self.assertAlmostEqual( retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN ) - self.assertTrue(await retry_policy.should_retry(1)) - self.assertFalse(await retry_policy.should_retry(1)) + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertFalse(await retry_policy.should_retry(1, 0)) self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + async def test_retry_policy_csot(self): + retry_policy = _RetryPolicy(_TokenBucket()) + self.assertTrue(await retry_policy.should_retry(1, 0.5)) + with pymongo.timeout(0.5): + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertTrue(await retry_policy.should_retry(1, 0.1)) + # Would exceed the timeout, should not retry. + self.assertFalse(await retry_policy.should_retry(1, 1.0)) + self.assertTrue(await retry_policy.should_retry(1, 1.0)) + if __name__ == "__main__": unittest.main() diff --git a/test/test_backpressure.py b/test/test_backpressure.py index 3403542ed0..182ce424a9 100644 --- a/test/test_backpressure.py +++ b/test/test_backpressure.py @@ -15,8 +15,11 @@ """Test Client Backpressure spec.""" from __future__ import annotations +import asyncio import sys +import pymongo + sys.path[0:0] = [""] from test import ( @@ -187,12 +190,12 @@ def test_retry_policy(self): self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) for i in range(1, helpers._MAX_RETRIES + 1): - self.assertTrue(retry_policy.should_retry(i)) - self.assertFalse(retry_policy.should_retry(helpers._MAX_RETRIES + 1)) + self.assertTrue(retry_policy.should_retry(i, 0)) + self.assertFalse(retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0)) for i in range(capacity - helpers._MAX_RETRIES): - self.assertTrue(retry_policy.should_retry(1)) + self.assertTrue(retry_policy.should_retry(1, 0)) # No tokens left, should not retry. - self.assertFalse(retry_policy.should_retry(1)) + self.assertFalse(retry_policy.should_retry(1, 0)) self.assertEqual(retry_policy.token_bucket.tokens, 0) # record_success should generate tokens. @@ -200,18 +203,28 @@ def test_retry_policy(self): retry_policy.record_success(retry=False) self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) for i in range(2): - self.assertTrue(retry_policy.should_retry(1)) - self.assertFalse(retry_policy.should_retry(1)) + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertFalse(retry_policy.should_retry(1, 0)) # Recording a successful retry should return 1 additional token. retry_policy.record_success(retry=True) self.assertAlmostEqual( retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN ) - self.assertTrue(retry_policy.should_retry(1)) - self.assertFalse(retry_policy.should_retry(1)) + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertFalse(retry_policy.should_retry(1, 0)) self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + def test_retry_policy_csot(self): + retry_policy = _RetryPolicy(_TokenBucket()) + self.assertTrue(retry_policy.should_retry(1, 0.5)) + with pymongo.timeout(0.5): + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertTrue(retry_policy.should_retry(1, 0.1)) + # Would exceed the timeout, should not retry. + self.assertFalse(retry_policy.should_retry(1, 1.0)) + self.assertTrue(retry_policy.should_retry(1, 1.0)) + if __name__ == "__main__": unittest.main() From 949f3517497517c43bd518daed3cf3e2535017d5 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Fri, 22 Aug 2025 09:48:44 -0700 Subject: [PATCH 5/5] PYTHON-5506 Cleaner time import --- pymongo/asynchronous/helpers.py | 3 +-- pymongo/asynchronous/mongo_client.py | 4 +--- pymongo/synchronous/helpers.py | 3 +-- pymongo/synchronous/mongo_client.py | 4 +--- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 477c8f902b..6ef3beacf5 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -21,7 +21,7 @@ import random import socket import sys -import time +import time as time # noqa: PLC0414 # needed in sync version from typing import ( Any, Callable, @@ -83,7 +83,6 @@ async def inner(*args: Any, **kwargs: Any) -> Any: # DRIVERS-3240 will determine these defaults. DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 DEFAULT_RETRY_TOKEN_RETURN = 0.1 -_TIME = time # Added so synchro script doesn't remove the time import. def _backoff( diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 47408f26db..d9994e9902 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -35,7 +35,7 @@ import asyncio import contextlib import os -import time +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -175,8 +175,6 @@ UpdateMany, ] -_TIME = time # Added so synchro script doesn't remove the time import. - class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): HOST = "localhost" diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 52bbbbbd6b..0a2cd71062 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -21,7 +21,7 @@ import random import socket import sys -import time +import time as time # noqa: PLC0414 # needed in sync version from typing import ( Any, Callable, @@ -83,7 +83,6 @@ def inner(*args: Any, **kwargs: Any) -> Any: # DRIVERS-3240 will determine these defaults. DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 DEFAULT_RETRY_TOKEN_RETURN = 0.1 -_TIME = time # Added so synchro script doesn't remove the time import. def _backoff( diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index b4116375c8..9beda807ef 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -35,7 +35,7 @@ import asyncio import contextlib import os -import time +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -172,8 +172,6 @@ UpdateMany, ] -_TIME = time # Added so synchro script doesn't remove the time import. - class MongoClient(common.BaseObject, Generic[_DocumentType]): HOST = "localhost"