Skip to content

Commit

Permalink
[DIPU] Wx/modify maximum schema due to the case in the inference of i…
Browse files Browse the repository at this point in the history
…nternlm (DeepLink-org#494)

* improve maximum schema due to the case in the inference of internlm

* fix bug according to comments

* fix bug
  • Loading branch information
POI-WX authored and ustclight-sls committed Dec 8, 2023
1 parent 51978d9 commit 34924b0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
15 changes: 12 additions & 3 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1134,8 +1134,12 @@
interface: diopiMaxAll(ctx, out, self)

- schema: "maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"
no_device_check_args: [other]
interface: diopiMaximum(ctx, out, self, other)
no_device_check_args: [self, other]
ins: [selfTemp, otherTemp]
custom_code_at_the_beginning: |
auto selfTemp = (self.numel() == 1 && self.is_cpu()) ? self.to(other.device()) : self;
auto otherTemp = (other.numel() == 1 && other.is_cpu()) ? other.to(self.device()) : other;
interface: diopiMaximum(ctx, out, selfTemp, otherTemp)

- schema: "max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_indices) -> (Tensor(a!) max, Tensor(b!) max_indices)"
custom_code_at_the_beginning: |
Expand Down Expand Up @@ -1679,7 +1683,12 @@
interface: diopiClampMaxInp(ctx, self, max)

- schema: "minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"
interface: diopiMinimum(ctx,out, self, other)
no_device_check_args: [self, other]
ins: [selfTemp, otherTemp]
custom_code_at_the_beginning: |
auto selfTemp = (self.numel() == 1 && self.is_cpu()) ? self.to(other.device()) : self;
auto otherTemp = (other.numel() == 1 && other.is_cpu()) ? other.to(self.device()) : other;
interface: diopiMinimum(ctx, out, selfTemp, otherTemp)

- schema: "scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)"
interface: diopiScatterScalar(ctx, out, self, dim, value, index, "")
Expand Down
40 changes: 40 additions & 0 deletions dipu/tests/python/unittests/test_minimum_maximum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,53 @@ def test_minimum(self):
r_cpu = torch.minimum(a.to(self.cpu), b.to(self.cpu))
self.assertEqual(r_dipu.to(self.cpu), r_cpu)

def test_minimum_scalar(self):
# special test cases from the inference of internlm
a = torch.randn((3, 4))
b = torch.tensor(torch.finfo(a.dtype).max)
# scalar on cpu
r_dipu1 = torch.minimum(a.to(self.dipu), b)
# scalar on device
r_dipu2 = torch.minimum(a.to(self.dipu), b.to(self.dipu))
r_cpu = torch.minimum(a, b)
self.assertEqual(r_dipu1.to(self.cpu), r_cpu)
self.assertEqual(r_dipu2.to(self.cpu), r_cpu)

def test_minimum_different_devices(self):
a = torch.tensor([1, -2, 3])
b = torch.tensor([4, 0, 2]).to(self.dipu)
with self.assertRaises(RuntimeError) as context:
torch.minimum(a, b)
self.assertIn(
'Expected all tensors to be on the same device', str(context.exception))

def test_maximum(self):
a = torch.tensor((1, 2, -1))
b = torch.tensor((3, 0, 4))
r_dipu = torch.maximum(a.to(self.dipu), b.to(self.dipu))
r_cpu = torch.maximum(a.to(self.cpu), b.to(self.cpu))
self.assertEqual(r_dipu.to(self.cpu), r_cpu)

def test_maximum_scalar(self):
# special test cases from the inference of internlm
a = torch.randn((3, 4))
b = torch.tensor(torch.finfo(a.dtype).min)
# scalar on cpu
r_dipu1 = torch.maximum(a.to(self.dipu), b)
# scalar on device
r_dipu2 = torch.maximum(a.to(self.dipu), b.to(self.dipu))
r_cpu = torch.maximum(a, b)
self.assertEqual(r_dipu1.to(self.cpu), r_cpu)
self.assertEqual(r_dipu2.to(self.cpu), r_cpu)

def test_maximum_different_devices(self):
a = torch.tensor([1, -2, 3])
b = torch.tensor([4, 0, 2]).to(self.dipu)
with self.assertRaises(RuntimeError) as context:
torch.maximum(a, b)
self.assertIn(
'Expected all tensors to be on the same device', str(context.exception))


if __name__ == "__main__":
run_tests()

0 comments on commit 34924b0

Please sign in to comment.