Skip to content

Commit dd497bc

Browse files
authored
Relax supported model check (#634)
1 parent ce5ac93 commit dd497bc

File tree

6 files changed

+61
-27
lines changed

6 files changed

+61
-27
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
run: python -c "from nixtla import NixtlaClient"
4141

4242
run-all-tests:
43-
runs-on: ubuntu-latest
43+
runs-on: nixtla-linux-large-public
4444
timeout-minutes: 60
4545
strategy:
4646
fail-fast: false

nbs/src/nixtla_client.ipynb

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"import logging\n",
4747
"import math\n",
4848
"import os\n",
49+
"import re\n",
4950
"import warnings\n",
5051
"from collections.abc import Sequence\n",
5152
"from concurrent.futures import ThreadPoolExecutor, as_completed\n",
@@ -715,7 +716,34 @@
715716
" if hist_exog_list and logger:\n",
716717
" logger.info(f'Using historical exogenous features: {hist_exog_list}')\n",
717718
"\n",
718-
" return X, hist_exog"
719+
" return X, hist_exog\n",
720+
"\n",
721+
"def _model_in_list(model:str, model_list: tuple[Any]) -> bool:\n",
722+
" for m in model_list:\n",
723+
" if isinstance(m, str):\n",
724+
" if m == model:\n",
725+
" return True\n",
726+
" elif isinstance(m, re.Pattern):\n",
727+
" if m.fullmatch(model):\n",
728+
" return True\n",
729+
" return False"
730+
]
731+
},
732+
{
733+
"cell_type": "code",
734+
"execution_count": null,
735+
"metadata": {},
736+
"outputs": [],
737+
"source": [
738+
"#| hide\n",
739+
"\n",
740+
"assert _model_in_list(\"a\", (\"a\", \"b\"))\n",
741+
"assert not _model_in_list(\"a\", (\"b\", \"c\"))\n",
742+
"assert _model_in_list(\"axb\", (\"x\", re.compile(\"a.*b\")))\n",
743+
"assert _model_in_list(\"axb\", (\"x\", re.compile(\"^a.*b$\")))\n",
744+
"assert _model_in_list(\"a-b\", (\"x\", re.compile(\"^a-.*b$\")))\n",
745+
"assert _model_in_list(\"a-dfdfb\", (\"x\", re.compile(\"^a-.*b$\")))\n",
746+
"assert not _model_in_list(\"abc\", (\"x\", re.compile(\"ab\"), re.compile(\"abcd\")))"
719747
]
720748
},
721749
{
@@ -1139,10 +1167,7 @@
11391167
" )\n",
11401168
" self._model_params: dict[tuple[str, str], tuple[int, int]] = {}\n",
11411169
" self._is_azure = 'ai.azure' in base_url\n",
1142-
" if self._is_azure:\n",
1143-
" self.supported_models = ['azureai']\n",
1144-
" else:\n",
1145-
" self.supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n",
1170+
" self.supported_models:list[Any] = [re.compile('^timegpt-.+$'), 'azureai']\n",
11461171
"\n",
11471172
" def _make_request(\n",
11481173
" self,\n",
@@ -1374,9 +1399,9 @@
13741399
" ) -> tuple[DFType, Optional[DFType], bool, _FreqType]:\n",
13751400
" if validate_api_key and not self.validate_api_key(log=False):\n",
13761401
" raise Exception('API Key not valid, please email support@nixtla.io')\n",
1377-
" if model not in self.supported_models:\n",
1402+
" if not _model_in_list(model, tuple(self.supported_models)):\n",
13781403
" raise ValueError(\n",
1379-
" f'unsupported model: {model}. supported models: {self.supported_models}'\n",
1404+
" f'unsupported model: {model}.'\n",
13801405
" )\n",
13811406
" drop_id = id_col not in df.columns\n",
13821407
" if drop_id:\n",

nixtla/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.6.7.dev2"
1+
__version__ = "0.6.7.dev3"
22
__all__ = ["NixtlaClient"]
33
from .nixtla_client import NixtlaClient

nixtla/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@
129129
'nixtla/nixtla_client.py'),
130130
'nixtla.nixtla_client._maybe_infer_freq': ( 'src/nixtla_client.html#_maybe_infer_freq',
131131
'nixtla/nixtla_client.py'),
132+
'nixtla.nixtla_client._model_in_list': ( 'src/nixtla_client.html#_model_in_list',
133+
'nixtla/nixtla_client.py'),
132134
'nixtla.nixtla_client._parse_in_sample_output': ( 'src/nixtla_client.html#_parse_in_sample_output',
133135
'nixtla/nixtla_client.py'),
134136
'nixtla.nixtla_client._partition_series': ( 'src/nixtla_client.html#_partition_series',

nixtla/nixtla_client.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import math
1111
import os
12+
import re
1213
import warnings
1314
from collections.abc import Sequence
1415
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -638,15 +639,26 @@ def _process_exog_features(
638639

639640
return X, hist_exog
640641

641-
# %% ../nbs/src/nixtla_client.ipynb 11
642+
643+
def _model_in_list(model: str, model_list: tuple[Any]) -> bool:
644+
for m in model_list:
645+
if isinstance(m, str):
646+
if m == model:
647+
return True
648+
elif isinstance(m, re.Pattern):
649+
if m.fullmatch(model):
650+
return True
651+
return False
652+
653+
# %% ../nbs/src/nixtla_client.ipynb 12
642654
class AuditDataSeverity(Enum):
643655
"""Enum class to indicate audit data severity levels"""
644656

645657
FAIL = "Fail" # Indicates a critical issue that requires immediate attention
646658
CASE_SPECIFIC = "Case Specific" # Indicates an issue that may be acceptable in specific contexts
647659
PASS = "Pass" # Indicates that the data is acceptable
648660

649-
# %% ../nbs/src/nixtla_client.ipynb 13
661+
# %% ../nbs/src/nixtla_client.ipynb 14
650662
def _audit_duplicate_rows(
651663
df: AnyDFType,
652664
id_col: str = "unique_id",
@@ -660,7 +672,7 @@ def _audit_duplicate_rows(
660672
else:
661673
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
662674

663-
# %% ../nbs/src/nixtla_client.ipynb 16
675+
# %% ../nbs/src/nixtla_client.ipynb 17
664676
def _audit_missing_dates(
665677
df: AnyDFType,
666678
freq: _Freq,
@@ -688,7 +700,7 @@ def _audit_missing_dates(
688700
else:
689701
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
690702

691-
# %% ../nbs/src/nixtla_client.ipynb 19
703+
# %% ../nbs/src/nixtla_client.ipynb 20
692704
def _audit_categorical_variables(
693705
df: AnyDFType,
694706
id_col: str = "unique_id",
@@ -708,7 +720,7 @@ def _audit_categorical_variables(
708720
else:
709721
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
710722

711-
# %% ../nbs/src/nixtla_client.ipynb 22
723+
# %% ../nbs/src/nixtla_client.ipynb 23
712724
def _audit_leading_zeros(
713725
df: pd.DataFrame,
714726
id_col: str = "unique_id",
@@ -733,7 +745,7 @@ def _audit_leading_zeros(
733745
else:
734746
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
735747

736-
# %% ../nbs/src/nixtla_client.ipynb 25
748+
# %% ../nbs/src/nixtla_client.ipynb 26
737749
def _audit_negative_values(
738750
df: AnyDFType,
739751
target_col: str = "y",
@@ -746,7 +758,7 @@ def _audit_negative_values(
746758
else:
747759
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
748760

749-
# %% ../nbs/src/nixtla_client.ipynb 27
761+
# %% ../nbs/src/nixtla_client.ipynb 28
750762
class ApiError(Exception):
751763
status_code: Optional[int]
752764
body: Any
@@ -760,7 +772,7 @@ def __init__(
760772
def __str__(self) -> str:
761773
return f"status_code: {self.status_code}, body: {self.body}"
762774

763-
# %% ../nbs/src/nixtla_client.ipynb 29
775+
# %% ../nbs/src/nixtla_client.ipynb 30
764776
class NixtlaClient:
765777

766778
def __init__(
@@ -821,10 +833,7 @@ def __init__(
821833
)
822834
self._model_params: dict[tuple[str, str], tuple[int, int]] = {}
823835
self._is_azure = "ai.azure" in base_url
824-
if self._is_azure:
825-
self.supported_models = ["azureai"]
826-
else:
827-
self.supported_models = ["timegpt-1", "timegpt-1-long-horizon"]
836+
self.supported_models: list[Any] = [re.compile("^timegpt-.+$"), "azureai"]
828837

829838
def _make_request(
830839
self,
@@ -1057,10 +1066,8 @@ def _run_validations(
10571066
) -> tuple[DFType, Optional[DFType], bool, _FreqType]:
10581067
if validate_api_key and not self.validate_api_key(log=False):
10591068
raise Exception("API Key not valid, please email support@nixtla.io")
1060-
if model not in self.supported_models:
1061-
raise ValueError(
1062-
f"unsupported model: {model}. supported models: {self.supported_models}"
1063-
)
1069+
if not _model_in_list(model, tuple(self.supported_models)):
1070+
raise ValueError(f"unsupported model: {model}.")
10641071
drop_id = id_col not in df.columns
10651072
if drop_id:
10661073
df = ufp.copy_if_pandas(df, deep=False)
@@ -2862,7 +2869,7 @@ def clean_data(
28622869

28632870
return df, all_pass, error_dfs, case_specific_dfs
28642871

2865-
# %% ../nbs/src/nixtla_client.ipynb 31
2872+
# %% ../nbs/src/nixtla_client.ipynb 32
28662873
def _forecast_wrapper(
28672874
df: pd.DataFrame,
28682875
client: NixtlaClient,

settings.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ author = Nixtla
88
author_email = business@nixtla.io
99
copyright = Nixtla Inc.
1010
branch = main
11-
version = 0.6.7.dev2
11+
version = 0.6.7.dev3
1212
min_python = 3.9
1313
audience = Developers
1414
language = English

0 commit comments

Comments
 (0)