Skip to content

Commit

Permalink
padding not needed, as not doing ring passing
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 24, 2024
1 parent 18a1f86 commit 84b05ee
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions assert_zig_zag.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,19 @@ def start(

# zig zag

padded_zig_zag_input, remove_pad = zig_zag_pad_seq(zig_zag_input)
inp, gather_seq = zig_zag_shard(zig_zag_input, all_gather_batch = True)

padded_zig_zag_input, gather_seq = zig_zag_shard(padded_zig_zag_input, all_gather_batch = True)

qkv = attention.to_qkv(padded_zig_zag_input)
qkv = attention.to_qkv(inp)

q, k, v = rearrange(qkv, 'b n (h d) -> b h n d', d = dim_head).split(attention.qkv_head_breakdown, dim = -3)

o = zig_zag_attn(q, k, v)

o = rearrange(o, 'b h n d -> b n (h d)')

padded_zig_zag_out = attention.to_out(o)

padded_zig_zag_out = gather_seq(padded_zig_zag_out)
zig_zag_out = attention.to_out(o)

zig_zag_out = remove_pad(padded_zig_zag_out)
zig_zag_out = gather_seq(zig_zag_out)

zig_zag_out.mean().backward()

Expand Down

0 comments on commit 84b05ee

Please sign in to comment.