Skip to content

Commit

Permalink
ci fix and removing unused parameter and adding check if pvc exists a…
Browse files Browse the repository at this point in the history
…lready
  • Loading branch information
deepanker13 committed Jan 11, 2024
1 parent 1034403 commit f520329
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
4 changes: 2 additions & 2 deletions sdk/python/kubeflow/trainer/hf_llm_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json


def setup_model_and_tokenizer(model_uri, transformer_type, model_dir, train_args):
def setup_model_and_tokenizer(model_uri, transformer_type, model_dir):
# Set up the model and tokenizer
parsed_uri = urlparse(model_uri)
model_name = parsed_uri.netloc + parsed_uri.path
Expand Down Expand Up @@ -115,7 +115,7 @@ def parse_arguments():
args = parse_arguments()
train_args = TrainingArguments(**json.loads(args.training_parameters))
model, tokenizer = setup_model_and_tokenizer(
args.model_uri, args.transformer_type, args.model_dir, train_args
args.model_uri, args.transformer_type, args.model_dir
)
train_data, eval_data = load_and_preprocess_data(
args.dataset_name, args.dataset_dir, args.transformer_type, tokenizer
Expand Down
11 changes: 10 additions & 1 deletion sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,16 @@ def train(
),
)
except Exception as e:
raise RuntimeError("failed to create pvc")
pvc_list = self.core_api.list_namespaced_persistent_volume_claim(namespace)
# Check if the PVC with the specified name exists
for pvc in pvc_list.items:
if pvc.metadata.name == constants.TRAINER_PVC_NAME:
print(
f"PVC '{constants.TRAINER_PVC_NAME}' already exists in namespace '{namespace}'."
)
break
else:
raise RuntimeError("failed to create pvc")

if isinstance(model_provider_parameters, HuggingFaceModelParams):
mp = "hf"
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kubeflow/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def get_pytorchjob_template(
master_pod_template_spec: models.V1PodTemplateSpec = None,
worker_pod_template_spec: models.V1PodTemplateSpec = None,
num_worker_replicas: Optional[int] = None,
num_procs_per_worker: Optional[int] = None,
num_procs_per_worker: Optional[int] = 0,
elastic_policy: Optional[models.KubeflowOrgV1ElasticPolicy] = None,
):
# Check if at least one replica is set.
Expand Down

0 comments on commit f520329

Please sign in to comment.