Skip to content

Commit

Permalink
fix glu
Browse files Browse the repository at this point in the history
  • Loading branch information
longranger2 committed Feb 6, 2023
1 parent 44f93ae commit 21b2d6a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_glu.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def glu_axis_size(self):
paddle.nn.functional.glu(x, axis=256)

def test_errors(self):
self.assertRaises(AssertionError, self.glu_axis_size)
self.assertRaises(ValueError, self.glu_axis_size)


if __name__ == '__main__':
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,13 @@ def glu(x, axis=-1, name=None):
check_variable_and_dtype(
x, 'input', ['float16', 'float32', 'float64'], "glu"
)
assert axis < len(x.shape), "axis must < rank(x)"
rank = len(x.shape)
if not (-rank <= axis < rank):
raise ValueError(
"Expected value range of `axis` is [{}, {}), but received axis: {}".format(
-rank, rank, axis
)
)
a, b = chunk(x, 2, axis=axis, name=name)
gate = sigmoid(b, name=name)
out = paddle.multiply(a, gate, name=name)
Expand Down

0 comments on commit 21b2d6a

Please sign in to comment.