Skip to content

Commit cc68014

Browse files
authored
feat: Persist RequestList state (#1274)
- closes #99
1 parent 7d928fd commit cc68014

File tree

5 files changed

+357
-46
lines changed

5 files changed

+357
-46
lines changed

src/crawlee/_utils/recoverable_state.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Generic, TypeVar
3+
from typing import TYPE_CHECKING, Generic, Literal, TypeVar
44

55
from pydantic import BaseModel
66

@@ -34,7 +34,7 @@ def __init__(
3434
*,
3535
default_state: TStateModel,
3636
persist_state_key: str,
37-
persistence_enabled: bool = False,
37+
persistence_enabled: Literal[True, False, 'explicit_only'] = False,
3838
persist_state_kvs_name: str | None = None,
3939
persist_state_kvs_id: str | None = None,
4040
logger: logging.Logger,
@@ -43,13 +43,14 @@ def __init__(
4343
4444
Args:
4545
default_state: The default state model instance to use when no persisted state is found.
46-
A deep copy is made each time the state is used.
46+
A deep copy is made each time the state is used.
4747
persist_state_key: The key under which the state is stored in the KeyValueStore
48-
persistence_enabled: Flag to enable or disable state persistence
48+
persistence_enabled: Flag to enable or disable state persistence. Use 'explicit_only' if you want to be able
49+
to save the state manually, but without any automatic persistence.
4950
persist_state_kvs_name: The name of the KeyValueStore to use for persistence.
50-
If neither a name nor and id are supplied, the default store will be used.
51+
If neither a name nor and id are supplied, the default store will be used.
5152
persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
52-
If neither a name nor and id are supplied, the default store will be used.
53+
If neither a name nor and id are supplied, the default store will be used.
5354
logger: A logger instance for logging operations related to state persistence
5455
"""
5556
self._default_state = default_state
@@ -71,7 +72,7 @@ async def initialize(self) -> TStateModel:
7172
Returns:
7273
The loaded state model
7374
"""
74-
if not self._persistence_enabled:
75+
if self._persistence_enabled is False:
7576
self._state = self._default_state.model_copy(deep=True)
7677
return self.current_value
7778

@@ -84,11 +85,12 @@ async def initialize(self) -> TStateModel:
8485

8586
await self._load_saved_state()
8687

87-
# Import here to avoid circular imports.
88-
from crawlee import service_locator # noqa: PLC0415
88+
if self._persistence_enabled is True:
89+
# Import here to avoid circular imports.
90+
from crawlee import service_locator # noqa: PLC0415
8991

90-
event_manager = service_locator.get_event_manager()
91-
event_manager.on(event=Event.PERSIST_STATE, listener=self.persist_state)
92+
event_manager = service_locator.get_event_manager()
93+
event_manager.on(event=Event.PERSIST_STATE, listener=self.persist_state)
9294

9395
return self.current_value
9496

@@ -101,12 +103,13 @@ async def teardown(self) -> None:
101103
if not self._persistence_enabled:
102104
return
103105

104-
# Import here to avoid circular imports.
105-
from crawlee import service_locator # noqa: PLC0415
106+
if self._persistence_enabled is True:
107+
# Import here to avoid circular imports.
108+
from crawlee import service_locator # noqa: PLC0415
106109

107-
event_manager = service_locator.get_event_manager()
108-
event_manager.off(event=Event.PERSIST_STATE, listener=self.persist_state)
109-
await self.persist_state()
110+
event_manager = service_locator.get_event_manager()
111+
event_manager.off(event=Event.PERSIST_STATE, listener=self.persist_state)
112+
await self.persist_state()
110113

111114
@property
112115
def current_value(self) -> TStateModel:
@@ -116,6 +119,21 @@ def current_value(self) -> TStateModel:
116119

117120
return self._state
118121

122+
@property
123+
def is_initialized(self) -> bool:
124+
"""Check if the state has already been initialized."""
125+
return self._state is not None
126+
127+
async def has_persisted_state(self) -> bool:
128+
"""Check if there is any persisted state in the key-value store."""
129+
if not self._persistence_enabled:
130+
return False
131+
132+
if self._key_value_store is None:
133+
raise RuntimeError('Recoverable state has not yet been initialized')
134+
135+
return await self._key_value_store.record_exists(self._persist_state_key)
136+
119137
async def reset(self) -> None:
120138
"""Reset the state to the default values and clear any persisted state.
121139
@@ -139,17 +157,21 @@ async def persist_state(self, event_data: EventPersistStateData | None = None) -
139157
Args:
140158
event_data: Optional data associated with a PERSIST_STATE event
141159
"""
142-
self._log.debug(f'Persisting state of the Statistics (event_data={event_data}).')
160+
self._log.debug(
161+
f'Persisting RecoverableState (model={self._default_state.__class__.__name__}, event_data={event_data}).'
162+
)
143163

144164
if self._key_value_store is None or self._state is None:
145165
raise RuntimeError('Recoverable state has not yet been initialized')
146166

147-
if self._persistence_enabled:
167+
if self._persistence_enabled is True or self._persistence_enabled == 'explicit_only':
148168
await self._key_value_store.set_value(
149169
self._persist_state_key,
150170
self._state.model_dump(mode='json', by_alias=True),
151171
'application/json',
152172
)
173+
else:
174+
self._log.debug('Persistence is not enabled - not doing anything')
153175

154176
async def _load_saved_state(self) -> None:
155177
if self._key_value_store is None:

src/crawlee/request_loaders/_request_list.py

Lines changed: 131 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,83 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from collections.abc import AsyncIterable, AsyncIterator, Iterable
5-
from typing import TYPE_CHECKING
4+
import contextlib
5+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable
6+
from logging import getLogger
7+
from typing import Annotated
68

9+
from pydantic import BaseModel, ConfigDict, Field
710
from typing_extensions import override
811

12+
from crawlee._request import Request
913
from crawlee._utils.docs import docs_group
1014
from crawlee.request_loaders._request_loader import RequestLoader
1115

12-
if TYPE_CHECKING:
13-
from crawlee._request import Request
16+
logger = getLogger(__name__)
17+
18+
19+
class RequestListState(BaseModel):
20+
model_config = ConfigDict(populate_by_name=True)
21+
22+
next_index: Annotated[int, Field(alias='nextIndex')] = 0
23+
next_unique_key: Annotated[str | None, Field(alias='nextUniqueKey')] = None
24+
in_progress: Annotated[set[str], Field(alias='inProgress')] = set()
25+
26+
27+
class RequestListData(BaseModel):
28+
requests: Annotated[list[Request], Field()]
1429

1530

1631
@docs_group('Request loaders')
1732
class RequestList(RequestLoader):
18-
"""Represents a (potentially very large) list of URLs to crawl.
19-
20-
Disclaimer: The `RequestList` class is in its early version and is not fully implemented. It is currently
21-
intended mainly for testing purposes and small-scale projects. The current implementation is only in-memory
22-
storage and is very limited. It will be (re)implemented in the future. For more details, see the GitHub issue:
23-
https://github.com/apify/crawlee-python/issues/99. For production usage we recommend to use the `RequestQueue`.
24-
"""
33+
"""Represents a (potentially very large) list of URLs to crawl."""
2534

2635
def __init__(
2736
self,
2837
requests: Iterable[str | Request] | AsyncIterable[str | Request] | None = None,
2938
name: str | None = None,
39+
persist_state_key: str | None = None,
40+
persist_requests_key: str | None = None,
3041
) -> None:
3142
"""Initialize a new instance.
3243
3344
Args:
3445
requests: The request objects (or their string representations) to be added to the provider.
3546
name: A name of the request list.
47+
persist_state_key: A key for persisting the progress information of the RequestList.
48+
If you do not pass a key but pass a `name`, a key will be derived using the name.
49+
Otherwise, state will not be persisted.
50+
persist_requests_key: A key for persisting the request data loaded from the `requests` iterator.
51+
If specified, the request data will be stored in the KeyValueStore to make sure that they don't change
52+
over time. This is useful if the `requests` iterator pulls the data dynamically.
3653
"""
54+
from crawlee._utils.recoverable_state import RecoverableState # noqa: PLC0415
55+
3756
self._name = name
3857
self._handled_count = 0
3958
self._assumed_total_count = 0
4059

41-
self._in_progress = set[str]()
42-
self._next: Request | None = None
60+
self._next: tuple[Request | None, Request | None] = (None, None)
61+
62+
if persist_state_key is None and name is not None:
63+
persist_state_key = f'SDK_REQUEST_LIST_STATE-{name}'
64+
65+
self._state = RecoverableState(
66+
default_state=RequestListState(),
67+
persistence_enabled=bool(persist_state_key),
68+
persist_state_key=persist_state_key or '',
69+
logger=logger,
70+
)
71+
72+
self._persist_request_data = bool(persist_requests_key)
73+
74+
self._requests_data = RecoverableState(
75+
default_state=RequestListData(requests=[]),
76+
# With request data persistence enabled, a snapshot of the requests will be done on initialization
77+
persistence_enabled='explicit_only' if self._persist_request_data else False,
78+
persist_state_key=persist_requests_key or '',
79+
logger=logger,
80+
)
4381

4482
if isinstance(requests, AsyncIterable):
4583
self._requests = requests.__aiter__()
@@ -50,6 +88,53 @@ def __init__(
5088

5189
self._requests_lock: asyncio.Lock | None = None
5290

91+
async def _get_state(self) -> RequestListState:
92+
# If state is already initialized, we are done
93+
if self._state.is_initialized:
94+
return self._state.current_value
95+
96+
# Initialize recoverable state
97+
await self._state.initialize()
98+
await self._requests_data.initialize()
99+
100+
# Initialize lock if necessary
101+
if self._requests_lock is None:
102+
self._requests_lock = asyncio.Lock()
103+
104+
# If the RequestList is configured to persist request data, ensure that a copy of request data is used
105+
if self._persist_request_data:
106+
async with self._requests_lock:
107+
if not await self._requests_data.has_persisted_state():
108+
self._requests_data.current_value.requests = [
109+
request if isinstance(request, Request) else Request.from_url(request)
110+
async for request in self._requests
111+
]
112+
await self._requests_data.persist_state()
113+
114+
self._requests = self._iterate_in_threadpool(
115+
self._requests_data.current_value.requests[self._state.current_value.next_index :]
116+
)
117+
# If not using persistent request data, advance the request iterator
118+
else:
119+
async with self._requests_lock:
120+
for _ in range(self._state.current_value.next_index):
121+
with contextlib.suppress(StopAsyncIteration):
122+
await self._requests.__anext__()
123+
124+
# Check consistency of the stored state and the request iterator
125+
if (unique_key_to_check := self._state.current_value.next_unique_key) is not None:
126+
await self._ensure_next_request()
127+
128+
next_unique_key = self._next[0].unique_key if self._next[0] is not None else None
129+
if next_unique_key != unique_key_to_check:
130+
raise RuntimeError(
131+
f"""Mismatch at index {
132+
self._state.current_value.next_index
133+
} in persisted requests - Expected unique key `{unique_key_to_check}`, got `{next_unique_key}`"""
134+
)
135+
136+
return self._state.current_value
137+
53138
@property
54139
def name(self) -> str | None:
55140
return self._name
@@ -65,42 +150,62 @@ async def get_total_count(self) -> int:
65150
@override
66151
async def is_empty(self) -> bool:
67152
await self._ensure_next_request()
68-
return self._next is None
153+
return self._next[0] is None
69154

70155
@override
71156
async def is_finished(self) -> bool:
72-
return len(self._in_progress) == 0 and await self.is_empty()
157+
state = await self._get_state()
158+
return len(state.in_progress) == 0 and await self.is_empty()
73159

74160
@override
75161
async def fetch_next_request(self) -> Request | None:
162+
await self._get_state()
76163
await self._ensure_next_request()
77164

78-
if self._next is None:
165+
if self._next[0] is None:
79166
return None
80167

81-
self._in_progress.add(self._next.id)
168+
state = await self._get_state()
169+
state.in_progress.add(self._next[0].id)
82170
self._assumed_total_count += 1
83171

84-
next_request = self._next
85-
self._next = None
172+
next_request = self._next[0]
173+
if next_request is not None:
174+
state.next_index += 1
175+
state.next_unique_key = self._next[1].unique_key if self._next[1] is not None else None
176+
177+
self._next = (self._next[1], None)
178+
await self._ensure_next_request()
86179

87180
return next_request
88181

89182
@override
90183
async def mark_request_as_handled(self, request: Request) -> None:
91184
self._handled_count += 1
92-
self._in_progress.remove(request.id)
185+
state = await self._get_state()
186+
state.in_progress.remove(request.id)
93187

94188
async def _ensure_next_request(self) -> None:
189+
await self._get_state()
190+
95191
if self._requests_lock is None:
96192
self._requests_lock = asyncio.Lock()
97193

98-
try:
99-
async with self._requests_lock:
100-
if self._next is None:
101-
self._next = self._transform_request(await self._requests.__anext__())
102-
except StopAsyncIteration:
103-
self._next = None
194+
async with self._requests_lock:
195+
if None in self._next:
196+
if self._next[0] is None:
197+
to_enqueue = [item async for item in self._dequeue_requests(2)]
198+
self._next = (to_enqueue[0], to_enqueue[1])
199+
else:
200+
to_enqueue = [item async for item in self._dequeue_requests(1)]
201+
self._next = (self._next[0], to_enqueue[0])
202+
203+
async def _dequeue_requests(self, count: int) -> AsyncGenerator[Request | None]:
204+
for _ in range(count):
205+
try:
206+
yield self._transform_request(await self._requests.__anext__())
207+
except StopAsyncIteration: # noqa: PERF203
208+
yield None
104209

105210
async def _iterate_in_threadpool(self, iterable: Iterable[str | Request]) -> AsyncIterator[str | Request]:
106211
"""Inspired by a function of the same name from encode/starlette."""

src/crawlee/statistics/_statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class Statistics(Generic[TStatisticsState]):
6767
def __init__(
6868
self,
6969
*,
70-
persistence_enabled: bool = False,
70+
persistence_enabled: bool | Literal['explicit_only'] = False,
7171
persist_state_kvs_name: str | None = None,
7272
persist_state_key: str | None = None,
7373
log_message: str = 'Statistics',

0 commit comments

Comments
 (0)