-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extremely slow fp8 conv2d wgrad operation #103
Comments
Hi jimgao1, I noticed that you are using NHWC layout for both dy and X tensor for fp8 wgrad, which is a low performance configuration (https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#supported-graph-patterns : "All tensors can be in either NHWC or CHWN layout. In general, both dy and X tensors provide the best performance when they are in a CHWN layout."). |
Hi @yanqinz2, thanks for the suggestion! I tried setting the layout to CHWN but the performance remains to be suboptimal. Included below is my change as well as the profiling results for CHWN. Code change: auto X = graph->tensor(fe::graph::Tensor_attributes()
.set_name("image")
.set_dim({n, k, p, q})
// .set_stride({k * p * q, 1, k * p, k})
.set_stride({1, n * p * q, n * q, n})
.set_data_type(io_type));
auto DY = graph->tensor(fe::graph::Tensor_attributes()
.set_name("grad")
.set_dim({n, c, h, w})
// .set_stride({h * w * c, 1, w * c, c})
.set_stride({1, h * w * n, w * n, n})
.set_data_type(io_type));
Measurements for n = 64, h = 56, w = 56, k = 64, c = 64, r = 3, s = 3, stride = 1, padding = 1:
|
May I know if this issue is related to the:
item in the cuDNN release notes? |
It is not related with the item in the release note. It is actually an issue with heuristics mode A. We are actively working on this one. |
Describe the bug
fp8 e4m3 wgrad seems to be extremely slow compared to both FP32 and FP16, often 50x to 100x slower.
I have attached the profiling results in this Google spreadsheet.
I have tested a variety of problem sizes. For each size I have measured fp16 wgrad and fp8 wgrad with a number of different variants (wrt the IO/intermediate/compute data types).
Expected behavior
We expect fp8 wgrad operators to be at least as fast (if not faster) than its fp16 and fp32 counterparts.
System Environment (please complete the following information):
API logs
Both frontend and backend logs are attached in this gist.
To Reproduce
Compile and run the benchmarking script.
Command I used to compile is:
Additional context
This issue references this post on nvidia forums.
The text was updated successfully, but these errors were encountered: