Skip to content

Commit

Permalink
feat: replace TimeGPT class by NixtlaClient class (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
AzulGarza authored Apr 2, 2024
1 parent d1f5268 commit e00bd51
Show file tree
Hide file tree
Showing 15 changed files with 592 additions and 401 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ jobs:
run: pip install ./

- name: Check import
run: python -c "from nixtlats import TimeGPT;"
run: |
python -c "from nixtlats import TimeGPT;"
python -c "from nixtlats import NixtlaClient;"
run-tests:
runs-on: ${{ matrix.os }}
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ Get started with TimeGPT now:
```python
df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity-short.csv')

from nixtlats import TimeGPT
timegpt = TimeGPT(
from nixtlats import NixtlaClient
nixtla_client = NixtlaClient(
# defaults to os.environ.get("NIXTLA_API_KEY")
api_key = 'my_api_key_provided_by_nixtla'
)
fcst_df = timegpt.forecast(df, h=24, level=[80, 90])
fcst_df = nixtla_client.forecast(df, h=24, level=[80, 90])
```

![](./nbs/img/forecast_readme.png)
6 changes: 3 additions & 3 deletions action_files/models_performance/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from utilsforecast.evaluation import evaluate
from utilsforecast.losses import mae, mape, mse

from nixtlats import TimeGPT
from nixtlats import NixtlaClient


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -141,7 +141,7 @@ def evaluate_timegpt(self, model: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
init_time = time()
# A: this sould be replaced with
# cross validation
timegpt = TimeGPT()
timegpt = NixtlaClient()
fcst_df = timegpt.forecast(
df=self.df_train,
X_df=self.df_test.drop(columns=self.target_col)
Expand Down Expand Up @@ -200,7 +200,7 @@ def evaluate_benchmark_performace(self) -> Tuple[pd.DataFrame, pd.DataFrame]:

def plot_and_save_forecasts(self, cv_df: pd.DataFrame, plot_dir: str) -> str:
"""Plot ans saves forecasts, returns the path of the plot"""
timegpt = TimeGPT()
timegpt = NixtlaClient()
df = self.df.copy()
df[self.time_col] = pd.to_datetime(df[self.time_col])
if not self.has_id_col:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"#| default_exp distributed.timegpt"
"#| default_exp distributed.nixtla_client"
]
},
{
Expand Down Expand Up @@ -83,7 +83,7 @@
"outputs": [],
"source": [
"#| export\n",
"class _DistributedTimeGPT:\n",
"class _DistributedNixtlaClient:\n",
"\n",
" def __init__(\n",
" self, \n",
Expand Down Expand Up @@ -300,49 +300,49 @@
" )\n",
" return fcst_df\n",
" \n",
" def _instantiate_timegpt(self):\n",
" from nixtlats.timegpt import _TimeGPT\n",
" timegpt = _TimeGPT(\n",
" def _instantiate_nixtla_client(self):\n",
" from nixtlats.nixtla_client import _NixtlaClient\n",
" nixtla_client = _NixtlaClient(\n",
" api_key=self.api_key, \n",
" base_url=self.base_url,\n",
" max_retries=self.max_retries,\n",
" retry_interval=self.retry_interval,\n",
" max_wait_time=self.max_wait_time,\n",
" )\n",
" return timegpt\n",
" return nixtla_client\n",
"\n",
" def _forecast(\n",
" self, \n",
" df: pd.DataFrame, \n",
" kwargs,\n",
" ) -> pd.DataFrame:\n",
" timegpt = self._instantiate_timegpt()\n",
" return timegpt._forecast(df=df, **kwargs)\n",
" nixtla_client = self._instantiate_nixtla_client()\n",
" return nixtla_client._forecast(df=df, **kwargs)\n",
"\n",
" def _forecast_x(\n",
" self, \n",
" df: pd.DataFrame, \n",
" X_df: pd.DataFrame,\n",
" kwargs,\n",
" ) -> pd.DataFrame:\n",
" timegpt = self._instantiate_timegpt()\n",
" return timegpt._forecast(df=df, X_df=X_df, **kwargs)\n",
" nixtla_client = self._instantiate_nixtla_client()\n",
" return nixtla_client._forecast(df=df, X_df=X_df, **kwargs)\n",
"\n",
" def _detect_anomalies(\n",
" self, \n",
" df: pd.DataFrame, \n",
" kwargs,\n",
" ) -> pd.DataFrame:\n",
" timegpt = self._instantiate_timegpt()\n",
" return timegpt._detect_anomalies(df=df, **kwargs)\n",
" nixtla_client = self._instantiate_nixtla_client()\n",
" return nixtla_client._detect_anomalies(df=df, **kwargs)\n",
"\n",
" def _cross_validation(\n",
" self, \n",
" df: pd.DataFrame, \n",
" kwargs,\n",
" ) -> pd.DataFrame:\n",
" timegpt = self._instantiate_timegpt()\n",
" return timegpt._cross_validation(df=df, **kwargs)\n",
" nixtla_client = self._instantiate_nixtla_client()\n",
" return nixtla_client._cross_validation(df=df, **kwargs)\n",
" \n",
" @staticmethod\n",
" def _get_forecast_schema(id_col, time_col, level, quantiles, cv=False):\n",
Expand Down Expand Up @@ -400,7 +400,7 @@
" time_col: str = 'ds',\n",
" **fcst_kwargs,\n",
" ):\n",
" fcst_df = distributed_timegpt.forecast(\n",
" fcst_df = distributed_nixtla_client.forecast(\n",
" df=df, \n",
" h=horizon,\n",
" id_col=id_col,\n",
Expand Down Expand Up @@ -442,7 +442,7 @@
" time_col: str = 'ds',\n",
" **fcst_kwargs,\n",
" ):\n",
" fcst_df = distributed_timegpt.forecast(\n",
" fcst_df = distributed_nixtla_client.forecast(\n",
" df=df, \n",
" h=horizon, \n",
" num_partitions=1,\n",
Expand All @@ -452,7 +452,7 @@
" **fcst_kwargs\n",
" )\n",
" fcst_df = fa.as_pandas(fcst_df)\n",
" fcst_df_2 = distributed_timegpt.forecast(\n",
" fcst_df_2 = distributed_nixtla_client.forecast(\n",
" df=df, \n",
" h=horizon, \n",
" num_partitions=1,\n",
Expand Down Expand Up @@ -485,7 +485,7 @@
" time_col: str = 'ds',\n",
" **fcst_kwargs,\n",
" ):\n",
" fcst_df = distributed_timegpt.forecast(\n",
" fcst_df = distributed_nixtla_client.forecast(\n",
" df=df, \n",
" h=horizon, \n",
" num_partitions=1,\n",
Expand All @@ -494,7 +494,7 @@
" **fcst_kwargs\n",
" )\n",
" fcst_df = fa.as_pandas(fcst_df)\n",
" fcst_df_2 = distributed_timegpt.forecast(\n",
" fcst_df_2 = distributed_nixtla_client.forecast(\n",
" df=df, \n",
" h=horizon, \n",
" num_partitions=2,\n",
Expand Down Expand Up @@ -523,7 +523,7 @@
" time_col: str = 'ds',\n",
" **fcst_kwargs,\n",
" ):\n",
" fcst_df = distributed_timegpt.cross_validation(\n",
" fcst_df = distributed_nixtla_client.cross_validation(\n",
" df=df, \n",
" h=horizon, \n",
" num_partitions=1,\n",
Expand All @@ -532,7 +532,7 @@
" **fcst_kwargs\n",
" )\n",
" fcst_df = fa.as_pandas(fcst_df)\n",
" fcst_df_2 = distributed_timegpt.cross_validation(\n",
" fcst_df_2 = distributed_nixtla_client.cross_validation(\n",
" df=df, \n",
" h=horizon, \n",
" num_partitions=2,\n",
Expand Down Expand Up @@ -592,7 +592,7 @@
" time_col: str = 'ds',\n",
" **fcst_kwargs,\n",
" ):\n",
" fcst_df = distributed_timegpt.forecast(\n",
" fcst_df = distributed_nixtla_client.forecast(\n",
" df=df, \n",
" X_df=X_df,\n",
" h=horizon,\n",
Expand All @@ -610,7 +610,7 @@
" exp_cols.extend([f'TimeGPT-lo-{lv}' for lv in reversed(level)])\n",
" exp_cols.extend([f'TimeGPT-hi-{lv}' for lv in level])\n",
" test_eq(cols, exp_cols)\n",
" fcst_df_2 = distributed_timegpt.forecast(\n",
" fcst_df_2 = distributed_nixtla_client.forecast(\n",
" df=df, \n",
" h=horizon,\n",
" id_col=id_col,\n",
Expand Down Expand Up @@ -640,7 +640,7 @@
" time_col: str = 'ds',\n",
" **fcst_kwargs,\n",
" ):\n",
" fcst_df = distributed_timegpt.forecast(\n",
" fcst_df = distributed_nixtla_client.forecast(\n",
" df=df, \n",
" X_df=X_df,\n",
" h=horizon, \n",
Expand All @@ -650,7 +650,7 @@
" **fcst_kwargs\n",
" )\n",
" fcst_df = fa.as_pandas(fcst_df)\n",
" fcst_df_2 = distributed_timegpt.forecast(\n",
" fcst_df_2 = distributed_nixtla_client.forecast(\n",
" df=df, \n",
" h=horizon, \n",
" num_partitions=2,\n",
Expand Down Expand Up @@ -705,7 +705,7 @@
" time_col: str = 'ds',\n",
" **anomalies_kwargs,\n",
" ):\n",
" anomalies_df = distributed_timegpt.detect_anomalies(\n",
" anomalies_df = distributed_nixtla_client.detect_anomalies(\n",
" df=df, \n",
" id_col=id_col,\n",
" time_col=time_col,\n",
Expand All @@ -731,15 +731,15 @@
" time_col: str = 'ds',\n",
" **anomalies_kwargs,\n",
" ):\n",
" anomalies_df = distributed_timegpt.detect_anomalies(\n",
" anomalies_df = distributed_nixtla_client.detect_anomalies(\n",
" df=df, \n",
" num_partitions=1,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
" **anomalies_kwargs\n",
" )\n",
" anomalies_df = fa.as_pandas(anomalies_df)\n",
" anomalies_df_2 = distributed_timegpt.detect_anomalies(\n",
" anomalies_df_2 = distributed_nixtla_client.detect_anomalies(\n",
" df=df, \n",
" num_partitions=2,\n",
" id_col=id_col,\n",
Expand All @@ -766,7 +766,7 @@
" time_col: str = 'ds',\n",
" **anomalies_kwargs,\n",
" ):\n",
" anomalies_df = distributed_timegpt.detect_anomalies(\n",
" anomalies_df = distributed_nixtla_client.detect_anomalies(\n",
" df=df, \n",
" num_partitions=1,\n",
" id_col=id_col,\n",
Expand All @@ -775,7 +775,7 @@
" **anomalies_kwargs\n",
" )\n",
" anomalies_df = fa.as_pandas(anomalies_df)\n",
" anomalies_df_2 = distributed_timegpt.detect_anomalies(\n",
" anomalies_df_2 = distributed_nixtla_client.detect_anomalies(\n",
" df=df, \n",
" num_partitions=1,\n",
" id_col=id_col,\n",
Expand Down Expand Up @@ -844,9 +844,9 @@
" assert all(col in df_qls.columns for col in exp_q_cols)\n",
" # test monotonicity of quantiles\n",
" df_qls.apply(lambda x: x.is_monotonic_increasing, axis=1).sum() == len(exp_q_cols)\n",
" test_method_qls(distributed_timegpt.forecast)\n",
" test_method_qls(distributed_timegpt.forecast, add_history=True)\n",
" test_method_qls(distributed_timegpt.cross_validation)"
" test_method_qls(distributed_nixtla_client.forecast)\n",
" test_method_qls(distributed_nixtla_client.forecast, add_history=True)\n",
" test_method_qls(distributed_nixtla_client.cross_validation)"
]
},
{
Expand All @@ -856,7 +856,7 @@
"outputs": [],
"source": [
"#| hide\n",
"distributed_timegpt = _DistributedTimeGPT()"
"distributed_nixtla_client = _DistributedNixtlaClient()"
]
},
{
Expand Down
Loading

0 comments on commit e00bd51

Please sign in to comment.