From 9eeafb048fb1a0c0f5c2d4cc06d04babc5656953 Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Mon, 12 Oct 2020 07:04:10 +0100 Subject: [PATCH] Rename LearningRateLogger and reposition EarlyStopping callback for lightning 1.0.0rc4 compatibility --- README.md | 7 +++-- docs/source/getting-started.rst | 7 +++-- docs/source/tutorials/ar.ipynb | 4 +-- docs/source/tutorials/stallion.ipynb | 26 ++++++++++++++----- examples/ar.py | 7 +++-- examples/nbeats.py | 2 +- examples/stallion.py | 7 +++-- .../temporal_fusion_transformer/tuning.py | 11 +++++--- tests/test_models/test_nbeats.py | 3 +-- .../test_temporal_fusion_transformer.py | 2 +- 10 files changed, 43 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 02b44b99..aee1ea97 100755 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ documentation with detailed tutorials. ```python import pytorch_lightning as pl -from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer @@ -89,14 +89,13 @@ val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, nu early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min") -lr_logger = LearningRateLogger() +lr_logger = LearningRateMonitor() trainer = pl.Trainer( max_epochs=100, gpus=0, gradient_clip_val=0.1, - early_stop_callback=early_stop_callback, limit_train_batches=30, - callbacks=[lr_logger], + callbacks=[lr_logger, early_stop_callback], ) diff --git a/docs/source/getting-started.rst b/docs/source/getting-started.rst index 5a5f04a9..a75fe841 100644 --- a/docs/source/getting-started.rst +++ b/docs/source/getting-started.rst @@ -62,7 +62,7 @@ Example .. code-block:: python import pytorch_lightning as pl - from pytorch_lightning.callbacks import EarlyStopping + from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer @@ -98,14 +98,13 @@ Example # define trainer with early stopping early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min") - lr_logger = LearningRateLogger() + lr_logger = LearningRateMonitor() trainer = pl.Trainer( max_epochs=100, gpus=0, gradient_clip_val=0.1, - early_stop_callback=early_stop_callback, limit_train_batches=30, - callbacks=[lr_logger], + callbacks=[lr_logger, early_stop_callback], ) # create the model diff --git a/docs/source/tutorials/ar.ipynb b/docs/source/tutorials/ar.ipynb index a7ee8bb0..df60a4a7 100644 --- a/docs/source/tutorials/ar.ipynb +++ b/docs/source/tutorials/ar.ipynb @@ -723,7 +723,7 @@ " gpus=0,\n", " weights_summary=\"top\",\n", " gradient_clip_val=0.1,\n", - " early_stop_callback=early_stop_callback,\n", + " callbacks=[early_stop_callback],\n", " limit_train_batches=30,\n", ")\n", "\n", @@ -1097,7 +1097,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.8.3" } }, "nbformat": 4, diff --git a/docs/source/tutorials/stallion.ipynb b/docs/source/tutorials/stallion.ipynb index 3d3dbed3..38f4e846 100644 --- a/docs/source/tutorials/stallion.ipynb +++ b/docs/source/tutorials/stallion.ipynb @@ -62,7 +62,7 @@ "\n", "\n", "import pytorch_lightning as pl\n", - "from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger\n", + "from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor\n", "from pytorch_lightning.loggers import TensorBoardLogger\n", "\n", "from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline\n", @@ -1078,7 +1078,7 @@ "source": [ "# configure network and trainer\n", "early_stop_callback = EarlyStopping(monitor=\"val_loss\", min_delta=1e-4, patience=10, verbose=False, mode=\"min\")\n", - "lr_logger = LearningRateLogger() # log the learning rate\n", + "lr_logger = LearningRateMonitor() # log the learning rate\n", "logger = TensorBoardLogger(\"lightning_logs\") # logging results to a tensorboard\n", "\n", "trainer = pl.Trainer(\n", @@ -1086,10 +1086,9 @@ " gpus=0,\n", " weights_summary=\"top\",\n", " gradient_clip_val=0.1,\n", - " early_stop_callback=early_stop_callback,\n", " limit_train_batches=30, # coment in for training, running valiation every 30 batches\n", " # fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs\n", - " callbacks=[lr_logger],\n", + " callbacks=[lr_logger, early_stop_callback],\n", " logger=logger,\n", ")\n", "\n", @@ -2503,7 +2502,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -2529,9 +2528,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEHCAYAAAC+1b08AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deZRcZ3nn8e9Te1fvm1prq2VbtiTLsmS3ZRsbW2BjsDHYhhDj4SQBQswwEMJkOCETFidkyAkZkkkCJwHBeBwmwUkYgzFgIAYvMnhDtvZdspau3qrXqq59ue/8UdVtye5Wt7qr+tbyfM7p01V1a3lu39avr977LmKMQSmlVPlz2F2AUkqpwtBAV0qpCqGBrpRSFUIDXSmlKoQGulJKVQiXXR/c1tZmurq67Pp4pZQqSy+//PKwMaZ9um22BXpXVxc7d+606+OVUqosicjpmbZpk4tSSlUIDXSllKoQGuhKKVUhbGtDV0pVt3Q6TSAQIJFI2F1KSfL5fKxcuRK32z3n12igK6VsEQgEqK+vp6urCxGxu5ySYoxhZGSEQCDAmjVr5vw6bXJRStkikUjQ2tqqYT4NEaG1tfWC//eiga6Uso2G+czm87PRQFdKqQqhga6UUgu0bdu2qYGSd9xxB+Pj49M+z7IMqYxVtDr0oqhSShXQ448//obHspYhlcmSzhqcDsHjKs65tJ6hK6Wq0qlTp1i3bh0f+chH2LhxIx/4wAf4+c9/zg033MDatWt56aWXiEajfPjDH+aaa65hy5Yt/OAHPwAgHo/z/ve/n02bNnHvvfcSj8en3rerq4vh4WEA7rrrLrZcdRUbN17O9u3fnHpOXV0dn/3sZ7nyyiu57rrrGBwcLMg+6Rm6Usp2f/bDAxzsCxf0PTcsb+CBd11+3uccP36c7373u2zfvp1rrrmG73znO/zyl7/kscce4y/+4i/YsGEDb33rW3nwwQcZHx9n69at3HrrrXzjG9/A7/ezd+9e9u7dy1VXXXXO+6azFtFkhr/7h+20tLQQj8d5y5uv566776G9vY1oNMp1113Hl770Jf7oj/6Ib37zm3zuc59b8D5roCulqtaaNWu44oorALj88su55ZZbEBGuuOIKTp06RSAQ4LHHHuMrX/kKkOtqeebMGXbs2MEnP/lJADZt2sSmTZswJtc+bgwkUllqLcM3/uFr/OiHubP63kCAEyeO097ehsfj4c477wTg6quv5oknnijI/migK6VsN9uZdLF4vd6p2w6HY+q+w+Egk8ngdDp55JFHuOyyy97w2sluhcYYjIFYKksinZ3a/uyOZ3j6qSd54qln8fv9vPPtt071K3e73VOvdzqdZDKZguyPtqErpdQM3v72t/PVr34VYwwAu3btAuCmm27in//5n0llLH79ym727dtL/ilTwqEQTU1N+P1+jh45zK9ferHo9WqgK6XUDD7/+c+TTqfZtGkTGzdu5POf/zwAH7n/o4TCEa7aciX/62/+mqu7r3nDa2+97e1kMlnetPUq/scX/5Rrtl5b9HrFvP7PyiLp7u42usCFUtXr0KFDrF+/3u4yLkgma5HMWGSt+eem0yHUeufW2j3dz0hEXjbGdE/3fG1DV0qpWWQsi2R6YUG+GDTQlVJqBlnLkMxkyWRLO8gnaaArpWxjjCnJCboyWYtU1rI1yOfTHD7rRVEReVBEgiKyf4bt20QkJCK7819fuOAqlFJVx+fzMTIyMq/gKpbJAUGxlL1n5ZPzoft8vgt63VzO0B8CvgZ8+zzPedYYc+cFfbJSqqqtXLmSQCDA0NCQrXUYA5YxZLIWixHhIoJ3DnO5TK5YdCFmDXRjzA4R6bqgd1VKqVm43e4LWo2n0DJZi97xOGdGYyTT1qI1QDf63WzuainKexdqF64XkT1AH/BpY8yBAr2vUkoVVDprcXokRmAsVjYXO+eqEIH+CrDaGBMRkTuAR4G10z1RRO4H7gfo7OwswEcrpdTcZLIWp0djnBmNka2wIJ+04JGixpiwMSaSv/044BaRthmeu90Y022M6W5vb1/oRyul1KyyluHUcJRfnRjh5FC0YsMcCnCGLiJLgUFjjBGRreT+SIwsuDKllFoAyzL0jsc5ORwt6ipBpWTWQBeRh4FtQJuIBIAHADeAMebrwG8AHxORDBAH3m9KqR+SUqqqGGPoCyU4ORQ9Z/bDajCXXi73zbL9a+S6NSqllG2MMQyEc0EeS1VXkE/SkaJKqbI3HElyPBghkijMvOLlSgNdKVW2QvE0x4MTjEXTdpdSEjTQlVJlJ5bKcCIYZTCcsLuUkqKBrpQqG8lMlpPDUfrG41jV0XHlgmigK6VKXtYynB6JcrqCBwUVgga6Uqqk9Y7HORGMVE1f8oXQQFdKlaSsZTjUH2YgpO3kc6WBrpQqOYl0lt0941XfDfFCaaArpUrKaDTFvt4QaW1iuWAa6EqpktEzGuPo4AQ6ecj8aKArpWxnWYZDA2H6x7W9fCE00JVStkqks+wNhAjHdbTnQmmgK6VsMx5LsTcQ0i6JBaKBrpSyRWAs116uIz4LRwNdKbXoTg1HOR6M2F1GxVnwEnRKKXUhghMJDfMi0UBXSi2aiUSaA71hu8uoWBroSqlFkcxk2dMTImtpJ/Ni0UBXShWdZRn2BkJVt8bnYtNAV0oV3cH+MKGY9jMvNg10pVRRnRqO6oyJi0QDXSlVNMGJBCeGtEfLYtFAV0oVxUQizYG+sE60tYg00JVSBZfKWLkeLbpc3KLSQFdKFVSuR8u49mixgQa6UqqgDg2EGdceLbaYNdBF5EERCYrI/lmed42IZEXkNwpXnlKqnJweieqc5jaayxn6Q8A7zvcEEXECXwZ+VoCalFJl6MxIjGOD2qPFTrMGujFmBzA6y9N+H3gECBaiKKVUeTk5HOXo4ITdZVS9Bbehi8gK4B7g6wsvRylVbo4HI5zQ2RNLQiEuiv4t8BljzKyXtEXkfhHZKSI7h4aGCvDRSik7HR2c4NRw1O4yVF4hFrjoBv5VRADagDtEJGOMefT1TzTGbAe2A3R3d2sHVaXKlDGGwwMT9I7F7S5FnWXBgW6MWTN5W0QeAn40XZgrpSqDMYYDfWGdn6UEzRroIvIwsA1oE5EA8ADgBjDGaLu5UlXEsgz7+0IEw0m7S1HTmDXQjTH3zfXNjDEfXFA1SqmSZVmGvb0hhic0zEuVLhKtlJpV1jLs7hlnLJqyuxR1HhroSqnzymQtdveM63D+MqBzuSilZmSMYX+fzs1SLjTQlVIzOjEU1TbzMqKBrpSaVjCc0EFDZUYDXSn1BpOrDanyooGulDpHOmuxNxAia+lg7nKjga6UmmKMYV9viHhKVxsqRxroSqkpx4MRRiPa17xcaaArpQDoD8U5PRKzuwy1ABroSinCiTSH+vUiaLnTQFeqyiUzWfb2hLAsuytRC6WBrlQVsyzDvkCIRFovglYCDXSlqtjR4IQO668gGuhKVane8TiBUV1xqJJooCtVhSYSaY4M6EXQSqOBrlSVyVq5wUN6EbTyaKArVWUO9YeJJfUiaCXSQFeqivSH4rq4cwXTQFeqSsRSGQ4PTNhdhioiDXSlqoBlmdwMilmdQbGSaaArVQWOBieIJDJ2l6GKTANdqQoXDCe0v3mV0EBXqoIl0lkO6qRbVUMDXakKNblYRUbbzauGBrpSFerEUISQztNSVWYNdBF5UESCIrJ/hu13icheEdktIjtF5MbCl6mUuhAjkSSnhnWximozlzP0h4B3nGf7L4ArjTGbgQ8D3ypAXUqpeUpmshzo03bzajRroBtjdgCj59keMcZMNtLVAtpgp5RNjDHs7w2TyuhELdWoIG3oInKPiBwGfkzuLF0pZYOTw1HGorrIc7UqSKAbY75vjFkH3A38+UzPE5H78+3sO4eGhgrx0UqpvNFoipPDUbvLUDYqaC+XfPPMxSLSNsP27caYbmNMd3t7eyE/Wqmqlsxk2d8bwmiDZ1VbcKCLyCUiIvnbVwEeYGSh76uUmptcu3lI280VrtmeICIPA9uANhEJAA8AbgBjzNeB9wK/LSJpIA7ce9ZFUqVUkZ0YijIW1f7mag6Bboy5b5btXwa+XLCKlFJzlutvru3mKkdHiipVphLpLPu1v7k6iwa6UmVost08re3m6iwa6EqVoePBCOM6T4t6HQ10pcrM0ESS0yM6T4t6Iw10pcpIIp3lQF/I7jLUPGUtQ2AsRmCsOH+QZ+3lopQqDZPrgur85qXPGEMoniYwFqd3PE5gLE5gLEZ/KEHGMnz05ov477evL/jnaqArVSaOD0UIx7XdvNRYlqE3FOfkUJTAeC64e8fiRFPZqec01bhZ0VzDhmUNrO2o571XryxKLRroSpWBYDjBGW03LwnprMXJ4SjHghGOBSc4EYwST+fC2+tysKKphqtXN7OiqYaVzX5WNNVQ53stahv9bta01RalNg10pUrcWDSl85vbKJLMcDwY4Xg+wE+PxMhYuWav5Y0+tq5p4ZIldVzcXktbnRdHbiYUW2igK1XCxqIpdveMk7W03XyxGGPoGY3zSs8Yu3vGCYzFAXA6hK5WP7eu7+CSJXVc0l53zpl3KSitapRSU8ZjGuaLxbIMx4IRduVDfDiSQgQuaa/j7s3LubSjnq7WWjyu0u4YqIGuVAkaj6XYpWFeVOmsxcH+MLvOjLO7Z5xIMoPLIWxY3sA7r1jG5lVN1Pvcdpd5QTTQlSoxU2Gu3RMLzspPmfCrEyPs7w2RzFjUuJ1sWtnIllVNbFzRiM/ttLvMedNAV6qEaJgXRzSZ4Vcnhnnq8BBDkSQNPhfXXdTKllVNrFtaj8tZ2k0pc6WBrlSJCMXSGuYF1jMW46nDQV54dZRU1mLtkjrec9UKtnQ24XJURoifTQNdqRIQiqV5pWdMw7wAMpbF7p5xnjwc5OhgBI/TwbVrWnjLuiV0tvjtLq+oNNCVspmGeWGE42l2HBvimaNDjMXStNV5eN/VK7nhkjbqvNURddWxl0qVqFA8zS4N83kzxnByOMqTR4LsPDVGxjJsWNbAB65dzaYVjTgc9g3ysYMGulI2mUik2XVmTCfbmodUxuLXp0Z58kiQ0yMxfG4HN61t5y3r2lnWWGN3ebbRQFfKBslMlt094xrmF2g4kuTpI0P88vgwkWSGZY0+PrC1k+svbi3r7oaFooGu1CKzLMOenhDJtC4fNxeWMRzqD/PU4SH2BMZBYMuqJt66bgmXddQjNs6dUmo00JVaZAf6wjoN7hwk0lmeOzHCk4eDDIQT1Ptc3H7FUm5e205rndfu8kqSBrpSi+jEUITBcMLuMkraYDjBU0eC/PL4MIm0xZq2Wn73xjV0r27GXSEDgIpFA12pRTIYTnByKGp3GSXJGMPB/jC/OBxkXyCEQ4TurmZuWbeEi9rr7C6vbGigK7UIQvG0rgU6jWQ6y/Ov5ppV+kK5ZpU7Ny3j5kvbafJ77C6v7GigK1VkiXSWPT3jWHoNdMpIJMmTh4M8e3yYWCpLZ4ufD9/QxTVdLdqssgAa6EoVUdYy7O4ZJ5XRNIfcsPyfHRjkh3v6sIzhqs5mblm/hEva67S3SgHMGugi8iBwJxA0xmycZvsHgM/k70aAjxlj9hS0SqXK1P7eEJFExu4ySsKpkSj/9NwpesbiXL26md+8eqX2VimwuZyhPwR8Dfj2DNtPAjcbY8ZE5HZgO3BtYcpTqnwdD04wNJG0uwzbpTIWj+3p4z8ODlDvc/PxbRezpbPZ7rIq0qyBbozZISJd59n+3Fl3XwBWLrwspcpb33icU8Mxu8uw3ZGBCb79/CkGJ5K8+ZI23te9Er9HW3qLpdA/2d8FfjLTRhG5H7gfoLOzs8AfrVRpGI+lODwQtrsMW8VSGR55pZdnjg7RXuflv73tUtYva7C7rIpXsEAXkbeQC/QbZ3qOMWY7uSYZuru7dRILVXFC8TS7q7xHy+6ecf7lxdOMx9PctqGDuzYvx+vSeVYWQ0ECXUQ2Ad8CbjfGjBTiPZUqN9W+fNx4LMW/7wzw0qlRVjTV8F+2XcKatlq7y6oqCw50EekEvgf8ljHm6MJLUqr8jEVT7A5UZ5hPJNL8dP8ATx4JYgzctXk5t1++tGLW6Swnc+m2+DCwDWgTkQDwAOAGMMZ8HfgC0Ar8Q74facYY012sgpUqNSORJHsDIbJWdYV5LJXhiYODPHFokGTG4ro1rbz7yuW012tXRLvMpZfLfbNs/wjwkYJVpFQZGY4k2RuorjbzZDrLk0eC/HT/ANFUlqtXN3PXlctZ3lS9C0uUCu0/pNQ8BScS7O8NVU2Yp7MWO44O8eN9/YQTGa5Y0cg9m1fQ2VrZCy8XksspeF3Fa4rSQFdqHgbDuTA3VdDKkrUMz50Y5od7+xmNpriso56PbVvO2iX1dpdWEhwOqPe5qfW48LgEt9OBy+nA7RQ8TgfuqS8p+vQGGuhKXaD+UJyDfeGKD/OsZXjh5AiP7+1ncCLJmrZaPnh9F+uXVe8qQZPhXe9z0ZD/Xud1lczPQwNdqQvQOx7ncH9lh3kma/H8qyM8vm+AoUiSzhY/n3jLJVy5srFkgmsxuJxCnddFXYmG93Q00JWao57RGEcGJuwuo2jSWYvnTozw+L5+RqIpulr9vH/rJWxaUdlB7nE5qPW6qPU6qfW4pm6X42AoDXSl5uDUcJTjwYjdZRRFOmvx7LFhfrK/n7FYmovaavmt61Zz+fKGigtyv9dJU42HRr+bWo+TWq+rouZf10BXahbHBic4PVJ5E22lMhY7jg3xk/0DhOJp1i6p40NvWlMxbeROh9BQ46Kxxk1jjYfGGjeeIvYwKQUa6ErNYHKdy/7xylrUOZ21eOboEI/nux9e1lHP7715DZd1lHeQu10Omv3u3Bl4Ta7N2+Eo3/2ZDw10paaRtQz7ekMMV9B85lnL8PyrIzy2p4/RaIp1S+v5z1cu59KO8u1+6HU7WFLvY0m9lya/u6z/IBWCBrpSr5POWuzpGWc8lra7lIIwxvDymTEe3d3HQChBV6ufD17fxYbl5Tmdrd/jZEmDl/Y6H41+t93llBQNdKXOkkhn2d0zXhHLxk02GX1vVy+nR2Isa/TxsZsv5qrOprI7k63zuWiv97Kk3ku9T0N8JhroSuXFUhl2nRknnsraXcqCnRiK8MgrAY4ORmit9fChG7q4fk1rWbUp1/lcdDT46Gjw6ipHc6Q/JaWAcCLNrjPjpDPlPTFLYCzG93f1sicQot7n4r5rVnHTpe1l0zXP73HS0eijo8FHnVfj6ULpT0xVvdFoij1lPpd573icH+7pY+fpMWrcTu7evJxb13fgc5f+4Jgaj5OOBi9LGnw0aHPKgmigq6oWDCfY31e+Myb2h+L8cE8/vz41isfl4I4rlnLbhqUlf3Zb43HSXu+lo14vbBZSaR91pYrozEiMY8GJspyXZSCc4Id7+njp1Cgep4N3bFzKbRs6SvKCodMpNPjc+QE+7qoY4GMXDXRVdYwxHAtGOFOGoz8Hwwl+tLefF06O4HY6uG1DB++4fGnJBLkI+D350Zn+XHjXepxl16umXGmgq6piWYb9fSGC4fIaMDQ0keRHe/t4/tURnA7h1vW5IG+sKY0g97oddLb4Wd5UUzYXYCuRBrqqGqmMxZ7AOKEyGjAUTWb4wZ4+njkyhAi8dd0Sbt+4rGSCvM7nYnWrn456X1l1iaxUGuiqKsRSGXafGSdWJn3MLcuw49gQj+7uI5rKcNPadt61aRlNfo/dpQHQUudhdYuf1jpdELqUaKCriheKpdkdKJ8+5kcHJ3j4pTP0jMW5tKOO+67pZFWL/et2OhywpN7H6lZ/ybTZq3NpoKuKFpxIcKA3TNYq/a4sI5Ek3305wM7TY7T4PXz0povoXt1s+wVFp1NY2VTDqhZ/WfRrr2Ya6Kpi9YzGODpY+t0SUxmLnx0Y4Cf7BzAY3rVpGe/YuNT2FXMaatysaK5haYMPp7aPlwUNdFWRjg5OlHy3xMlZEL+7M8BINEX36mbed/VKW9ulnU5hWaOPFU012qxShjTQVUVJZrIc6AszGknZXcp5HR2c4NHdvRwdjLCyuYZP33Ap65baN52tno1XBg10VTFGIkkO9IVJlfDFz6ODEzy2p4/DAxM01rj5wLWd3LS23ZYQ1bPxyjNroIvIg8CdQNAYs3Ga7euA/wNcBXzWGPOVglep1HlYluHV4Qinhku3ieX1QX5v9ypuvrTdliHwHpeD1a1+VjTV4NJBQBVlLmfoDwFfA749w/ZR4JPA3QWqSak5i6ey7OsNEY6X5mChUgpyn9s5FeQ6CKgyzRroxpgdItJ1nu1BICgi7yxgXUrNaiCU4NBAuCSnvS2lIPd7nKxuq2VZg47mrHSL2oYuIvcD9wN0dnYu5kerCpK1DIcHwvSPJ+wu5Q2OByM8uru3NILc62RNWy1LG3y292VXi2NRA90Ysx3YDtDd3V16p1Wq5E0k0uzrDRFLltYQ/v5QnO+90suunnEafC5bg7zO5+Kitlra670a5FVGe7mostEzmpu/vJQWoxiPpXhsTx/PHh/G63Jw9+blvG19B95FHlHpcEB7nY/lTT6dX6WKaaCrkpfOWhzqD5fUlLexVIafHhjg5weDZI3hlnVLeOcVyxa9+1+t18WKphqWNfl02lo1p26LDwPbgDYRCQAPAG4AY8zXRWQpsBNoACwR+RSwwRgTLlrVqmqEE2n2BULES2SWxHTW4qkjQX68t59oKsu1a1q4e/MK2usX76zY6RQ66nP9x3X5NnW2ufRyuW+W7QPAyoJVpFReKTWxWMbw4slRHt3Vy0g0xYZlDbz3qhWsbq1dtBoa/W6WN+loTjUzbXJRJSeTtTjUP8Fg2P5eLOmsxYsnR3ni4CC943E6W/z89vWruXx546J8vtvlYFmjj+VNNSW/8LOyn/6GqJISTqTZHwjZvhDFRCLN00eGeOpIkHAiw4qmGn7vxjVcs6YFR5F7johAS62HFU01tNV5te+4mjMNdFUySqGJpXc8zs8PDvL8qyNkLMPGFQ3ctn4p65fVF70LYI3HOXU2rvOOq/nQQFe2y2QtDg9MMBCyp4nFGMPB/jD/cXCQA31h3E7hTRe3cuv6DpY31RT1s7W7oSokDXRlq1A8zYFee5pY0lmLF14d4YlDg/SNJ2iscXP35uXcfGl70bsfelwOOlv8LG+qsWXwkapMGujKFumsxYmhCL1j8UVfUSgUT/P0kSBPHx1iIpFhVXMNH76hi2u6Worel9vvcdLZ6md5o06QpQpPA10tuoFQgqODE4s+b3lgLMYTBwd58eQoGcuwaWUjb1vfwbqlxW8fr/O56GqtpaNBh+Or4tFAV4smlspwqH+CsejirSZkGcP+3hBPHBzk0MAEHpeDN69t45b1HSxt8BX985v8brraamnT9nG1CDTQVdFZluHkSJTTI9FF68GSzGR5/sQIPz8UZCCcoNnv5j1bVnDTpe2L0p+7rd5LV6ufJr+n6J+l1CQNdFVUI5EkRwYmFu2iZ+9YnOdfHeHZY0NEU1lWt/r5vRvXcHVXMy5HcdvH630u2uq9dDT4dBCQsoX+1qmiSKSzHBuMLMpoz5FIkhdPjvLiyVF6x+M4BK5c1cTb1newdkld0dqsnQ6hudZDW52Htjqv9h1XttNAVwWVyVqcHo1xZjRW1JWEIokMO0/nQvxYMALAxe21/KetnXSvbqahpjjdDr1uB211XtrqvLTUenROFVVSNNBVQViWoWcsxqmRGOki9V5JprPsDozz4slRDvSGyRrD8kYf92xZwdaulqLNeNhQ486dhdd7aVjk6XGVuhAa6GpBjDH0hRK8OhQhmS58kCfTuUWgXz4zxt5AiGTGotnv5m0bOrh2TQsrm2sK3qTicgqttV5a6zy01nnwurQpRZUHDXQ1b4PhBCeCkYJf8IylMuwN5EJ8f2+IdNZQ73Nx3UWtXLumhUuW1BV8gqxar4v2eg+ttV6a/G7tK67Kkga6umDDkSQnghEmEpmCvWckmWF3zzivnB7jYH+YjGVoqnHz5rXtXN3ZzNoldQUfWdnod7O0wUd7vV7QVJVBA13N2Wg0xcnhCGPRdEHeL5bKsPPUGDtPj3F4IIxloLXWw1vWLaF7dTNr2moLfibudTtY1ljD8iYffo/++qvKor/R6ryMMQyGk5weiRbkjDxjWezvDfPCqyPs7hknYxmW1Ht5++VLubqzmdWt/oI3dzgdQnu9l2WNPlpqPdqcoiqWBrqaVtYy9I7F6RmLLXg9T2MMp0diPPfqCL8+NcpEIkOd18XNl7Zz/UWtRQlxyDWpLGv00dGgCyir6qCBrs6RzGTpGY0TGIuRWWA/8skBP8+9OsJAKIHLIWxe1cT1F7dy+fKGgo/cFMl1MWyt9dDR4KNWR2uqKqO/8QqAaDLD6ZEYA+H4guZbGQwn2N8bYlfPOEcGJjDA2iV13Hbdarq7mgvebu12OWit9UwN9NG5xVU100CvYol0lpFoimA4wUhkfjMgprMWRwYm2N8XYl8gxOBEEoCOBi/vvnI5113UWvABP5NzprTVemmocWmbuFJ5GuhVxBjDeCzNSDTJcCRFZJ4XOYcjSfYFQuzrC3F4IDevudspXLa0nreuW8LGFY10FHBq2hqPkya/m2a/DvRR6nw00Cvc5Fn4SCTJSDQ1r/lVEuksx4IRDvaF2d8Xoj+/9md7nZcbL27jipWNXNZRX7DmDr/XSbPfQ7PfQ5PfrX3ElZojDfQKE0tlCMczhBNpRqPzOwvPWoZTI1EO9Yc52B/mxFCUrGVwOYRLO+q5aW07V6xspKN+4avviORGaeYC3E2TX9vBlZovDfQylkhnCSfSUwEejqfn1TNlsq/5wf4wh/rDHB6YIJ7OIsCqFj9vW9/BhmUNXLKkrmBhq6M0lSq8WQNdRB4E7gSCxpiN02wX4O+AO4AY8EFjzCuFLrSapbMW8XSWeCpLLJUlHE8TTqTnNRmWMYaxWJqesRg9ozF6RuOcHI4yGstdFG2r83BNVzMbljVw2dJ66gs4u2C9z8XSfL9wDXGlCm8uZ+gPAV8Dvj3D9tuBtfmva4F/zH9Xc2SMIZHOh3Y6SzyVIZ6yiKUyxNPZefcHz1qGgXAiH9y5Oekea1cAAAvRSURBVMp7xuJEkq81w3TUe7movZZ3LlvGhmUNBe+RUuudDHGvDrVXqshm/RdmjNkhIl3necpdwLeNMQZ4QUSaRGSZMaa/QDWe4/kTI/zPnx1maaOPJfW5s72ljV466n105M/+Smn5r1TGIpnJksxYua/0G2+nsxZmnmN4LMswGksxNJEkOJFkKP8VnEjQH0qQsXJv7HIIK5tr2LKqiVUtfjpb/Kxsrin4mbLTKTT43DT53SV3LJSqdIX417YC6DnrfiD/2BsCXUTuB+4H6OzsnPcHel1ODvVP8PSRoWmnbq1xO6fmsm7xe2iu9UxddGvO32+p9dBa66HO68qtOjN5bc+AgamANRhM/jHInU1bVm5OkqxlyFhm6nsma51zP2tZCxqkk8xkiSQyRJKvfYXi6anQHppIMhxNkbVe+2vgdAhtdR7a672sX9ZAZ4ufVS1+ljb4Cr66zuQFzcYaN401bhpq3NR6nNovXCmbFCLQp/vXO+35pjFmO7AdoLu7e17npNdf3Mr1F7dO3Y8kMwyGE/SPJzg9EqVnLEbveJzBcJKRSJK+sQQTyTSJGdqbXQ6h3ueizuuixuOkxu3E5859P+f+1G3HnOcFMSbX/p3KWqQzue+p/Pd01kzdTmUs4qnsa8GdD/FUdvqaa9xO2uu9rGzxs6WzmSX1XtrrvSyp99Ls9xR8mtmpz/U4afBNhreLep9bl2BTqoQUItADwKqz7q8E+grwvnNS53VR117Hxe113Li27ZxtmaxFJJlhIpFhJJKiLxRnMJwgFEszkcj1DJnIh+dEIk08nWUsliaeSky1Z5999ltoAnhcuT8Qfo+TOq+LJr+blc01uf3yuqjzuag/57abWm9xz4KdTsl9Zv4P3eRnunSCK6VKWiEC/THgEyLyr+QuhoaK1X5+oVxOB01+D01+D6ta/GymCWMM8XSWiURmKswjiQyJ9BubbozJNZ3EU9mzLlheWMi7nQ48LgcepwO3S/Dk77udDlwOsbV5wuGAGreLWq/znD8YNR7tgaJUOZpLt8WHgW1Am4gEgAcAN4Ax5uvA4+S6LB4n123xQ8UqthBEBL/Hhd/joqPhtcczWYtoKhfY0VQm9z2ZIZbO4nY6iraK/GJwOYU6b26fa73Oqe81bm3vVqqSzKWXy32zbDfAxwtWkU1cTgeNNQ4apwnuZGYy6HNdChNpa6p3SmqBvVQKweUUfG4nXpfjnO9+Ty68deSlUtVB+5TNgdflxOty0uSffrsxJneR86yAn7zgmbUMmazBMq/1fMlkcz1hsvnHrHwTjtMhuBwOnA6Z+nK97rbL6XhDcOuFSaUUaKAXhIjgcUnuTLiw43KUUmrO9P/iSilVITTQlVKqQmigK6VUhdBAV0qpCqGBrpRSFUIDXSmlKoQGulJKVQgNdKWUqhAa6EopVSHE2DQJiYgMAafn+fI2YLiA5dhJ96U0Vcq+VMp+gO7LpNXGmPbpNtgW6AshIjuNMd1211EIui+lqVL2pVL2A3Rf5kKbXJRSqkJooCulVIUo10DfbncBBaT7UpoqZV8qZT9A92VWZdmGrpRS6o3K9QxdKaXU62igK6VUhSi7QBeRd4jIERE5LiJ/bHc9CyEip0Rkn4jsFpGddtdzIUTkQREJisj+sx5rEZEnRORY/nuznTXOxQz78aci0ps/LrtF5A47a5wrEVklIk+JyCEROSAif5B/vKyOy3n2o+yOi4j4ROQlEdmT35c/yz9elGNSVm3oIuIEjgJvAwLAr4H7jDEHbS1snkTkFNBtjCm7wRIichMQAb5tjNmYf+yvgFFjzF/m/9g2G2M+Y2eds5lhP/4UiBhjvmJnbRdKRJYBy4wxr4hIPfAycDfwQcrouJxnP36TMjsuIiJArTEmIiJu4JfAHwDvoQjHpNzO0LcCx40xrxpjUsC/AnfZXFNVMsbsAEZf9/BdwD/lb/8TuX+EJW2G/ShLxph+Y8wr+dsTwCFgBWV2XM6zH2XH5ETyd935L0ORjkm5BfoKoOes+wHK9EDnGeA/RORlEbnf7mIKoMMY0w+5f5TAEpvrWYhPiMjefJNMSTdRTEdEuoAtwIuU8XF53X5AGR4XEXGKyG4gCDxhjCnaMSm3QJdpHiufNqM3usEYcxVwO/Dx/H//lf3+EbgY2Az0A39tbzkXRkTqgEeATxljwnbXM1/T7EdZHhdjTNYYsxlYCWwVkY3F+qxyC/QAsOqs+yuBPptqWTBjTF/+exD4PrkmpXI2mG//nGwHDdpcz7wYYwbz/wgt4JuU0XHJt9M+AvyLMeZ7+YfL7rhMtx/lfFwAjDHjwNPAOyjSMSm3QP81sFZE1oiIB3g/8JjNNc2LiNTmL/ggIrXAbcD+87+q5D0G/E7+9u8AP7Cxlnmb/IeWdw9lclzyF+D+N3DIGPM3Z20qq+My036U43ERkXYRacrfrgFuBQ5TpGNSVr1cAPJdlf4WcAIPGmO+ZHNJ8yIiF5E7KwdwAd8pp30RkYeBbeSmAR0EHgAeBf4d6ATOAO8zxpT0BccZ9mMbuf/WG+AU8NHJ9s5SJiI3As8C+wAr//CfkGt/Lpvjcp79uI8yOy4isoncRU8nuRPofzfGfFFEWinCMSm7QFdKKTW9cmtyUUopNQMNdKWUqhAa6EopVSE00JVSqkJooCulVIXQQFdKqQqhga6KKj/l6adF5IsicquNdWyebbpVEXl3BUzJ/Cd216Dso/3QVVGVylS0IvJBclMVf8LOOqYjIk5jTLZA7xUxxtQV4r1U+dEzdFVwIvJZyS1C8nPgsvxjD4nIb+Rv/6WIHMzPmveV/GMdIvL9/EIAe0TkTfnH/1BE9ue/PpV/rEvOXZDi0/k/HIjI0yLy5fyiAkdF5M35aSK+CNybXxjh3hnq/qCIfO2sev9eRJ4TkVcna5/hddtEZEe+/oMi8nURceS33SYiz4vIKyLy3fyEU5OLm3xBRH4JvE9yC7e8kt/3X+SfU5ufVfDXIrJLRO46q87vichPJbdAwl9N/lyBmvw+/sv8jp4qZy67C1CVRUSuJjfHzhZyv1+vkFugYHJ7C7l5ONYZY8zkPBfA3wPPGGPukdxCJnX59/oQcC25mTZfFJFngLFZynAZY7bmm1geMMbcKiJf4MLP0JcBNwLryM298f/O89ytwAbgNPBT4D0i8jTwOeBWY0xURD4D/CG5Py4ACWPMjSLSTu7ndJMx5mT+ZwTwWeBJY8yH8z+nl/J/JCE3BH4LkASOiMhXjTF/LCKfyM/sp6qQBroqtDcD3zfGxABE5PWTp4WBBPAtEfkx8KP8428Ffhty040CofycHt83xkTz7/W9/PvPNiHb5CyDLwNdC9iXR/Mz+x0UkY5ZnvuSMebVfJ0Pk/tDkCAX8r/KzTeFB3j+rNf8W/77dcAOY8xJgLPm9LgNeLeIfDp/30du7g+AXxhjQvnPOwis5ty1AlQV0kBXxTDjhRljTEZEtgK3kDuT/wS5MJ/OdPPfA2Q4t7nQ97rtyfz3LAv7HU+edXumWia9fp9N/jVPGGPum+E10bPee7qfmQDvNcYcOedBkWtfV9tC91NVCG1DV4W2A7hHRGry0wO/6+yN+TbkRmPM48CnyDUdAPwC+Fj+OU4Raci/190i4pfcFMP3kJuFbxBYIiKtIuIF7pxDXRNA/cJ3b0ZbJTetswO4l9zakS8AN4jIJQD5/bh0mtc+D9wsImvyz5tscvkZ8PuSP70XkS1zqCMtubnEVRXSQFcFlV8L8t+A3eQWKHj2dU+pB34kInuBZ4D/mn/8D4C3iMg+ck0ll+ff6yHgJXJTwH7LGLPLGJMm1w79Irkmm8NzKO0pYMP5Loou0PPAX5Kbo/skuaaiIXILND+c398XyLXHnyP/vPuB74nIHl5rivlzcmtQ7s1fBP7zOdSxPf98vShahbTbolILJCLbgE8bY+byPwWlikbP0JVSqkLoGbqqOiLyIXJNPGf7lTHm47O87grg/77u4aQx5tpC1qfUfGmgK6VUhdAmF6WUqhAa6EopVSE00JVSqkJooCulVIX4/9yd/tXa7f04AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "# plotting median and 25% and 75% percentile\n", "agg_dependency = dependency.groupby(\"discount_in_percent\").normalized_prediction.agg(\n", diff --git a/examples/ar.py b/examples/ar.py index 964cc6af..bf64de2e 100644 --- a/examples/ar.py +++ b/examples/ar.py @@ -6,7 +6,7 @@ import pandas as pd from pandas.core.common import SettingWithCopyWarning import pytorch_lightning as pl -from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger import torch @@ -64,20 +64,19 @@ validation.save("validation.pkl") early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min") -lr_logger = LearningRateLogger() +lr_logger = LearningRateMonitor() trainer = pl.Trainer( max_epochs=100, gpus=0, weights_summary="top", gradient_clip_val=0.1, - early_stop_callback=early_stop_callback, limit_train_batches=30, limit_val_batches=3, # fast_dev_run=True, # logger=logger, # profiler=True, - callbacks=[lr_logger], + callbacks=[lr_logger, early_stop_callback], ) diff --git a/examples/nbeats.py b/examples/nbeats.py index c4adf2bf..f44f962c 100644 --- a/examples/nbeats.py +++ b/examples/nbeats.py @@ -55,7 +55,7 @@ gpus=0, weights_summary="top", gradient_clip_val=0.1, - early_stop_callback=early_stop_callback, + callbacks=[early_stop_callback], limit_train_batches=15, # limit_val_batches=1, # fast_dev_run=True, diff --git a/examples/stallion.py b/examples/stallion.py index c0b6ec1b..0d0456ae 100644 --- a/examples/stallion.py +++ b/examples/stallion.py @@ -6,7 +6,7 @@ import pandas as pd from pandas.core.common import SettingWithCopyWarning import pytorch_lightning as pl -from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger import torch @@ -93,21 +93,20 @@ validation.save("validation.pkl") early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") -lr_logger = LearningRateLogger() +lr_logger = LearningRateMonitor() trainer = pl.Trainer( max_epochs=100, gpus=0, weights_summary="top", gradient_clip_val=0.1, - early_stop_callback=early_stop_callback, limit_train_batches=30, # val_check_interval=20, # limit_val_batches=1, # fast_dev_run=True, # logger=logger, # profiler=True, - callbacks=[lr_logger], + callbacks=[lr_logger, early_stop_callback], ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index a803a54c..c81019fd 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -9,7 +9,7 @@ from optuna.integration import PyTorchLightningPruningCallback, TensorBoardCallback import pytorch_lightning as pl from pytorch_lightning import Callback -from pytorch_lightning.callbacks import LearningRateLogger +from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger import statsmodels.api as sm import torch @@ -99,7 +99,7 @@ def objective(trial: optuna.Trial) -> float: # TensorBoard. We don't use any logger here as it requires us to implement several abstract # methods. Instead we setup a simple callback, that saves metrics from each validation step. metrics_callback = MetricsCallback() - learning_rate_callback = LearningRateLogger() + learning_rate_callback = LearningRateMonitor() logger = TensorBoardLogger(log_dir, name="optuna", version=trial.number) gradient_clip_val = trial.suggest_loguniform("gradient_clip_val", *gradient_clip_val_range) trainer = pl.Trainer( @@ -107,8 +107,11 @@ def objective(trial: optuna.Trial) -> float: max_epochs=max_epochs, gradient_clip_val=gradient_clip_val, gpus=[0] if torch.cuda.is_available() else None, - callbacks=[metrics_callback, learning_rate_callback], - early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_loss"), + callbacks=[ + metrics_callback, + learning_rate_callback, + PyTorchLightningPruningCallback(trial, monitor="val_loss"), + ], logger=logger, **trainer_kwargs, ) diff --git a/tests/test_models/test_nbeats.py b/tests/test_models/test_nbeats.py index d43276f1..0c8e798b 100644 --- a/tests/test_models/test_nbeats.py +++ b/tests/test_models/test_nbeats.py @@ -4,7 +4,6 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_forecasting.metrics import QuantileLoss from pytorch_forecasting.models import NBeats @@ -21,7 +20,7 @@ def test_integration(dataloaders_fixed_window_without_coveratiates, tmp_path, gp gpus=gpus, weights_summary="top", gradient_clip_val=0.1, - early_stop_callback=early_stop_callback, + callbacks=[early_stop_callback], fast_dev_run=True, logger=logger, ) diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index ee42a199..648fbdd8 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -40,7 +40,7 @@ def test_integration(multiple_dataloaders_with_coveratiates, tmp_path, gpus): gpus=gpus, weights_summary="top", gradient_clip_val=0.1, - early_stop_callback=early_stop_callback, + callbacks=[early_stop_callback], fast_dev_run=True, logger=logger, )