Skip to content

Commit

Permalink
Merge pull request #71 from Kevin2/update_filter_deep_notebook
Browse files Browse the repository at this point in the history
Small additions to new main changes
  • Loading branch information
FeGeyer authored Nov 11, 2020
2 parents 74db327 + de69c71 commit e5874ae
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 54 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/.rc/filter_deep_eval/source_normal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 5 additions & 4 deletions examples/.rc/training_filter_deep_amp.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ amp_phase = true
source_list = false
arch_name = "filter_deep_amp"
loss_func = "loss_amp"
num_epochs = 400
num_epochs = 200
inspection = false
output_format = "png"

[hypers]
batch_size = 100
lr = 2e-3

[param_scheduling]
use = true
lr_start = 1e-1
lr_max = 5e-1
lr_stop = 1e0
lr_start = 1e-3
lr_max = 5e-3
lr_stop = 1e-2
1 change: 1 addition & 0 deletions examples/.rc/training_filter_deep_phase.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ arch_name = "filter_deep_phase"
loss_func = "loss_phase"
num_epochs = 400
inspection = false
output_format = "png"

[hypers]
batch_size = 100
Expand Down
118 changes: 113 additions & 5 deletions examples/04_filter_deep.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion radionets/dl_framework/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def reshape_2d(array):
return array.reshape(-1, *shape)


def plot_loss(learn, model_path, output_format="png"):
def plot_loss(learn, model_path, output_format="pdf"):
"""
Plot train and valid loss of model.
Expand Down
2 changes: 1 addition & 1 deletion radionets/dl_framework/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def define_learner(
test=False,
):
model_path = train_conf["model_path"]
model_name = model_path.split("build/")[-1].split("/")[0]
model_name = model_path.split("build/")[-1].split("/")[-1].split("/")[0].split(".")[0]
lr = train_conf["lr"]
opt_func = Adam
if train_conf["norm_path"] != "none":
Expand Down
145 changes: 103 additions & 42 deletions radionets/evaluation/train_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
histogram_ms_ssim,
)
from radionets.evaluation.utils import (
create_databunch,
reshape_2d,
load_pretrained_model,
get_images,
Expand All @@ -22,6 +23,7 @@
from radionets.evaluation.jet_angle import calc_jet_angle
from radionets.evaluation.dynamic_range import calc_dr
from pytorch_msssim import ms_ssim
from tqdm import tqdm


def get_prediction(conf, num_images=None, rand=False):
Expand Down Expand Up @@ -155,90 +157,149 @@ def create_source_plots(conf, num_images=3, rand=False):


def evaluate_viewing_angle(conf):
if conf["separate"]:
pred, img_test, img_true = get_separate_prediction(conf)
else:
pred, img_test, img_true = get_prediction(conf)
# create DataLoader
loader = create_databunch(
conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"]
)
model_path = conf["model_path"]
out_path = Path(model_path).parent / "evaluation"
out_path.mkdir(parents=True, exist_ok=True)

ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])
img_size = loader.dataset[0][0][0].shape[-1]
model = load_pretrained_model(conf["arch_name"], conf["model_path"], img_size)
if conf["separate"]:
model_2 = load_pretrained_model(
conf["arch_name_2"], conf["model_path_2"], img_size
)

alpha_truths = []
alpha_preds = []

# iterate trough DataLoader
for i, (img_test, img_true) in enumerate(tqdm(loader)):

m_truth, n_truth, alpha_truth = calc_jet_angle(torch.tensor(ifft_truth))
m_pred, n_pred, alpha_pred = calc_jet_angle(torch.tensor(ifft_pred))
pred = eval_model(img_test, model)
if conf["separate"]:
pred_2 = eval_model(img_test, model_2)
pred = torch.cat((pred, pred_2), dim=1)

ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

m_truth, n_truth, alpha_truth = calc_jet_angle(torch.tensor(ifft_truth))
m_pred, n_pred, alpha_pred = calc_jet_angle(torch.tensor(ifft_pred))

alpha_truths.extend(alpha_truth)
alpha_preds.extend(alpha_pred)

alpha_truths = torch.tensor(alpha_truths)
alpha_preds = torch.tensor(alpha_preds)

click.echo("\nCreating histogram of jet angles.\n")
histogram_jet_angles(
alpha_truth,
alpha_pred,
out_path,
plot_format=conf["format"],
alpha_truths, alpha_preds, out_path, plot_format=conf["format"],
)


def evaluate_dynamic_range(conf):
if conf["separate"]:
pred, img_test, img_true = get_separate_prediction(conf)
else:
pred, img_test, img_true = get_prediction(conf)
# create Dataloader
loader = create_databunch(
conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"]
)
model_path = conf["model_path"]
out_path = Path(model_path).parent / "evaluation"
out_path.mkdir(parents=True, exist_ok=True)

ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])
img_size = loader.dataset[0][0][0].shape[-1]
model = load_pretrained_model(conf["arch_name"], conf["model_path"], img_size)
if conf["separate"]:
model_2 = load_pretrained_model(
conf["arch_name_2"], conf["model_path_2"], img_size
)

dr_truths = np.array([])
dr_preds = np.array([])

# iterate trough DataLoader
for i, (img_test, img_true) in enumerate(tqdm(loader)):

pred = eval_model(img_test, model)
if conf["separate"]:
pred_2 = eval_model(img_test, model_2)
pred = torch.cat((pred, pred_2), dim=1)

dr_truth, dr_pred, _, _ = calc_dr(ifft_truth, ifft_pred)
ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

dr_truth, dr_pred, _, _ = calc_dr(ifft_truth, ifft_pred)
dr_truths = np.append(dr_truths, dr_truth)
dr_preds = np.append(dr_preds, dr_pred)

click.echo(
f"\nMean dynamic range for true source distributions:\
{round(dr_truth.mean())}\n"
{round(dr_truths.mean())}\n"
)
click.echo(
f"\nMean dynamic range for predicted source distributions:\
{round(dr_pred.mean())}\n"
{round(dr_preds.mean())}\n"
)

click.echo("\nCreating histogram of dynamic ranges.\n")
histogram_dynamic_ranges(
dr_truth,
dr_pred,
out_path,
plot_format=conf["format"],
dr_truths, dr_preds, out_path, plot_format=conf["format"],
)


def evaluate_ms_ssim(conf):
pred, _, img_true = get_prediction(conf)
# create DataLoader
loader = create_databunch(
conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"]
)
model_path = conf["model_path"]
out_path = Path(model_path).parent / "evaluation"
out_path.mkdir(parents=True, exist_ok=True)

ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])
img_size = loader.dataset[0][0][0].shape[-1]
model = load_pretrained_model(conf["arch_name"], conf["model_path"], img_size)
if conf["separate"]:
model_2 = load_pretrained_model(
conf["arch_name_2"], conf["model_path_2"], img_size
)

vals = []

if ifft_truth.shape[-1] < 160:
if img_size < 160:
click.echo(
"\nThis is only a placeholder!\
Images too small for meaningful ms ssim calculations.\n"
Images too small for meaningful ms ssim calculations.\n"
)

ifft_truth = pad_unsqueeze(torch.tensor(ifft_truth))
ifft_pred = pad_unsqueeze(torch.tensor(ifft_pred))
# iterate trough DataLoader
for i, (img_test, img_true) in enumerate(tqdm(loader)):

vals = torch.tensor(
[
ms_ssim(pred.unsqueeze(0), truth.unsqueeze(0), data_range=truth.max())
for pred, truth in zip(ifft_pred, ifft_truth)
]
)
pred = eval_model(img_test, model)
if conf["separate"]:
pred_2 = eval_model(img_test, model_2)
pred = torch.cat((pred, pred_2), dim=1)

click.echo("\nCreating ms-ssim histogram.\n")
ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

if img_size < 160:
ifft_truth = pad_unsqueeze(torch.tensor(ifft_truth))
ifft_pred = pad_unsqueeze(torch.tensor(ifft_pred))

vals.extend(
[
ms_ssim(pred.unsqueeze(0), truth.unsqueeze(0), data_range=truth.max())
for pred, truth in zip(ifft_pred, ifft_truth)
]
)

click.echo("\nCreating ms-ssim histogram.\n")
vals = torch.tensor(vals)
histogram_ms_ssim(
vals,
out_path,
plot_format=conf["format"],
vals, out_path, plot_format=conf["format"],
)

click.echo(f"\nThe mean ms-ssim value is {vals.mean()}.\n")
15 changes: 14 additions & 1 deletion radionets/evaluation/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import numpy as np
import pandas as pd
from radionets.dl_framework.model import load_pre_model
from radionets.dl_framework.data import do_normalisation
from radionets.dl_framework.data import do_normalisation, load_data
import radionets.dl_framework.architecture as architecture
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader


def create_databunch(data_path, fourier, source_list, batch_size):
# Load data sets
test_ds = load_data(
data_path, mode="test", fourier=fourier, source_list=source_list,
)

# Create databunch with defined batchsize
data = DataLoader(test_ds, batch_size=batch_size, shuffle=True)
return data


def read_config(config):
Expand Down Expand Up @@ -36,6 +48,7 @@ def read_config(config):
eval_conf["viewing_angle"] = config["eval"]["evaluate_viewing_angle"]
eval_conf["dynamic_range"] = config["eval"]["evaluate_dynamic_range"]
eval_conf["ms_ssim"] = config["eval"]["evaluate_ms_ssim"]
eval_conf["batch_size"] = config["eval"]["batch_size"]
return eval_conf


Expand Down

0 comments on commit e5874ae

Please sign in to comment.