Skip to content

Commit f680790

Browse files
kyteinskybigcat88
andauthored
fix: fetch_models_task fixed for files and docs added (#376)
Changes proposed in this pull request: * fix `fetch_models_task` fixed for individual files download * added docs for `fetch_models_task` and model dict in `set_handlers` * added a custom exception for model download fail since this function can be used independently too although I'm not too sure of the custom exception, it doesn't serve much purpose except a few pylint ignore comments. --------- Signed-off-by: Anupam Kumar <kyteinsky@gmail.com> Signed-off-by: bigcat88 <bigcat88@icloud.com> Co-authored-by: bigcat88 <bigcat88@icloud.com>
1 parent 9518a89 commit f680790

File tree

4 files changed

+105
-68
lines changed

4 files changed

+105
-68
lines changed

nc_py_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from . import ex_app, options
44
from ._exceptions import (
5+
ModelFetchError,
56
NextcloudException,
67
NextcloudExceptionNotFound,
78
NextcloudMissingCapabilities,

nc_py_api/_exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,7 @@ def check_error(response: Response, info: str = ""):
6565
response.raise_for_status()
6666
except HTTPError as e:
6767
raise NextcloudException(status_code, reason=response.reason, info=info) from e
68+
69+
70+
class ModelFetchError(Exception):
71+
"""Exception raised when model fetching fails."""

nc_py_api/ex_app/integration_fastapi.py

Lines changed: 98 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import typing
10+
from traceback import format_exc
1011
from urllib.parse import urlparse
1112

1213
import niquests
@@ -22,10 +23,10 @@
2223
from starlette.requests import HTTPConnection, Request
2324
from starlette.types import ASGIApp, Receive, Scope, Send
2425

26+
from .._exceptions import ModelFetchError
2527
from .._misc import get_username_secret_from_headers
2628
from ..nextcloud import AsyncNextcloudApp, NextcloudApp
2729
from ..talk_bot import TalkBotMessage
28-
from .defs import LogLvl
2930
from .misc import persistent_storage
3031

3132

@@ -70,9 +71,24 @@ def set_handlers(
7071
7172
.. note:: When this parameter is ``False``, the provision of ``models_to_fetch`` is not allowed.
7273
73-
:param models_to_fetch: Dictionary describing which models should be downloaded during `init`.
74+
:param models_to_fetch: Dictionary describing which models should be downloaded during `init` of the form:
75+
.. code-block:: python
76+
{
77+
"model_url_1": {
78+
"save_path": "path_or_filename_to_save_the_model_to",
79+
},
80+
"huggingface_model_name_1": {
81+
"max_workers": 4,
82+
"cache_dir": "path_to_cache_dir",
83+
"revision": "revision_to_fetch",
84+
...
85+
},
86+
...
87+
}
88+
7489
7590
.. note:: ``huggingface_hub`` package should be present for automatic models fetching.
91+
All model options are optional and can be left empty.
7692
7793
:param map_app_static: Should be folders ``js``, ``css``, ``l10n``, ``img`` automatically mounted in FastAPI or not.
7894
@@ -121,74 +137,98 @@ def __map_app_static_folders(fast_api_app: FastAPI):
121137

122138

123139
def fetch_models_task(nc: NextcloudApp, models: dict[str, dict], progress_init_start_value: int) -> None:
124-
"""Use for cases when you want to define custom `/init` but still need to easy download models."""
140+
"""Use for cases when you want to define custom `/init` but still need to easy download models.
141+
142+
:param nc: NextcloudApp instance.
143+
:param models_to_fetch: Dictionary describing which models should be downloaded of the form:
144+
.. code-block:: python
145+
{
146+
"model_url_1": {
147+
"save_path": "path_or_filename_to_save_the_model_to",
148+
},
149+
"huggingface_model_name_1": {
150+
"max_workers": 4,
151+
"cache_dir": "path_to_cache_dir",
152+
"revision": "revision_to_fetch",
153+
...
154+
},
155+
...
156+
}
157+
158+
.. note:: ``huggingface_hub`` package should be present for automatic models fetching.
159+
All model options are optional and can be left empty.
160+
161+
:param progress_init_start_value: Integer value defining from which percent the progress should start.
162+
163+
:raises ModelFetchError: in case of a model download error.
164+
:raises NextcloudException: in case of a network error reaching the Nextcloud server.
165+
"""
125166
if models:
126167
current_progress = progress_init_start_value
127168
percent_for_each = min(int((100 - progress_init_start_value) / len(models)), 99)
128169
for model in models:
129-
if model.startswith(("http://", "https://")):
130-
models[model]["path"] = __fetch_model_as_file(
131-
current_progress, percent_for_each, nc, model, models[model]
132-
)
133-
else:
134-
models[model]["path"] = __fetch_model_as_snapshot(
135-
current_progress, percent_for_each, nc, model, models[model]
136-
)
137-
current_progress += percent_for_each
170+
try:
171+
if model.startswith(("http://", "https://")):
172+
models[model]["path"] = __fetch_model_as_file(
173+
current_progress, percent_for_each, nc, model, models[model]
174+
)
175+
else:
176+
models[model]["path"] = __fetch_model_as_snapshot(
177+
current_progress, percent_for_each, nc, model, models[model]
178+
)
179+
current_progress += percent_for_each
180+
except BaseException as e: # noqa pylint: disable=broad-exception-caught
181+
nc.set_init_status(current_progress, f"Downloading of '{model}' failed: {e}: {format_exc()}")
182+
raise ModelFetchError(f"Downloading of '{model}' failed.") from e
138183
nc.set_init_status(100)
139184

140185

141186
def __fetch_model_as_file(
142187
current_progress: int, progress_for_task: int, nc: NextcloudApp, model_path: str, download_options: dict
143-
) -> str | None:
188+
) -> str:
144189
result_path = download_options.pop("save_path", urlparse(model_path).path.split("/")[-1])
145-
try:
146-
147-
with niquests.get("GET", model_path, stream=True) as response:
148-
if not response.is_success:
149-
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' returned {response.status_code} status.")
150-
return None
151-
downloaded_size = 0
152-
linked_etag = ""
153-
for each_history in response.history:
154-
linked_etag = each_history.headers.get("X-Linked-ETag", "")
155-
if linked_etag:
156-
break
157-
if not linked_etag:
158-
linked_etag = response.headers.get("X-Linked-ETag", response.headers.get("ETag", ""))
159-
total_size = int(response.headers.get("Content-Length"))
160-
try:
161-
existing_size = os.path.getsize(result_path)
162-
except OSError:
163-
existing_size = 0
164-
if linked_etag and total_size == existing_size:
165-
with builtins.open(result_path, "rb") as file:
166-
sha256_hash = hashlib.sha256()
167-
for byte_block in iter(lambda: file.read(4096), b""):
168-
sha256_hash.update(byte_block)
169-
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
170-
nc.set_init_status(min(current_progress + progress_for_task, 99))
171-
return None
172-
173-
with builtins.open(result_path, "wb") as file:
174-
last_progress = current_progress
175-
for chunk in response.iter_raw(-1):
176-
downloaded_size += file.write(chunk)
177-
if total_size:
178-
new_progress = min(current_progress + int(progress_for_task * downloaded_size / total_size), 99)
179-
if new_progress != last_progress:
180-
nc.set_init_status(new_progress)
181-
last_progress = new_progress
182-
183-
return result_path
184-
except Exception as e: # noqa pylint: disable=broad-exception-caught
185-
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' raised an exception: {e}")
186-
187-
return None
190+
with niquests.get(model_path, stream=True) as response:
191+
if not response.ok:
192+
raise ModelFetchError(
193+
f"Downloading of '{model_path}' failed, returned ({response.status_code}) {response.text}"
194+
)
195+
downloaded_size = 0
196+
linked_etag = ""
197+
for each_history in response.history:
198+
linked_etag = each_history.headers.get("X-Linked-ETag", "")
199+
if linked_etag:
200+
break
201+
if not linked_etag:
202+
linked_etag = response.headers.get("X-Linked-ETag", response.headers.get("ETag", ""))
203+
total_size = int(response.headers.get("Content-Length"))
204+
try:
205+
existing_size = os.path.getsize(result_path)
206+
except OSError:
207+
existing_size = 0
208+
if linked_etag and total_size == existing_size:
209+
with builtins.open(result_path, "rb") as file:
210+
sha256_hash = hashlib.sha256()
211+
for byte_block in iter(lambda: file.read(4096), b""):
212+
sha256_hash.update(byte_block)
213+
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
214+
nc.set_init_status(min(current_progress + progress_for_task, 99))
215+
return result_path
216+
217+
with builtins.open(result_path, "wb") as file:
218+
last_progress = current_progress
219+
for chunk in response.iter_raw(-1):
220+
downloaded_size += file.write(chunk)
221+
if total_size:
222+
new_progress = min(current_progress + int(progress_for_task * downloaded_size / total_size), 99)
223+
if new_progress != last_progress:
224+
nc.set_init_status(new_progress)
225+
last_progress = new_progress
226+
227+
return result_path
188228

189229

190230
def __fetch_model_as_snapshot(
191-
current_progress: int, progress_for_task, nc: NextcloudApp, mode_name: str, download_options: dict
231+
current_progress: int, progress_for_task, nc: NextcloudApp, model_name: str, download_options: dict
192232
) -> str:
193233
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
194234
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
@@ -201,7 +241,7 @@ def display(self, msg=None, pos=None):
201241
workers = download_options.pop("max_workers", 2)
202242
cache = download_options.pop("cache_dir", persistent_storage())
203243
return snapshot_download(
204-
mode_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache
244+
model_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache
205245
)
206246

207247

tests/_install_init_handler_models.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55

66
from nc_py_api import NextcloudApp, ex_app
77

8-
INVALID_URL = "https://invalid_url"
9-
MODEL_NAME1 = "MBZUAI/LaMini-T5-61M"
8+
# TO-DO: add tests when ExApp fails to initialize due to invalid model fetch
109
MODEL_NAME2 = "https://huggingface.co/MBZUAI/LaMini-T5-61M/resolve/main/pytorch_model.bin"
1110
MODEL_NAME2_http = "http://huggingface.co/MBZUAI/LaMini-T5-61M/resolve/main/pytorch_model.bin"
12-
INVALID_PATH = "https://huggingface.co/invalid_path"
1311
SOME_FILE = "https://raw.githubusercontent.com/cloud-py-api/nc_py_api/main/README.md"
1412

1513

@@ -19,11 +17,8 @@ async def lifespan(_app: FastAPI):
1917
APP,
2018
enabled_handler,
2119
models_to_fetch={
22-
INVALID_URL: {},
23-
MODEL_NAME1: {},
2420
MODEL_NAME2: {},
2521
MODEL_NAME2_http: {},
26-
INVALID_PATH: {},
2722
SOME_FILE: {},
2823
},
2924
)
@@ -35,10 +30,7 @@ async def lifespan(_app: FastAPI):
3530

3631
def enabled_handler(enabled: bool, _nc: NextcloudApp) -> str:
3732
if enabled:
38-
try:
39-
assert ex_app.get_model_path(MODEL_NAME1)
40-
except Exception: # noqa
41-
return "model1 not found"
33+
assert ex_app.get_model_path(MODEL_NAME2)
4234
assert Path("pytorch_model.bin").is_file()
4335
return ""
4436

0 commit comments

Comments
 (0)