Skip to content

Commit

Permalink
Add flash attention (huggingface#241)
Browse files Browse the repository at this point in the history
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
  • Loading branch information
LaurentMazare authored Jul 26, 2023
1 parent c97d512 commit d9f9c85
Show file tree
Hide file tree
Showing 22 changed files with 2,699 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "candle-examples/examples/flash-attn/cutlass"]
path = candle-flash-attn/cutlass
url = https://github.com/NVIDIA/cutlass.git
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"candle-wasm-examples/whisper",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
]

Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub use dtype::{DType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use layout::Layout;
pub use op::CustomOp1;
pub use op::{CustomOp1, CustomOp2, CustomOp3};
pub use shape::{Shape, D};
pub use storage::Storage;
pub use strided_index::{StridedBlocks, StridedIndex};
Expand Down
2 changes: 2 additions & 0 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ readme = "README.md"
candle = { path = "../candle-core" }
candle-nn = { path = "../candle-nn" }
candle-transformers = { path = "../candle-transformers" }
candle-flash-attn = { path = "../candle-flash-attn", optional = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
Expand All @@ -37,4 +38,5 @@ anyhow = { workspace = true }
[features]
default = []
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
6 changes: 6 additions & 0 deletions candle-examples/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ use std::path::PathBuf;
struct KernelDirectories {
kernel_dir: &'static str,
rust_target: &'static str,
include_dirs: &'static [&'static str],
}

const DIRS: [KernelDirectories; 1] = [KernelDirectories {
kernel_dir: "examples/custom-ops/kernels/",
rust_target: "examples/custom-ops/cuda_kernels.rs",
include_dirs: &[],
}];

impl KernelDirectories {
Expand All @@ -32,12 +34,15 @@ impl KernelDirectories {
{
let mut command = std::process::Command::new("nvcc");
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
let include_dirs: Vec<String> =
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", out_dir.to_str().unwrap()])
.arg(format!("-I/{}", self.kernel_dir))
.args(include_dirs)
.arg(cu_file);
let output = command
.spawn()
Expand Down Expand Up @@ -221,6 +226,7 @@ fn compute_cap() -> Result<usize> {
}

println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");

if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
Expand Down
5 changes: 4 additions & 1 deletion candle-examples/examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ struct Args {

#[arg(long)]
v2: bool,

#[arg(long)]
use_flash_attn: bool,
}

fn main() -> Result<()> {
Expand All @@ -124,7 +127,7 @@ fn main() -> Result<()> {
let args = Args::parse();

let device = candle_examples::device(args.cpu)?;
let config = Config::config_7b();
let config = Config::config_7b(args.use_flash_attn);
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let (llama, tokenizer_filename) = match args.npy {
Expand Down
33 changes: 26 additions & 7 deletions candle-examples/examples/llama/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ pub struct Config {
pub n_head: usize,
pub n_embd: usize,
pub n_key_value_head: usize,
pub use_flash_attn: bool,
}

impl Config {
pub fn config_7b() -> Self {
pub fn config_7b(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
Expand All @@ -25,6 +26,7 @@ impl Config {
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
use_flash_attn,
}
}
}
Expand Down Expand Up @@ -140,6 +142,17 @@ struct CausalSelfAttention {
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
}

#[cfg(feature = "flash-attn")]
fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80)
}

#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}

impl CausalSelfAttention {
Expand Down Expand Up @@ -202,12 +215,17 @@ impl CausalSelfAttention {

let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;

let y = if self.use_flash_attn {
flash_attn(&q, &k, &v)?
} else {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?
};
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = y.to_dtype(x_dtype)?;
let y = self.o_proj.forward(&y)?;
Expand Down Expand Up @@ -245,6 +263,7 @@ impl CausalSelfAttention {
n_key_value_head: cfg.n_key_value_head,
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
})
}
}
Expand Down
18 changes: 18 additions & 0 deletions candle-flash-attn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "candle-flash-attn"
version = "0.1.0"
edition = "2021"

description = "Flash attention layer for the candle ML framework."
repository = "https://github.com/LaurentMazare/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
readme = "README.md"

[dependencies]
candle = { path = "../candle-core", features = ["cuda"] }
half = { version = "2.3.1", features = ["num-traits"] }

[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
182 changes: 182 additions & 0 deletions candle-flash-attn/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::io::Write;
use std::path::PathBuf;

fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=kernels/flash_fwd_hdim32_fp16_sm80.cu");
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
println!("cargo:rerun-if-changed=kernels/flash.h");
println!("cargo:rerun-if-changed=kernels/philox.cuh");
println!("cargo:rerun-if-changed=kernels/softmax.h");
println!("cargo:rerun-if-changed=kernels/utils.h");
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
println!("cargo:rerun-if-changed=kernels/block_info.h");
println!("cargo:rerun-if-changed=kernels/static_switch.h");

let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
let mut out_dir = PathBuf::from(out_dir);
// TODO: Getting up two levels avoid having to recompile this too often, however it's likely
// not a safe assumption.
out_dir.pop();
out_dir.pop();
set_cuda_include_dir()?;
let compute_cap = compute_cap()?;

let mut command = std::process::Command::new("nvcc");
let out_file = out_dir.join("libflashattention.a");

let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu");
let should_compile = if out_file.exists() {
let out_modified = out_file.metadata()?.modified()?;
let in_modified = cu_file.metadata()?.modified()?;
in_modified.duration_since(out_modified).is_ok()
} else {
true
};
if should_compile {
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--lib")
.args(["-o", out_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
println!("cargo:rustc-link-search={}", out_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");

/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
finishing to run for some reason. Calling nvcc manually worked fine.
cc::Build::new()
.cuda(true)
.include("cutlass/include")
.flag("--expt-relaxed-constexpr")
.flag("--default-stream")
.flag("per-thread")
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
.compile("flashattn");
*/
Ok(())
}

fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);

let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}

#[allow(unused)]
fn compute_cap() -> Result<usize> {
// Grab compute code from nvidia-smi
let mut compute_cap = {
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};

// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();

let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
};

// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}

println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}
1 change: 1 addition & 0 deletions candle-flash-attn/cutlass
Submodule cutlass added at c4f6b8
Loading

0 comments on commit d9f9c85

Please sign in to comment.