From 6bd2cfac06045d43c46599813898e64360854da0 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Tue, 28 May 2024 16:18:46 +0100 Subject: [PATCH] Dev #252: trying to load_model rather than weights --- icenet/model/predict.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/icenet/model/predict.py b/icenet/model/predict.py index 6aeec42..309aa53 100644 --- a/icenet/model/predict.py +++ b/icenet/model/predict.py @@ -6,11 +6,11 @@ import pandas as pd import tensorflow as tf +from tensorflow.keras.models import load_model from icenet.data.loader import save_sample from icenet.data.dataset import IceNetDataSet from icenet.model.cli import predict_args -from icenet.model.networks.tensorflow import unet_batchnorm """ @@ -21,9 +21,6 @@ def predict_forecast( dataset_config: object, network_name: object, dataset_name: object = None, - legacy_rounding: bool = False, - model_func: callable = unet_batchnorm, - n_filters_factor: float = 1 / 8, network_folder: object = None, output_folder: object = None, save_args: bool = False, @@ -55,18 +52,14 @@ def predict_forecast( network_folder = os.path.join(".", "results", "networks", network_name) dataset_name = dataset_name if dataset_name else ds.identifier - network_path = os.path.join( - network_folder, "{}.network_{}.{}.h5".format(network_name, - dataset_name, - seed)) + model_path = os.path.join( + network_folder, "{}.network_{}.{}".format(network_name, + dataset_name, + seed)) - logging.info("Loading model from {}...".format(network_path)) + logging.info("Loading model from {}...".format(model_path)) - network = model_func((*ds.shape, dl.num_channels), [], [], - legacy_rounding=legacy_rounding, - n_filters_factor=n_filters_factor, - n_forecast_days=ds.n_forecast_days) - network.load_weights(network_path) + network = load_model(model_path, compile=False) if not test_set: logging.info("Generating forecast inputs from processed/ files") @@ -168,8 +161,6 @@ def main(): # do we need to retain the train SD name in the # network? dataset_name=args.ident if args.ident else args.dataset, - legacy_rounding=args.legacy_rounding, - n_filters_factor=args.n_filters_factor, output_folder=output_folder, save_args=args.save_args, seed=args.seed,