From 6e5c13dcfb310d6f613e7dbacfc829db1b0cb184 Mon Sep 17 00:00:00 2001 From: Ruslan Baikulov Date: Sat, 3 Feb 2024 15:48:49 +0300 Subject: [PATCH] feat: Return metrics from fit method the same way as from from validate --- argus/model/model.py | 6 +++++- tests/model/test_model.py | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/argus/model/model.py b/argus/model/model.py index 1db695e..5655cc0 100644 --- a/argus/model/model.py +++ b/argus/model/model.py @@ -196,6 +196,9 @@ def fit(self, List of callbacks to be attached to the validation process. Defaults to `None`. + Returns: + dict: The metrics dictionary. + """ self._check_train_ready() metrics = [] if metrics is None else metrics @@ -222,7 +225,8 @@ def validation_epoch(train_state, val_engine, val_loader): val_engine.run(val_loader, -1, 0) attach_callbacks(train_engine, callbacks) - train_engine.run(train_loader, 0, num_epochs) + state = train_engine.run(train_loader, 0, num_epochs) + return state.metrics def validate(self, val_loader: Iterable, diff --git a/tests/model/test_model.py b/tests/model/test_model.py index 9b0e8fb..a3ba0dd 100644 --- a/tests/model/test_model.py +++ b/tests/model/test_model.py @@ -91,10 +91,10 @@ def test_fit_train_val_loader(self, drop_last=True, batch_size=32) val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64) val_loss_before = model.validate(val_loader)['val_loss'] - model.fit(train_loader, - val_loader=val_loader, - num_epochs=32) - val_loss_after = model.validate(val_loader)['val_loss'] + val_loss_after = model.fit(train_loader, + val_loader=val_loader, + num_epochs=32)['val_loss'] + assert val_loss_after == model.validate(val_loader)['val_loss'] assert val_loss_after < val_loss_before assert val_loss_after < 0.3