Skip to content

Commit

Permalink
Make compatible with new keras Callback model
Browse files Browse the repository at this point in the history
see https://keras.io/guides/writing_your_own_callbacks/ for guide -
methods have been split to provide a different one for
train/eval/predict time.

This change translates `on_batch_end` to `on_train_batch_end` rather
than also for validation batches, so batch history only includes train
batch loss.
  • Loading branch information
ali.teeney committed Jun 2, 2020
1 parent 9a13745 commit f6e9966
Showing 1 changed file with 39 additions and 5 deletions.
44 changes: 39 additions & 5 deletions kerashistoryplot/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ class Callback(object):
optionally include `val_loss`
(if validation is enabled in `fit`), and `val_acc`
(if validation and accuracy monitoring are enabled).
on_batch_begin: logs include `size`,
on_(train/predict/test)_batch_begin: logs include `size`,
the number of samples in the current batch.
on_batch_end: logs include `loss`, and optionally `acc`
(if accuracy monitoring is enabled).
on_(train/predict/test)_batch_end: logs include `loss`, and optionally
`acc` (if accuracy monitoring is enabled).
For other methods, currently no data is passed in logs but that may change
in the future.
"""

def __init__(self):
Expand All @@ -42,9 +44,29 @@ def on_epoch_end(self, epoch, logs=None):
pass

def on_batch_begin(self, batch, logs=None):
pass
"""Backwards compatibility alias for `on_train_batch_begin`."""
return self.on_train_batch_begin(batch, logs)

def on_batch_end(self, batch, logs=None):
"""Backwards compatibility alias for `on_train_batch_end`."""
return self.on_train_batch_end(batch, logs)

def on_train_batch_begin(self, batch, logs=None):
pass

def on_train_batch_end(self, batch, logs=None):
pass

def on_test_batch_begin(self, batch, logs=None):
pass

def on_test_batch_end(self, batch, logs=None):
pass

def on_predict_batch_begin(self, batch, logs=None):
pass

def on_predict_batch_end(self, batch, logs=None):
pass

def on_train_begin(self, logs=None):
Expand All @@ -53,6 +75,18 @@ def on_train_begin(self, logs=None):
def on_train_end(self, logs=None):
pass

def on_test_begin(self, logs=None):
pass

def on_test_end(self, logs=None):
pass

def on_predict_begin(self, logs=None):
pass

def on_predict_end(self, logs=None):
pass


class BatchHistory(Callback):
"""Log all history, including per-batch losses and metrics.
Expand All @@ -67,7 +101,7 @@ def on_epoch_begin(self, epoch, logs=None):
batch_history = self.history.setdefault('batches', [])
batch_history.append({})

def on_batch_end(self, batch, logs=None):
def on_train_batch_end(self, batch, logs=None):
batch_history = self.history['batches'][-1]
for k, v in logs.items():
batch_history.setdefault(k, []).append(v)
Expand Down

0 comments on commit f6e9966

Please sign in to comment.