Skip to content

Commit

Permalink
Check for boolean values as argument on pow function. (pytorch#114133)
Browse files Browse the repository at this point in the history
Hello everyone! 😄
Also @lezcano , nice to meet you! :)

Sorry if I miss anything, this is my first time around here. 🙃

This PR basically makes the same behaviour for cuda when using `torch.pow`. Basically Python considers True as 1 and False as 0. I just added this check into `pow` function. From what I understood, when I do `.equal` for `Scalar` that is boolean, I'm sure that types match so that won't cause more trouble.

I know that the issue suggest to disable this case but that could be a little more complicated, in my humble opinion. And that can create some compability problems too, I guess.

My argument is that code below is correct for native language, so I guess it does makes sense sending booleans as Scalar.

```
$ x = True
$ x + x
2
```

This was my first test:
```
Python 3.12.0 | packaged by Anaconda, Inc. | (main, Oct  2 2023, 17:29:18) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.pow(torch.tensor([1, 2], device='cuda'), True)
tensor([1, 2], device='cuda:0')
>>> torch.pow(torch.tensor([1, 2]), True)
tensor([1, 2])
>>> torch.pow(torch.tensor([1, 2]), False)
tensor([1, 1])
>>> torch.pow(torch.tensor([1, 2], device='cuda'), False)
tensor([1, 1], device='cuda:0')
```

I've run `test_torch.py` and got following results, so my guess is that I didn't break anything. I was just looking for a test that uses linear regression, as suggested.

```
Ran 1619 tests in 52.363s

OK (skipped=111)
[TORCH_VITAL] Dataloader.enabled		 True
[TORCH_VITAL] Dataloader.basic_unit_test		 TEST_VALUE_STRING
[TORCH_VITAL] CUDA.used		 true

```
(I can paste whole log, if necessary)

If this is a bad idea overall, dont worry about it. It's not a big deal, it's actually a two line change 😅  so can we talk of how do things in a different strategy.

For the record I've signed the agreement already. And I didn't run linter because it's not working 😞 . Looks like PyYaml 6.0 is broken and there's a 6.0.1 fix already but I have no idea how to update that 😅

Fixes pytorch#113198

Pull Request resolved: pytorch#114133
Approved by: https://github.com/lezcano
  • Loading branch information
pdrocaldeira authored and pytorchmergebot committed Nov 22, 2023
1 parent aca6446 commit 5f504d1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/Pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ TORCH_IMPL_FUNC(pow_Tensor_Tensor_out) (const Tensor& base, const Tensor& exp, c
}

TORCH_IMPL_FUNC(pow_Tensor_Scalar_out) (const Tensor& base, const Scalar& exp, const Tensor& out) {
if (exp.equal(0.0)) {
if (exp.equal(0.0) || exp.equal(false)) {
out.fill_(1);
} else if (exp.equal(1.0)) {
} else if (exp.equal(1.0) || exp.equal(true) ) {
out.copy_(base);
} else {
pow_tensor_scalar_stub(device_type(), *this, exp);
Expand Down
2 changes: 1 addition & 1 deletion test/test_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def test_pow(self, device, dtype):
(100, 100), low=1, high=range_high, dtype=dtype, device=device
)

exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3]
exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3, True, False]
complex_exponents = [
-2.5j,
-1.0j,
Expand Down

0 comments on commit 5f504d1

Please sign in to comment.