-
Notifications
You must be signed in to change notification settings - Fork 0
/
HolisticAttention.py
38 lines (30 loc) · 1.2 KB
/
HolisticAttention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.parameter import Parameter
import numpy as np
import scipy.stats as st
def gkern(kernlen=16, nsig=3):
interval = (2*nsig+1.)/kernlen
x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
kern1d = np.diff(st.norm.cdf(x))
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
kernel = kernel_raw/kernel_raw.sum()
return kernel
def min_max_norm(in_):
max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
in_ = in_ - min_
return in_.div(max_-min_+1e-8)
class HA(nn.Module):
# holistic attention module
def __init__(self):
super(HA, self).__init__()
gaussian_kernel = np.float32(gkern(31, 4))
gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...]
self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel))
def forward(self, attention, x):
soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15)
soft_attention = min_max_norm(soft_attention)
x = torch.mul(x, soft_attention.max(attention))
return x