-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] Dynamic fp8 + rms_norm fusion #31
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
81ad334
to
7d1adbf
Compare
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
…ops to constants Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
7d1adbf
to
651ebdc
Compare
has_residual>( | ||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); | ||
} else { | ||
// FP8 - Do not invert s_token_scale for exact match with FBGemm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s_token_scale
-> token_scale
ss += x * x; | ||
} | ||
|
||
using BlockReduce = cub::BlockReduce<float, 1024>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the block_dim.x is defined as
dim3 block(std::min(hidden_size, 1024));
is it safe doing cub::BlockReduce<float,1024>
when block_dim.x is < 1024 ?
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] | ||
NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing | ||
HIDDEN_SIZES = [1, 2, 3, 4, 16, 64, 67, 768, 2048, 5120, 5137, 8192, | ||
8193] # Arbitrary values for testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably reduce the hidden_sizes to [1, 3, 4, 16, 64, 2048, 5120, 5137] + the vectorization edge-cases to save test times.
@@ -22,6 +22,7 @@ | |||
supports_moe_ops = False | |||
with contextlib.suppress(ImportError): | |||
import vllm._moe_C # noqa: F401 | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: whitespace changes, here and below
Reviewed the kernel files and kernel tests. Left some minor comments. LGTM otherwise. |
This PR cleans up the fusion pass to make it easier to add other multi-output patterns. Then it adds dynamic fp8 rmsnorm fusion.