diff --git a/install.py b/install.py index b47188c..e2d51a4 100644 --- a/install.py +++ b/install.py @@ -59,7 +59,7 @@ def install_fa2(compile=False, ci=False): env = os.environ.copy() # limit max jobs to save memory in CI if ci: - env["MAX_JOBS"] = 4 + env["MAX_JOBS"] = "4" FA2_PATH = REPO_PATH.joinpath("submodules", "flash-attention") cmd = [sys.executable, "setup.py", "install"] subprocess.check_call(cmd, cwd=str(FA2_PATH.resolve()), env=env) @@ -74,7 +74,7 @@ def install_fa3(ci=False): env = os.environ.copy() # limit max jobs to save memory in CI if ci: - env["MAX_JOBS"] = 4 + env["MAX_JOBS"] = "4" cmd = [sys.executable, "setup.py", "install"] subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()), env=env)