Skip to content

Commit

Permalink
update configs
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed May 23, 2024
1 parent 027f4b0 commit 5e0160c
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 16 deletions.
3 changes: 1 addition & 2 deletions configs/llm_finetuning/gpt2_sft.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
task: llm
task: llm-sft
base_model: openai-community/gpt2
project_name: autotrain-gpt2-finetuned-guanaco
log: tensorboard
Expand All @@ -13,7 +13,6 @@ data:
text_column: text

params:
trainer: sft
block_size: 1024
model_max_length: 2048
max_prompt_length: 512
Expand Down
3 changes: 1 addition & 2 deletions configs/llm_finetuning/llama3-70b-orpo-v1.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
task: llm
task: llm-orpo
base_model: meta-llama/Meta-Llama-3-70B-Instruct
project_name: autotrain-llama3-70b-orpo-v1
log: tensorboard
Expand All @@ -15,7 +15,6 @@ data:
prompt_text_column: prompt

params:
trainer: orpo
block_size: 2048
model_max_length: 8192
max_prompt_length: 1024
Expand Down
3 changes: 1 addition & 2 deletions configs/llm_finetuning/llama3-70b-sft.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
task: llm
task: llm-sft
base_model: meta-llama/Meta-Llama-3-70B-Instruct
project_name: autotrain-llama3-70b-math-v1
log: tensorboard
Expand All @@ -13,7 +13,6 @@ data:
text_column: text

params:
trainer: sft
block_size: 2048
model_max_length: 8192
epochs: 2
Expand Down
3 changes: 1 addition & 2 deletions configs/llm_finetuning/llama3-8b-dpo-qlora.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
task: llm
task: llm-dpo
base_model: meta-llama/Meta-Llama-3-8B-Instruct
project_name: autotrain-llama3-8b-dpo-qlora
log: tensorboard
Expand All @@ -15,7 +15,6 @@ data:
prompt_text_column: prompt

params:
trainer: dpo
block_size: 1024
model_max_length: 2048
max_prompt_length: 512
Expand Down
3 changes: 1 addition & 2 deletions configs/llm_finetuning/llama3-8b-orpo-space.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
task: llm
task: llm-orpo
base_model: meta-llama/Meta-Llama-3-8B-Instruct
project_name: autotrain-llama3-8b-orpo-t1
log: tensorboard
Expand All @@ -15,7 +15,6 @@ data:
prompt_text_column: prompt

params:
trainer: orpo
block_size: 1024
model_max_length: 8192
max_prompt_length: 512
Expand Down
3 changes: 1 addition & 2 deletions configs/llm_finetuning/llama3-8b-orpo.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
task: llm
task: llm-orpo
base_model: meta-llama/Meta-Llama-3-8B-Instruct
project_name: autotrain-llama3-8b-orpo
log: tensorboard
Expand All @@ -15,7 +15,6 @@ data:
prompt_text_column: prompt

params:
trainer: orpo
block_size: 1024
model_max_length: 8192
max_prompt_length: 512
Expand Down
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

logger = Logger().get_logger()
__version__ = "0.7.106"
__version__ = "0.7.108"


def is_colab():
Expand Down
3 changes: 2 additions & 1 deletion src/autotrain/app/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,10 @@ def _munge_params_img_obj_det(self):
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["image_column"] = "autotrain_image"
_params["objects_column"] = "autotrain_label"
_params["objects_column"] = "autotrain_objects"
_params["valid_split"] = "validation"
else:

_params["image_column"] = self.column_mapping.get("image" if not self.api else "image_column", "image")
_params["objects_column"] = self.column_mapping.get(
"objects" if not self.api else "objects_column", "objects"
Expand Down
13 changes: 11 additions & 2 deletions src/autotrain/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ def __post_init__(self):
}
self.task_aliases = {
"llm": "lm_training",
"llm_training": "lm_training",
"llm_finetuning": "lm_training",
"llm-sft": "lm_training",
"llm-orpo": "lm_training",
"llm-generic": "lm_training",
"llm-dpo": "lm_training",
"llm-reward": "lm_training",
"dreambooth": "dreambooth",
"image_binary_classification": "image_multi_class_classification",
"image-binary-classification": "image_multi_class_classification",
Expand Down Expand Up @@ -123,6 +126,12 @@ def _parse_config(self):

if self.task == "lm_training":
params["chat_template"] = self.config["data"]["chat_template"]
if "-" in self.config["task"]:
params["trainer"] = self.config["task"].split("-")[1]
if params["trainer"] == "generic":
params["trainer"] = "default"
if params["trainer"] not in ["sft", "orpo", "dpo", "reward", "default"]:
raise ValueError("Invalid LLM training task")

if self.task != "dreambooth":
for k, v in self.config["data"]["column_mapping"].items():
Expand Down

0 comments on commit 5e0160c

Please sign in to comment.