6
6
import time
7
7
import uuid
8
8
from typing import Any , Dict , Optional
9
- from asyncio import Queue
9
+ from asyncio import Queue , Lock
10
10
11
11
from .rp_logger import RunPodLogger
12
12
@@ -72,10 +72,19 @@ def __new__(cls):
72
72
JobsProgress ._instance = set .__new__ (cls )
73
73
return JobsProgress ._instance
74
74
75
+ def __init__ (self ):
76
+ if not hasattr (self , "_lock" ):
77
+ # Initialize the lock once
78
+ self ._lock = Lock ()
79
+
75
80
def __repr__ (self ) -> str :
76
81
return f"<{ self .__class__ .__name__ } >: { self .get_job_list ()} "
77
82
78
- def add (self , element : Any ):
83
+ async def clear (self ) -> None :
84
+ async with self ._lock :
85
+ return super ().clear ()
86
+
87
+ async def add (self , element : Any ):
79
88
"""
80
89
Adds a Job object to the set.
81
90
@@ -92,16 +101,17 @@ def add(self, element: Any):
92
101
if not isinstance (element , Job ):
93
102
raise TypeError ("Only Job objects can be added to JobsProgress." )
94
103
95
- log .debug (f"JobsProgress.add | { element } " )
96
- return super ().add (element )
104
+ async with self ._lock :
105
+ log .debug (f"JobsProgress.add" , element .id )
106
+ super ().add (element )
97
107
98
- def remove (self , element : Any ):
108
+ async def remove (self , element : Any ):
99
109
"""
100
- Adds a Job object to the set.
110
+ Removes a Job object from the set.
101
111
102
- If the added element is a string, then `Job(id=element)` is added
112
+ If the element is a string, then `Job(id=element)` is removed
103
113
104
- If the added element is a dict, that `Job(**element)` is added
114
+ If the element is a dict, then `Job(**element)` is removed
105
115
"""
106
116
if isinstance (element , str ):
107
117
element = Job (id = element )
@@ -112,34 +122,37 @@ def remove(self, element: Any):
112
122
if not isinstance (element , Job ):
113
123
raise TypeError ("Only Job objects can be removed from JobsProgress." )
114
124
115
- log .debug (f"JobsProgress.remove | { element } " )
116
- return super ().remove (element )
125
+ async with self ._lock :
126
+ log .debug (f"JobsProgress.remove" , element .id )
127
+ return super ().discard (element )
117
128
118
- def get (self , element : Any ) -> Job :
129
+ async def get (self , element : Any ) -> Job :
119
130
if isinstance (element , str ):
120
131
element = Job (id = element )
121
132
122
133
if not isinstance (element , Job ):
123
134
raise TypeError ("Only Job objects can be retrieved from JobsProgress." )
124
135
125
- for job in self :
126
- if job == element :
127
- return job
136
+ async with self ._lock :
137
+ for job in self :
138
+ if job == element :
139
+ return job
128
140
129
141
def get_job_list (self ) -> str :
130
142
"""
131
143
Returns the list of job IDs as comma-separated string.
132
144
"""
133
- if not self . get_job_count ( ):
145
+ if not len ( self ):
134
146
return None
135
147
136
148
return "," .join (str (job ) for job in self )
137
149
138
- def get_job_count (self ) -> int :
150
+ async def get_job_count (self ) -> int :
139
151
"""
140
- Returns the number of jobs.
152
+ Returns the number of jobs in a thread-safe manner .
141
153
"""
142
- return len (self )
154
+ async with self ._lock :
155
+ return len (self )
143
156
144
157
145
158
class JobsQueue (Queue ):
@@ -162,7 +175,7 @@ async def add_job(self, job: dict):
162
175
If the queue is full, wait until a free
163
176
slot is available before adding item.
164
177
"""
165
- log .debug (f"JobsQueue.add_job | { job } " )
178
+ log .debug (f"JobsQueue.add_job" , job [ "id" ] )
166
179
return await self .put (job )
167
180
168
181
async def get_job (self ) -> dict :
0 commit comments