Skip to content

Commit

Permalink
Update to the app (#354)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur authored Nov 21, 2023
1 parent 22944c0 commit ae5e9ca
Show file tree
Hide file tree
Showing 19 changed files with 939 additions and 1,283 deletions.
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ nltk==3.8.1
optuna==3.3.0
Pillow==10.0.0
protobuf==4.23.4
pydantic==1.10.11
sacremoses==0.0.53
scikit-learn==1.3.0
sentencepiece==0.1.99
Expand All @@ -20,7 +19,6 @@ werkzeug==2.3.6
xgboost==1.7.6
huggingface_hub>=0.16.4
requests==2.31.0
gradio==3.41.0
einops==0.6.1
invisible-watermark==0.2.0
packaging==23.1
Expand All @@ -35,4 +33,5 @@ diffusers==0.21.4
bitsandbytes==0.41.0
# extras
rouge_score==0.1.2
py7zr==0.20.6
py7zr==0.20.6
fastapi==0.104.1
15 changes: 6 additions & 9 deletions src/autotrain/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ def run_training():
params = json.loads(PARAMS)
logger.info(params)
if TASK_ID == 9:
try:
params = LLMTrainingParams.parse_raw(params)
except Exception:
params = LLMTrainingParams.parse_obj(params)
params = LLMTrainingParams.model_validate_json(params)
params.project_name = "/tmp/model"
params.save(output_dir=params.project_name)
cmd = ["accelerate", "launch", "--num_machines", "1", "--num_processes", "1"]
Expand All @@ -60,7 +57,7 @@ def run_training():
]
)
elif TASK_ID == 28:
params = Seq2SeqParams.parse_raw(params)
params = Seq2SeqParams.model_validate_json(params)
params.project_name = "/tmp/model"
params.save(output_dir=params.project_name)
cmd = ["accelerate", "launch", "--num_machines", "1", "--num_processes", "1"]
Expand All @@ -79,7 +76,7 @@ def run_training():
]
)
elif TASK_ID in (1, 2):
params = TextClassificationParams.parse_raw(params)
params = TextClassificationParams.model_validate_json(params)
params.project_name = "/tmp/model"
params.save(output_dir=params.project_name)
cmd = ["accelerate", "launch", "--num_machines", "1", "--num_processes", "1"]
Expand All @@ -98,7 +95,7 @@ def run_training():
]
)
elif TASK_ID in (13, 14, 15, 16, 26):
params = TabularParams.parse_raw(params)
params = TabularParams.model_validate_json(params)
params.project_name = "/tmp/model"
params.save(output_dir=params.project_name)
cmd = [
Expand All @@ -109,7 +106,7 @@ def run_training():
os.path.join(params.project_name, "training_params.json"),
]
elif TASK_ID == 27:
params = GenericParams.parse_raw(params)
params = GenericParams.model_validate_json(params)
params.project_name = "/tmp/model"
params.save(output_dir=params.project_name)
cmd = [
Expand All @@ -120,7 +117,7 @@ def run_training():
os.path.join(params.project_name, "training_params.json"),
]
elif TASK_ID == 25:
params = DreamBoothTrainingParams.parse_raw(params)
params = DreamBoothTrainingParams.model_validate_json(params)
params.project_name = "/tmp/model"
params.save(output_dir=params.project_name)
cmd = [
Expand Down
Loading

0 comments on commit ae5e9ca

Please sign in to comment.