From 88bdc91b75627da16cfb3d54fc8aa973031f7b01 Mon Sep 17 00:00:00 2001 From: markus583 Date: Tue, 10 Sep 2024 14:46:15 +0200 Subject: [PATCH] add ONNX support --- README.md | 30 ++++++++++++++ scripts/export_to_onnx_sat.py | 74 ++++++++++++++++++++++++++++++----- setup.py | 2 +- test.py | 8 ++-- wtpsplit/__init__.py | 63 ++++++++++++++++------------- wtpsplit/extract.py | 10 ++--- 6 files changed, 140 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index d01ed335..a0f04d69 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,36 @@ sat_adapted.split("This is a test This is another test.") # returns ['This is a test ', 'This is another test'] ``` +## ONNX Support +🚀 You can now enable even faster ONNX inference for `sat` and `sat-sm` models! 🚀 + +```python +sat = SaT("sat-3l-sm", onnx_providers=["CUDAExecutionProvider"]) +``` + +```python +>>> from wtpsplit import SaT +>>> texts = ["This is a sentence. This is another sentence."] * 1000 + +# PyTorch GPU +>>> model = SaT("sat-3l-sm") +>>> model.half().to("cuda") +>>> %timeit list(model.split(texts)) +138 ms ± 8.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) + +# onnxruntime GPU +>>> model = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider"]) +>>> %timeit list(model.split(texts)) +198 ms ± 1.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) +``` + +If you wish to use LoRA in combination with an ONNX model: +- Run `scripts/export_to_onnx_sat.py` with `use_lora: True` and an appropriate `output_dir: `. + - If you have a local LoRA module, use `lora_path`. + - If you wish to load a LoRA module from the HuggingFace hub, use `style_or_domain` and `language`. +- Load the ONNX model with merged LoRA weights: +`sat = SaT(, onnx_providers=["CUDAExecutionProvider"])` + ## Available Models diff --git a/scripts/export_to_onnx_sat.py b/scripts/export_to_onnx_sat.py index dd82baa5..7727304a 100644 --- a/scripts/export_to_onnx_sat.py +++ b/scripts/export_to_onnx_sat.py @@ -1,21 +1,32 @@ from dataclasses import dataclass from pathlib import Path +import adapters # noqa import onnx import torch +from adapters.models import MODEL_MIXIN_MAPPING # noqa +from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa +from huggingface_hub import hf_hub_download from onnxruntime.transformers.optimizer import optimize_model # noqa from transformers import AutoModelForTokenClassification, HfArgumentParser import wtpsplit # noqa import wtpsplit.models # noqa +from wtpsplit.utils import Constants + +MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin @dataclass class Args: - model_name_or_path: str = "segment-any-text/sat-1l-sm" - output_dir: str = "sat-1l-sm" + model_name_or_path: str = "segment-any-text/sat-1l-sm" # model from HF Hub: https://huggingface.co/segment-any-text + output_dir: str = "sat-1l-sm" # output directory, saves to current directory device: str = "cuda" - # TODO: lora merging here + use_lora: bool = False + lora_path: str = None # local path to lora weights + # otherwise, fetch from HF Hub: + style_or_domain: str = "ud" + language: str = "en" if __name__ == "__main__": @@ -25,25 +36,70 @@ class Args: output_dir.mkdir(exist_ok=True, parents=True) model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path, force_download=False) - model = model.half() # CUDA ONLY! + model = model.to(args.device) + # fetch config.json from huggingface hub + hf_hub_download( + repo_id=args.model_name_or_path, + filename="config.json", + local_dir=output_dir, + ) + + # LoRA SETUP + if args.use_lora: + # adapters need xlm-roberta as model type. + model_type = model.config.model_type + model.config.model_type = "xlm-roberta" + adapters.init(model) + # reset model type (used later) + model.config.model_type = model_type + if not args.lora_path: + for file in [ + "adapter_config.json", + "head_config.json", + "pytorch_adapter.bin", + "pytorch_model_head.bin", + ]: + hf_hub_download( + repo_id=args.model_name_or_path, + subfolder=f"loras/{args.style_or_domain}/{args.language}", + filename=file, + local_dir=Constants.CACHE_DIR, + ) + lora_load_path = str(Constants.CACHE_DIR / "loras" / args.style_or_domain / args.language) + else: + lora_load_path = args.lora_path + + print(f"Using LoRA weights from {lora_load_path}.") + model.load_adapter( + lora_load_path, + set_active=True, + with_head=True, + load_as="sat-lora", + ) + # merge lora weights into transformer for 0 efficiency overhead + model.merge_adapter("sat-lora") + print("LoRA setup done.") + # LoRA setup done, model is now ready for export. + + model = model.half() + torch.onnx.export( model, { - "attention_mask": torch.zeros((1, 1), dtype=torch.float16, device=args.device), - "input_ids": torch.zeros((1, 1), dtype=torch.int64, device=args.device), + "attention_mask": torch.randn((1, 1), dtype=torch.float16, device=args.device), + "input_ids": torch.randint(0, 250002, (1, 1), dtype=torch.int64, device=args.device), }, output_dir / "model.onnx", verbose=True, input_names=["attention_mask", "input_ids"], output_names=["logits"], dynamic_axes={ - "input_ids": {0: "batch", 1: "sequence"}, "attention_mask": {0: "batch", 1: "sequence"}, + "input_ids": {0: "batch", 1: "sequence"}, "logits": {0: "batch", 1: "sequence"}, }, - # opset_version=11 ) m = optimize_model( @@ -60,4 +116,4 @@ class Args: onnx.save_model(m.model, optimized_model_path) onnx_model = onnx.load(output_dir / "model.onnx") - onnx.checker.check_model(onnx_model, full_check=True) + print(onnx.checker.check_model(onnx_model, full_check=True)) diff --git a/setup.py b/setup.py index 269a7f0f..132081f1 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="wtpsplit", - version="2.0.8", + version="2.1.0", packages=find_packages(), description="Universal Robust, Efficient and Adaptable Sentence Segmentation", author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer", diff --git a/test.py b/test.py index 40d44a0f..25e86ac3 100644 --- a/test.py +++ b/test.py @@ -2,11 +2,11 @@ from wtpsplit import WtP, SaT -# def test_split_ort(): -# sat = SaT("segment-any-text/sat-3l", ort_providers=["CPUExecutionProvider"]) +def test_split_ort(): + sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) -# splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.005) -# assert splits == ["This is a test sentence ", "This is another test sentence."] + splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.25) + assert splits == ["This is a test sentence ", "This is another test sentence."] def test_split_torch(): diff --git a/wtpsplit/__init__.py b/wtpsplit/__init__.py index a98c6dfc..cf643a14 100644 --- a/wtpsplit/__init__.py +++ b/wtpsplit/__init__.py @@ -15,10 +15,10 @@ from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer from transformers.utils.hub import cached_file -from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, extract +from wtpsplit.extract import BertCharORTWrapper, SaTORTWrapper, PyTorchWrapper, extract from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs -__version__ = "2.0.8" +__version__ = "2.1.0" warnings.simplefilter("default", DeprecationWarning) # show by default warnings.simplefilter("ignore", category=FutureWarning) # for tranformers @@ -88,8 +88,6 @@ def __init__( try: import onnxruntime as ort # noqa - - ort.set_default_logger_severity(0) except ModuleNotFoundError: raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.") @@ -449,38 +447,39 @@ def __init__( if is_local: model_path = Path(model_name) - onnx_path = model_path / "model.onnx" + onnx_path = model_path / "model_optimized.onnx" if not onnx_path.exists(): onnx_path = None else: # no need to load if no ort_providers set if ort_providers is not None: - onnx_path = cached_file(model_name_to_fetch, "model.onnx", **(from_pretrained_kwargs or {})) + onnx_path = cached_file(model_name_to_fetch, "model_optimized.onnx", **(from_pretrained_kwargs or {})) else: onnx_path = None if ort_providers is not None: - raise NotImplementedError("ONNX is not supported for SaT *yet*.") - # if onnx_path is None: - # raise ValueError( - # "Could not find an ONNX model in the model directory. Try `use_ort=False` to run with PyTorch." - # ) - - # try: - # import onnxruntime as ort # noqa - - # ort.set_default_logger_severity(0) - # except ModuleNotFoundError: - # raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.") - - # # to register models for AutoConfig - # import wtpsplit.configs # noqa - - # # TODO: ONNX integration - # self.model = SaTORTWrapper( - # AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})), - # ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})), - # ) + if onnx_path is None: + raise ValueError( + "Could not find an ONNX model in the model directory. Try `use_ort=False` to run with PyTorch." + ) + + try: + import onnxruntime as ort # noqa + except ModuleNotFoundError: + raise ValueError("Please install `onnxruntime` to use SaT with an ONNX model.") + + # to register models for AutoConfig + import wtpsplit.configs # noqa + + self.model = SaTORTWrapper( + AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})), + ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})), + ) + if lora_path: + raise ValueError( + "If using ONNX with LoRA, execute `scripts/export_to_onnx_sat.py` with `use_lora=True`." + "Reference the chosen `output_dir` here for `model_name_or_model`. and set `lora_path=None`." + ) else: # to register models for AutoConfig try: @@ -496,7 +495,6 @@ def __init__( ) ) # LoRA LOADING - # TODO: LoRA + ONNX ? if not lora_path: if (style_or_domain and not language) or (language and not style_or_domain): raise ValueError("Please specify both language and style_or_domain!") @@ -792,3 +790,12 @@ def get_default_threshold(model_str: str): text, np.where(probs > sentence_threshold)[0], strip_whitespace=strip_whitespace ) yield sentences + + +if __name__ == "__main__": + sat = SaT("sat-3l-lora", ort_providers=["CPUExecutionProvider"]) + print(sat.split("Hello, World! Next.")) + + wtp = WtP("wtp-bert-tiny", ort_providers=["CPUExecutionProvider"]) + print(wtp.split("Hello, World! Next.")) + print("DONE!") diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 6bdf1411..a3c93880 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -43,11 +43,11 @@ def __getattr__(self, name): def __call__(self, input_ids, attention_mask): logits = self.ort_session.run( - output_names=["logits"], - input_feed={ - "attention_mask": attention_mask.astype(np.int64), - "input_ids": input_ids.astype(np.float16), - }, # .astype(np.int64)}, + ["logits"], + { + self.ort_session.get_inputs()[0].name: input_ids.astype(np.int64), + self.ort_session.get_inputs()[1].name: attention_mask.astype(np.float16), + }, )[0] return {"logits": logits}