diff --git a/tools/flash_attn/hopper.patch b/tools/flash_attn/hopper.patch index 584e8ff..b34a402 100644 --- a/tools/flash_attn/hopper.patch +++ b/tools/flash_attn/hopper.patch @@ -8,7 +8,7 @@ index f9f3cfd..132ce07 100644 def append_nvcc_threads(nvcc_extra_args): - return nvcc_extra_args + ["--threads", "4"] + nvcc_threads = os.getenv("NVCC_THREADS") or "4" -+ return nvcc_extra_args + ["--threads", NVCC_THREADS] ++ return nvcc_extra_args + ["--threads", nvcc_threads] cmdclass = {}