diff --git a/candle/ckpt_utils.py b/candle/ckpt_utils.py index b051411..822883f 100644 --- a/candle/ckpt_utils.py +++ b/candle/ckpt_utils.py @@ -209,6 +209,9 @@ def scan_params(self, gParams): self.save_weights_only = self.param( "ckpt_save_weights_only", True, ParamType.BOOLEAN ) + self.save_weights_file = self.param( + "ckpt_save_weights_file", True, ParamType.BOOLEAN + ) self.checksum_enabled = self.param("ckpt_checksum", False, ParamType.BOOLEAN) self.keep_mode = self.param( "ckpt_keep_mode", @@ -349,23 +352,25 @@ def write_model(self, dir_work, epoch): Do the I/O, report stats dir_work: A pathlib.Path """ - self.model_file = dir_work / "model.h5" - self.debug("writing model to: '%s'" % self.relpath(self.model_file)) - start = time.time() - - # Call down to backend-specific model writer: - self.write_model_backend(self.model, epoch) - - stop = time.time() - duration = stop - start - stats = os.stat(self.model_file) - MB = stats.st_size / (1024 * 1024) - rate = MB / duration - self.debug( - "model wrote: %0.3f MB in %0.3f seconds (%0.2f MB/s)." - % (MB, duration, rate) - ) - self.checksum(dir_work) + if self.save_weights_file: + self.model_file = dir_work / "model.h5" + self.debug("writing model to: '%s'" % + self.relpath(self.model_file)) + start = time.time() + + # Call down to backend-specific model writer: + self.write_model_backend(self.model, epoch) + + stop = time.time() + duration = stop - start + stats = os.stat(self.model_file) + MB = stats.st_size / (1024 * 1024) + rate = MB / duration + self.debug( + "model wrote: %0.3f MB in %0.3f seconds (%0.2f MB/s)." + % (MB, duration, rate) + ) + self.checksum(dir_work) self.write_json(dir_work / "ckpt-info.json", epoch) def checksum(self, dir_work): @@ -387,8 +392,9 @@ def write_json(self, jsonfile, epoch): D["epoch"] = epoch D["save_best_metric"] = self.save_best_metric D["best_metric_last"] = self.best_metric_last - D["model_file"] = "model.h5" - D["checksum"] = self.cksum_model + if self.save_weights_file: + D["model_file"] = "model.h5" + D["checksum"] = self.cksum_model D["timestamp"] = now.strftime("%Y-%m-%d %H:%M:%S") if self.timestamp_last is None: time_elapsed = "__FIRST__" diff --git a/candle/parsing_utils.py b/candle/parsing_utils.py index 069a7c3..5655f3f 100644 --- a/candle/parsing_utils.py +++ b/candle/parsing_utils.py @@ -187,6 +187,7 @@ class ConfigDict(TypedDict): ckpt_save_best: bool ckpt_save_best_metric: str ckpt_save_weights_only: bool + ckpt_save_weights_file: bool ckpt_save_interval: int ckpt_keep_mode: str ckpt_keep_limit: int @@ -612,6 +613,12 @@ class ConfigDict(TypedDict): "default": False, "help": "Toggle saving only weights (not optimizer) (NYI).", }, + { + "name": "ckpt_save_weights_file", + "type": str2bool, + "default": True, + "help": "Toggle whether to save weights. JSON will still be saved.", + }, { "name": "ckpt_save_interval", "type": int,