forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
huasdorff_distance_loss.py
160 lines (140 loc) · 5.55 KB
/
huasdorff_distance_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
master/code/train_LA_HD.py (Apache-2.0 License)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import distance_transform_edt as distance
from torch import Tensor
from mmseg.registry import MODELS
from .utils import get_class_weight, weighted_loss
def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
"""
compute the distance transform map of foreground in mask
Args:
img_gt: Ground truth of the image, (b, h, w)
pred: Predictions of the segmentation head after softmax, (b, c, h, w)
Returns:
output: the foreground Distance Map (SDM)
dtm(x) = 0; x in segmentation boundary
inf|x-y|; x in segmentation
"""
fg_dtm = torch.zeros_like(pred)
out_shape = pred.shape
for b in range(out_shape[0]): # batch size
for c in range(1, out_shape[1]): # default 0 channel is background
posmask = img_gt[b].byte()
if posmask.any():
posdis = distance(posmask)
fg_dtm[b][c] = torch.from_numpy(posdis)
return fg_dtm
@weighted_loss
def hd_loss(seg_soft: Tensor,
gt: Tensor,
seg_dtm: Tensor,
gt_dtm: Tensor,
class_weight=None,
ignore_index=255) -> Tensor:
"""
compute huasdorff distance loss for segmentation
Args:
seg_soft: softmax results, shape=(b,c,x,y)
gt: ground truth, shape=(b,x,y)
seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
gt_dtm: ground truth distance transform map, shape=(b,c,x,y)
Returns:
output: hd_loss
"""
assert seg_soft.shape[0] == gt.shape[0]
total_loss = 0
num_class = seg_soft.shape[1]
if class_weight is not None:
assert class_weight.ndim == num_class
for i in range(1, num_class):
if i != ignore_index:
delta_s = (seg_soft[:, i, ...] - gt.float())**2
s_dtm = seg_dtm[:, i, ...]**2
g_dtm = gt_dtm[:, i, ...]**2
dtm = s_dtm + g_dtm
multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
hd_loss = multiplied.mean()
if class_weight is not None:
hd_loss *= class_weight[i]
total_loss += hd_loss
return total_loss / num_class
@MODELS.register_module()
class HuasdorffDisstanceLoss(nn.Module):
"""HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
Maps Boost Segmentation CNNs: An Empirical Study.
<http://proceedings.mlr.press/v121/ma20b.html>`_.
Args:
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float): Weight of the loss. Defaults to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
loss_name (str): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_boundary'.
"""
def __init__(self,
reduction='mean',
class_weight=None,
loss_weight=1.0,
ignore_index=255,
loss_name='loss_huasdorff_disstance',
**kwargs):
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self._loss_name = loss_name
self.ignore_index = ignore_index
def forward(self,
pred: Tensor,
target: Tensor,
avg_factor=None,
reduction_override=None,
**kwargs) -> Tensor:
"""Forward function.
Args:
pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
target (Tensor): Ground truth of the image. (B, H, W)
avg_factor (int, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): The reduction method used
to override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
Tensor: Loss tensor.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = pred.new_tensor(self.class_weight)
else:
class_weight = None
pred_soft = F.softmax(pred, dim=1)
valid_mask = (target != self.ignore_index).long()
target = target * valid_mask
with torch.no_grad():
gt_dtm = compute_dtm(target.cpu(), pred_soft)
gt_dtm = gt_dtm.float()
seg_dtm2 = compute_dtm(
pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
seg_dtm2 = seg_dtm2.float()
loss_hd = self.loss_weight * hd_loss(
pred_soft,
target,
seg_dtm=seg_dtm2,
gt_dtm=gt_dtm,
reduction=reduction,
avg_factor=avg_factor,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss_hd
@property
def loss_name(self):
return self._loss_name