From 88bdc91b75627da16cfb3d54fc8aa973031f7b01 Mon Sep 17 00:00:00 2001 From: markus583 Date: Tue, 10 Sep 2024 14:46:15 +0200 Subject: [PATCH 1/4] add ONNX support --- README.md | 30 ++++++++++++++ scripts/export_to_onnx_sat.py | 74 ++++++++++++++++++++++++++++++----- setup.py | 2 +- test.py | 8 ++-- wtpsplit/__init__.py | 63 ++++++++++++++++------------- wtpsplit/extract.py | 10 ++--- 6 files changed, 140 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index d01ed335..a0f04d69 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,36 @@ sat_adapted.split("This is a test This is another test.") # returns ['This is a test ', 'This is another test'] ``` +## ONNX Support +🚀 You can now enable even faster ONNX inference for `sat` and `sat-sm` models! 🚀 + +```python +sat = SaT("sat-3l-sm", onnx_providers=["CUDAExecutionProvider"]) +``` + +```python +>>> from wtpsplit import SaT +>>> texts = ["This is a sentence. This is another sentence."] * 1000 + +# PyTorch GPU +>>> model = SaT("sat-3l-sm") +>>> model.half().to("cuda") +>>> %timeit list(model.split(texts)) +138 ms ± 8.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) + +# onnxruntime GPU +>>> model = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider"]) +>>> %timeit list(model.split(texts)) +198 ms ± 1.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) +``` + +If you wish to use LoRA in combination with an ONNX model: +- Run `scripts/export_to_onnx_sat.py` with `use_lora: True` and an appropriate `output_dir: `. + - If you have a local LoRA module, use `lora_path`. + - If you wish to load a LoRA module from the HuggingFace hub, use `style_or_domain` and `language`. +- Load the ONNX model with merged LoRA weights: +`sat = SaT(, onnx_providers=["CUDAExecutionProvider"])` + ## Available Models diff --git a/scripts/export_to_onnx_sat.py b/scripts/export_to_onnx_sat.py index dd82baa5..7727304a 100644 --- a/scripts/export_to_onnx_sat.py +++ b/scripts/export_to_onnx_sat.py @@ -1,21 +1,32 @@ from dataclasses import dataclass from pathlib import Path +import adapters # noqa import onnx import torch +from adapters.models import MODEL_MIXIN_MAPPING # noqa +from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa +from huggingface_hub import hf_hub_download from onnxruntime.transformers.optimizer import optimize_model # noqa from transformers import AutoModelForTokenClassification, HfArgumentParser import wtpsplit # noqa import wtpsplit.models # noqa +from wtpsplit.utils import Constants + +MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin @dataclass class Args: - model_name_or_path: str = "segment-any-text/sat-1l-sm" - output_dir: str = "sat-1l-sm" + model_name_or_path: str = "segment-any-text/sat-1l-sm" # model from HF Hub: https://huggingface.co/segment-any-text + output_dir: str = "sat-1l-sm" # output directory, saves to current directory device: str = "cuda" - # TODO: lora merging here + use_lora: bool = False + lora_path: str = None # local path to lora weights + # otherwise, fetch from HF Hub: + style_or_domain: str = "ud" + language: str = "en" if __name__ == "__main__": @@ -25,25 +36,70 @@ class Args: output_dir.mkdir(exist_ok=True, parents=True) model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path, force_download=False) - model = model.half() # CUDA ONLY! + model = model.to(args.device) + # fetch config.json from huggingface hub + hf_hub_download( + repo_id=args.model_name_or_path, + filename="config.json", + local_dir=output_dir, + ) + + # LoRA SETUP + if args.use_lora: + # adapters need xlm-roberta as model type. + model_type = model.config.model_type + model.config.model_type = "xlm-roberta" + adapters.init(model) + # reset model type (used later) + model.config.model_type = model_type + if not args.lora_path: + for file in [ + "adapter_config.json", + "head_config.json", + "pytorch_adapter.bin", + "pytorch_model_head.bin", + ]: + hf_hub_download( + repo_id=args.model_name_or_path, + subfolder=f"loras/{args.style_or_domain}/{args.language}", + filename=file, + local_dir=Constants.CACHE_DIR, + ) + lora_load_path = str(Constants.CACHE_DIR / "loras" / args.style_or_domain / args.language) + else: + lora_load_path = args.lora_path + + print(f"Using LoRA weights from {lora_load_path}.") + model.load_adapter( + lora_load_path, + set_active=True, + with_head=True, + load_as="sat-lora", + ) + # merge lora weights into transformer for 0 efficiency overhead + model.merge_adapter("sat-lora") + print("LoRA setup done.") + # LoRA setup done, model is now ready for export. + + model = model.half() + torch.onnx.export( model, { - "attention_mask": torch.zeros((1, 1), dtype=torch.float16, device=args.device), - "input_ids": torch.zeros((1, 1), dtype=torch.int64, device=args.device), + "attention_mask": torch.randn((1, 1), dtype=torch.float16, device=args.device), + "input_ids": torch.randint(0, 250002, (1, 1), dtype=torch.int64, device=args.device), }, output_dir / "model.onnx", verbose=True, input_names=["attention_mask", "input_ids"], output_names=["logits"], dynamic_axes={ - "input_ids": {0: "batch", 1: "sequence"}, "attention_mask": {0: "batch", 1: "sequence"}, + "input_ids": {0: "batch", 1: "sequence"}, "logits": {0: "batch", 1: "sequence"}, }, - # opset_version=11 ) m = optimize_model( @@ -60,4 +116,4 @@ class Args: onnx.save_model(m.model, optimized_model_path) onnx_model = onnx.load(output_dir / "model.onnx") - onnx.checker.check_model(onnx_model, full_check=True) + print(onnx.checker.check_model(onnx_model, full_check=True)) diff --git a/setup.py b/setup.py index 269a7f0f..132081f1 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="wtpsplit", - version="2.0.8", + version="2.1.0", packages=find_packages(), description="Universal Robust, Efficient and Adaptable Sentence Segmentation", author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer", diff --git a/test.py b/test.py index 40d44a0f..25e86ac3 100644 --- a/test.py +++ b/test.py @@ -2,11 +2,11 @@ from wtpsplit import WtP, SaT -# def test_split_ort(): -# sat = SaT("segment-any-text/sat-3l", ort_providers=["CPUExecutionProvider"]) +def test_split_ort(): + sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) -# splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.005) -# assert splits == ["This is a test sentence ", "This is another test sentence."] + splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.25) + assert splits == ["This is a test sentence ", "This is another test sentence."] def test_split_torch(): diff --git a/wtpsplit/__init__.py b/wtpsplit/__init__.py index a98c6dfc..cf643a14 100644 --- a/wtpsplit/__init__.py +++ b/wtpsplit/__init__.py @@ -15,10 +15,10 @@ from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer from transformers.utils.hub import cached_file -from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, extract +from wtpsplit.extract import BertCharORTWrapper, SaTORTWrapper, PyTorchWrapper, extract from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs -__version__ = "2.0.8" +__version__ = "2.1.0" warnings.simplefilter("default", DeprecationWarning) # show by default warnings.simplefilter("ignore", category=FutureWarning) # for tranformers @@ -88,8 +88,6 @@ def __init__( try: import onnxruntime as ort # noqa - - ort.set_default_logger_severity(0) except ModuleNotFoundError: raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.") @@ -449,38 +447,39 @@ def __init__( if is_local: model_path = Path(model_name) - onnx_path = model_path / "model.onnx" + onnx_path = model_path / "model_optimized.onnx" if not onnx_path.exists(): onnx_path = None else: # no need to load if no ort_providers set if ort_providers is not None: - onnx_path = cached_file(model_name_to_fetch, "model.onnx", **(from_pretrained_kwargs or {})) + onnx_path = cached_file(model_name_to_fetch, "model_optimized.onnx", **(from_pretrained_kwargs or {})) else: onnx_path = None if ort_providers is not None: - raise NotImplementedError("ONNX is not supported for SaT *yet*.") - # if onnx_path is None: - # raise ValueError( - # "Could not find an ONNX model in the model directory. Try `use_ort=False` to run with PyTorch." - # ) - - # try: - # import onnxruntime as ort # noqa - - # ort.set_default_logger_severity(0) - # except ModuleNotFoundError: - # raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.") - - # # to register models for AutoConfig - # import wtpsplit.configs # noqa - - # # TODO: ONNX integration - # self.model = SaTORTWrapper( - # AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})), - # ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})), - # ) + if onnx_path is None: + raise ValueError( + "Could not find an ONNX model in the model directory. Try `use_ort=False` to run with PyTorch." + ) + + try: + import onnxruntime as ort # noqa + except ModuleNotFoundError: + raise ValueError("Please install `onnxruntime` to use SaT with an ONNX model.") + + # to register models for AutoConfig + import wtpsplit.configs # noqa + + self.model = SaTORTWrapper( + AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})), + ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})), + ) + if lora_path: + raise ValueError( + "If using ONNX with LoRA, execute `scripts/export_to_onnx_sat.py` with `use_lora=True`." + "Reference the chosen `output_dir` here for `model_name_or_model`. and set `lora_path=None`." + ) else: # to register models for AutoConfig try: @@ -496,7 +495,6 @@ def __init__( ) ) # LoRA LOADING - # TODO: LoRA + ONNX ? if not lora_path: if (style_or_domain and not language) or (language and not style_or_domain): raise ValueError("Please specify both language and style_or_domain!") @@ -792,3 +790,12 @@ def get_default_threshold(model_str: str): text, np.where(probs > sentence_threshold)[0], strip_whitespace=strip_whitespace ) yield sentences + + +if __name__ == "__main__": + sat = SaT("sat-3l-lora", ort_providers=["CPUExecutionProvider"]) + print(sat.split("Hello, World! Next.")) + + wtp = WtP("wtp-bert-tiny", ort_providers=["CPUExecutionProvider"]) + print(wtp.split("Hello, World! Next.")) + print("DONE!") diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 6bdf1411..a3c93880 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -43,11 +43,11 @@ def __getattr__(self, name): def __call__(self, input_ids, attention_mask): logits = self.ort_session.run( - output_names=["logits"], - input_feed={ - "attention_mask": attention_mask.astype(np.int64), - "input_ids": input_ids.astype(np.float16), - }, # .astype(np.int64)}, + ["logits"], + { + self.ort_session.get_inputs()[0].name: input_ids.astype(np.int64), + self.ort_session.get_inputs()[1].name: attention_mask.astype(np.float16), + }, )[0] return {"logits": logits} From 1e155602fadb463099ef16dd0ce26d86f11fd1ee Mon Sep 17 00:00:00 2001 From: markus583 Date: Tue, 10 Sep 2024 15:03:56 +0200 Subject: [PATCH 2/4] remove if __name__ --- wtpsplit/__init__.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/wtpsplit/__init__.py b/wtpsplit/__init__.py index cf643a14..538c2bfe 100644 --- a/wtpsplit/__init__.py +++ b/wtpsplit/__init__.py @@ -453,7 +453,9 @@ def __init__( else: # no need to load if no ort_providers set if ort_providers is not None: - onnx_path = cached_file(model_name_to_fetch, "model_optimized.onnx", **(from_pretrained_kwargs or {})) + onnx_path = cached_file( + model_name_to_fetch, "model_optimized.onnx", **(from_pretrained_kwargs or {}) + ) else: onnx_path = None @@ -790,12 +792,3 @@ def get_default_threshold(model_str: str): text, np.where(probs > sentence_threshold)[0], strip_whitespace=strip_whitespace ) yield sentences - - -if __name__ == "__main__": - sat = SaT("sat-3l-lora", ort_providers=["CPUExecutionProvider"]) - print(sat.split("Hello, World! Next.")) - - wtp = WtP("wtp-bert-tiny", ort_providers=["CPUExecutionProvider"]) - print(wtp.split("Hello, World! Next.")) - print("DONE!") From b9198d076b43a3769494fb0ee2682096fdf37ded Mon Sep 17 00:00:00 2001 From: markus583 Date: Sun, 22 Sep 2024 15:22:08 +0200 Subject: [PATCH 3/4] remove exact timings. --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a0f04d69..914c56a5 100644 --- a/README.md +++ b/README.md @@ -60,13 +60,13 @@ sat = SaT("sat-3l-sm", onnx_providers=["CUDAExecutionProvider"]) # PyTorch GPU >>> model = SaT("sat-3l-sm") >>> model.half().to("cuda") ->>> %timeit list(model.split(texts)) -138 ms ± 8.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) +>>> list(model.split(texts)) +# quite fast already, but... # onnxruntime GPU >>> model = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider"]) >>> %timeit list(model.split(texts)) -198 ms ± 1.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) +# ...this should be ~50% faster! ``` If you wish to use LoRA in combination with an ONNX model: From b549ab6999ad4f1aae6bdf665e2d6eb0c298fed6 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 24 Sep 2024 13:39:41 +0100 Subject: [PATCH 4/4] fix onnx input ordering, add onnxruntime-gpu timings --- README.md | 18 ++++++++++-------- scripts/export_all_to_onnx.sh | 23 +++++++++++++++++++++++ scripts/export_to_onnx_sat.py | 23 +++++++++++++++++++---- wtpsplit/extract.py | 4 ++-- 4 files changed, 54 insertions(+), 14 deletions(-) create mode 100644 scripts/export_all_to_onnx.sh diff --git a/README.md b/README.md index 914c56a5..76f63adc 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ sat_adapted.split("This is a test This is another test.") 🚀 You can now enable even faster ONNX inference for `sat` and `sat-sm` models! 🚀 ```python -sat = SaT("sat-3l-sm", onnx_providers=["CUDAExecutionProvider"]) +sat = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) ``` ```python @@ -58,15 +58,17 @@ sat = SaT("sat-3l-sm", onnx_providers=["CUDAExecutionProvider"]) >>> texts = ["This is a sentence. This is another sentence."] * 1000 # PyTorch GPU ->>> model = SaT("sat-3l-sm") ->>> model.half().to("cuda") ->>> list(model.split(texts)) +>>> model_pytorch = SaT("sat-3l-sm") +>>> model_pytorch.half().to("cuda"); +>>> %timeit list(model_pytorch.split(texts)) +# 144 ms ± 252 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) # quite fast already, but... # onnxruntime GPU ->>> model = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider"]) ->>> %timeit list(model.split(texts)) -# ...this should be ~50% faster! +>>> model_ort = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) +>>> %timeit list(model_ort.split(texts)) +# 94.9 ms ± 165 μs per loop (mean ± std. dev. of 7 runs, 10 loops each +# ...this should be ~50% faster! (tested on RTX 3090) ``` If you wish to use LoRA in combination with an ONNX model: @@ -74,7 +76,7 @@ If you wish to use LoRA in combination with an ONNX model: - If you have a local LoRA module, use `lora_path`. - If you wish to load a LoRA module from the HuggingFace hub, use `style_or_domain` and `language`. - Load the ONNX model with merged LoRA weights: -`sat = SaT(, onnx_providers=["CUDAExecutionProvider"])` +`sat = SaT(, onnx_providers=["CUDAExecutionProvider", "CPUExecutionProvider"])` ## Available Models diff --git a/scripts/export_all_to_onnx.sh b/scripts/export_all_to_onnx.sh new file mode 100644 index 00000000..7c21a00f --- /dev/null +++ b/scripts/export_all_to_onnx.sh @@ -0,0 +1,23 @@ +# all models in manually defined array of models +models=( + "sat-1l-sm" + "sat-3l-sm" + "sat-6l-sm" + "sat-9l-sm" + "sat-12l-sm" + "sat-1l" + "sat-3l" + "sat-6l" + "sat-9l" + "sat-12l" + "sat-1l-no-limited-lookahead" + "sat-3l-no-limited-lookahead" + "sat-6l-no-limited-lookahead" + "sat-9l-no-limited-lookahead" + "sat-12l-no-limited-lookahead" +) + +for model in "${models[@]}" +do + python scripts/export_to_onnx_sat.py --model_name_or_path=segment-any-text/$model --output_dir=output_onnx_exports/$model --upload_to_hub=True +done \ No newline at end of file diff --git a/scripts/export_to_onnx_sat.py b/scripts/export_to_onnx_sat.py index 7727304a..c6fadf24 100644 --- a/scripts/export_to_onnx_sat.py +++ b/scripts/export_to_onnx_sat.py @@ -6,7 +6,7 @@ import torch from adapters.models import MODEL_MIXIN_MAPPING # noqa from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, HfApi from onnxruntime.transformers.optimizer import optimize_model # noqa from transformers import AutoModelForTokenClassification, HfArgumentParser @@ -27,6 +27,7 @@ class Args: # otherwise, fetch from HF Hub: style_or_domain: str = "ud" language: str = "en" + upload_to_hub: bool = False if __name__ == "__main__": @@ -88,16 +89,16 @@ class Args: torch.onnx.export( model, { + "input_ids": torch.randint(0, model.config.vocab_size, (1, 1), dtype=torch.int64, device=args.device), "attention_mask": torch.randn((1, 1), dtype=torch.float16, device=args.device), - "input_ids": torch.randint(0, 250002, (1, 1), dtype=torch.int64, device=args.device), }, output_dir / "model.onnx", verbose=True, - input_names=["attention_mask", "input_ids"], + input_names=["input_ids", "attention_mask"], output_names=["logits"], dynamic_axes={ - "attention_mask": {0: "batch", 1: "sequence"}, "input_ids": {0: "batch", 1: "sequence"}, + "attention_mask": {0: "batch", 1: "sequence"}, "logits": {0: "batch", 1: "sequence"}, }, ) @@ -117,3 +118,17 @@ class Args: onnx_model = onnx.load(output_dir / "model.onnx") print(onnx.checker.check_model(onnx_model, full_check=True)) + + if args.upload_to_hub: + api = HfApi() + + api.upload_file( + path_or_fileobj=output_dir / "model_optimized.onnx", + path_in_repo="model_optimized.onnx", + repo_id=args.model_name_or_path, + ) + api.upload_file( + path_or_fileobj=output_dir / "model.onnx", + path_in_repo="model.onnx", + repo_id=args.model_name_or_path, + ) \ No newline at end of file diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index a3c93880..7ecd2c2f 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -45,8 +45,8 @@ def __call__(self, input_ids, attention_mask): logits = self.ort_session.run( ["logits"], { - self.ort_session.get_inputs()[0].name: input_ids.astype(np.int64), - self.ort_session.get_inputs()[1].name: attention_mask.astype(np.float16), + "attention_mask": attention_mask.astype(np.float16), + "input_ids": input_ids.astype(np.int64), }, )[0]