Skip to content

Commit

Permalink
padding not needed, as not doing ring passing
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 24, 2024
1 parent 18a1f86 commit eb1c284
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 22 deletions.
12 changes: 4 additions & 8 deletions assert_zig_zag.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,19 @@ def start(

# zig zag

padded_zig_zag_input, remove_pad = zig_zag_pad_seq(zig_zag_input)
inp, gather_seq = zig_zag_shard(zig_zag_input, all_gather_batch = True)

padded_zig_zag_input, gather_seq = zig_zag_shard(padded_zig_zag_input, all_gather_batch = True)

qkv = attention.to_qkv(padded_zig_zag_input)
qkv = attention.to_qkv(inp)

q, k, v = rearrange(qkv, 'b n (h d) -> b h n d', d = dim_head).split(attention.qkv_head_breakdown, dim = -3)

o = zig_zag_attn(q, k, v)

o = rearrange(o, 'b h n d -> b n (h d)')

padded_zig_zag_out = attention.to_out(o)

padded_zig_zag_out = gather_seq(padded_zig_zag_out)
zig_zag_out = attention.to_out(o)

zig_zag_out = remove_pad(padded_zig_zag_out)
zig_zag_out = gather_seq(zig_zag_out)

zig_zag_out.mean().backward()

Expand Down
14 changes: 0 additions & 14 deletions ring_attention_pytorch/zig_zag_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,6 @@ def default(v, d):
def divisible_by(num, den):
return (num % den) == 0

# pad sequence to 2 x <world size> for sharding

def zig_zag_pad_seq(t):
seq_len = t.shape[-2]
chunks = 2 * get_world_size()

padded_seq_len = ceil(seq_len / chunks) * chunks
t = F.pad(t, (0, 0, 0, padded_seq_len - seq_len), value = 0.)

def inverse(out):
return out[..., :seq_len, :]

return t, inverse

# zig zag sharding and its inverse

def zig_zag_shard(t, all_gather_batch = False):
Expand Down

0 comments on commit eb1c284

Please sign in to comment.