Skip to content

Commit

Permalink
another researcher shared with me that complex valued states does not…
Browse files Browse the repository at this point in the history
… add anything. prepare gate loop operator without complex for use in another project
  • Loading branch information
lucidrains committed Nov 17, 2023
1 parent 3e237a9 commit d170c86
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion gateloop_transformer/gateloop_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ def forward(

# data gated linear attention with "gateloop operator"

def maybe_real(t):
if not torch.is_complex(t):
return t

return t.real

def gate_loop_operator(q, k, v, a):
"""
the pseudocode in section 3.2 of the paper
Expand All @@ -200,7 +206,7 @@ def binary_operator(a, b):
a_i, kv_i = a
a_j, kv_j = b

return a_j * a_i, a_j.real * kv_i + kv_j
return a_j * a_i, maybe_real(a_j) * kv_i + kv_j

_, kv = associative_scan(binary_operator, (a, kv))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'gateloop-transformer',
packages = find_packages(exclude=[]),
version = '0.0.22',
version = '0.0.23',
license='MIT',
description = 'GateLoop Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit d170c86

Please sign in to comment.