-
Notifications
You must be signed in to change notification settings - Fork 1
/
diff_operators.py
executable file
·71 lines (56 loc) · 1.98 KB
/
diff_operators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from torch.autograd import grad
def hessian(y, x):
"""hessian of y wrt x
y: shape (meta_batch_size, num_observations, channels)
x: shape (meta_batch_size, num_observations, 2)
"""
meta_batch_size, num_observations = y.shape[:2]
grad_y = torch.ones_like(y[..., 0]).to(y.device)
h = torch.zeros(
meta_batch_size, num_observations, y.shape[-1], x.shape[-1], x.shape[-1]
).to(y.device)
for i in range(y.shape[-1]):
# calculate dydx over batches for each feature value of y
dydx = grad(y[..., i], x, grad_y, create_graph=True)[0]
# calculate hessian on y for each x value
for j in range(x.shape[-1]):
h[..., i, j, :] = grad(dydx[..., j], x, grad_y, create_graph=True)[0][
..., :
]
status = 0
if torch.any(torch.isnan(h)):
status = -1
return h, status
def laplace(y, x):
grad = gradient(y, x)
return divergence(grad, x)
def divergence(y, x):
div = 0.0
for i in range(y.shape[-1]):
div += grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][
..., i : i + 1
]
return div
def gradient(y, x, grad_outputs=None):
y = y.squeeze(0)
if grad_outputs is None:
grad_outputs = torch.ones_like(y)
grad = torch.autograd.grad(
y, x, grad_outputs=grad_outputs, create_graph=True, allow_unused=True
)
return grad[0]
def jacobian(y, x):
"""jacobian of y wrt x"""
meta_batch_size, num_observations = y.shape[:2]
jac = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1]).to(
y.device
) # (meta_batch_size*num_points, 2, 2)
for i in range(y.shape[-1]):
# calculate dydx over batches for each feature value of y
y_flat = y[..., i].view(-1, 1)
jac[:, :, i, :] = grad(y_flat, x, torch.ones_like(y_flat), create_graph=True)[0]
status = 0
if torch.any(torch.isnan(jac)):
status = -1
return jac, status