diff --git a/requirements.txt b/requirements.txt index 03401ef327..cdcada601d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,4 @@ transformers accelerate diffusers bitsandbytes +flash-attn \ No newline at end of file diff --git a/src/autotrain/cli/run_llm.py b/src/autotrain/cli/run_llm.py index 71cc89c796..c8816b5060 100644 --- a/src/autotrain/cli/run_llm.py +++ b/src/autotrain/cli/run_llm.py @@ -325,6 +325,13 @@ def register_subcommand(parser: ArgumentParser): "required": False, "type": str, }, + { + "arg": "--use_flash_attention_2", + "help": "Use flash attention 2", + "required": False, + "action": "store_true", + "alias": ["--use-flash-attention-2", "--use-fa2"], + }, ] run_llm_parser = parser.add_parser("llm", description="✨ Run AutoTrain LLM") for arg in arg_list: @@ -364,6 +371,7 @@ def __init__(self, args): "use_int8", "use_int4", "merge_adapter", + "use_flash_attention_2", ] for arg_name in store_true_arg_names: if getattr(self.args, arg_name) is None: @@ -451,6 +459,7 @@ def run(self): token=self.args.token, merge_adapter=self.args.merge_adapter, username=self.args.username, + use_flash_attention_2=self.args.use_flash_attention_2, ) # space training diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index 10fc12544a..d2b5693a2c 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -127,6 +127,7 @@ def train(config): torch_dtype=torch.float16, device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None, trust_remote_code=True, + use_flash_attention_2=config.use_flash_attention_2, ) else: model = AutoModelForCausalLM.from_pretrained( @@ -134,6 +135,7 @@ def train(config): config=model_config, use_auth_token=config.token, trust_remote_code=True, + use_flash_attention_2=config.use_flash_attention_2, ) model.resize_token_embeddings(len(tokenizer)) diff --git a/src/autotrain/trainers/clm/params.py b/src/autotrain/trainers/clm/params.py index 5d52272bbf..c3c8054831 100644 --- a/src/autotrain/trainers/clm/params.py +++ b/src/autotrain/trainers/clm/params.py @@ -44,6 +44,7 @@ class LLMTrainingParams(BaseModel): target_modules: str = Field(None, title="Target modules") merge_adapter: bool = Field(False, title="Merge adapter") username: str = Field(None, title="Hugging Face Username") + use_flash_attention_2: bool = Field(False, title="Use flash attention 2") def save(self, output_dir): os.makedirs(output_dir, exist_ok=True)