Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recipe check vllm e2e #929

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
10 changes: 8 additions & 2 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union

from loguru import logger

from llmcompressor.core.events import EventType
from llmcompressor.core.helpers import log_model_info, should_log_model_info
from llmcompressor.core.lifecycle import CompressionLifecycle
Expand Down Expand Up @@ -260,12 +262,16 @@ def reset_stage(self):
self.lifecycle.initialized_ = False
self.lifecycle.finalized = False

def get_serialized_recipe(self) -> str:
def get_serialized_recipe(self) -> Optional[str]:
"""
:return: serialized string of the current compiled recipe
"""
recipe = self.lifecycle.recipe_container.compiled_recipe
return recipe.yaml()

if recipe is not None and hasattr(recipe, "yaml"):
return recipe.yaml()

logger.warning("Recipe not found in session - it may have been reset")

def _log_model_info(self):
# Log model level logs if cadence reached
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ def skip(*args, **kwargs):

recipe_path = os.path.join(save_directory, "recipe.yaml")
session = active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

if (recipe_yaml_str := session.get_serialized_recipe()) is not None:
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

# copy python files from cache dir to save_path if any
copy_python_files_from_model_cache(model, save_directory)
Expand Down
78 changes: 67 additions & 11 deletions tests/e2e/vLLM/test_vllm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import re
import shutil
import unittest
from typing import Callable

import pytest
from datasets import load_dataset
from loguru import logger
from parameterized import parameterized, parameterized_class
from transformers import AutoTokenizer

from llmcompressor.core import active_session
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from tests.testing_utils import (
Expand All @@ -22,6 +26,7 @@
vllm_installed = True
except ImportError:
vllm_installed = False
logger.warning("vllm is not installed. This test will be skipped")

# Defines the file paths to the directories containing the test configs
# for each of the quantization schemes
Expand All @@ -32,6 +37,15 @@
WNA16_2of4 = "tests/e2e/vLLM/configs/WNA16_2of4"
CONFIGS = [WNA16, FP8, INT8, ACTORDER, WNA16_2of4]

HF_MODEL_HUB_NAME = "nm-testing"

EXPECTED_SAVED_FILES = [
"config.json",
r"^model(?:-\d{5}-of-\d{5})?\.safetensors$",
"recipe.yaml",
"tokenizer.json",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any other default files suggestions welcome here

]


def gen_test_name(testcase_func: Callable, param_num: int, param: dict) -> str:
return "_".join(
Expand Down Expand Up @@ -76,8 +90,8 @@ class TestvLLM(unittest.TestCase):
save_dir = None

def setUp(self):
print("========== RUNNING ==============")
print(self.scheme)
logger.info("========== RUNNING ==============")
logger.debug(self.scheme)

self.device = "cuda:0"
self.oneshot_kwargs = {}
Expand Down Expand Up @@ -124,34 +138,76 @@ def test_vllm(self):
)

# Apply quantization.
print("ONESHOT KWARGS", self.oneshot_kwargs)
logger.debug("ONESHOT KWARGS", self.oneshot_kwargs)
oneshot(
**self.oneshot_kwargs,
clear_sparse_session=True,
oneshot_device=self.device,
)

# check that session contains recipe
self._check_session_contains_recipe()

self.oneshot_kwargs["model"].save_pretrained(self.save_dir)
tokenizer.save_pretrained(self.save_dir)

# check that expected files exist
self._check_save_dir_has_expected_files()

# Run vLLM with saved model
print("================= RUNNING vLLM =========================")
logger.info("================= RUNNING vLLM =========================")
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)
if "W4A16_2of4" in self.scheme:
# required by the kernel
llm = LLM(model=self.save_dir, dtype=torch.float16)
else:
llm = LLM(model=self.save_dir)
outputs = llm.generate(self.prompts, sampling_params)
print("================= vLLM GENERATION ======================")

logger.info("================= vLLM GENERATION ======================")
for output in outputs:
assert output
prompt = output.prompt
generated_text = output.outputs[0].text
print("PROMPT", prompt)
print("GENERATED TEXT", generated_text)
print("================= UPLOADING TO HUB ======================")
self.oneshot_kwargs["model"].push_to_hub(f"nm-testing/{self.save_dir}-e2e")
tokenizer.push_to_hub(f"nm-testing/{self.save_dir}-e2e")
logger.debug("PROMPT", prompt)
logger.debug("GENERATED TEXT", generated_text)

logger.info("================= UPLOADING TO HUB ======================")
hf_upload_path = os.path.join(HF_MODEL_HUB_NAME, f"{self.save_dir}-e2e")
self.oneshot_kwargs["model"].push_to_hub(hf_upload_path)
tokenizer.push_to_hub(hf_upload_path)

def tearDown(self):
if self.save_dir is not None:
shutil.rmtree(self.save_dir)

def _check_session_contains_recipe(self) -> None:
session = active_session()
recipe_yaml_str = session.get_serialized_recipe()
assert recipe_yaml_str is not None

def _check_save_dir_has_expected_files(self):
files = os.listdir(self.save_dir)
logger.debug("Saved files: ", files)

matched_patterns = set()

for expected in EXPECTED_SAVED_FILES:
# Find all files matching the expected pattern
matches = [
file
for file in files
if (
re.fullmatch(expected, file)
if expected.startswith("^")
else file == expected
)
]
if matches is not None:
matched_patterns.add(expected)

assert len(matched_patterns) == len(EXPECTED_SAVED_FILES), (
"expected: ",
EXPECTED_SAVED_FILES,
"\n saved: ",
list(matched_patterns),
)
Loading