Skip to content

Commit

Permalink
feat: Return metrics from fit method the same way as from from validate
Browse files Browse the repository at this point in the history
  • Loading branch information
lRomul committed Feb 3, 2024
1 parent 73cb23b commit 6e5c13d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
6 changes: 5 additions & 1 deletion argus/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def fit(self,
List of callbacks to be attached to the validation process.
Defaults to `None`.
Returns:
dict: The metrics dictionary.
"""
self._check_train_ready()
metrics = [] if metrics is None else metrics
Expand All @@ -222,7 +225,8 @@ def validation_epoch(train_state, val_engine, val_loader):
val_engine.run(val_loader, -1, 0)

attach_callbacks(train_engine, callbacks)
train_engine.run(train_loader, 0, num_epochs)
state = train_engine.run(train_loader, 0, num_epochs)
return state.metrics

def validate(self,
val_loader: Iterable,
Expand Down
8 changes: 4 additions & 4 deletions tests/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def test_fit_train_val_loader(self,
drop_last=True, batch_size=32)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64)
val_loss_before = model.validate(val_loader)['val_loss']
model.fit(train_loader,
val_loader=val_loader,
num_epochs=32)
val_loss_after = model.validate(val_loader)['val_loss']
val_loss_after = model.fit(train_loader,
val_loader=val_loader,
num_epochs=32)['val_loss']
assert val_loss_after == model.validate(val_loader)['val_loss']
assert val_loss_after < val_loss_before
assert val_loss_after < 0.3

Expand Down

0 comments on commit 6e5c13d

Please sign in to comment.