Skip to content

Commit

Permalink
Dev icenet-ai#252: trying to load_model rather than weights
Browse files Browse the repository at this point in the history
  • Loading branch information
JimCircadian committed May 28, 2024
1 parent 0ebcdcb commit 6bd2cfa
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions icenet/model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import pandas as pd
import tensorflow as tf

from tensorflow.keras.models import load_model

from icenet.data.loader import save_sample
from icenet.data.dataset import IceNetDataSet
from icenet.model.cli import predict_args
from icenet.model.networks.tensorflow import unet_batchnorm

"""
Expand All @@ -21,9 +21,6 @@ def predict_forecast(
dataset_config: object,
network_name: object,
dataset_name: object = None,
legacy_rounding: bool = False,
model_func: callable = unet_batchnorm,
n_filters_factor: float = 1 / 8,
network_folder: object = None,
output_folder: object = None,
save_args: bool = False,
Expand Down Expand Up @@ -55,18 +52,14 @@ def predict_forecast(
network_folder = os.path.join(".", "results", "networks", network_name)

dataset_name = dataset_name if dataset_name else ds.identifier
network_path = os.path.join(
network_folder, "{}.network_{}.{}.h5".format(network_name,
dataset_name,
seed))
model_path = os.path.join(
network_folder, "{}.network_{}.{}".format(network_name,
dataset_name,
seed))

logging.info("Loading model from {}...".format(network_path))
logging.info("Loading model from {}...".format(model_path))

network = model_func((*ds.shape, dl.num_channels), [], [],
legacy_rounding=legacy_rounding,
n_filters_factor=n_filters_factor,
n_forecast_days=ds.n_forecast_days)
network.load_weights(network_path)
network = load_model(model_path, compile=False)

if not test_set:
logging.info("Generating forecast inputs from processed/ files")
Expand Down Expand Up @@ -168,8 +161,6 @@ def main():
# do we need to retain the train SD name in the
# network?
dataset_name=args.ident if args.ident else args.dataset,
legacy_rounding=args.legacy_rounding,
n_filters_factor=args.n_filters_factor,
output_folder=output_folder,
save_args=args.save_args,
seed=args.seed,
Expand Down

0 comments on commit 6bd2cfa

Please sign in to comment.