Skip to content

Commit

Permalink
Add a safe wrapper for cross entropy.
Browse files Browse the repository at this point in the history
This fixes an error we were getting when softmax_cross_entropy_with_logits received empty tensors.
  • Loading branch information
IanTayler committed Dec 13, 2017
1 parent 849b0b3 commit a664f4d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
5 changes: 4 additions & 1 deletion luminoth/models/fasterrcnn/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from luminoth.models.fasterrcnn.rcnn_target import RCNNTarget
from luminoth.models.fasterrcnn.roi_pool import ROIPoolingLayer
from luminoth.utils.losses import smooth_l1_loss
from luminoth.utils.safe_wrappers import (
safe_softmax_cross_entropy_with_logits
)
from luminoth.utils.vars import (
get_initializer, layer_summaries, variable_summaries,
get_activation_function
Expand Down Expand Up @@ -304,7 +307,7 @@ def loss(self, prediction_dict):

# We get cross entropy loss of each proposal.
cross_entropy_per_proposal = (
tf.nn.softmax_cross_entropy_with_logits(
safe_softmax_cross_entropy_with_logits(
labels=cls_target_one_hot, logits=cls_score_labeled
)
)
Expand Down
5 changes: 4 additions & 1 deletion luminoth/models/fasterrcnn/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from .rpn_target import RPNTarget
from .rpn_proposal import RPNProposal
from luminoth.utils.losses import smooth_l1_loss
from luminoth.utils.safe_wrappers import (
safe_softmax_cross_entropy_with_logits
)
from luminoth.utils.vars import (
get_initializer, layer_summaries, variable_summaries,
get_activation_function
Expand Down Expand Up @@ -257,7 +260,7 @@ def loss(self, prediction_dict):
cls_target = tf.one_hot(labels, depth=2)

# Equivalent to log loss
ce_per_anchor = tf.nn.softmax_cross_entropy_with_logits(
ce_per_anchor = safe_softmax_cross_entropy_with_logits(
labels=cls_target, logits=cls_score
)
prediction_dict['cross_entropy_per_anchor'] = ce_per_anchor
Expand Down
16 changes: 16 additions & 0 deletions luminoth/utils/safe_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import tensorflow as tf


def safe_softmax_cross_entropy_with_logits(
labels, logits, name='safe_cross_entropy'):
with tf.name_scope(name):
safety_condition = tf.greater(
tf.shape(logits)[0], 0, name='safety_condition'
)
return tf.cond(
safety_condition,
true_fn=lambda: tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits
),
false_fn=lambda: tf.constant([], dtype=logits.dtype)
)

0 comments on commit a664f4d

Please sign in to comment.