From 6cd7b52684ee2bb00c4af2397d85025c1fb00d41 Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Fri, 4 Jun 2021 14:05:32 -0500 Subject: [PATCH] [#98] don't pass weights to `fit` call --- river_dl/train.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/river_dl/train.py b/river_dl/train.py index 3bfb551..7fbf151 100644 --- a/river_dl/train.py +++ b/river_dl/train.py @@ -99,9 +99,7 @@ def train_model( # use built in 'fit' method unless model is grad correction x_trn_pre = io_data["x_trn"] # combine with weights to pass to loss function - y_trn_pre = np.concatenate( - [io_data["y_pre_trn"], io_data["y_pre_wgts"]], axis=2 - ) + y_trn_pre = io_data["y_pre_trn"] model.compile(optimizer_pre, loss=loss_func) @@ -138,9 +136,7 @@ def train_model( ) x_trn_obs = io_data["x_trn"] - y_trn_obs = np.concatenate( - [io_data["y_obs_trn"], io_data["y_obs_wgts"]], axis=2 - ) + y_trn_obs = io_data["y_obs_trn"] model.fit( x=x_trn_obs,