From b0d6de9a9460ff495cc6144a9cbd1c4046a4282e Mon Sep 17 00:00:00 2001 From: "Hezhi (Helen) Xie" Date: Wed, 28 Aug 2024 09:24:04 -0700 Subject: [PATCH] [SDK] Fix trainer error: Update the version of base image and add "num_labels" for downloading pretrained models (#2230) * fix trainer error Signed-off-by: helenxie-bit * rerun tests Signed-off-by: helenxie-bit * update the process of num_labels in trainer Signed-off-by: helenxie-bit * rerun tests Signed-off-by: helenxie-bit * adjust the default value of 'num_labels' Signed-off-by: helenxie-bit --------- Signed-off-by: helenxie-bit Signed-off-by: Andrey Velichkevich --- .../storage_initializer/hugging_face.py | 1 + sdk/python/kubeflow/trainer/Dockerfile | 2 +- .../kubeflow/trainer/hf_llm_training.py | 26 +++++++++++++------ .../kubeflow/training/api/training_client.py | 2 ++ 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/sdk/python/kubeflow/storage_initializer/hugging_face.py b/sdk/python/kubeflow/storage_initializer/hugging_face.py index 4b5b0794a9..33be724cf0 100644 --- a/sdk/python/kubeflow/storage_initializer/hugging_face.py +++ b/sdk/python/kubeflow/storage_initializer/hugging_face.py @@ -38,6 +38,7 @@ class HuggingFaceModelParams: model_uri: str transformer_type: TRANSFORMER_TYPES access_token: str = None + num_labels: Optional[int] = None def __post_init__(self): # Custom checks or validations can be added here diff --git a/sdk/python/kubeflow/trainer/Dockerfile b/sdk/python/kubeflow/trainer/Dockerfile index d0ebee4aa3..6b98e3de31 100644 --- a/sdk/python/kubeflow/trainer/Dockerfile +++ b/sdk/python/kubeflow/trainer/Dockerfile @@ -1,5 +1,5 @@ # Use an official Pytorch runtime as a parent image -FROM nvcr.io/nvidia/pytorch:23.10-py3 +FROM nvcr.io/nvidia/pytorch:24.06-py3 # Set the working directory in the container WORKDIR /app diff --git a/sdk/python/kubeflow/trainer/hf_llm_training.py b/sdk/python/kubeflow/trainer/hf_llm_training.py index 26dd4fbe0e..5b3a4360fb 100644 --- a/sdk/python/kubeflow/trainer/hf_llm_training.py +++ b/sdk/python/kubeflow/trainer/hf_llm_training.py @@ -29,17 +29,26 @@ logger.setLevel(logging.INFO) -def setup_model_and_tokenizer(model_uri, transformer_type, model_dir): +def setup_model_and_tokenizer(model_uri, transformer_type, model_dir, num_labels): # Set up the model and tokenizer parsed_uri = urlparse(model_uri) model_name = parsed_uri.netloc + parsed_uri.path - model = transformer_type.from_pretrained( - pretrained_model_name_or_path=model_name, - cache_dir=model_dir, - local_files_only=True, - trust_remote_code=True, - ) + if num_labels != "None": + model = transformer_type.from_pretrained( + pretrained_model_name_or_path=model_name, + cache_dir=model_dir, + local_files_only=True, + trust_remote_code=True, + num_labels=int(num_labels), + ) + else: + model = transformer_type.from_pretrained( + pretrained_model_name_or_path=model_name, + cache_dir=model_dir, + local_files_only=True, + trust_remote_code=True, + ) tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=model_name, @@ -145,6 +154,7 @@ def parse_arguments(): parser.add_argument("--model_uri", help="model uri") parser.add_argument("--transformer_type", help="model transformer type") + parser.add_argument("--num_labels", default="None", help="number of classes") parser.add_argument("--model_dir", help="directory containing model") parser.add_argument("--dataset_dir", help="directory containing dataset") parser.add_argument("--lora_config", help="lora_config") @@ -163,7 +173,7 @@ def parse_arguments(): logger.info("Setup model and tokenizer") model, tokenizer = setup_model_and_tokenizer( - args.model_uri, transformer_type, args.model_dir + args.model_uri, transformer_type, args.model_dir, args.num_labels ) logger.info("Preprocess dataset") diff --git a/sdk/python/kubeflow/training/api/training_client.py b/sdk/python/kubeflow/training/api/training_client.py index edac130194..4165904b21 100644 --- a/sdk/python/kubeflow/training/api/training_client.py +++ b/sdk/python/kubeflow/training/api/training_client.py @@ -264,6 +264,8 @@ def train( model_provider_parameters.model_uri, "--transformer_type", model_provider_parameters.transformer_type.__name__, + "--num_labels", + str(model_provider_parameters.num_labels), "--model_dir", VOLUME_PATH_MODEL, "--dataset_dir",