1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
- from collections .abc import AsyncIterable , AsyncIterator , Iterable
5
- from typing import TYPE_CHECKING
4
+ import contextlib
5
+ from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Iterable
6
+ from logging import getLogger
7
+ from typing import Annotated
6
8
9
+ from pydantic import BaseModel , ConfigDict , Field
7
10
from typing_extensions import override
8
11
12
+ from crawlee ._request import Request
9
13
from crawlee ._utils .docs import docs_group
10
14
from crawlee .request_loaders ._request_loader import RequestLoader
11
15
12
- if TYPE_CHECKING :
13
- from crawlee ._request import Request
16
+ logger = getLogger (__name__ )
17
+
18
+
19
+ class RequestListState (BaseModel ):
20
+ model_config = ConfigDict (populate_by_name = True )
21
+
22
+ next_index : Annotated [int , Field (alias = 'nextIndex' )] = 0
23
+ next_unique_key : Annotated [str | None , Field (alias = 'nextUniqueKey' )] = None
24
+ in_progress : Annotated [set [str ], Field (alias = 'inProgress' )] = set ()
25
+
26
+
27
+ class RequestListData (BaseModel ):
28
+ requests : Annotated [list [Request ], Field ()]
14
29
15
30
16
31
@docs_group ('Request loaders' )
17
32
class RequestList (RequestLoader ):
18
- """Represents a (potentially very large) list of URLs to crawl.
19
-
20
- Disclaimer: The `RequestList` class is in its early version and is not fully implemented. It is currently
21
- intended mainly for testing purposes and small-scale projects. The current implementation is only in-memory
22
- storage and is very limited. It will be (re)implemented in the future. For more details, see the GitHub issue:
23
- https://github.com/apify/crawlee-python/issues/99. For production usage we recommend to use the `RequestQueue`.
24
- """
33
+ """Represents a (potentially very large) list of URLs to crawl."""
25
34
26
35
def __init__ (
27
36
self ,
28
37
requests : Iterable [str | Request ] | AsyncIterable [str | Request ] | None = None ,
29
38
name : str | None = None ,
39
+ persist_state_key : str | None = None ,
40
+ persist_requests_key : str | None = None ,
30
41
) -> None :
31
42
"""Initialize a new instance.
32
43
33
44
Args:
34
45
requests: The request objects (or their string representations) to be added to the provider.
35
46
name: A name of the request list.
47
+ persist_state_key: A key for persisting the progress information of the RequestList.
48
+ If you do not pass a key but pass a `name`, a key will be derived using the name.
49
+ Otherwise, state will not be persisted.
50
+ persist_requests_key: A key for persisting the request data loaded from the `requests` iterator.
51
+ If specified, the request data will be stored in the KeyValueStore to make sure that they don't change
52
+ over time. This is useful if the `requests` iterator pulls the data dynamically.
36
53
"""
54
+ from crawlee ._utils .recoverable_state import RecoverableState # noqa: PLC0415
55
+
37
56
self ._name = name
38
57
self ._handled_count = 0
39
58
self ._assumed_total_count = 0
40
59
41
- self ._in_progress = set [str ]()
42
- self ._next : Request | None = None
60
+ self ._next : tuple [Request | None , Request | None ] = (None , None )
61
+
62
+ if persist_state_key is None and name is not None :
63
+ persist_state_key = f'SDK_REQUEST_LIST_STATE-{ name } '
64
+
65
+ self ._state = RecoverableState (
66
+ default_state = RequestListState (),
67
+ persistence_enabled = bool (persist_state_key ),
68
+ persist_state_key = persist_state_key or '' ,
69
+ logger = logger ,
70
+ )
71
+
72
+ self ._persist_request_data = bool (persist_requests_key )
73
+
74
+ self ._requests_data = RecoverableState (
75
+ default_state = RequestListData (requests = []),
76
+ # With request data persistence enabled, a snapshot of the requests will be done on initialization
77
+ persistence_enabled = 'explicit_only' if self ._persist_request_data else False ,
78
+ persist_state_key = persist_requests_key or '' ,
79
+ logger = logger ,
80
+ )
43
81
44
82
if isinstance (requests , AsyncIterable ):
45
83
self ._requests = requests .__aiter__ ()
@@ -50,6 +88,53 @@ def __init__(
50
88
51
89
self ._requests_lock : asyncio .Lock | None = None
52
90
91
+ async def _get_state (self ) -> RequestListState :
92
+ # If state is already initialized, we are done
93
+ if self ._state .is_initialized :
94
+ return self ._state .current_value
95
+
96
+ # Initialize recoverable state
97
+ await self ._state .initialize ()
98
+ await self ._requests_data .initialize ()
99
+
100
+ # Initialize lock if necessary
101
+ if self ._requests_lock is None :
102
+ self ._requests_lock = asyncio .Lock ()
103
+
104
+ # If the RequestList is configured to persist request data, ensure that a copy of request data is used
105
+ if self ._persist_request_data :
106
+ async with self ._requests_lock :
107
+ if not await self ._requests_data .has_persisted_state ():
108
+ self ._requests_data .current_value .requests = [
109
+ request if isinstance (request , Request ) else Request .from_url (request )
110
+ async for request in self ._requests
111
+ ]
112
+ await self ._requests_data .persist_state ()
113
+
114
+ self ._requests = self ._iterate_in_threadpool (
115
+ self ._requests_data .current_value .requests [self ._state .current_value .next_index :]
116
+ )
117
+ # If not using persistent request data, advance the request iterator
118
+ else :
119
+ async with self ._requests_lock :
120
+ for _ in range (self ._state .current_value .next_index ):
121
+ with contextlib .suppress (StopAsyncIteration ):
122
+ await self ._requests .__anext__ ()
123
+
124
+ # Check consistency of the stored state and the request iterator
125
+ if (unique_key_to_check := self ._state .current_value .next_unique_key ) is not None :
126
+ await self ._ensure_next_request ()
127
+
128
+ next_unique_key = self ._next [0 ].unique_key if self ._next [0 ] is not None else None
129
+ if next_unique_key != unique_key_to_check :
130
+ raise RuntimeError (
131
+ f"""Mismatch at index {
132
+ self ._state .current_value .next_index
133
+ } in persisted requests - Expected unique key `{ unique_key_to_check } `, got `{ next_unique_key } `"""
134
+ )
135
+
136
+ return self ._state .current_value
137
+
53
138
@property
54
139
def name (self ) -> str | None :
55
140
return self ._name
@@ -65,42 +150,62 @@ async def get_total_count(self) -> int:
65
150
@override
66
151
async def is_empty (self ) -> bool :
67
152
await self ._ensure_next_request ()
68
- return self ._next is None
153
+ return self ._next [ 0 ] is None
69
154
70
155
@override
71
156
async def is_finished (self ) -> bool :
72
- return len (self ._in_progress ) == 0 and await self .is_empty ()
157
+ state = await self ._get_state ()
158
+ return len (state .in_progress ) == 0 and await self .is_empty ()
73
159
74
160
@override
75
161
async def fetch_next_request (self ) -> Request | None :
162
+ await self ._get_state ()
76
163
await self ._ensure_next_request ()
77
164
78
- if self ._next is None :
165
+ if self ._next [ 0 ] is None :
79
166
return None
80
167
81
- self ._in_progress .add (self ._next .id )
168
+ state = await self ._get_state ()
169
+ state .in_progress .add (self ._next [0 ].id )
82
170
self ._assumed_total_count += 1
83
171
84
- next_request = self ._next
85
- self ._next = None
172
+ next_request = self ._next [0 ]
173
+ if next_request is not None :
174
+ state .next_index += 1
175
+ state .next_unique_key = self ._next [1 ].unique_key if self ._next [1 ] is not None else None
176
+
177
+ self ._next = (self ._next [1 ], None )
178
+ await self ._ensure_next_request ()
86
179
87
180
return next_request
88
181
89
182
@override
90
183
async def mark_request_as_handled (self , request : Request ) -> None :
91
184
self ._handled_count += 1
92
- self ._in_progress .remove (request .id )
185
+ state = await self ._get_state ()
186
+ state .in_progress .remove (request .id )
93
187
94
188
async def _ensure_next_request (self ) -> None :
189
+ await self ._get_state ()
190
+
95
191
if self ._requests_lock is None :
96
192
self ._requests_lock = asyncio .Lock ()
97
193
98
- try :
99
- async with self ._requests_lock :
100
- if self ._next is None :
101
- self ._next = self ._transform_request (await self ._requests .__anext__ ())
102
- except StopAsyncIteration :
103
- self ._next = None
194
+ async with self ._requests_lock :
195
+ if None in self ._next :
196
+ if self ._next [0 ] is None :
197
+ to_enqueue = [item async for item in self ._dequeue_requests (2 )]
198
+ self ._next = (to_enqueue [0 ], to_enqueue [1 ])
199
+ else :
200
+ to_enqueue = [item async for item in self ._dequeue_requests (1 )]
201
+ self ._next = (self ._next [0 ], to_enqueue [0 ])
202
+
203
+ async def _dequeue_requests (self , count : int ) -> AsyncGenerator [Request | None ]:
204
+ for _ in range (count ):
205
+ try :
206
+ yield self ._transform_request (await self ._requests .__anext__ ())
207
+ except StopAsyncIteration : # noqa: PERF203
208
+ yield None
104
209
105
210
async def _iterate_in_threadpool (self , iterable : Iterable [str | Request ]) -> AsyncIterator [str | Request ]:
106
211
"""Inspired by a function of the same name from encode/starlette."""
0 commit comments