4
4
import csv
5
5
import json
6
6
import os
7
+ import sys
7
8
import tempfile
8
9
from pathlib import Path
9
10
from typing import TYPE_CHECKING , overload
@@ -95,14 +96,21 @@ async def atomic_write(
95
96
"""
96
97
max_retries = 3
97
98
dir_path = path .parent
99
+ tmp_path : str | None = None
98
100
99
- def _sync_write () -> str :
100
- # Create a tmp file in the target dir, return its name.
101
+ def _write_windows () -> None :
102
+ if is_binary :
103
+ path .write_bytes (data ) # type: ignore[arg-type]
104
+ else :
105
+ path .write_text (data , encoding = 'utf-8' ) # type: ignore[arg-type]
106
+
107
+ def _write_linux () -> str :
101
108
fd , tmp_path = tempfile .mkstemp (
102
109
suffix = f'{ path .suffix } .tmp' ,
103
110
prefix = f'{ path .name } .' ,
104
111
dir = str (dir_path ),
105
112
)
113
+
106
114
try :
107
115
if is_binary :
108
116
with os .fdopen (fd , 'wb' ) as tmp_file :
@@ -116,11 +124,17 @@ def _sync_write() -> str:
116
124
return tmp_path
117
125
118
126
try :
119
- tmp_path = await asyncio .to_thread (_sync_write )
120
- await asyncio .to_thread (os .replace , tmp_path , str (path ))
127
+ # We have to differentiate between Windows and Linux due to the permissions errors
128
+ # in Windows when working with temporary files.
129
+ if sys .platform == 'win32' :
130
+ await asyncio .to_thread (_write_windows )
131
+ else :
132
+ tmp_path = await asyncio .to_thread (_write_linux )
133
+ await asyncio .to_thread (os .replace , tmp_path , str (path ))
121
134
except (FileNotFoundError , PermissionError ):
122
135
if retry_count < max_retries :
123
- await asyncio .to_thread (Path (tmp_path ).unlink , missing_ok = True )
136
+ if tmp_path is not None :
137
+ await asyncio .to_thread (Path (tmp_path ).unlink , missing_ok = True )
124
138
return await atomic_write (
125
139
path ,
126
140
data ,
@@ -131,7 +145,8 @@ def _sync_write() -> str:
131
145
raise
132
146
133
147
finally :
134
- await asyncio .to_thread (Path (tmp_path ).unlink , missing_ok = True )
148
+ if tmp_path is not None :
149
+ await asyncio .to_thread (Path (tmp_path ).unlink , missing_ok = True )
135
150
136
151
137
152
async def export_json_to_stream (
0 commit comments