-
Notifications
You must be signed in to change notification settings - Fork 0
/
pixel_plot.py
177 lines (158 loc) · 5.12 KB
/
pixel_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import random
from functools import partial
import numpy as np
import PIL
import torch
from beartype import beartype
from beartype.typing import (
Callable,
Optional,
)
from jaxtyping import Float
from jsonargparse import CLI
from matplotlib import pyplot as plt
from torchvision import transforms
from imclassplots.directions import (
get_gradient_based_direction,
get_orthogonal_1d_direction,
get_random_1d_direction,
)
from imclassplots.peturb import (
peturb,
peturb_and_predict,
)
from imclassplots.plot import plot_predictions
def seed_everything(seed: int) -> None:
"""Seed everything for reproducibility"""
random.seed(42)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@beartype
def main(
image_fpath: str,
true_label: int,
grid_size: int,
scale_factor: float,
model: Optional[torch.nn.Module],
model_fn: Optional[Callable[[], torch.nn.Module]],
model_fn_kwargs: Optional[dict],
direction: str,
batch_size: int,
display_ims: bool,
dataset_labels: list[str],
dataset_transform: Callable[
[PIL.Image.Image],
Float[torch.Tensor, " dim1 dim2 dim3"],
],
dataset_normalize: Callable[
[Float[torch.Tensor, " dim1 dim2 dim3"]],
Float[torch.Tensor, " dim1 dim2 dim3"],
],
dataset_imsize_x: int,
dataset_imsize_y: int,
device: str = "cpu",
random_seed: int = 42,
) -> None:
"""Evaluate image over grid of perturbations in two directions and plot.
Saves (predictions,x_direction,y_direction,orig_img)
Args:
image_fpath: path to image. Will be resized to 32x32 if not already
true_label: expected class label of image. TODO what does this mean
grid_size: size of grid (square)
scale_factor: scale factor for peturbations
model: model to evaluate
model_fn: function to create model if model is not provided
model_fn_kwargs: kwargs for model_fn
direction: method to pick xdirection: random or gradient
batch_size: batch size
display_ims: visualise images alongside plot
dataset_labels: labels used in classifier training
dataset_transform: transform used in classifier training
dataset_normalize: normalize used in classifier training
(typically as part of transform)
device: device to run model on
random_seed: seed for random number generator
"""
seed_everything(random_seed)
img = PIL.Image.open(image_fpath).resize(
(dataset_imsize_x, dataset_imsize_y), PIL.Image.Resampling.LANCZOS
)
if model is None: # create model if not provided
if model_fn is None:
raise ValueError("model or model_fn must be provided")
if model_fn_kwargs is None:
model_fn_kwargs = {}
model = model_fn(**model_fn_kwargs)
model = model.eval()
model.to(device)
# directions
if direction == "random":
im_size = img.height * img.width * len(img.getbands())
x_direction = get_random_1d_direction(size=im_size)
elif direction == "gradient":
x_direction = get_gradient_based_direction(
model=model,
imtensor=transforms.ToTensor()(img),
normalize=dataset_normalize,
label=true_label,
device=device,
)
else:
raise ValueError(
f"direction must be random or gradient, not {direction}"
)
y_direction = get_orthogonal_1d_direction(u=x_direction)
predictions = peturb_and_predict(
image=img,
model=model,
label=true_label,
grid_size=grid_size,
data_transform=dataset_transform,
x_direction=x_direction,
y_direction=y_direction,
device=device,
batch_size=batch_size,
scale_factor=scale_factor,
)
plot_directory = "./plots"
if not os.path.exists(plot_directory):
os.makedirs(plot_directory)
data_fname = os.path.join(
plot_directory,
f"predictionsAndDirs_label"
f"{true_label}_gridsize{grid_size}_sf{scale_factor}.pt",
)
torch.save(
(predictions, x_direction, y_direction, transforms.ToTensor()(img)),
data_fname,
)
print(f"saved (predictions,x_direction,y_direction) tuple at {data_fname}")
figure_fname = os.path.join(
plot_directory,
f"fig_{true_label}_gridsize{grid_size}_sf{scale_factor}.png",
)
im_gen_fn = (
partial(
peturb, img=img, direction_a=x_direction, direction_b=y_direction
)
if display_ims
else None
)
fig = plot_predictions(
predictions=predictions,
class_labels=dataset_labels,
true_image_label=true_label,
display_ims=display_ims,
im_generation_fn=im_gen_fn,
scale_factor=scale_factor,
)
fig.savefig(figure_fname)
print(f"saved figure at {figure_fname}")
plt.show()
if __name__ == "__main__":
CLI(main, as_positional=False, fail_untyped=False)