Skip to content

Commit 0c8c4ec

Browse files
committed
Fix atomic write on Windows
1 parent 2cb04c5 commit 0c8c4ec

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

src/crawlee/_utils/file.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import csv
55
import json
66
import os
7+
import sys
78
import tempfile
89
from pathlib import Path
910
from typing import TYPE_CHECKING, overload
@@ -95,14 +96,21 @@ async def atomic_write(
9596
"""
9697
max_retries = 3
9798
dir_path = path.parent
99+
tmp_path: str | None = None
98100

99-
def _sync_write() -> str:
100-
# Create a tmp file in the target dir, return its name.
101+
def _write_windows() -> None:
102+
if is_binary:
103+
path.write_bytes(data) # type: ignore[arg-type]
104+
else:
105+
path.write_text(data, encoding='utf-8') # type: ignore[arg-type]
106+
107+
def _write_linux() -> str:
101108
fd, tmp_path = tempfile.mkstemp(
102109
suffix=f'{path.suffix}.tmp',
103110
prefix=f'{path.name}.',
104111
dir=str(dir_path),
105112
)
113+
106114
try:
107115
if is_binary:
108116
with os.fdopen(fd, 'wb') as tmp_file:
@@ -116,11 +124,17 @@ def _sync_write() -> str:
116124
return tmp_path
117125

118126
try:
119-
tmp_path = await asyncio.to_thread(_sync_write)
120-
await asyncio.to_thread(os.replace, tmp_path, str(path))
127+
# We have to differentiate between Windows and Linux due to the permissions errors
128+
# in Windows when working with temporary files.
129+
if sys.platform == 'win32':
130+
await asyncio.to_thread(_write_windows)
131+
else:
132+
tmp_path = await asyncio.to_thread(_write_linux)
133+
await asyncio.to_thread(os.replace, tmp_path, str(path))
121134
except (FileNotFoundError, PermissionError):
122135
if retry_count < max_retries:
123-
await asyncio.to_thread(Path(tmp_path).unlink, missing_ok=True)
136+
if tmp_path is not None:
137+
await asyncio.to_thread(Path(tmp_path).unlink, missing_ok=True)
124138
return await atomic_write(
125139
path,
126140
data,
@@ -131,7 +145,8 @@ def _sync_write() -> str:
131145
raise
132146

133147
finally:
134-
await asyncio.to_thread(Path(tmp_path).unlink, missing_ok=True)
148+
if tmp_path is not None:
149+
await asyncio.to_thread(Path(tmp_path).unlink, missing_ok=True)
135150

136151

137152
async def export_json_to_stream(

0 commit comments

Comments
 (0)