From 0781eb1f2edeb315002e3b2f05a3d28e3cdb74c3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 9 May 2023 18:52:14 -0700 Subject: [PATCH 1/3] Revert "actually, just follow @ipoletaev advice and remove autotuner for now" This reverts commit 2226ec8aeee03e9fbbf561e50fbf114b9677d3e9. --- lion_pytorch/lion_pytorch.py | 17 +++++++++----- lion_pytorch/triton.py | 43 +++++++++++++++++++----------------- setup.py | 2 +- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py index f96fec6..e025f89 100644 --- a/lion_pytorch/lion_pytorch.py +++ b/lion_pytorch/lion_pytorch.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Tuple, Optional, Callable import torch @@ -34,8 +33,7 @@ def __init__( lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, - use_triton: bool = False, - triton_block_size: int = 1024 + use_triton: bool = False ): assert lr > 0. assert all([0. <= beta <= 1. for beta in betas]) @@ -54,7 +52,7 @@ def __init__( if use_triton: from lion_pytorch.triton import update_fn as triton_update_fn - self.update_fn = partial(triton_update_fn, BLOCK_SIZE = triton_block_size) + self.update_fn = triton_update_fn @torch.no_grad() def step( @@ -67,6 +65,11 @@ def step( with torch.enable_grad(): loss = closure() + # address an issue with autotune and in-place updates with triton + # on the first .step call, simply do not update parameters in-place, if using triton + + update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict() + # update all parameters for group in self.param_groups: @@ -88,7 +91,11 @@ def step( lr, wd, beta1, - beta2 + beta2, + **update_kwargs ) + if not self.took_first_step: + self.took_first_step = True + return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index 0615dbc..ca59800 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -8,18 +8,12 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() -# helper functions - -def calc_num_warps(block_size): - num_warps = 4 - if block_size >= 2048: - num_warps = 8 - if block_size >= 4096: - num_warps = 16 - return num_warps - # triton cuda kernel +@triton.autotune(configs = [ + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), +], key = ['n_elements']) @triton.jit def update_fn_kernel( p_ptr, @@ -87,20 +81,25 @@ def update_fn( wd: float, beta1: float, beta2: float, - inplace: bool = True, - BLOCK_SIZE: int = 1024 + inplace: bool = True ): assert all([t.is_cuda for t in (p, grad, exp_avg)]) - n_elements = p.numel() - block_size = triton.next_power_of_2(BLOCK_SIZE) - num_warps = calc_num_warps(block_size) - n_rows = triton.cdiv(n_elements, block_size) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + # address autotune and in-place update issue + + if not inplace: + orig_p = p + orig_exp_avg = exp_avg + + p = p.clone() + exp_avg = exp_avg.clone() # call triton cuda kernel - update_fn_kernel[(n_rows,)]( + update_fn_kernel[grid]( p, grad, exp_avg, @@ -108,7 +107,11 @@ def update_fn( wd, beta1, beta2, - n_elements, - num_warps = num_warps, - BLOCK_SIZE = BLOCK_SIZE + n_elements ) + + # update if not in-place call + + if not inplace: + orig_p.copy_(p) + orig_exp_avg.copy_(exp_avg) diff --git a/setup.py b/setup.py index 35eb081..1eace6d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lion-pytorch', packages = find_packages(exclude=[]), - version = '0.1.0', + version = '0.0.8', license='MIT', description = 'Lion Optimizer - Pytorch', author = 'Phil Wang', From 2671a69efd8a3b4ff1043f83685a53fad92ce2c4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 9 May 2023 18:52:21 -0700 Subject: [PATCH 2/3] Revert "address an issue with triton auto-tuner and in-place calls. make the assumption that after the first optimizer.step call, things are properly cached" This reverts commit 6ab873a380b47ebc5ea6f68ea588606daebb8b85. --- lion_pytorch/lion_pytorch.py | 15 +-------------- lion_pytorch/triton.py | 28 ++++------------------------ setup.py | 2 +- 3 files changed, 6 insertions(+), 39 deletions(-) diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py index e025f89..0a6258a 100644 --- a/lion_pytorch/lion_pytorch.py +++ b/lion_pytorch/lion_pytorch.py @@ -47,8 +47,6 @@ def __init__( super().__init__(params, defaults) self.update_fn = update_fn - self.use_triton = use_triton - self.took_first_step = False if use_triton: from lion_pytorch.triton import update_fn as triton_update_fn @@ -65,13 +63,6 @@ def step( with torch.enable_grad(): loss = closure() - # address an issue with autotune and in-place updates with triton - # on the first .step call, simply do not update parameters in-place, if using triton - - update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict() - - # update all parameters - for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): @@ -91,11 +82,7 @@ def step( lr, wd, beta1, - beta2, - **update_kwargs + beta2 ) - if not self.took_first_step: - self.took_first_step = True - return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index ca59800..4cc35f0 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -1,5 +1,4 @@ import torch -from torch import Tensor try: import triton @@ -8,7 +7,6 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() -# triton cuda kernel @triton.autotune(configs = [ triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), @@ -74,31 +72,19 @@ def update_fn_kernel( tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) def update_fn( - p: Tensor, - grad: Tensor, - exp_avg: Tensor, + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, lr: float, wd: float, beta1: float, - beta2: float, - inplace: bool = True + beta2: float ): assert all([t.is_cuda for t in (p, grad, exp_avg)]) n_elements = p.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - # address autotune and in-place update issue - - if not inplace: - orig_p = p - orig_exp_avg = exp_avg - - p = p.clone() - exp_avg = exp_avg.clone() - - # call triton cuda kernel - update_fn_kernel[grid]( p, grad, @@ -109,9 +95,3 @@ def update_fn( beta2, n_elements ) - - # update if not in-place call - - if not inplace: - orig_p.copy_(p) - orig_exp_avg.copy_(exp_avg) diff --git a/setup.py b/setup.py index 1eace6d..dab77ac 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lion-pytorch', packages = find_packages(exclude=[]), - version = '0.0.8', + version = '0.0.7', license='MIT', description = 'Lion Optimizer - Pytorch', author = 'Phil Wang', From 3d1e555a52060ec67d7fd51d93890930b3039346 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 10 May 2023 07:43:10 -0700 Subject: [PATCH 3/3] attempt to fix autotune + inplace update issue --- lion_pytorch/triton.py | 12 ++++++++++-- setup.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py index 4cc35f0..ab5b08e 100644 --- a/lion_pytorch/triton.py +++ b/lion_pytorch/triton.py @@ -7,10 +7,18 @@ print('triton is not installed, please install by running `pip install triton -U --pre`') exit() +# clone param and exp_avg before autotuning takes place +# as those are updated in-place + +def clone_inplace_updated_params(nargs): + nargs['p_ptr'] = nargs['p_ptr'].clone() + nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone() + +# triton cuda kernel @triton.autotune(configs = [ - triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), - triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params), + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params), ], key = ['n_elements']) @triton.jit def update_fn_kernel( diff --git a/setup.py b/setup.py index dab77ac..ee9ddb1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lion-pytorch', packages = find_packages(exclude=[]), - version = '0.0.7', + version = '0.1.2', license='MIT', description = 'Lion Optimizer - Pytorch', author = 'Phil Wang',