Skip to content

Commit 3ace579

Browse files
authored
fix: parse Content-Disposition properly, closes #414 (#416)
1 parent 7f22cb7 commit 3ace579

File tree

2 files changed

+66
-43
lines changed

2 files changed

+66
-43
lines changed

runpod/serverless/utils/rp_download.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
import uuid
1212
import zipfile
1313
from concurrent.futures import ThreadPoolExecutor
14-
from email import message_from_string
15-
from typing import List, Union
14+
from typing import List, Union, Dict
1615
from urllib.parse import urlparse
1716

1817
import backoff
@@ -35,28 +34,36 @@ def calculate_chunk_size(file_size: int) -> int:
3534
return 1024 * 1024 * 10 # 10 MB
3635

3736

37+
def extract_disposition_params(content_disposition: str) -> Dict[str, str]:
38+
parts = (p.strip() for p in content_disposition.split(";"))
39+
40+
params = {
41+
key.strip().lower(): value.strip().strip('"')
42+
for part in parts
43+
if "=" in part
44+
for key, value in [part.split("=", 1)]
45+
}
46+
47+
return params
48+
49+
3850
def download_files_from_urls(job_id: str, urls: Union[str, List[str]]) -> List[str]:
3951
"""
4052
Accepts a single URL or a list of URLs and downloads the files.
4153
Returns the list of downloaded file absolute paths.
4254
Saves the files in a directory called "downloaded_files" in the job directory.
4355
"""
44-
download_directory = os.path.abspath(
45-
os.path.join("jobs", job_id, "downloaded_files")
46-
)
56+
download_directory = os.path.abspath(os.path.join("jobs", job_id, "downloaded_files"))
4757
os.makedirs(download_directory, exist_ok=True)
4858

4959
@backoff.on_exception(backoff.expo, RequestException, max_tries=3)
5060
def download_file(url: str, path_to_save: str) -> str:
51-
with SyncClientSession().get(
52-
url, headers=HEADERS, stream=True, timeout=5
53-
) as response:
61+
with SyncClientSession().get(url, headers=HEADERS, stream=True, timeout=5) as response:
5462
response.raise_for_status()
5563
content_disposition = response.headers.get("Content-Disposition")
5664
file_extension = ""
5765
if content_disposition:
58-
msg = message_from_string(f"Content-Disposition: {content_disposition}")
59-
params = dict(msg.items())
66+
params = extract_disposition_params(content_disposition)
6067
file_extension = os.path.splitext(params.get("filename", ""))[1]
6168

6269
# If no extension could be determined from 'Content-Disposition', get it from the URL
@@ -113,15 +120,15 @@ def file(file_url: str) -> dict:
113120

114121
download_response = SyncClientSession().get(file_url, headers=HEADERS, timeout=30)
115122

116-
original_file_name = []
117-
if "Content-Disposition" in download_response.headers.keys():
118-
original_file_name = re.findall(
119-
"filename=(.+)", download_response.headers["Content-Disposition"]
120-
)
123+
content_disposition = download_response.headers.get("Content-Disposition")
121124

122-
if len(original_file_name) > 0:
123-
original_file_name = original_file_name[0]
124-
else:
125+
original_file_name = ""
126+
if content_disposition:
127+
params = extract_disposition_params(content_disposition)
128+
129+
original_file_name = params.get("filename", "")
130+
131+
if not original_file_name:
125132
download_path = urlparse(file_url).path
126133
original_file_name = os.path.basename(download_path)
127134

tests/test_serverless/test_utils/test_download.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Tests for runpod | serverless | modules | download.py """
1+
"""Tests for runpod | serverless | modules | download.py"""
22

33
# pylint: disable=R0903,W0613
44

@@ -17,6 +17,7 @@
1717
URL_LIST = [
1818
"https://example.com/picture.jpg",
1919
"https://example.com/picture.jpg?X-Amz-Signature=123",
20+
"https://example.com/file_without_extension",
2021
]
2122

2223
JOB_ID = "job_123"
@@ -75,9 +76,7 @@ def test_calculate_chunk_size(self):
7576
self.assertEqual(calculate_chunk_size(1024), 1024)
7677
self.assertEqual(calculate_chunk_size(1024 * 1024), 1024)
7778
self.assertEqual(calculate_chunk_size(1024 * 1024 * 1024), 1024 * 1024)
78-
self.assertEqual(
79-
calculate_chunk_size(1024 * 1024 * 1024 * 10), 1024 * 1024 * 10
80-
)
79+
self.assertEqual(calculate_chunk_size(1024 * 1024 * 1024 * 10), 1024 * 1024 * 10)
8180

8281
@patch("os.makedirs", return_value=None)
8382
@patch("runpod.http_client.SyncClientSession.get", side_effect=mock_requests_get)
@@ -86,29 +85,26 @@ def test_download_files_from_urls(self, mock_open_file, mock_get, mock_makedirs)
8685
"""
8786
Tests download_files_from_urls
8887
"""
88+
urls = ["https://example.com/picture.jpg", "https://example.com/file_without_extension"]
8989
downloaded_files = download_files_from_urls(
9090
JOB_ID,
91-
[
92-
"https://example.com/picture.jpg",
93-
],
91+
urls,
9492
)
9593

96-
self.assertEqual(len(downloaded_files), 1)
94+
self.assertEqual(len(downloaded_files), len(urls))
9795

98-
# Check that the url was called with SyncClientSession.get
99-
self.assertIn("https://example.com/picture.jpg", mock_get.call_args_list[0][0])
96+
for index, url in enumerate(urls):
97+
# Check that the url was called with SyncClientSession.get
98+
self.assertIn(url, mock_get.call_args_list[index][0])
10099

101-
# Check that the file has the correct extension
102-
self.assertTrue(downloaded_files[0].endswith(".jpg"))
100+
# Check that the file has the correct extension
101+
self.assertTrue(downloaded_files[index].endswith(".jpg"))
103102

104-
mock_open_file.assert_called_once_with(downloaded_files[0], "wb")
105-
mock_makedirs.assert_called_once_with(
106-
os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True
107-
)
103+
mock_open_file.assert_any_call(downloaded_files[index], "wb")
108104

109-
string_download_file = download_files_from_urls(
110-
JOB_ID, "https://example.com/picture.jpg"
111-
)
105+
mock_makedirs.assert_called_once_with(os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True)
106+
107+
string_download_file = download_files_from_urls(JOB_ID, "https://example.com/picture.jpg")
112108
self.assertTrue(string_download_file[0].endswith(".jpg"))
113109

114110
# Check if None is returned when url is None
@@ -124,9 +120,7 @@ def test_download_files_from_urls(self, mock_open_file, mock_get, mock_makedirs)
124120
@patch("os.makedirs", return_value=None)
125121
@patch("runpod.http_client.SyncClientSession.get", side_effect=mock_requests_get)
126122
@patch("builtins.open", new_callable=mock_open)
127-
def test_download_files_from_urls_signed(
128-
self, mock_open_file, mock_get, mock_makedirs
129-
):
123+
def test_download_files_from_urls_signed(self, mock_open_file, mock_get, mock_makedirs):
130124
"""
131125
Tests download_files_from_urls with signed urls
132126
"""
@@ -147,9 +141,7 @@ def test_download_files_from_urls_signed(
147141
self.assertTrue(downloaded_files[0].endswith(".jpg"))
148142

149143
mock_open_file.assert_called_once_with(downloaded_files[0], "wb")
150-
mock_makedirs.assert_called_once_with(
151-
os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True
152-
)
144+
mock_makedirs.assert_called_once_with(os.path.abspath(f"jobs/{JOB_ID}/downloaded_files"), exist_ok=True)
153145

154146

155147
class FileDownloaderTestCase(unittest.TestCase):
@@ -179,6 +171,30 @@ def test_download_file(self, mock_file, mock_get):
179171
# Check that the file was written correctly
180172
mock_file().write.assert_called_once_with(b"file content")
181173

174+
@patch("runpod.serverless.utils.rp_download.SyncClientSession.get")
175+
@patch("builtins.open", new_callable=mock_open)
176+
def test_download_file(self, mock_file, mock_get):
177+
"""
178+
Tests download_file using filename from Content-Disposition
179+
"""
180+
# Mock the response from SyncClientSession.get
181+
mock_response = MagicMock()
182+
mock_response.content = b"file content"
183+
mock_response.headers = {"Content-Disposition": 'inline; filename="test_file.txt"'}
184+
mock_get.return_value = mock_response
185+
186+
# Call the function with a test URL
187+
result = file("http://test.com/file_without_extension")
188+
189+
# Check the result
190+
self.assertEqual(result["type"], "txt")
191+
self.assertEqual(result["original_name"], "test_file.txt")
192+
self.assertTrue(result["file_path"].endswith(".txt"))
193+
self.assertIsNone(result["extracted_path"])
194+
195+
# Check that the file was written correctly
196+
mock_file().write.assert_called_once_with(b"file content")
197+
182198
@patch("runpod.serverless.utils.rp_download.SyncClientSession.get")
183199
@patch("builtins.open", new_callable=mock_open)
184200
@patch("runpod.serverless.utils.rp_download.zipfile.ZipFile")

0 commit comments

Comments
 (0)