Skip to content

Commit

Permalink
Fix plotting for msssim
Browse files Browse the repository at this point in the history
  • Loading branch information
FeGeyer committed Aug 16, 2024
1 parent 799a1aa commit 9de966a
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions radionets/evaluation/train_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def evaluate_ms_ssim_sampled(conf):
name_model = Path(model_path).stem
data_path = str(out_path) + f"/sampled_imgs_{name_model}.h5"
loader = create_sampled_databunch(data_path, conf["batch_size"])
vals = []
vals = np.array([])

# iterate trough DataLoader
for i, (samp, std, img_true) in enumerate(tqdm(loader)):
Expand All @@ -578,10 +578,9 @@ def evaluate_ms_ssim_sampled(conf):
win_size=7,
size_average=False,
)
vals.extend(val)
vals = np.append(vals, val)

click.echo("\nCreating ms-ssim histogram.\n")
vals = torch.tensor(vals)
histogram_ms_ssim(vals, out_path, plot_format=conf["format"])

click.echo(f"\nThe mean ms-ssim value is {vals.mean()}.\n")
Expand Down

0 comments on commit 9de966a

Please sign in to comment.