-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge OpenAI Triton commit f8b5301
#3069
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This should enable python 3.13 and 3.13t wheels and disable building python 3.8 wheels. Related to: pytorch/pytorch#143654
Running the example given in the [autotune docstring](https://triton-lang.org/main/python-api/generated/triton.autotune.html) gives the error ```python import triton import torch @triton.autotune(configs=[ triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), ], key=['x_size'] # the two above configs will be evaluated anytime # the value of x_size changes ) @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] if __name__ == '__main__': x = torch.ones(8, device="cuda") kernel[lambda _: (1,)](x, x.numel()) ``` ``` Traceback (most recent call last): File "...", line 18, in <module> kernel[lambda _: (1,)](x, x.size) File ".../triton/runtime/jit.py", line 345, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) File ".../triton/runtime/autotuner.py", line 156, in run timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} File ".../triton/runtime/autotuner.py", line 156, in <dictcomp> timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} File ".../triton/runtime/autotuner.py", line 133, in _bench return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) File ".../triton/testing.py", line 106, in do_bench fn() File ".../triton/runtime/autotuner.py", line 114, in kernel_call self.fn.run( File ".../triton/runtime/jit.py", line 618, in run bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) TypeError: dynamic_func() missing 1 required positional argument: 'META' ``` It seems it cannot parse the kwargs `**META`, so the keyword arguments must be manually specified. Also, `BLOCK_SIZE` should probably be marked as `tl.constexpr`. ```python @triton.autotune( configs=[ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8), ], key=["x_size"], # the two above configs will be evaluated anytime # the value of x_size changes ) @triton.jit def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): ... ``` Similarly, for the [heuristics](https://triton-lang.org/main/python-api/generated/triton.heuristics.html) example, first, the same `**META` issue applies, second, `args` is no longer a list of positional argument values but a dictionary from argument name to value, and third, `2 ** int(math.ceil(math.log2(args[1])))` is awkward and `triton.next_power_of_2(args['x_size'])` should be preferred. ```python import torch import triton @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size if __name__ == "__main__": x = torch.ones(8, device="cuda") kernel[lambda _: (1,)](x, x.numel()) ``` ``` Traceback (most recent call last): File "...", line 15, in <module> kernel[lambda _: (1,)](x, x.numel()) File ".../triton/runtime/jit.py", line 345, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) File ".../triton/runtime/autotuner.py", line 337, in run kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) File "...", line 7, in <lambda> @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) KeyError: 1 ``` Applying the suggested changes results in ```python # smallest power-of-two >= x_size @triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])}) @triton.jit def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): ... ```
Fixes #5484. Since python objects can arbitrarily override `__contains__`, using `inspect.ismodule` seems to be the most general solution, beyond numpy arrays. Overriding a module's `__contains__` would be very strange. ```python >>> import triton.language as tl >>> import inspect >>> inspect.ismodule(tl) True >>> inspect.ismodule(tl.core) True ```
<!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> I accidentally deleted a newline in triton-lang/triton#5487 which causes the [website](https://triton-lang.org/main/python-api/generated/triton.heuristics.html) to not have the code block... oops. Generated locally and can confirm the syntax is now correct. ![image](https://github.com/user-attachments/assets/0856b672-6951-4064-8c47-a0839f7b4a98) # New contributor declaration - [ ] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because this is a documentation change. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
Also create the `getThreadsPerWarpForOperand` interface. This PR doesn't completely remove the decompose method because `getThreadsPerWarpForOperand` hasn't been implemented yet for some AMD specific encodings
Signed-off-by: Whitney Tsang <[email protected]>
whitneywhtsang
changed the title
Merge OpenAI Triton commit
Merge OpenAI Triton commit Dec 26, 2024
755d416
f8b5301
anmyachev
approved these changes
Dec 26, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR change the Triton base from 755d416 to f8b5301 (Dec 25).
Pass rate: 99.03%
Please do not squash and merge this PR.