diff --git a/sdk/python/kubeflow/trainer/hf_llm_training.py b/sdk/python/kubeflow/trainer/hf_llm_training.py index 26c48c08dd..c39c547c83 100644 --- a/sdk/python/kubeflow/trainer/hf_llm_training.py +++ b/sdk/python/kubeflow/trainer/hf_llm_training.py @@ -16,7 +16,7 @@ import json -def setup_model_and_tokenizer(model_uri, transformer_type, model_dir, train_args): +def setup_model_and_tokenizer(model_uri, transformer_type, model_dir): # Set up the model and tokenizer parsed_uri = urlparse(model_uri) model_name = parsed_uri.netloc + parsed_uri.path @@ -115,7 +115,7 @@ def parse_arguments(): args = parse_arguments() train_args = TrainingArguments(**json.loads(args.training_parameters)) model, tokenizer = setup_model_and_tokenizer( - args.model_uri, args.transformer_type, args.model_dir, train_args + args.model_uri, args.transformer_type, args.model_dir ) train_data, eval_data = load_and_preprocess_data( args.dataset_name, args.dataset_dir, args.transformer_type, tokenizer diff --git a/sdk/python/kubeflow/training/api/training_client.py b/sdk/python/kubeflow/training/api/training_client.py index a8187de7e0..fe62e0c271 100644 --- a/sdk/python/kubeflow/training/api/training_client.py +++ b/sdk/python/kubeflow/training/api/training_client.py @@ -171,7 +171,16 @@ def train( ), ) except Exception as e: - raise RuntimeError("failed to create pvc") + pvc_list = self.core_api.list_namespaced_persistent_volume_claim(namespace) + # Check if the PVC with the specified name exists + for pvc in pvc_list.items: + if pvc.metadata.name == constants.TRAINER_PVC_NAME: + print( + f"PVC '{constants.TRAINER_PVC_NAME}' already exists in namespace '{namespace}'." + ) + break + else: + raise RuntimeError("failed to create pvc") if isinstance(model_provider_parameters, HuggingFaceModelParams): mp = "hf" diff --git a/sdk/python/kubeflow/training/utils/utils.py b/sdk/python/kubeflow/training/utils/utils.py index 655839225b..d4a9a0e011 100644 --- a/sdk/python/kubeflow/training/utils/utils.py +++ b/sdk/python/kubeflow/training/utils/utils.py @@ -300,7 +300,7 @@ def get_pytorchjob_template( master_pod_template_spec: models.V1PodTemplateSpec = None, worker_pod_template_spec: models.V1PodTemplateSpec = None, num_worker_replicas: Optional[int] = None, - num_procs_per_worker: Optional[int] = None, + num_procs_per_worker: Optional[int] = 0, elastic_policy: Optional[models.KubeflowOrgV1ElasticPolicy] = None, ): # Check if at least one replica is set.