Skip to content

Commit

Permalink
Merge pull request #5 from soumik12345/wandb-dev
Browse files Browse the repository at this point in the history
remove changes from deprecated files
  • Loading branch information
rishiraj authored Oct 22, 2023
2 parents be11810 + 03d774f commit af4b265
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/autotrain/trainers/lm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def group_texts(examples):
logging_steps = int(0.2 * len(valid_data) / job_config.train_batch_size)
if logging_steps == 0:
logging_steps = 1

training_args = dict(
output_dir=model_path,
per_device_train_batch_size=job_config.train_batch_size,
Expand All @@ -400,7 +400,7 @@ def group_texts(examples):
save_strategy="epoch",
disable_tqdm=not bool(os.environ.get("ENABLE_TQDM", 0)),
gradient_accumulation_steps=job_config.gradient_accumulation_steps,
report_to=job_config.log,
report_to="none",
auto_find_batch_size=True,
lr_scheduler_type=job_config.scheduler,
optim=job_config.optimizer,
Expand Down Expand Up @@ -465,4 +465,4 @@ def group_texts(examples):
model_repo.git_pull()
model_repo.git_add()
model_repo.git_commit(commit_message="Commit From AutoTrain")
model_repo.git_push()
model_repo.git_push()
6 changes: 3 additions & 3 deletions src/autotrain/trainers/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def train(co2_tracker, payload, huggingface_token, model_path):
fp16 = True
if model_config.model_type in FP32_MODELS or device == "cpu":
fp16 = False

training_args = dict(
output_dir="/tmp/autotrain",
per_device_train_batch_size=job_config.train_batch_size,
Expand All @@ -219,7 +219,7 @@ def train(co2_tracker, payload, huggingface_token, model_path):
save_strategy="epoch",
disable_tqdm=not bool(os.environ.get("ENABLE_TQDM", 0)),
gradient_accumulation_steps=job_config.gradient_accumulation_steps,
report_to=job_config.log,
report_to="none",
auto_find_batch_size=True,
lr_scheduler_type=job_config.scheduler,
optim=job_config.optimizer,
Expand Down Expand Up @@ -274,4 +274,4 @@ def train(co2_tracker, payload, huggingface_token, model_path):
model_repo.git_pull()
model_repo.git_add()
model_repo.git_commit(commit_message="Commit From AutoTrain")
model_repo.git_push()
model_repo.git_push()

0 comments on commit af4b265

Please sign in to comment.