Skip to content

Commit

Permalink
Merge pull request #89 from Nixtla/feat/warning_horizon
Browse files Browse the repository at this point in the history
feat: add warning for shot horizon
  • Loading branch information
AzulGarza authored Aug 18, 2023
2 parents a3fcd45 + 35efe61 commit 6800540
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 1 deletion.
137 changes: 136 additions & 1 deletion nbs/timegpt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
"from itertools import product\n",
"\n",
"from dotenv import load_dotenv\n",
"from fastcore.test import test_eq, test_fail\n",
"from fastcore.test import test_eq, test_fail, test_warns\n",
"from nbdev.showdoc import show_doc\n",
"from tqdm import TqdmExperimentalWarning\n",
"\n",
Expand Down Expand Up @@ -426,6 +426,12 @@
" input_size: int,\n",
" model_horizon: int,\n",
" ):\n",
" if h > model_horizon:\n",
" main_logger.warning(\n",
" 'The specified horizon \"h\" exceeds the model horizon. '\n",
" 'This may lead to less accurate forecasts. '\n",
" 'Please consider using a smaller horizon.'\n",
" )\n",
" # restrict input if\n",
" # - we dont want to finetune\n",
" # - we dont have exogenous regegressors\n",
Expand Down Expand Up @@ -1173,6 +1179,135 @@
"- `target_col`: The variable that we want to forecast."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:__main__:Validating inputs...\n",
"INFO:__main__:Preprocessing dataframes...\n",
"INFO:__main__:Calling Forecast Endpoint...\n",
"WARNING:__main__:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>timestamp</th>\n",
" <th>TimeGPT</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1961-01-01</td>\n",
" <td>426.181030</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1961-02-01</td>\n",
" <td>440.651031</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1961-03-01</td>\n",
" <td>458.021057</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1961-04-01</td>\n",
" <td>452.927734</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1961-05-01</td>\n",
" <td>435.785400</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>1968-12-01</td>\n",
" <td>415.290924</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>1969-01-01</td>\n",
" <td>416.995209</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>1969-02-01</td>\n",
" <td>423.885803</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>1969-03-01</td>\n",
" <td>429.938690</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>1969-04-01</td>\n",
" <td>430.399078</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>100 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" timestamp TimeGPT\n",
"0 1961-01-01 426.181030\n",
"1 1961-02-01 440.651031\n",
"2 1961-03-01 458.021057\n",
"3 1961-04-01 452.927734\n",
"4 1961-05-01 435.785400\n",
".. ... ...\n",
"95 1968-12-01 415.290924\n",
"96 1969-01-01 416.995209\n",
"97 1969-02-01 423.885803\n",
"98 1969-03-01 429.938690\n",
"99 1969-04-01 430.399078\n",
"\n",
"[100 rows x 2 columns]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#| hide\n",
"# test warning horizon too long\n",
"timegpt.forecast(df=df.tail(3), h=100, time_col='timestamp', target_col='value')"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
6 changes: 6 additions & 0 deletions nixtlats/timegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ def _hit_multi_series_endpoint(
input_size: int,
model_horizon: int,
):
if h > model_horizon:
main_logger.warning(
'The specified horizon "h" exceeds the model horizon. '
"This may lead to less accurate forecasts. "
"Please consider using a smaller horizon."
)
# restrict input if
# - we dont want to finetune
# - we dont have exogenous regegressors
Expand Down

0 comments on commit 6800540

Please sign in to comment.