Skip to content

Commit

Permalink
Download support for custom Huggingface models (#40)
Browse files Browse the repository at this point in the history
* Download custom Huggingface models and progress bars

* changes default repo_version to None in pytest
  • Loading branch information
AyushSawant18588 authored Dec 4, 2023
1 parent 0ba918d commit b629686
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
python-version: 3.11

- name: Install Python dependencies
run: pip install pytest black pylint torchserve==0.8.2 torch==2.0.1 transformers==4.33.0 -r llm/requirements.txt
run: pip install --no-cache-dir pytest black pylint torchserve==0.8.2 torch==2.0.1 transformers==4.33.0 -r llm/requirements.txt

- name: Run pylint
run: pylint ./llm
Expand Down
44 changes: 36 additions & 8 deletions llm/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class with relevant information.

hf.hf_token_check(gen_model.repo_info.repo_id, gen_model.repo_info.hf_token)

if gen_model.repo_info.repo_version == "":
if not gen_model.repo_info.repo_version:
gen_model.repo_info.repo_version = model["repo_version"]

gen_model.repo_info.repo_version = hf.get_repo_commit_id(
Expand All @@ -225,15 +225,36 @@ class with relevant information.
os.path.dirname(__file__),
HANDLER,
)
if gen_model.repo_info.repo_version == "":
if not gen_model.repo_info.repo_version:
gen_model.repo_info.repo_version = "1.0"
elif gen_model.repo_info.repo_id:
hf.hf_token_check(gen_model.repo_info.repo_id, gen_model.repo_info.hf_token)
gen_model.repo_info.repo_version = hf.get_repo_commit_id(
repo_id=gen_model.repo_info.repo_id,
revision=gen_model.repo_info.repo_version,
token=gen_model.repo_info.hf_token,
)
gen_model.is_custom = True
if gen_model.mar_utils.handler_path == "":
gen_model.mar_utils.handler_path = os.path.join(
os.path.dirname(__file__),
HANDLER,
)
else:
print(
"## Please check your model name, it should be one of the following : "
"## If you want to create a model archive file with the supported models, "
"make sure you're model name is present in the below : "
)
print(list(models.keys()))
print(
"If it is a custom model and you have model files include no_download flag : "
"If you want to create a model archive file for a custom model, there "
"are two methods:\n"
"1. If you have already downloaded the custom model files, please include"
" the --no_download flag and provide the model_path directory which contains "
"the model files.\n"
"2. If you need to download the model files, provide the HuggingFace "
"repository ID along with a model_path driectory where the model "
"files are to be downloaded."
)
sys.exit(1)

Expand Down Expand Up @@ -295,13 +316,13 @@ class with relevant information.
else:
if check_if_folder_empty(gen_model.mar_utils.model_path):
print(
f"\n##Error: {gen_model.model_name} model files not found"
f" in the provided path: {gen_model.mar_utils.model_path}"
f"\n##Error: {gen_model.model_name} model files for the custom"
f" model not found in the provided path: {gen_model.mar_utils.model_path}"
)
sys.exit(1)
else:
print(
f"\n## Generating MAR file for custom model files: {gen_model.model_name}"
f"\n## Generating MAR file for custom model files: {gen_model.model_name} \n"
)

create_folder_if_not_exists(gen_model.mar_utils.mar_output)
Expand Down Expand Up @@ -348,6 +369,13 @@ def run_script(params: argparse.Namespace) -> bool:
metavar="mn",
help="name of the model",
)
parser.add_argument(
"--repo_id",
type=str,
default="",
metavar="ri",
help="HuggingFace repository ID (In case of custom model download)",
)
parser.add_argument(
"--no_download", action="store_false", help="flag to not download"
)
Expand Down Expand Up @@ -376,7 +404,7 @@ def run_script(params: argparse.Namespace) -> bool:
parser.add_argument(
"--repo_version",
type=str,
default="",
default=None,
metavar="rv",
help="commit id of the HuggingFace Repo",
)
Expand Down
20 changes: 16 additions & 4 deletions llm/kubeflow_inference_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import time
from typing import List, Dict
import tqdm
import utils.tsutils as ts
import utils.hf_utils as hf
from utils.system_utils import check_if_path_exists, get_all_files_in_directory
Expand Down Expand Up @@ -300,8 +301,15 @@ def health_check(model_name: str, deploy_name: str, model_timeout: int) -> None:
model_input = os.path.join(os.path.dirname(__file__), PATH_TO_SAMPLE)

retry_count = 0
sleep_time = 30
sleep_time = 15
success = False
total_tries = model_timeout / sleep_time
progress_bar = tqdm.tqdm(
total=total_tries,
unit="check",
desc="Waiting for Model to be ready",
bar_format="{desc}: |{bar}| {n_fmt}/{total_fmt} checks",
)
while not success and retry_count * sleep_time < model_timeout:
success = execute_inference_on_inputs(
[model_input], model_name, deploy_name, retry=True
Expand All @@ -310,12 +318,16 @@ def health_check(model_name: str, deploy_name: str, model_timeout: int) -> None:
if not success:
time.sleep(sleep_time)
retry_count += 1
progress_bar.update(1)

if success:
print("## Health check passed. Model deployed.\n\n")
progress_bar.update(total_tries - retry_count)
progress_bar.close()
print("\n## Health check passed. Model deployed.\n")
else:
progress_bar.close()
print(
f"## Failed health check after multiple retries for model - {model_name} \n"
f"\n## Failed health check after multiple retries for model - {model_name} \n"
)
sys.exit(1)

Expand Down Expand Up @@ -377,7 +389,7 @@ def execute(params: argparse.Namespace) -> None:
create_pvc(core_api, deploy_name, storage)
create_isvc(deploy_name, model_info, deployment_resources, model_params)

print("wait for model registration to complete, will take some time")
print("\nWait for model registration to complete, will take some time. \n")
health_check(model_info["model_name"], deploy_name, model_timeout)

if input_path:
Expand Down
2 changes: 1 addition & 1 deletion llm/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ wdir=$(dirname "$SCRIPT")

CPU_POD="8"
MEM_POD="32Gi"
MODEL_TIMEOUT_IN_SEC="1200"
MODEL_TIMEOUT_IN_SEC="1500"

function helpFunction()
{
Expand Down
73 changes: 71 additions & 2 deletions llm/tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def set_args(
model_name="",
output="",
model_path="",
repo_version="",
repo_version=None,
handler_path="",
):
"""
Expand All @@ -50,6 +50,7 @@ def set_args(
args.model_path = model_path
args.no_download = True
args.repo_version = repo_version
args.repo_id = ""
args.handler_path = handler_path
args.hf_token = None
args.debug = False
Expand Down Expand Up @@ -250,7 +251,7 @@ def test_short_repo_version_success():
assert result is True


def test_custom_model_success():
def test_custom_model_with_modelfiles_success():
"""
This function tests the custom model case.
This is done by clearing the 'model_config.json' and
Expand Down Expand Up @@ -285,8 +286,76 @@ def test_custom_model_no_model_files_failure():
args.no_download = False
try:
download.run_script(args)
except SystemExit as e:
custom_model_restore()
assert e.code == 1
else:
assert False


def test_custom_model_with_repo_id_success():
"""
This function tests the custom model case where
model files are to be downloaded for provided
repo ID.
This is done by clearing the 'model_config.json' and
generating the 'GPT2' MAR file.
Expected result: Success.
"""
model_path = custom_model_setup()
args = set_args(MODEL_NAME, OUTPUT, model_path)
args.repo_id = "gpt2"
try:
result = download.run_script(args)
custom_model_restore()
except SystemExit:
assert False
else:
assert result is True


def test_custom_model_wrong_repo_id_failure():
"""
This function tests the custom model case when
model repo ID is wrong.
Expected result: Failure.
"""
model_path = custom_model_setup()
model_store_path = os.path.join(
os.path.dirname(__file__), MODEL_NAME, "model-store"
)
empty_folder(model_path)
empty_folder(model_store_path)
args = set_args(MODEL_NAME, OUTPUT, model_path)
args.repo_id = "wrong_repo_id"
try:
download.run_script(args)
except SystemExit as e:
custom_model_restore()
assert e.code == 1
else:
assert False


def test_custom_model_wrong_repo_version_failure():
"""
This function tests the custom model case when
model repo version is wrong.
Expected result: Failure.
"""
model_path = custom_model_setup()
model_store_path = os.path.join(
os.path.dirname(__file__), MODEL_NAME, "model-store"
)
empty_folder(model_path)
empty_folder(model_store_path)
args = set_args(MODEL_NAME, OUTPUT, model_path)
args.repo_id = "gpt2"
args.repo_version = "wrong_version"
try:
download.run_script(args)
except SystemExit as e:
custom_model_restore()
assert e.code == 1
else:
assert False
Expand Down
3 changes: 3 additions & 0 deletions llm/utils/generate_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class MarUtils:
mar_output = str()
model_path = str()
handler_path = str()
extra_files = str()
requirements_file = str()


@dataclasses.dataclass
Expand Down Expand Up @@ -100,6 +102,7 @@ class with values set based on the arguments.

self.mar_utils.handler_path = params.handler_path

self.repo_info.repo_id = params.repo_id
self.repo_info.repo_version = params.repo_version
self.repo_info.hf_token = params.hf_token

Expand Down
25 changes: 19 additions & 6 deletions llm/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from huggingface_hub.utils import (
RepositoryNotFoundError,
RevisionNotFoundError,
HfHubHTTPError,
HFValidationError,
)
from utils.generate_data_model import GenerateDataModel

Expand Down Expand Up @@ -34,11 +36,17 @@ class with relevant information.
token=gen_model.repo_info.hf_token,
)
return repo_files
except (RepositoryNotFoundError, RevisionNotFoundError, KeyError):
except (
HfHubHTTPError,
HFValidationError,
RepositoryNotFoundError,
RevisionNotFoundError,
KeyError,
):
print(
(
"## Error: Please check either repo_id, repo_version "
"or huggingface token is not correct"
"\n## Error: Please check either repo_id, repo_version "
"or huggingface token is not correct\n"
)
)
sys.exit(1)
Expand Down Expand Up @@ -68,11 +76,16 @@ def get_repo_commit_id(repo_id: str, revision: str, token: str) -> str:
token=token,
)
return commit_info[0].commit_id
except (RepositoryNotFoundError, RevisionNotFoundError):
except (
HfHubHTTPError,
HFValidationError,
RepositoryNotFoundError,
RevisionNotFoundError,
):
print(
(
"## Error: Please check either repo_id, repo_version "
"or huggingface token is not correct"
"\n## Error: Please check either repo_id, repo_version "
"or huggingface token is not correct\n"
)
)
sys.exit(1)
Expand Down
Loading

0 comments on commit b629686

Please sign in to comment.