Skip to content

Commit

Permalink
Merge pull request #26 from lucidrains/pre-hook-autotuner
Browse files Browse the repository at this point in the history
use pre hook to fix in-place / autotuner issue
  • Loading branch information
lucidrains authored May 10, 2023
2 parents 2226ec8 + 3d1e555 commit 6629519
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 34 deletions.
10 changes: 2 additions & 8 deletions lion_pytorch/lion_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import Tuple, Optional, Callable

import torch
Expand Down Expand Up @@ -34,8 +33,7 @@ def __init__(
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
use_triton: bool = False,
triton_block_size: int = 1024
use_triton: bool = False
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
Expand All @@ -49,12 +47,10 @@ def __init__(
super().__init__(params, defaults)

self.update_fn = update_fn
self.use_triton = use_triton
self.took_first_step = False

if use_triton:
from lion_pytorch.triton import update_fn as triton_update_fn
self.update_fn = partial(triton_update_fn, BLOCK_SIZE = triton_block_size)
self.update_fn = triton_update_fn

@torch.no_grad()
def step(
Expand All @@ -67,8 +63,6 @@ def step(
with torch.enable_grad():
loss = closure()

# update all parameters

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

Expand Down
41 changes: 16 additions & 25 deletions lion_pytorch/triton.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from torch import Tensor

try:
import triton
Expand All @@ -8,18 +7,19 @@
print('triton is not installed, please install by running `pip install triton -U --pre`')
exit()

# helper functions
# clone param and exp_avg before autotuning takes place
# as those are updated in-place

def calc_num_warps(block_size):
num_warps = 4
if block_size >= 2048:
num_warps = 8
if block_size >= 4096:
num_warps = 16
return num_warps
def clone_inplace_updated_params(nargs):
nargs['p_ptr'] = nargs['p_ptr'].clone()
nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone()

# triton cuda kernel

@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
], key = ['n_elements'])
@triton.jit
def update_fn_kernel(
p_ptr,
Expand Down Expand Up @@ -80,35 +80,26 @@ def update_fn_kernel(
tl.store(offset_exp_avg_ptr, exp_avg, mask = mask)

def update_fn(
p: Tensor,
grad: Tensor,
exp_avg: Tensor,
p: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
lr: float,
wd: float,
beta1: float,
beta2: float,
inplace: bool = True,
BLOCK_SIZE: int = 1024
beta2: float
):
assert all([t.is_cuda for t in (p, grad, exp_avg)])

n_elements = p.numel()

block_size = triton.next_power_of_2(BLOCK_SIZE)
num_warps = calc_num_warps(block_size)
n_rows = triton.cdiv(n_elements, block_size)

# call triton cuda kernel
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

update_fn_kernel[(n_rows,)](
update_fn_kernel[grid](
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2,
n_elements,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE
n_elements
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'lion-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.2',
license='MIT',
description = 'Lion Optimizer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6629519

Please sign in to comment.