Skip to content

Commit

Permalink
[HGEMM] Release toy-hgemm library 0.1.0 (#146)
Browse files Browse the repository at this point in the history
* refactor

* refactor

* refactor

* refactor

* refactor

* refactor

* refactor

* refactor

* refactor

* refactor

* refactor

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Nov 22, 2024
1 parent 6ea2eb9 commit 56e2fe9
Show file tree
Hide file tree
Showing 197 changed files with 222 additions and 202 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/issue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ jobs:
close-issue-message: "This issue was closed because it has been inactive for 7 days since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}
repo-token: ${{ secrets.GITHUB_TOKEN }}
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ __pycache__
*.ncu*
*.sqlite*
*.engine
*.bin
outupt
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[submodule "third-party/cutlass"]
path = third-party/cutlass
url = https://github.com/NVIDIA/cutlass.git

2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -671,4 +671,4 @@ into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.
<https://www.gnu.org/licenses/why-not-lgpl.html>.
286 changes: 143 additions & 143 deletions README.md

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion hgemm/tools/clear.sh

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
93 changes: 50 additions & 43 deletions hgemm/README.md → kernels/hgemm/README.md
Original file line number Diff line number Diff line change
@@ -1,56 +1,65 @@
# 🔥🔥Toy-HGEMM Library: Achieve the performance of cuBLAS
## 🔥🔥Toy-HGEMM Library: Achieve the performance of cuBLAS

|CUDA Cores|Sliced K(Loop over K)|Tile Block|Tile Thread|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
|**WMMA(m16n16k16)**|**MMA(m16n8k16)**|**Pack LDST(128 bits)**|**SMEM Padding**|
|WMMA(m16n16k16)|MMA(m16n8k16)|Pack LDST(128 bits)|SMEM Padding|
|✔️|✔️|✔️|✔️|
|**Copy Async**|**Tile MMA(More Threads)**|**Tile Warp(More Values)**|**Multi Stages**|
|Copy Async|Tile MMA(More Threads)|Tile Warp(More Values)|Multi Stages|
|✔️|✔️|✔️|✔️|
|**Reg Double Buffers**|**Block Swizzle**|**Warp Swizzle**|**Collective Store(Reg Reuse&Warp Shfl)**|
|Reg Double Buffers|Block Swizzle|Warp Swizzle|Collective Store(Warp Shfl)|
|✔️|✔️|✔️|✔️|
|**Row Major(NN)**|**Col Major(TN)**|**SGEMM TF32**|**SMEM Swizzle(CuTe)**|
|Row Major(NN)|Col Major(TN)|SGEMM TF32|SMEM Swizzle(CuTe)|
|✔️|✔️|✔️|✔️|

<details>
<summary> 🔑️ 点击查看所有支持的HGEMM Kernels! </summary>

- [X] hgemm_sliced_k_f16_kernel
- [X] hgemm_t_8x8_sliced_k_f16x4_kernel(unpack)
- [X] hgemm_t_8x8_sliced_k_f16x4_pack_kernel(pack 16x4)
- [X] hgemm_t_8x8_sliced_k_f16x4_bcf_kernel(bank conflicts reduce)
- [X] hgemm_t_8x8_sliced_k_f16x4_pack_bcf_kernel(bank conflicts reduce, pack)
- [X] hgemm_t_8x8_sliced_k_f16x8_pack_bcf_kernel(bank conflicts reduce, pack)
- [X] hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf_kernel(bank conflicts reduce, pack, double buffers)
- [X] hgemm_t_8x8_sliced_k16/32_f16x8_pack_bcf_dbuf_kernel(pack, double buffers)
- [X] hgemm_t_8x8_sliced_k16/32_f16x8_pack_bcf_dbuf_async_kernel(pack, double buffers, copy async)
- [X] hgemm_wmma_m16n16k16_naive(WMMA)
- [X] hgemm_wmma_m16n16k16_mma4x2(WMMA, Tile MMA)
- [X] hgemm_wmma_m16n16k16_mma4x2_warp2x4(TWMMA, Tile MMA/Warp, pack)
- [X] hgemm_wmma_m16n16k16_mma4x2_warp2x4_async(WMMA, Tile MMA/Warp, Copy Async)
- [X] hgemm_wmma_m16n16k16_mma4x2_warp2x4_async_offset(WMMA, Tile MMA/Warp, Copy Async, Pad)
- [X] hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(WMMA, Tile MMA/Warp, Copy Async, Double Buffers, Pad)
- [X] hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(WMMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle)
- [X] hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages(WMMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle)
- [X] hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages(WMMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle)
- [X] hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(WMMA, Tile MMA/Warp, Copy Async, Double Buffers, Pad)
- [X] hgemm_mma_m16n8k16_naive(MMA)
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4(MMA, Tile MMA/Warp, pack)
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle)
- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle, Warp swizzle, Reg Double Buffers, Collective Store with Reg Reuse & Warp Shuffle)
- [X] hgemm_mma_stages_block_swizzle_tn_cute(MMA, Tile MMA/Warp, Copy Async, Stages, Block Swizzle, SMEM Swizzle, Collective Store with SMEM)
- [X] PyTorch bindings

</details>
## 📖 HGEMM CUDA Kernels in Toy-HGEMM Library 🎉🎉

```C++
void hgemm_naive_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_pack(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x8_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_stages_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
```
## 安装
本仓库实现的HGEMM CUDA kernels可以作为一个python库toy-hgemm使用,安装命令如下。(可选)
## 📖 安装
本仓库实现的HGEMM可以作为一个python库使用(可选)
```bash
git submodule update --init --recursive --force
bash tools/install.sh # pip uninstall toy-hgemm 卸载
git submodule update --init --recursive --force # 更新cutlass, 必须
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall toy-hgemm -y 卸载
```

## 测试命令
## 📖 测试

**CUTLASS**: 更新CUTLASS依赖库
```bash
Expand Down Expand Up @@ -114,7 +123,7 @@ M N K = 16128 16128 16128, Time = 0.07319142 0.07320709 0.07326925 s, A
M N K = 16384 16384 16384, Time = 0.07668429 0.07669371 0.07670784 s, AVG Performance = 114.6912 Tflops
```

## 目前性能
## 📖 目前性能

### NVIDIA L20

Expand All @@ -132,8 +141,6 @@ M N K = 16384 16384 16384, Time = 0.07668429 0.07669371 0.07670784 s, A

![NVIDIA_L20_NN+TN+v2](https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99)

- WMMA: Up to 113.76 TFLOPS, 113.83/119.5=95.25% TFLOPS utilization, 113.83/116.25=97.91% cuBLAS performance.
- MMA: Up to 115.12 TFLOPS, 115.12/119.5=96.33% TFLOPS utilization, 115.12/116.25=99.03% cuBLAS performance.

全量MNK测试命令(提示: 每个MNK单独测试的性能数据更准确)
```bash
Expand Down Expand Up @@ -166,7 +173,7 @@ python3 hgemm.py --wmma-all --plot
```


## 性能优化笔记
## 📖 性能优化笔记

### PyTorch HGEMM Profile

Expand Down
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion hgemm/makefile → kernels/hgemm/makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
INCLUDE_DIRS=-I ../ -I ./utils -I ../third-party/cutlass/include -I ../third-party/cutlass/tools/util/include
INCLUDE_DIRS=-I ./utils -I ../../third-party/cutlass/include -I ../../third-party/cutlass/tools/util/include
default:
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.bin -O2 -arch=sm_89 -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
nvcc cublas/hgemm_cublas.cu -o hgemm_cublas.bin -O2 -arch=sm_89 -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion hgemm/setup.py → kernels/hgemm/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from tools.utils import (get_build_sources, get_build_cuda_cflags)

# package name managed by pip, which can be remove by `pip uninstall tiny_pkg`
# package name managed by pip, which can be remove by `pip uninstall toy-hgemm`
PACKAGE_NAME = "toy-hgemm"

ext_modules = []
Expand Down
5 changes: 5 additions & 0 deletions kernels/hgemm/tools/clear.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set -x

rm -rf __pycache__ build dist toy_hgemm.egg-info *.bin

set +x
4 changes: 4 additions & 0 deletions hgemm/tools/install.sh → kernels/hgemm/tools/install.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
set -x

git submodule update --init --recursive --force
python3 -m pip uninstall toy-hgemm -y
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl && cd -
rm -rf toy_hgemm.egg-info __pycache__

set +x
22 changes: 11 additions & 11 deletions hgemm/tools/utils.py → kernels/hgemm/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def get_build_sources():


def get_project_dir():
return os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
return os.path.dirname(os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))


def get_build_cuda_cflags(build_pkg: bool = False):
Expand Down Expand Up @@ -68,22 +68,22 @@ def get_build_cuda_cflags(build_pkg: bool = False):
extra_cuda_cflags.append("-Xptxas -v")
else:
extra_cuda_cflags.append("--ptxas-options=-v")
extra_cuda_cflags.append("--ptxas-options=-O2")
extra_cuda_cflags.append("--ptxas-options=-O3")
# extra cuda flags for cute hgemm
project_dir = get_project_dir()
extra_cuda_cflags.append('-DNO_MMA_HGEMM_BIN')
extra_cuda_cflags.append('-DNO_WMMA_HGEMM_BIN')
extra_cuda_cflags.append('-DNO_CUTE_HGEMM_BIN')
extra_cuda_cflags.append('-DNO_CUBLAS_HGEMM_BIN')
# add cutlass headers and link cublas.
extra_cuda_cflags.append(f'-I {project_dir}')
extra_cuda_cflags.append(f'-I {project_dir}/utils')
extra_cuda_cflags.append(f'-I {project_dir}/naive')
extra_cuda_cflags.append(f'-I {project_dir}/wmma')
extra_cuda_cflags.append(f'-I {project_dir}/mma')
extra_cuda_cflags.append(f'-I {project_dir}/cutlass')
extra_cuda_cflags.append(f'-I {project_dir}/cublas')
extra_cuda_cflags.append(f'-I {project_dir}/pybind')
extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm')
extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/utils')
extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/naive')
extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/wmma')
extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/mma')
extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/cutlass')
extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/cublas')
extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/pybind')
extra_cuda_cflags.append(f'-I {project_dir}/third-party/cutlass/include')
extra_cuda_cflags.append(f'-I {project_dir}/third-party/cutlass/tools/util/include')
extra_cuda_cflags.append('-lcublas')
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 2 additions & 0 deletions third-party/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@
build
*.whl
tmp
bin


0 comments on commit 56e2fe9

Please sign in to comment.