Skip to content

Commit

Permalink
add flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Sep 27, 2023
1 parent d15150f commit 1ef756d
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ transformers
accelerate
diffusers
bitsandbytes
flash-attn
9 changes: 9 additions & 0 deletions src/autotrain/cli/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def register_subcommand(parser: ArgumentParser):
"required": False,
"type": str,
},
{
"arg": "--use_flash_attention_2",
"help": "Use flash attention 2",
"required": False,
"action": "store_true",
"alias": ["--use-flash-attention-2", "--use-fa2"],
},
]
run_llm_parser = parser.add_parser("llm", description="✨ Run AutoTrain LLM")
for arg in arg_list:
Expand Down Expand Up @@ -364,6 +371,7 @@ def __init__(self, args):
"use_int8",
"use_int4",
"merge_adapter",
"use_flash_attention_2",
]
for arg_name in store_true_arg_names:
if getattr(self.args, arg_name) is None:
Expand Down Expand Up @@ -451,6 +459,7 @@ def run(self):
token=self.args.token,
merge_adapter=self.args.merge_adapter,
username=self.args.username,
use_flash_attention_2=self.args.use_flash_attention_2,
)

# space training
Expand Down
2 changes: 2 additions & 0 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,15 @@ def train(config):
torch_dtype=torch.float16,
device_map={"": Accelerator().process_index} if torch.cuda.is_available() else None,
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)
else:
model = AutoModelForCausalLM.from_pretrained(
config.model,
config=model_config,
use_auth_token=config.token,
trust_remote_code=True,
use_flash_attention_2=config.use_flash_attention_2,
)

model.resize_token_embeddings(len(tokenizer))
Expand Down
1 change: 1 addition & 0 deletions src/autotrain/trainers/clm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class LLMTrainingParams(BaseModel):
target_modules: str = Field(None, title="Target modules")
merge_adapter: bool = Field(False, title="Merge adapter")
username: str = Field(None, title="Hugging Face Username")
use_flash_attention_2: bool = Field(False, title="Use flash attention 2")

def save(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
Expand Down

0 comments on commit 1ef756d

Please sign in to comment.