-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimal ordering with block mask #56
Comments
Yeah, this is a pretty fun idea :) I had previously played around with an idea like this using a
and so, this allows you to transform any existing For 2d images, uwu (on Discord) suggested trying a Morton curve, which could be a good alternative, since it's cheap to "compute" :) |
I think it is worth it if you can do the permutation once before a series of attention operations. That is pretty much the case in vision transformers with local windows. I also tried the Hilbert and Moore curves, but I haven't conducted a proper benchmark. |
The issue isn't necessarily that permuting the tokens itself is expensive, but rather that after the permutation you need to load the permutation index into the "inner loop" of the attention, which does offset some of the sparsity gains you can get. Why is why Morton curves were an interesting suggestion to me, since I think they're fairly cheaply computable "within" the kernel itself. |
I don't think you need to load the permutation index if you compute the Basically my idea was to find a permutation to minimize the number of (non-empty) blocks in a |
If you can guarantee that all of your non-empty blocks are "full" (i.e. non-masked at all), then you don't need to load the permutation index for those blocks. However, for the partially-masked blocks, you still need to load permutation index to compute the mask for those blocks. For example, this is NATTEN with a hilbert curve. |
I don't see where the permutation indices appear anymore after the |
Yes, that's what I mean. You must load from your column indices (which represent a permutation) in your inner loop. |
I think we are speaking of the same think in different terms, but I don't see how the column indices represent permutations. They are ordered (which allows faster access than random indexing) and target a subset of the full block. I agree that the subset is determined by the original permutation, but the indexing operation does not involve a permutation. Anyway, your NATTEN + Hilbert curve seems much more efficient than NATTEN alone! Do you still have the code to generate the permutation? I used a random Python library previously. |
as long as the positions of your latents aren't changing, there shouldnt be any case where you need to apply a permutation at score/mask mod time, unless you are relying on some non permutation equivariant function of q_idx and kv_idx. then:
There is one case where you would need to permute/unpermute the morton code in pytorch (for 3d, but you can also use it for 2d by passing in 3d coords with the last dim zerod) if anyone finds this helpful. sorting by this is decent but wont be as good an ordering as hilbert curve probably: def quantize_coords(coords: torch.Tensor, bits: int = 21):
max_int = (1 << bits) - 1
coords = coords.clamp(0, 1)
return (coords * max_int).long()
def split_by_3_bits_21(x: torch.Tensor):
x = (x | (x << 32)) & 0x1f00000000ffff
x = (x | (x << 16)) & 0x1f0000ff0000ff
x = (x | (x << 8)) & 0x100f00f00f00f00f
x = (x | (x << 4)) & 0x10c30c30c30c30c3
x = (x | (x << 2)) & 0x1249249249249249
return x
@torch.compile
def morton_encode(coords: torch.Tensor, bits: int = 21):
coords = quantize_coords(coords, bits)
x = split_by_3_bits_21(coords[..., 0])
y = split_by_3_bits_21(coords[..., 1]) << 1
z = split_by_3_bits_21(coords[..., 2]) << 2
morton_code = x | y | z
return morton_code |
From my understanding, flex attention (using
block_mask
) gets faster when the number of empty blocks is larger. If the inputs (Q, K, V) do not represent sequences, but graphs with local connectivity (e.g. pixels in an image) the ordering of the elements has a huge impact on the number of empty blocks.It would be very useful to add helpers to find optimal, or simply better, orderings given a mask. For example, for images, it is likely better to order the pixels by small patch (close to the attention window size), rather than the standard row-by-row order.
Note that this is related to the minimum degree algorithm.
The text was updated successfully, but these errors were encountered: