From 5f504d1de74a5189f93e65aa9283fc4607b8631c Mon Sep 17 00:00:00 2001 From: Pedro Caldeira Date: Wed, 22 Nov 2023 22:57:32 +0000 Subject: [PATCH] Check for boolean values as argument on pow function. (#114133) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hello everyone! 😄 Also @lezcano , nice to meet you! :) Sorry if I miss anything, this is my first time around here. 🙃 This PR basically makes the same behaviour for cuda when using `torch.pow`. Basically Python considers True as 1 and False as 0. I just added this check into `pow` function. From what I understood, when I do `.equal` for `Scalar` that is boolean, I'm sure that types match so that won't cause more trouble. I know that the issue suggest to disable this case but that could be a little more complicated, in my humble opinion. And that can create some compability problems too, I guess. My argument is that code below is correct for native language, so I guess it does makes sense sending booleans as Scalar. ``` $ x = True $ x + x 2 ``` This was my first test: ``` Python 3.12.0 | packaged by Anaconda, Inc. | (main, Oct 2 2023, 17:29:18) [GCC 11.2.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import torch >>> torch.pow(torch.tensor([1, 2], device='cuda'), True) tensor([1, 2], device='cuda:0') >>> torch.pow(torch.tensor([1, 2]), True) tensor([1, 2]) >>> torch.pow(torch.tensor([1, 2]), False) tensor([1, 1]) >>> torch.pow(torch.tensor([1, 2], device='cuda'), False) tensor([1, 1], device='cuda:0') ``` I've run `test_torch.py` and got following results, so my guess is that I didn't break anything. I was just looking for a test that uses linear regression, as suggested. ``` Ran 1619 tests in 52.363s OK (skipped=111) [TORCH_VITAL] Dataloader.enabled True [TORCH_VITAL] Dataloader.basic_unit_test TEST_VALUE_STRING [TORCH_VITAL] CUDA.used true ``` (I can paste whole log, if necessary) If this is a bad idea overall, dont worry about it. It's not a big deal, it's actually a two line change 😅 so can we talk of how do things in a different strategy. For the record I've signed the agreement already. And I didn't run linter because it's not working 😞 . Looks like PyYaml 6.0 is broken and there's a 6.0.1 fix already but I have no idea how to update that 😅 Fixes #113198 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114133 Approved by: https://github.com/lezcano --- aten/src/ATen/native/Pow.cpp | 4 ++-- test/test_binary_ufuncs.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 0fa0fceab6ca37..5c8147d7ced38f 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -50,9 +50,9 @@ TORCH_IMPL_FUNC(pow_Tensor_Tensor_out) (const Tensor& base, const Tensor& exp, c } TORCH_IMPL_FUNC(pow_Tensor_Scalar_out) (const Tensor& base, const Scalar& exp, const Tensor& out) { - if (exp.equal(0.0)) { + if (exp.equal(0.0) || exp.equal(false)) { out.fill_(1); - } else if (exp.equal(1.0)) { + } else if (exp.equal(1.0) || exp.equal(true) ) { out.copy_(base); } else { pow_tensor_scalar_stub(device_type(), *this, exp); diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 2f22569f9cf1a8..9fcc8b445eb1cd 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1345,7 +1345,7 @@ def test_pow(self, device, dtype): (100, 100), low=1, high=range_high, dtype=dtype, device=device ) - exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3] + exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3, True, False] complex_exponents = [ -2.5j, -1.0j,