From f08fc6291f2b492f168d3d0b9c7f4121d5aa4d6e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 13 Dec 2023 09:33:07 -0800 Subject: [PATCH] take care of caching for simple gateloop --- gateloop_transformer/simplified_gate_loop.py | 41 +++++++++++++++----- setup.py | 2 +- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/gateloop_transformer/simplified_gate_loop.py b/gateloop_transformer/simplified_gate_loop.py index a7a1a00..91c293b 100644 --- a/gateloop_transformer/simplified_gate_loop.py +++ b/gateloop_transformer/simplified_gate_loop.py @@ -4,7 +4,7 @@ from typing import Tuple -from einops import rearrange +from einops import rearrange, pack, unpack from einops.layers.torch import Rearrange from gateloop_transformer.gateloop_transformer import RMSNorm @@ -12,7 +12,10 @@ # plain pytorch non-fused associative scan -def gate_loop_operator(q, kv, a): +def exists(v): + return v is not None + +def gate_loop_operator(q, kv, a, cache = None): @torch.jit.script def binary_operator( @@ -23,9 +26,18 @@ def binary_operator( a_j, kv_j = b return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i) - _, kv = associative_scan(binary_operator, (a, kv)) + if exists(cache): + cache_a, cache_kv = cache + a, a_ps = pack([cache_a, a], 'b * d') + kv, kv_ps = pack([cache_kv, kv], 'b * d') + + a, kv = associative_scan(binary_operator, (a, kv)) - return q * kv + if exists(cache): + _, a = unpack(a, a_ps, 'b * d') + _, kv = unpack(kv, kv_ps, 'b * d') + + return q * kv, (a[:, -1], kv[:, -1]) # using jax associative scan @@ -48,7 +60,7 @@ def binary_operator(e_i, e_j): return q * y - return jax2torch(jax_gate_loop_operator) + return jax2torch(jax_gate_loop_operator), None # simple gate loop layer @@ -75,6 +87,8 @@ def __init__( Rearrange('b n (qkva d) -> qkva (b d) n 1', qkva = 3) ) + self.use_jax = use_jax_associative_scan + if use_jax_associative_scan: self.gate_loop_fn = get_jax_gate_loop_operator() else: @@ -84,8 +98,12 @@ def __init__( self.reverse = reverse - def forward(self, x): - + def forward( + self, + x, + cache = None, + return_cache = False + ): if self.reverse: x = torch.flip(x, dims = (-2,)) @@ -93,11 +111,16 @@ def forward(self, x): q, kv, a = self.to_qkva(x) - out = self.gate_loop_fn(q, kv, a.sigmoid()) + out, cache = self.gate_loop_fn(q, kv, a.sigmoid(), cache = cache) out = self.split_heads(out) if self.reverse: out = torch.flip(out, dims = (-2,)) - return out + if not return_cache: + assert not self.reverse, 'caching only works with non-reversed seq' + assert not self.use_jax, 'jax associative scan does not have caching yet' + return out + + return out, cache diff --git a/setup.py b/setup.py index 2fbf167..260b394 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'gateloop-transformer', packages = find_packages(exclude=[]), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'GateLoop Transformer', author = 'Phil Wang',