Skip to content

Commit

Permalink
Fix plotting for msssim sampled method (#173)
Browse files Browse the repository at this point in the history
* Fix plotting for msssim

* Towncrier
  • Loading branch information
FeGeyer authored Aug 16, 2024
1 parent 799a1aa commit d40ea95
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/changes/173.bubugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- fix plotting for `evaluate_msssim_sampled`
5 changes: 2 additions & 3 deletions radionets/evaluation/train_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def evaluate_ms_ssim_sampled(conf):
name_model = Path(model_path).stem
data_path = str(out_path) + f"/sampled_imgs_{name_model}.h5"
loader = create_sampled_databunch(data_path, conf["batch_size"])
vals = []
vals = np.array([])

# iterate trough DataLoader
for i, (samp, std, img_true) in enumerate(tqdm(loader)):
Expand All @@ -578,10 +578,9 @@ def evaluate_ms_ssim_sampled(conf):
win_size=7,
size_average=False,
)
vals.extend(val)
vals = np.append(vals, val)

click.echo("\nCreating ms-ssim histogram.\n")
vals = torch.tensor(vals)
histogram_ms_ssim(vals, out_path, plot_format=conf["format"])

click.echo(f"\nThe mean ms-ssim value is {vals.mean()}.\n")
Expand Down

0 comments on commit d40ea95

Please sign in to comment.