Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Check for boolean values as argument on pow function. (pytorch#114133)
Hello everyone! 😄 Also @lezcano , nice to meet you! :) Sorry if I miss anything, this is my first time around here. 🙃 This PR basically makes the same behaviour for cuda when using `torch.pow`. Basically Python considers True as 1 and False as 0. I just added this check into `pow` function. From what I understood, when I do `.equal` for `Scalar` that is boolean, I'm sure that types match so that won't cause more trouble. I know that the issue suggest to disable this case but that could be a little more complicated, in my humble opinion. And that can create some compability problems too, I guess. My argument is that code below is correct for native language, so I guess it does makes sense sending booleans as Scalar. ``` $ x = True $ x + x 2 ``` This was my first test: ``` Python 3.12.0 | packaged by Anaconda, Inc. | (main, Oct 2 2023, 17:29:18) [GCC 11.2.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import torch >>> torch.pow(torch.tensor([1, 2], device='cuda'), True) tensor([1, 2], device='cuda:0') >>> torch.pow(torch.tensor([1, 2]), True) tensor([1, 2]) >>> torch.pow(torch.tensor([1, 2]), False) tensor([1, 1]) >>> torch.pow(torch.tensor([1, 2], device='cuda'), False) tensor([1, 1], device='cuda:0') ``` I've run `test_torch.py` and got following results, so my guess is that I didn't break anything. I was just looking for a test that uses linear regression, as suggested. ``` Ran 1619 tests in 52.363s OK (skipped=111) [TORCH_VITAL] Dataloader.enabled True [TORCH_VITAL] Dataloader.basic_unit_test TEST_VALUE_STRING [TORCH_VITAL] CUDA.used true ``` (I can paste whole log, if necessary) If this is a bad idea overall, dont worry about it. It's not a big deal, it's actually a two line change 😅 so can we talk of how do things in a different strategy. For the record I've signed the agreement already. And I didn't run linter because it's not working 😞 . Looks like PyYaml 6.0 is broken and there's a 6.0.1 fix already but I have no idea how to update that 😅 Fixes pytorch#113198 Pull Request resolved: pytorch#114133 Approved by: https://github.com/lezcano
- Loading branch information