diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..12631cbc27 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "candle-examples/examples/flash-attn/cutlass"] + path = candle-flash-attn/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/Cargo.toml b/Cargo.toml index 863a5ebcc4..6b08af83f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "candle-wasm-examples/whisper", ] exclude = [ + "candle-flash-attn", "candle-kernels", ] diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 787ea63ac3..3dbae7fc54 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -65,7 +65,7 @@ pub use dtype::{DType, IntDType, WithDType}; pub use error::{Error, Result}; pub use indexer::IndexOp; pub use layout::Layout; -pub use op::CustomOp1; +pub use op::{CustomOp1, CustomOp2, CustomOp3}; pub use shape::{Shape, D}; pub use storage::Storage; pub use strided_index::{StridedBlocks, StridedIndex}; diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 8c152dba4b..4a989760f8 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -14,6 +14,7 @@ readme = "README.md" candle = { path = "../candle-core" } candle-nn = { path = "../candle-nn" } candle-transformers = { path = "../candle-transformers" } +candle-flash-attn = { path = "../candle-flash-attn", optional = true } serde = { workspace = true } serde_json = { workspace = true } num-traits = { workspace = true } @@ -37,4 +38,5 @@ anyhow = { workspace = true } [features] default = [] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] +flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] diff --git a/candle-examples/build.rs b/candle-examples/build.rs index 119e5b0a50..e21f1767a8 100644 --- a/candle-examples/build.rs +++ b/candle-examples/build.rs @@ -6,11 +6,13 @@ use std::path::PathBuf; struct KernelDirectories { kernel_dir: &'static str, rust_target: &'static str, + include_dirs: &'static [&'static str], } const DIRS: [KernelDirectories; 1] = [KernelDirectories { kernel_dir: "examples/custom-ops/kernels/", rust_target: "examples/custom-ops/cuda_kernels.rs", + include_dirs: &[], }]; impl KernelDirectories { @@ -32,12 +34,15 @@ impl KernelDirectories { { let mut command = std::process::Command::new("nvcc"); let out_dir = ptx_file.parent().context("no parent for ptx file")?; + let include_dirs: Vec = + self.include_dirs.iter().map(|c| format!("-I{c}")).collect(); command .arg(format!("--gpu-architecture=sm_{compute_cap}")) .arg("--ptx") .args(["--default-stream", "per-thread"]) .args(["--output-directory", out_dir.to_str().unwrap()]) .arg(format!("-I/{}", self.kernel_dir)) + .args(include_dirs) .arg(cu_file); let output = command .spawn() @@ -221,6 +226,7 @@ fn compute_cap() -> Result { } println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { compute_cap = compute_cap_str .parse::() diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index ea3ee4d7f0..62bff0ae74 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -116,6 +116,9 @@ struct Args { #[arg(long)] v2: bool, + + #[arg(long)] + use_flash_attn: bool, } fn main() -> Result<()> { @@ -124,7 +127,7 @@ fn main() -> Result<()> { let args = Args::parse(); let device = candle_examples::device(args.cpu)?; - let config = Config::config_7b(); + let config = Config::config_7b(args.use_flash_attn); let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?; let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; let (llama, tokenizer_filename) = match args.npy { diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index b074e5cb9e..f82d0ac859 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -13,10 +13,11 @@ pub struct Config { pub n_head: usize, pub n_embd: usize, pub n_key_value_head: usize, + pub use_flash_attn: bool, } impl Config { - pub fn config_7b() -> Self { + pub fn config_7b(use_flash_attn: bool) -> Self { Self { hidden_size: 4096, intermediate_size: 11008, @@ -25,6 +26,7 @@ impl Config { n_head: 32, n_embd: 4096, n_key_value_head: 32, + use_flash_attn, } } } @@ -140,6 +142,17 @@ struct CausalSelfAttention { n_key_value_head: usize, head_dim: usize, cache: Cache, + use_flash_attn: bool, +} + +#[cfg(feature = "flash-attn")] +fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result { + unimplemented!("compile with '--features flash-attn'") } impl CausalSelfAttention { @@ -202,12 +215,17 @@ impl CausalSelfAttention { let k = self.repeat_kv(k)?; let v = self.repeat_kv(v)?; - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(D::Minus1)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; + + let y = if self.use_flash_attn { + flash_attn(&q, &k, &v)? + } else { + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = att.softmax(D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?)? + }; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = y.to_dtype(x_dtype)?; let y = self.o_proj.forward(&y)?; @@ -245,6 +263,7 @@ impl CausalSelfAttention { n_key_value_head: cfg.n_key_value_head, head_dim: cfg.hidden_size / cfg.n_head, cache: cache.clone(), + use_flash_attn: cfg.use_flash_attn, }) } } diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml new file mode 100644 index 0000000000..25201a0eb3 --- /dev/null +++ b/candle-flash-attn/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "candle-flash-attn" +version = "0.1.0" +edition = "2021" + +description = "Flash attention layer for the candle ML framework." +repository = "https://github.com/LaurentMazare/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT/Apache-2.0" +readme = "README.md" + +[dependencies] +candle = { path = "../candle-core", features = ["cuda"] } +half = { version = "2.3.1", features = ["num-traits"] } + +[build-dependencies] +anyhow = { version = "1", features = ["backtrace"] } diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs new file mode 100644 index 0000000000..dc5b82e84a --- /dev/null +++ b/candle-flash-attn/build.rs @@ -0,0 +1,182 @@ +#![allow(unused)] +use anyhow::{Context, Result}; +use std::io::Write; +use std::path::PathBuf; + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-changed=kernels/flash_fwd_hdim32_fp16_sm80.cu"); + println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); + println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); + println!("cargo:rerun-if-changed=kernels/flash.h"); + println!("cargo:rerun-if-changed=kernels/philox.cuh"); + println!("cargo:rerun-if-changed=kernels/softmax.h"); + println!("cargo:rerun-if-changed=kernels/utils.h"); + println!("cargo:rerun-if-changed=kernels/kernel_traits.h"); + println!("cargo:rerun-if-changed=kernels/block_info.h"); + println!("cargo:rerun-if-changed=kernels/static_switch.h"); + + let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; + let mut out_dir = PathBuf::from(out_dir); + // TODO: Getting up two levels avoid having to recompile this too often, however it's likely + // not a safe assumption. + out_dir.pop(); + out_dir.pop(); + set_cuda_include_dir()?; + let compute_cap = compute_cap()?; + + let mut command = std::process::Command::new("nvcc"); + let out_file = out_dir.join("libflashattention.a"); + + let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu"); + let should_compile = if out_file.exists() { + let out_modified = out_file.metadata()?.modified()?; + let in_modified = cu_file.metadata()?.modified()?; + in_modified.duration_since(out_modified).is_ok() + } else { + true + }; + if should_compile { + command + .arg(format!("--gpu-architecture=sm_{compute_cap}")) + .arg("--lib") + .args(["-o", out_file.to_str().unwrap()]) + .args(["--default-stream", "per-thread"]) + .arg("-Icutlass/include") + .arg("--expt-relaxed-constexpr") + .arg(cu_file); + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + if !output.status.success() { + anyhow::bail!( + "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + } + println!("cargo:rustc-link-search={}", out_dir.display()); + println!("cargo:rustc-link-lib=flashattention"); + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=stdc++"); + + /* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never + finishing to run for some reason. Calling nvcc manually worked fine. + cc::Build::new() + .cuda(true) + .include("cutlass/include") + .flag("--expt-relaxed-constexpr") + .flag("--default-stream") + .flag("per-thread") + .flag(&format!("--gpu-architecture=sm_{compute_cap}")) + .file("kernels/flash_fwd_hdim32_fp16_sm80.cu") + .compile("flashattn"); + */ + Ok(()) +} + +fn set_cuda_include_dir() -> Result<()> { + // NOTE: copied from cudarc build.rs. + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .map(std::env::var) + .filter_map(Result::ok) + .map(Into::::into); + + let roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let roots = roots.into_iter().map(Into::::into); + let root = env_vars + .chain(roots) + .find(|path| path.join("include").join("cuda.h").is_file()) + .context("cannot find include/cuda.h")?; + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +#[allow(unused)] +fn compute_cap() -> Result { + // Grab compute code from nvidia-smi + let mut compute_cap = { + let out = std::process::Command::new("nvidia-smi") + .arg("--query-gpu=compute_cap") + .arg("--format=csv") + .output() + .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; + let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; + let mut lines = out.lines(); + assert_eq!( + lines.next().context("missing line in stdout")?, + "compute_cap" + ); + let cap = lines + .next() + .context("missing line in stdout")? + .replace('.', ""); + cap.parse::() + .with_context(|| format!("cannot parse as int {cap}"))? + }; + + // Grab available GPU codes from nvcc and select the highest one + let max_nvcc_code = { + let out = std::process::Command::new("nvcc") + .arg("--list-gpu-code") + .output() + .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); + let out = std::str::from_utf8(&out.stdout).unwrap(); + + let out = out.lines().collect::>(); + let mut codes = Vec::with_capacity(out.len()); + for code in out { + let code = code.split('_').collect::>(); + if !code.is_empty() && code.contains(&"sm") { + if let Ok(num) = code[1].parse::() { + codes.push(num); + } + } + } + codes.sort(); + if !codes.contains(&compute_cap) { + anyhow::bail!( + "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}." + ); + } + *codes.last().unwrap() + }; + + // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc, + // then choose the highest gpu code in nvcc + if compute_cap > max_nvcc_code { + println!( + "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}." + ); + compute_cap = max_nvcc_code; + } + + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + compute_cap = compute_cap_str + .parse::() + .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?; + println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP"); + } + println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); + Ok(compute_cap) +} diff --git a/candle-flash-attn/cutlass b/candle-flash-attn/cutlass new file mode 160000 index 0000000000..c4f6b8c6bc --- /dev/null +++ b/candle-flash-attn/cutlass @@ -0,0 +1 @@ +Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933 diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h new file mode 100644 index 0000000000..94251a41e4 --- /dev/null +++ b/candle-flash-attn/kernels/block_info.h @@ -0,0 +1,41 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + + template + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) + { + } + + template + inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const uint32_t actual_seqlen_q; + const uint32_t actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h new file mode 100644 index 0000000000..be4ae0ca8f --- /dev/null +++ b/candle-flash-attn/kernels/flash.h @@ -0,0 +1,141 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +// #ifdef OLD_GENERATOR_PATH +// #include +// #else +// #include +// #endif +// +// #include + + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + int *__restrict__ blockmask; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Random state. + // at::PhiloxCudaState philox_args; + + bool is_bf16; + bool is_causal; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); + +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 0000000000..1e534d672d --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,92 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// // For dropout there might be a lot of register spilling? +// // These two are very slow due to register spilling +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // This one is slightly slower +// // run_flash_fwd>(params, stream); +// }); +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + + +extern "C" void run_mha( + void *q_ptr, + void *k_ptr, + void *v_ptr, + void *o_ptr, + + uint32_t q_batch_stride, + uint32_t k_batch_stride, + uint32_t v_batch_stride, + uint32_t q_row_stride, + uint32_t k_row_stride, + uint32_t v_row_stride, + uint32_t q_head_stride, + uint32_t k_head_stride, + uint32_t v_head_stride, + + uint32_t b, + uint32_t h, + uint32_t h_k, + uint32_t d, + uint32_t d_rounded, + float softmax_scale, + + uint32_t seqlen_q, + uint32_t seqlen_k, + uint32_t seqlen_q_rounded, + uint32_t seqlen_k_rounded, + + int is_causal +) { + Flash_fwd_params params; + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + // Set the pointers and strides. + params.q_ptr = q_ptr; + params.k_ptr = k_ptr; + params.v_ptr = v_ptr; + // All stride are in elements, not bytes. + params.q_row_stride = q_row_stride; + params.k_row_stride = k_row_stride; + params.v_row_stride = v_row_stride; + params.q_head_stride = q_head_stride; + params.k_head_stride = k_head_stride; + params.v_head_stride = v_head_stride; + params.o_ptr = o_ptr; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + params.is_causal = is_causal; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + cudaStream_t stream = 0; // Use the default stream. + run_mha_fwd_(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h new file mode 100644 index 0000000000..232dea0dbc --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_kernel.h @@ -0,0 +1,579 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "philox.cuh" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(Layout, Int>, + Stride<_1, Int> >{}, + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(Layout, Int>, + Stride<_1, Int> >{}, + // TODO: Shouldn't this be size<1>? + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, + Tensor2 &acc_o, float softmax_scale_log2) { + if (Is_first) { + flash::template reduce_max(scores, scores_max); + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + flash::reduce_sum(scores, scores_sum); + } else { + Tensor scores_max_prev = make_fragment_like(scores_max); + copy(scores_max, scores_max_prev); + flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + Tensor scores_sum_cur = make_fragment_like(scores_sum); + flash::reduce_sum(scores, scores_sum_cur); + #pragma unroll + for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_thr_copy_P +) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + Layout l = tOrP.layout(); + Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); + // TODO(laurent): reactivate the following + // CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); + #pragma unroll + for (int mi = 0; mi < size<1>(tPrP); ++mi) { + copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal) { + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); + auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // Copy rmem to smem + // // copy(tQrQ, tQsQ); + // flash::cp_async_wait<0>(); + // __syncthreads(); + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + } + + // auto seeds = at::cuda::philox::unpack(params.philox_args); + // unsigned long long seed = std::get<0>(seeds); + // unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + unsigned long long seed = 0; + unsigned long long offset = 0; + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal) { + if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + } else { + // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) + // static_assert(decltype(size<0>(taccScS))::value == 4); + // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. + // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + // Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout())); + // flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM); + // Idk why it's get<1> and not get<0> of the stride. + // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } + // I can't get the stride from idx_row + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + uint32_t block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor tOrP_copy = make_fragment_like(tOrP); + copy(tOrP, tOrP_copy); + flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); + } + if (Is_dropout) { + flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + // if (cute::thread0()) { print(tOrP); } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= 0) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= 0; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + uint32_t block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor tOrP_copy = make_fragment_like(tOrP); + copy(tOrP, tOrP_copy); + flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); + } + if (Is_dropout) { + flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor lse = make_fragment_like(scores_sum); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + + // if (cute::thread0()) { print(acc_o_rowcol); } + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + copy(smem_thr_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + copy(gmem_thr_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h new file mode 100644 index 0000000000..398ce0779d --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -0,0 +1,251 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +// #include + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) { + flash::compute_attn(params); +} + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check + // for cu_seqlens_q as well. + const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // if (smem_size >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + // C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 32; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 64; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if constexpr(!Is_dropout) { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 96; + // auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 128; + // auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 160; + // auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 192; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 224; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h new file mode 100644 index 0000000000..3468e4bffc --- /dev/null +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -0,0 +1,366 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKtransposed = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + static_assert(kBlockN >= 64); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = 64; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutAtomPdStransposed = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutAtomQdOtransposed = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + + using SmemLayoutAtomdKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); + static constexpr int kSmemdPsumCount = kBlockM; + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + static constexpr int kSmemSize = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + + kSmemdSSize + kSmemPSize; + + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype( + make_tiled_copy(Copy_Atom{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/candle-flash-attn/kernels/philox.cuh b/candle-flash-attn/kernels/philox.cuh new file mode 100644 index 0000000000..6ce1440f28 --- /dev/null +++ b/candle-flash-attn/kernels/philox.cuh @@ -0,0 +1,165 @@ +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h +#pragma once +// Philox CUDA. + +namespace flash { + +struct ull2 { + unsigned long long x; + unsigned long long y; +}; + +inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; +} + +inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { + constexpr unsigned long kPhiloxSA = 0xD2511F53; + constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; +} + +inline __device__ uint4 philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + constexpr unsigned long kPhilox10A = 0x9E3779B9; + constexpr unsigned long kPhilox10B = 0xBB67AE85; + uint2 key = reinterpret_cast(seed); + uint4 counter; + ull2 *tmp = reinterpret_cast(&counter); + tmp->x = offset; + tmp->y = subsequence; + #pragma unroll + for (int i = 0; i < 6; i++) { + counter = philox_single_round(counter, key); + key.x += (kPhilox10A); + key.y += (kPhilox10B); + } + uint4 output = philox_single_round(counter, key); + return output; +} + +} // namespace flash + +namespace { + +class Philox { +public: + __device__ inline Philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) + : STATE(0) + , seed_(seed) + , offset_(offset) + , key(reinterpret_cast(seed)) { + //key.x = (unsigned int)seed; + //key.y = (unsigned int)(seed >> 32); + //counter = make_uint4(0, 0, 0, 0); + //counter.z = (unsigned int)(subsequence); + //counter.w = (unsigned int)(subsequence >> 32); + //STATE = 0; + //incr_n(offset / 4); + + // key = reinterpret_cast(seed); + ull2 * tmp = reinterpret_cast(&counter); + tmp->x = offset / 4; + tmp->y = subsequence; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); + // } + } + __device__ inline uint4 operator()() { + // // if (STATE == 0) { + // uint4 counter_ = counter; + // uint2 key_ = key; + // // 7-round philox + // #pragma unroll + // for (int i = 0; i < 6; i++) { + // counter_ = flash::philox_single_round(counter_, key_); + // key_.x += (kPhilox10A); + // key_.y += (kPhilox10B); + // } + // // output = philox_single_round(counter_, key_); + // uint4 output = flash::philox_single_round(counter_, key_); + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); + // // } + // incr(); + // // } + // // return a float4 directly + // // unsigned long ret; + // // switch(STATE) { + // // case 0: ret = output.x; break; + // // case 1: ret = output.y; break; + // // case 2: ret = output.z; break; + // // case 3: ret = output.w; break; + // //} + // // STATE = (STATE + 1) % 4; + // return output; + return flash::philox(seed_, offset_, offset_); + } + +private: + unsigned long long offset_, seed_; + struct ull2 { + uint64_t x; + uint64_t y; + }; + uint4 counter; + // uint4 output; + const uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + + __device__ uint4 incr128 (uint4 ctr) + { + uint4 res; + asm ("add.cc.u32 %0, %4, %8;\n\t" + "addc.cc.u32 %1, %5, %9;\n\t" + "addc.cc.u32 %2, %6, %10;\n\t" + "addc.u32 %3, %7, %11;\n\t" + : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) + : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), + "n"(1), "n"(0), "n"(0), "n"(0)); + return res; + } + + __device__ inline void incr() { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + counter = incr128(counter); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + // static const unsigned long kPhiloxSA = 0xD2511F53; + // static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +} // namespace diff --git a/candle-flash-attn/kernels/softmax.h b/candle-flash-attn/kernels/softmax.h new file mode 100644 index 0000000000..3e9a7b4597 --- /dev/null +++ b/candle-flash-attn/kernels/softmax.h @@ -0,0 +1,272 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include +#include + +#include "philox.cuh" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor &tensor, const uint32_t max_seqlen_k) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const uint32_t lane_id = threadIdx.x % 32; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, + const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, + const uint32_t warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const uint32_t lane_id = threadIdx.x % 32; + // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; + const uint32_t row_idx_offset = row_idx_offset_; + const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const uint32_t row_idx = row_idx_base + i * 8; + const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const uint32_t col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const uint32_t col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor &tensor, Tensor const &idx_rowcol, + const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) +{ + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); + #pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, + unsigned long long seed, unsigned long long offset, + uint32_t block_row_start, uint32_t block_col_start, + uint32_t block_row_stride) { + // tensor has shape (8, MMA_M, MMA_N / 2) + using T = typename Engine::value_type; + auto encode_dropout = [](bool keep, T val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); + }; + static_assert(decltype(size<2>(tensor))::value % 2 == 0); + const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); + const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); + // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } + #pragma unroll + for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { + uint2 rowcol = make_uint2(block_row_start, block_col_start); + #pragma unroll + for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { + // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} + uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); + // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} + uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); + // Special implementation for 16-bit types: we duplicate the threshold to the + // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction + // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, + // and the high 16 bits will be either 0xffff or 0x0000, depending on whether + // the random value is less than the threshold. + // We then do a bit-wise AND between the mask and the original value (in 32-bit). + // We're exploiting the fact that floating point comparison is equivalent to integer + // comparison, since we're comparing unsigned integers whose top 8-bits are zero. + if (!encode_dropout_in_sign_bit + && (std::is_same::value || std::is_same::value)) { + uint16_t rnd_16[16]; + #pragma unroll + for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } + uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); + #pragma unroll + for (int j = 0; j < 2; j++) { + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + #pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t mask; + asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); + tensor_uint32(i) &= mask; + } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } else { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); + } + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); + // // } + } + } +} + +} // namespace flash diff --git a/candle-flash-attn/kernels/static_switch.h b/candle-flash-attn/kernels/static_switch.h new file mode 100644 index 0000000000..4aa8474028 --- /dev/null +++ b/candle-flash-attn/kernels/static_switch.h @@ -0,0 +1,66 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h new file mode 100644 index 0000000000..2221a2faf3 --- /dev/null +++ b/candle-flash-attn/kernels/utils.h @@ -0,0 +1,388 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t relu2(const uint32_t x); + +template<> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +inline __device__ uint32_t convert_relu2(const float2 x); + +template<> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template<> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float2 half2_unpack(uint32_t a); + +template <> +inline __device__ float2 half2_unpack<__half>(uint32_t a) { + return __half22float2(reinterpret_cast<__half2 (&)>(a)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { + return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert two half2's or bf162's into float, then take their dot product. +template +inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { + float2 af = flash::half2_unpack(a); + float2 bf = flash::half2_unpack(b); + return af.x * bf.x + af.y * bf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two vectors of 8 half's or bf16's into float, then take their dot product. +template +inline __device__ float hmulsum8(const uint4 a, const uint4 b) { + float sum; + sum = flash::hfma2_to_float(a.x, b.x); + sum += flash::hfma2_to_float(a.y, b.y); + sum += flash::hfma2_to_float(a.z, b.z); + sum += flash::hfma2_to_float(a.w, b.w); + return sum; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + get<0, 1>(l), + get<1, 1, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void relu_(Tensor &tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); + #pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +inline __device__ auto convert_type_relu(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); + #pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy(TiledCopy thr_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + copy(thr_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(thr_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(thr_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs new file mode 100644 index 0000000000..e2c1663b38 --- /dev/null +++ b/candle-flash-attn/src/ffi.rs @@ -0,0 +1,35 @@ +use core::ffi::{c_int, c_void}; + +extern "C" { + pub(crate) fn run_mha( + q_ptr: *const c_void, + k_ptr: *const c_void, + v_ptr: *const c_void, + o_ptr: *const c_void, + + q_batch_stride: u32, + k_batch_stride: u32, + v_batch_stride: u32, + q_row_stride: u32, + k_row_stride: u32, + v_row_stride: u32, + q_head_stride: u32, + k_head_stride: u32, + v_head_stride: u32, + + b: u32, + h: u32, + h_k: u32, + d: u32, + d_rounded: u32, + softmax_scale: f32, + + seqlen_q: u32, + seqlen_k: u32, + seqlen_q_rounded: u32, + seqlen_k_rounded: u32, + + is_causal: c_int, + ); + +} diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs new file mode 100644 index 0000000000..989e19051f --- /dev/null +++ b/candle-flash-attn/src/lib.rs @@ -0,0 +1,59 @@ +mod ffi; + +use candle::backend::BackendStorage; +use candle::cuda_backend::cudarc::driver::DevicePtr; +use candle::cuda_backend::WrapErr; +use candle::{CpuStorage, Error, Layout, Result, Shape}; +use half::f16; + +pub struct FlashHdim32Sm80; + +impl candle::CustomOp3 for FlashHdim32Sm80 { + fn name(&self) -> &'static str { + "flash-hdim32-sm80" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + Err(Error::Wrapped("no cpu support for flash-attn".into())) + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + _q_l: &Layout, + k: &candle::CudaStorage, + _k_l: &Layout, + v: &candle::CudaStorage, + _v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = q.device(); + let out_shape = Shape::from(&[1]); + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; + let elem_count = out_shape.elem_count(); + let dst = unsafe { dev.alloc::(elem_count) }.w()?; + + unsafe { + let q_ptr = *q.device_ptr() as *const core::ffi::c_void; + let k_ptr = *k.device_ptr() as *const core::ffi::c_void; + let v_ptr = *v.device_ptr() as *const core::ffi::c_void; + let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; + ffi::run_mha( + q_ptr, k_ptr, v_ptr, dst_ptr, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.0, 1, 1, + 1, 1, 1, + ) + } + + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone()); + Ok((dst, out_shape)) + } +}