7
7
import json
8
8
import os
9
9
import typing
10
+ from traceback import format_exc
10
11
from urllib .parse import urlparse
11
12
12
13
import niquests
22
23
from starlette .requests import HTTPConnection , Request
23
24
from starlette .types import ASGIApp , Receive , Scope , Send
24
25
26
+ from .._exceptions import ModelFetchError
25
27
from .._misc import get_username_secret_from_headers
26
28
from ..nextcloud import AsyncNextcloudApp , NextcloudApp
27
29
from ..talk_bot import TalkBotMessage
28
- from .defs import LogLvl
29
30
from .misc import persistent_storage
30
31
31
32
@@ -70,9 +71,24 @@ def set_handlers(
70
71
71
72
.. note:: When this parameter is ``False``, the provision of ``models_to_fetch`` is not allowed.
72
73
73
- :param models_to_fetch: Dictionary describing which models should be downloaded during `init`.
74
+ :param models_to_fetch: Dictionary describing which models should be downloaded during `init` of the form:
75
+ .. code-block:: python
76
+ {
77
+ "model_url_1": {
78
+ "save_path": "path_or_filename_to_save_the_model_to",
79
+ },
80
+ "huggingface_model_name_1": {
81
+ "max_workers": 4,
82
+ "cache_dir": "path_to_cache_dir",
83
+ "revision": "revision_to_fetch",
84
+ ...
85
+ },
86
+ ...
87
+ }
88
+
74
89
75
90
.. note:: ``huggingface_hub`` package should be present for automatic models fetching.
91
+ All model options are optional and can be left empty.
76
92
77
93
:param map_app_static: Should be folders ``js``, ``css``, ``l10n``, ``img`` automatically mounted in FastAPI or not.
78
94
@@ -121,74 +137,98 @@ def __map_app_static_folders(fast_api_app: FastAPI):
121
137
122
138
123
139
def fetch_models_task (nc : NextcloudApp , models : dict [str , dict ], progress_init_start_value : int ) -> None :
124
- """Use for cases when you want to define custom `/init` but still need to easy download models."""
140
+ """Use for cases when you want to define custom `/init` but still need to easy download models.
141
+
142
+ :param nc: NextcloudApp instance.
143
+ :param models_to_fetch: Dictionary describing which models should be downloaded of the form:
144
+ .. code-block:: python
145
+ {
146
+ "model_url_1": {
147
+ "save_path": "path_or_filename_to_save_the_model_to",
148
+ },
149
+ "huggingface_model_name_1": {
150
+ "max_workers": 4,
151
+ "cache_dir": "path_to_cache_dir",
152
+ "revision": "revision_to_fetch",
153
+ ...
154
+ },
155
+ ...
156
+ }
157
+
158
+ .. note:: ``huggingface_hub`` package should be present for automatic models fetching.
159
+ All model options are optional and can be left empty.
160
+
161
+ :param progress_init_start_value: Integer value defining from which percent the progress should start.
162
+
163
+ :raises ModelFetchError: in case of a model download error.
164
+ :raises NextcloudException: in case of a network error reaching the Nextcloud server.
165
+ """
125
166
if models :
126
167
current_progress = progress_init_start_value
127
168
percent_for_each = min (int ((100 - progress_init_start_value ) / len (models )), 99 )
128
169
for model in models :
129
- if model .startswith (("http://" , "https://" )):
130
- models [model ]["path" ] = __fetch_model_as_file (
131
- current_progress , percent_for_each , nc , model , models [model ]
132
- )
133
- else :
134
- models [model ]["path" ] = __fetch_model_as_snapshot (
135
- current_progress , percent_for_each , nc , model , models [model ]
136
- )
137
- current_progress += percent_for_each
170
+ try :
171
+ if model .startswith (("http://" , "https://" )):
172
+ models [model ]["path" ] = __fetch_model_as_file (
173
+ current_progress , percent_for_each , nc , model , models [model ]
174
+ )
175
+ else :
176
+ models [model ]["path" ] = __fetch_model_as_snapshot (
177
+ current_progress , percent_for_each , nc , model , models [model ]
178
+ )
179
+ current_progress += percent_for_each
180
+ except BaseException as e : # noqa pylint: disable=broad-exception-caught
181
+ nc .set_init_status (current_progress , f"Downloading of '{ model } ' failed: { e } : { format_exc ()} " )
182
+ raise ModelFetchError (f"Downloading of '{ model } ' failed." ) from e
138
183
nc .set_init_status (100 )
139
184
140
185
141
186
def __fetch_model_as_file (
142
187
current_progress : int , progress_for_task : int , nc : NextcloudApp , model_path : str , download_options : dict
143
- ) -> str | None :
188
+ ) -> str :
144
189
result_path = download_options .pop ("save_path" , urlparse (model_path ).path .split ("/" )[- 1 ])
145
- try :
146
-
147
- with niquests .get ("GET" , model_path , stream = True ) as response :
148
- if not response .is_success :
149
- nc .log (LogLvl .ERROR , f"Downloading of '{ model_path } ' returned { response .status_code } status." )
150
- return None
151
- downloaded_size = 0
152
- linked_etag = ""
153
- for each_history in response .history :
154
- linked_etag = each_history .headers .get ("X-Linked-ETag" , "" )
155
- if linked_etag :
156
- break
157
- if not linked_etag :
158
- linked_etag = response .headers .get ("X-Linked-ETag" , response .headers .get ("ETag" , "" ))
159
- total_size = int (response .headers .get ("Content-Length" ))
160
- try :
161
- existing_size = os .path .getsize (result_path )
162
- except OSError :
163
- existing_size = 0
164
- if linked_etag and total_size == existing_size :
165
- with builtins .open (result_path , "rb" ) as file :
166
- sha256_hash = hashlib .sha256 ()
167
- for byte_block in iter (lambda : file .read (4096 ), b"" ):
168
- sha256_hash .update (byte_block )
169
- if f'"{ sha256_hash .hexdigest ()} "' == linked_etag :
170
- nc .set_init_status (min (current_progress + progress_for_task , 99 ))
171
- return None
172
-
173
- with builtins .open (result_path , "wb" ) as file :
174
- last_progress = current_progress
175
- for chunk in response .iter_raw (- 1 ):
176
- downloaded_size += file .write (chunk )
177
- if total_size :
178
- new_progress = min (current_progress + int (progress_for_task * downloaded_size / total_size ), 99 )
179
- if new_progress != last_progress :
180
- nc .set_init_status (new_progress )
181
- last_progress = new_progress
182
-
183
- return result_path
184
- except Exception as e : # noqa pylint: disable=broad-exception-caught
185
- nc .log (LogLvl .ERROR , f"Downloading of '{ model_path } ' raised an exception: { e } " )
186
-
187
- return None
190
+ with niquests .get (model_path , stream = True ) as response :
191
+ if not response .ok :
192
+ raise ModelFetchError (
193
+ f"Downloading of '{ model_path } ' failed, returned ({ response .status_code } ) { response .text } "
194
+ )
195
+ downloaded_size = 0
196
+ linked_etag = ""
197
+ for each_history in response .history :
198
+ linked_etag = each_history .headers .get ("X-Linked-ETag" , "" )
199
+ if linked_etag :
200
+ break
201
+ if not linked_etag :
202
+ linked_etag = response .headers .get ("X-Linked-ETag" , response .headers .get ("ETag" , "" ))
203
+ total_size = int (response .headers .get ("Content-Length" ))
204
+ try :
205
+ existing_size = os .path .getsize (result_path )
206
+ except OSError :
207
+ existing_size = 0
208
+ if linked_etag and total_size == existing_size :
209
+ with builtins .open (result_path , "rb" ) as file :
210
+ sha256_hash = hashlib .sha256 ()
211
+ for byte_block in iter (lambda : file .read (4096 ), b"" ):
212
+ sha256_hash .update (byte_block )
213
+ if f'"{ sha256_hash .hexdigest ()} "' == linked_etag :
214
+ nc .set_init_status (min (current_progress + progress_for_task , 99 ))
215
+ return result_path
216
+
217
+ with builtins .open (result_path , "wb" ) as file :
218
+ last_progress = current_progress
219
+ for chunk in response .iter_raw (- 1 ):
220
+ downloaded_size += file .write (chunk )
221
+ if total_size :
222
+ new_progress = min (current_progress + int (progress_for_task * downloaded_size / total_size ), 99 )
223
+ if new_progress != last_progress :
224
+ nc .set_init_status (new_progress )
225
+ last_progress = new_progress
226
+
227
+ return result_path
188
228
189
229
190
230
def __fetch_model_as_snapshot (
191
- current_progress : int , progress_for_task , nc : NextcloudApp , mode_name : str , download_options : dict
231
+ current_progress : int , progress_for_task , nc : NextcloudApp , model_name : str , download_options : dict
192
232
) -> str :
193
233
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
194
234
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
@@ -201,7 +241,7 @@ def display(self, msg=None, pos=None):
201
241
workers = download_options .pop ("max_workers" , 2 )
202
242
cache = download_options .pop ("cache_dir" , persistent_storage ())
203
243
return snapshot_download (
204
- mode_name , tqdm_class = TqdmProgress , ** download_options , max_workers = workers , cache_dir = cache
244
+ model_name , tqdm_class = TqdmProgress , ** download_options , max_workers = workers , cache_dir = cache
205
245
)
206
246
207
247
0 commit comments