Skip to content

Commit a477453

Browse files
authored
Fix: failed requests due to race conditions in the job queue vs job progress (#376)
* fix: JobsProgress is now asyncio-safe. This prevents any race conditions when job_progress.get_job_count() was checked before getting more jobs. * fix: strict jobs count for evaluating if new jobs can be taken `jobs_needed = concurrency - queue - in progress` * debug: better debug logs * improved unit tests
1 parent 85a402c commit a477453

File tree

14 files changed

+117
-89
lines changed

14 files changed

+117
-89
lines changed

.github/workflows/CD-publish_to_pypi.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ jobs:
1313

1414
steps:
1515
- uses: actions/checkout@v4
16-
- name: Set up Python 3.11.0
16+
- name: Set up Python 3.11.10
1717
uses: actions/setup-python@v5
1818
with:
19-
python-version: 3.11.0
19+
python-version: 3.11.10
2020

2121
- name: Install pypa/build
2222
run: >-

.github/workflows/CD-test_publish_to_pypi.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ jobs:
1414

1515
steps:
1616
- uses: actions/checkout@v4
17-
- name: Set up Python 3.11.0
17+
- name: Set up Python 3.11.10
1818
uses: actions/setup-python@v5
1919
with:
20-
python-version: 3.11.0
20+
python-version: 3.11.10
2121

2222
- name: Install pypa/build
2323
run: >-

.github/workflows/CI-pytests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
run_tests:
1616
strategy:
1717
matrix:
18-
python-version: [3.8, 3.9, 3.10.12, 3.11.0]
18+
python-version: [3.8, 3.9, 3.10.15, 3.11.10]
1919
runs-on: ubuntu-latest
2020

2121
steps:

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ runpod = "runpod.cli.entry:runpod_cli"
5454
test = [
5555
"asynctest",
5656
"nest_asyncio",
57-
"pylint==3.3.1",
5857
"pytest-asyncio",
5958
"pytest-cov",
6059
"pytest-timeout",

runpod/serverless/modules/rp_fastapi.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,12 @@ async def _realtime(self, job: Job):
286286
Performs model inference on the input data using the provided handler.
287287
If handler is not provided, returns an error message.
288288
"""
289-
job_list.add(job.id)
289+
await job_list.add(job.id)
290290

291291
# Process the job using the provided handler, passing in the job input.
292292
job_results = await run_job(self.config["handler"], job.__dict__)
293293

294-
job_list.remove(job.id)
294+
await job_list.remove(job.id)
295295

296296
# Return the results of the job processing.
297297
return jsonable_encoder(job_results)
@@ -304,7 +304,7 @@ async def _realtime(self, job: Job):
304304
async def _sim_run(self, job_request: DefaultRequest) -> JobOutput:
305305
"""Development endpoint to simulate run behavior."""
306306
assigned_job_id = f"test-{uuid.uuid4()}"
307-
job_list.add({
307+
await job_list.add({
308308
"id": assigned_job_id,
309309
"input": job_request.input,
310310
"webhook": job_request.webhook
@@ -345,7 +345,7 @@ async def _sim_runsync(self, job_request: DefaultRequest) -> JobOutput:
345345
# ---------------------------------- stream ---------------------------------- #
346346
async def _sim_stream(self, job_id: str) -> StreamOutput:
347347
"""Development endpoint to simulate stream behavior."""
348-
stashed_job = job_list.get(job_id)
348+
stashed_job = await job_list.get(job_id)
349349
if stashed_job is None:
350350
return jsonable_encoder(
351351
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
@@ -367,7 +367,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
367367
}
368368
)
369369

370-
job_list.remove(job.id)
370+
await job_list.remove(job.id)
371371

372372
if stashed_job.webhook:
373373
thread = threading.Thread(
@@ -384,7 +384,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
384384
# ---------------------------------- status ---------------------------------- #
385385
async def _sim_status(self, job_id: str) -> JobOutput:
386386
"""Development endpoint to simulate status behavior."""
387-
stashed_job = job_list.get(job_id)
387+
stashed_job = await job_list.get(job_id)
388388
if stashed_job is None:
389389
return jsonable_encoder(
390390
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
@@ -400,7 +400,7 @@ async def _sim_status(self, job_id: str) -> JobOutput:
400400
else:
401401
job_output = await run_job(self.config["handler"], job.__dict__)
402402

403-
job_list.remove(job.id)
403+
await job_list.remove(job.id)
404404

405405
if job_output.get("error", None):
406406
return jsonable_encoder(

runpod/serverless/modules/rp_job.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import traceback
99
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, List
1010

11+
import aiohttp
12+
1113
from runpod.http_client import ClientSession, TooManyRequests
1214
from runpod.serverless.modules.rp_logger import RunPodLogger
1315

@@ -34,15 +36,18 @@ def _job_get_url(batch_size: int = 1):
3436
Returns:
3537
str: The prepared URL for the 'get' request to the serverless API.
3638
"""
37-
job_in_progress = "1" if job_progress.get_job_count() else "0"
3839

3940
if batch_size > 1:
4041
job_take_url = JOB_GET_URL.replace("/job-take/", "/job-take-batch/")
4142
job_take_url += f"&batch_size={batch_size}"
4243
else:
4344
job_take_url = JOB_GET_URL
4445

45-
return job_take_url + f"&job_in_progress={job_in_progress}"
46+
job_in_progress = "1" if job_progress.get_job_list() else "0"
47+
job_take_url += f"&job_in_progress={job_in_progress}"
48+
49+
log.debug(f"rp_job | get_job: {job_take_url}")
50+
return job_take_url
4651

4752

4853
async def get_job(
@@ -60,14 +65,14 @@ async def get_job(
6065
num_jobs (int): The number of jobs to get.
6166
"""
6267
async with session.get(_job_get_url(num_jobs)) as response:
63-
log.debug(f"- Response: {type(response).__name__} {response.status}")
68+
log.debug(f"rp_job | Response: {type(response).__name__} {response.status}")
6469

6570
if response.status == 204:
66-
log.debug("- No content, no job to process.")
71+
log.debug("rp_job | Received 204 status, no jobs.")
6772
return
6873

6974
if response.status == 400:
70-
log.debug("- Received 400 status, expected when FlashBoot is enabled.")
75+
log.debug("rp_job | Received 400 status, expected when FlashBoot is enabled.")
7176
return
7277

7378
if response.status == 429:
@@ -83,16 +88,23 @@ async def get_job(
8388

8489
# Verify if the content type is JSON
8590
if response.content_type != "application/json":
86-
log.debug(f"- Unexpected content type: {response.content_type}")
91+
log.debug(f"rp_job | Unexpected content type: {response.content_type}")
8792
return
8893

8994
# Check if there is a non-empty content to parse
9095
if response.content_length == 0:
91-
log.debug("- No content to parse.")
96+
log.debug("rp_job | No content to parse.")
9297
return
9398

94-
jobs = await response.json()
95-
log.debug(f"- Received Job(s)")
99+
try:
100+
jobs = await response.json()
101+
log.debug("rp_job | Received Job(s)")
102+
except aiohttp.ContentTypeError:
103+
log.debug(f"rp_job | Response content is not valid JSON. {response.content}")
104+
return
105+
except ValueError as json_error:
106+
log.debug(f"rp_job | Failed to parse JSON response: {json_error}")
107+
return
96108

97109
# legacy job-take API
98110
if isinstance(jobs, dict):

runpod/serverless/modules/rp_scale.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,20 @@ async def get_jobs(self, session: ClientSession):
114114
Adds jobs to the JobsQueue
115115
"""
116116
while self.is_alive():
117-
log.debug(f"JobScaler.get_jobs | Jobs in progress: {job_progress.get_job_count()}")
117+
log.debug("JobScaler.get_jobs | Starting job acquisition.")
118118

119119
self.current_concurrency = self.concurrency_modifier(
120120
self.current_concurrency
121121
)
122-
log.debug(f"JobScaler.get_jobs | Concurrency set to: {self.current_concurrency}")
122+
log.debug(f"JobScaler.get_jobs | current Concurrency set to: {self.current_concurrency}")
123123

124-
jobs_needed = self.current_concurrency - job_progress.get_job_count()
124+
current_progress_count = await job_progress.get_job_count()
125+
log.debug(f"JobScaler.get_jobs | current Jobs in progress: {current_progress_count}")
126+
127+
current_queue_count = job_list.get_job_count()
128+
log.debug(f"JobScaler.get_jobs | current Jobs in queue: {current_queue_count}")
129+
130+
jobs_needed = self.current_concurrency - current_progress_count - current_queue_count
125131
if jobs_needed <= 0:
126132
log.debug("JobScaler.get_jobs | Queue is full. Retrying soon.")
127133
await asyncio.sleep(1) # don't go rapidly
@@ -197,10 +203,9 @@ async def handle_job(self, session: ClientSession, job: dict):
197203
"""
198204
Process an individual job. This function is run concurrently for multiple jobs.
199205
"""
200-
log.debug(f"JobScaler.handle_job | {job}")
201-
job_progress.add(job)
202-
203206
try:
207+
await job_progress.add(job)
208+
204209
await handle_job(session, self.config, job)
205210

206211
if self.config.get("refresh_worker", False):
@@ -215,4 +220,4 @@ async def handle_job(self, session: ClientSession, job: dict):
215220
job_list.task_done()
216221

217222
# Job is no longer in progress
218-
job_progress.remove(job["id"])
223+
await job_progress.remove(job["id"])

runpod/serverless/modules/worker_state.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77
import uuid
88
from typing import Any, Dict, Optional
9-
from asyncio import Queue
9+
from asyncio import Queue, Lock
1010

1111
from .rp_logger import RunPodLogger
1212

@@ -72,10 +72,19 @@ def __new__(cls):
7272
JobsProgress._instance = set.__new__(cls)
7373
return JobsProgress._instance
7474

75+
def __init__(self):
76+
if not hasattr(self, "_lock"):
77+
# Initialize the lock once
78+
self._lock = Lock()
79+
7580
def __repr__(self) -> str:
7681
return f"<{self.__class__.__name__}>: {self.get_job_list()}"
7782

78-
def add(self, element: Any):
83+
async def clear(self) -> None:
84+
async with self._lock:
85+
return super().clear()
86+
87+
async def add(self, element: Any):
7988
"""
8089
Adds a Job object to the set.
8190
@@ -92,16 +101,17 @@ def add(self, element: Any):
92101
if not isinstance(element, Job):
93102
raise TypeError("Only Job objects can be added to JobsProgress.")
94103

95-
log.debug(f"JobsProgress.add | {element}")
96-
return super().add(element)
104+
async with self._lock:
105+
log.debug(f"JobsProgress.add", element.id)
106+
super().add(element)
97107

98-
def remove(self, element: Any):
108+
async def remove(self, element: Any):
99109
"""
100-
Adds a Job object to the set.
110+
Removes a Job object from the set.
101111
102-
If the added element is a string, then `Job(id=element)` is added
112+
If the element is a string, then `Job(id=element)` is removed
103113
104-
If the added element is a dict, that `Job(**element)` is added
114+
If the element is a dict, then `Job(**element)` is removed
105115
"""
106116
if isinstance(element, str):
107117
element = Job(id=element)
@@ -112,34 +122,37 @@ def remove(self, element: Any):
112122
if not isinstance(element, Job):
113123
raise TypeError("Only Job objects can be removed from JobsProgress.")
114124

115-
log.debug(f"JobsProgress.remove | {element}")
116-
return super().remove(element)
125+
async with self._lock:
126+
log.debug(f"JobsProgress.remove", element.id)
127+
return super().discard(element)
117128

118-
def get(self, element: Any) -> Job:
129+
async def get(self, element: Any) -> Job:
119130
if isinstance(element, str):
120131
element = Job(id=element)
121132

122133
if not isinstance(element, Job):
123134
raise TypeError("Only Job objects can be retrieved from JobsProgress.")
124135

125-
for job in self:
126-
if job == element:
127-
return job
136+
async with self._lock:
137+
for job in self:
138+
if job == element:
139+
return job
128140

129141
def get_job_list(self) -> str:
130142
"""
131143
Returns the list of job IDs as comma-separated string.
132144
"""
133-
if not self.get_job_count():
145+
if not len(self):
134146
return None
135147

136148
return ",".join(str(job) for job in self)
137149

138-
def get_job_count(self) -> int:
150+
async def get_job_count(self) -> int:
139151
"""
140-
Returns the number of jobs.
152+
Returns the number of jobs in a thread-safe manner.
141153
"""
142-
return len(self)
154+
async with self._lock:
155+
return len(self)
143156

144157

145158
class JobsQueue(Queue):
@@ -162,7 +175,7 @@ async def add_job(self, job: dict):
162175
If the queue is full, wait until a free
163176
slot is available before adding item.
164177
"""
165-
log.debug(f"JobsQueue.add_job | {job}")
178+
log.debug(f"JobsQueue.add_job", job["id"])
166179
return await self.put(job)
167180

168181
async def get_job(self) -> dict:

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"test": [
1818
"asynctest",
1919
"nest_asyncio",
20-
"pylint==3.3.1",
2120
"pytest",
2221
"pytest-cov",
2322
"pytest-timeout",

tests/test_serverless/test_modules/test_http.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ def __init__(self, *args, **kwargs):
3030
class TestHTTP(unittest.IsolatedAsyncioTestCase):
3131
"""Test HTTP module."""
3232

33-
def setUp(self) -> None:
33+
async def asyncSetUp(self) -> None:
3434
self.job = {"id": "test_id"}
3535
self.job_data = {"output": "test_output"}
3636

37-
def tearDown(self) -> None:
37+
async def asyncTearDown(self) -> None:
3838
gc.collect()
3939

4040
async def test_send_result(self):

0 commit comments

Comments
 (0)