Skip to content

Commit

Permalink
take care of caching for simple gateloop
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 13, 2023
1 parent 652e8b6 commit f08fc62
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
41 changes: 32 additions & 9 deletions gateloop_transformer/simplified_gate_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@

from typing import Tuple

from einops import rearrange
from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange

from gateloop_transformer.gateloop_transformer import RMSNorm
from gateloop_transformer.associative_scan import associative_scan

# plain pytorch non-fused associative scan

def gate_loop_operator(q, kv, a):
def exists(v):
return v is not None

def gate_loop_operator(q, kv, a, cache = None):

@torch.jit.script
def binary_operator(
Expand All @@ -23,9 +26,18 @@ def binary_operator(
a_j, kv_j = b
return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)

_, kv = associative_scan(binary_operator, (a, kv))
if exists(cache):
cache_a, cache_kv = cache
a, a_ps = pack([cache_a, a], 'b * d')
kv, kv_ps = pack([cache_kv, kv], 'b * d')

a, kv = associative_scan(binary_operator, (a, kv))

return q * kv
if exists(cache):
_, a = unpack(a, a_ps, 'b * d')
_, kv = unpack(kv, kv_ps, 'b * d')

return q * kv, (a[:, -1], kv[:, -1])

# using jax associative scan

Expand All @@ -48,7 +60,7 @@ def binary_operator(e_i, e_j):

return q * y

return jax2torch(jax_gate_loop_operator)
return jax2torch(jax_gate_loop_operator), None

# simple gate loop layer

Expand All @@ -75,6 +87,8 @@ def __init__(
Rearrange('b n (qkva d) -> qkva (b d) n 1', qkva = 3)
)

self.use_jax = use_jax_associative_scan

if use_jax_associative_scan:
self.gate_loop_fn = get_jax_gate_loop_operator()
else:
Expand All @@ -84,20 +98,29 @@ def __init__(

self.reverse = reverse

def forward(self, x):

def forward(
self,
x,
cache = None,
return_cache = False
):
if self.reverse:
x = torch.flip(x, dims = (-2,))

x = self.norm(x)

q, kv, a = self.to_qkva(x)

out = self.gate_loop_fn(q, kv, a.sigmoid())
out, cache = self.gate_loop_fn(q, kv, a.sigmoid(), cache = cache)

out = self.split_heads(out)

if self.reverse:
out = torch.flip(out, dims = (-2,))

return out
if not return_cache:
assert not self.reverse, 'caching only works with non-reversed seq'
assert not self.use_jax, 'jax associative scan does not have caching yet'
return out

return out, cache
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'gateloop-transformer',
packages = find_packages(exclude=[]),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'GateLoop Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit f08fc62

Please sign in to comment.