diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 8812397c5..4b58185a2 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -1134,8 +1134,12 @@ interface: diopiMaxAll(ctx, out, self) - schema: "maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" - no_device_check_args: [other] - interface: diopiMaximum(ctx, out, self, other) + no_device_check_args: [self, other] + ins: [selfTemp, otherTemp] + custom_code_at_the_beginning: | + auto selfTemp = (self.numel() == 1 && self.is_cpu()) ? self.to(other.device()) : self; + auto otherTemp = (other.numel() == 1 && other.is_cpu()) ? other.to(self.device()) : other; + interface: diopiMaximum(ctx, out, selfTemp, otherTemp) - schema: "max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_indices) -> (Tensor(a!) max, Tensor(b!) max_indices)" custom_code_at_the_beginning: | @@ -1679,7 +1683,12 @@ interface: diopiClampMaxInp(ctx, self, max) - schema: "minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" - interface: diopiMinimum(ctx,out, self, other) + no_device_check_args: [self, other] + ins: [selfTemp, otherTemp] + custom_code_at_the_beginning: | + auto selfTemp = (self.numel() == 1 && self.is_cpu()) ? self.to(other.device()) : self; + auto otherTemp = (other.numel() == 1 && other.is_cpu()) ? other.to(self.device()) : other; + interface: diopiMinimum(ctx, out, selfTemp, otherTemp) - schema: "scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)" interface: diopiScatterScalar(ctx, out, self, dim, value, index, "") diff --git a/dipu/tests/python/unittests/test_minimum_maximum.py b/dipu/tests/python/unittests/test_minimum_maximum.py index eecc57bc1..a6b00383d 100644 --- a/dipu/tests/python/unittests/test_minimum_maximum.py +++ b/dipu/tests/python/unittests/test_minimum_maximum.py @@ -15,6 +15,26 @@ def test_minimum(self): r_cpu = torch.minimum(a.to(self.cpu), b.to(self.cpu)) self.assertEqual(r_dipu.to(self.cpu), r_cpu) + def test_minimum_scalar(self): + # special test cases from the inference of internlm + a = torch.randn((3, 4)) + b = torch.tensor(torch.finfo(a.dtype).max) + # scalar on cpu + r_dipu1 = torch.minimum(a.to(self.dipu), b) + # scalar on device + r_dipu2 = torch.minimum(a.to(self.dipu), b.to(self.dipu)) + r_cpu = torch.minimum(a, b) + self.assertEqual(r_dipu1.to(self.cpu), r_cpu) + self.assertEqual(r_dipu2.to(self.cpu), r_cpu) + + def test_minimum_different_devices(self): + a = torch.tensor([1, -2, 3]) + b = torch.tensor([4, 0, 2]).to(self.dipu) + with self.assertRaises(RuntimeError) as context: + torch.minimum(a, b) + self.assertIn( + 'Expected all tensors to be on the same device', str(context.exception)) + def test_maximum(self): a = torch.tensor((1, 2, -1)) b = torch.tensor((3, 0, 4)) @@ -22,6 +42,26 @@ def test_maximum(self): r_cpu = torch.maximum(a.to(self.cpu), b.to(self.cpu)) self.assertEqual(r_dipu.to(self.cpu), r_cpu) + def test_maximum_scalar(self): + # special test cases from the inference of internlm + a = torch.randn((3, 4)) + b = torch.tensor(torch.finfo(a.dtype).min) + # scalar on cpu + r_dipu1 = torch.maximum(a.to(self.dipu), b) + # scalar on device + r_dipu2 = torch.maximum(a.to(self.dipu), b.to(self.dipu)) + r_cpu = torch.maximum(a, b) + self.assertEqual(r_dipu1.to(self.cpu), r_cpu) + self.assertEqual(r_dipu2.to(self.cpu), r_cpu) + + def test_maximum_different_devices(self): + a = torch.tensor([1, -2, 3]) + b = torch.tensor([4, 0, 2]).to(self.dipu) + with self.assertRaises(RuntimeError) as context: + torch.maximum(a, b) + self.assertIn( + 'Expected all tensors to be on the same device', str(context.exception)) + if __name__ == "__main__": run_tests()