Skip to content

Commit

Permalink
Error handling for ACLSD loss
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 7, 2023
1 parent 0b49aa7 commit 474cfe4
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/raygun/torch/losses/WeightedMSELoss_ACLSD.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,16 @@ def forward(
"pred_affs_ac": pred_affs_ac.detach(),
}
)

lsd_loss = self._calc_loss(pred_lsds, gt_lsds, lsds_weights)
aff_loss = self._calc_loss(pred_affs, gt_affs, affs_weights)
ac_aff_loss = self._calc_loss(pred_affs_ac, gt_affs)
try:
lsd_loss = self._calc_loss(pred_lsds, gt_lsds, lsds_weights)
aff_loss = self._calc_loss(pred_affs, gt_affs, affs_weights)
except:
lsd_loss = aff_loss = 0.

try:
ac_aff_loss = self._calc_loss(pred_affs_ac, gt_affs)
except:
ac_aff_loss = 0.

self.loss_dict = {"LSDs": lsd_loss.detach(), "Affinities1": aff_loss.detach(), "Affinities2": ac_aff_loss}

Expand Down

0 comments on commit 474cfe4

Please sign in to comment.