diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py index e3bc04d414683..7359da3c8ec9b 100644 --- a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py @@ -441,6 +441,23 @@ def test_invalid_filter(): self.assertRaises(ValueError, test_invalid_filter) + def test_invalid_groups(): + paddle.enable_static() + input = paddle.static.data( + name='input_groups', shape=[1, 1, 1, 1], dtype='float32' + ) + offset = paddle.static.data( + name='offset_groups', shape=[1, 1], dtype='float32' + ) + mask = paddle.static.data( + name='mask_groups', shape=[1], dtype='float32' + ) + paddle.static.nn.deform_conv2d( + input, offset, mask, 1, 1, padding=1, groups=0 + ) + + self.assertRaises(ValueError, test_invalid_groups) + class TestDeformConv2DAPI(unittest.TestCase): def test_api(self): diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 53954f49f343a..0b278eefa1551 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -2255,6 +2255,8 @@ def deformable_conv( if groups is None: num_filter_channels = num_channels else: + if groups == 0: + raise ValueError("groups should not be 0.") if num_channels % groups != 0: raise ValueError("num_channels must be divisible by groups.") num_filter_channels = num_channels // groups