From b549ab6999ad4f1aae6bdf665e2d6eb0c298fed6 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 24 Sep 2024 13:39:41 +0100 Subject: [PATCH] fix onnx input ordering, add onnxruntime-gpu timings --- README.md | 18 ++++++++++-------- scripts/export_all_to_onnx.sh | 23 +++++++++++++++++++++++ scripts/export_to_onnx_sat.py | 23 +++++++++++++++++++---- wtpsplit/extract.py | 4 ++-- 4 files changed, 54 insertions(+), 14 deletions(-) create mode 100644 scripts/export_all_to_onnx.sh diff --git a/README.md b/README.md index 914c56a5..76f63adc 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ sat_adapted.split("This is a test This is another test.") 🚀 You can now enable even faster ONNX inference for `sat` and `sat-sm` models! 🚀 ```python -sat = SaT("sat-3l-sm", onnx_providers=["CUDAExecutionProvider"]) +sat = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) ``` ```python @@ -58,15 +58,17 @@ sat = SaT("sat-3l-sm", onnx_providers=["CUDAExecutionProvider"]) >>> texts = ["This is a sentence. This is another sentence."] * 1000 # PyTorch GPU ->>> model = SaT("sat-3l-sm") ->>> model.half().to("cuda") ->>> list(model.split(texts)) +>>> model_pytorch = SaT("sat-3l-sm") +>>> model_pytorch.half().to("cuda"); +>>> %timeit list(model_pytorch.split(texts)) +# 144 ms ± 252 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) # quite fast already, but... # onnxruntime GPU ->>> model = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider"]) ->>> %timeit list(model.split(texts)) -# ...this should be ~50% faster! +>>> model_ort = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) +>>> %timeit list(model_ort.split(texts)) +# 94.9 ms ± 165 μs per loop (mean ± std. dev. of 7 runs, 10 loops each +# ...this should be ~50% faster! (tested on RTX 3090) ``` If you wish to use LoRA in combination with an ONNX model: @@ -74,7 +76,7 @@ If you wish to use LoRA in combination with an ONNX model: - If you have a local LoRA module, use `lora_path`. - If you wish to load a LoRA module from the HuggingFace hub, use `style_or_domain` and `language`. - Load the ONNX model with merged LoRA weights: -`sat = SaT(, onnx_providers=["CUDAExecutionProvider"])` +`sat = SaT(, onnx_providers=["CUDAExecutionProvider", "CPUExecutionProvider"])` ## Available Models diff --git a/scripts/export_all_to_onnx.sh b/scripts/export_all_to_onnx.sh new file mode 100644 index 00000000..7c21a00f --- /dev/null +++ b/scripts/export_all_to_onnx.sh @@ -0,0 +1,23 @@ +# all models in manually defined array of models +models=( + "sat-1l-sm" + "sat-3l-sm" + "sat-6l-sm" + "sat-9l-sm" + "sat-12l-sm" + "sat-1l" + "sat-3l" + "sat-6l" + "sat-9l" + "sat-12l" + "sat-1l-no-limited-lookahead" + "sat-3l-no-limited-lookahead" + "sat-6l-no-limited-lookahead" + "sat-9l-no-limited-lookahead" + "sat-12l-no-limited-lookahead" +) + +for model in "${models[@]}" +do + python scripts/export_to_onnx_sat.py --model_name_or_path=segment-any-text/$model --output_dir=output_onnx_exports/$model --upload_to_hub=True +done \ No newline at end of file diff --git a/scripts/export_to_onnx_sat.py b/scripts/export_to_onnx_sat.py index 7727304a..c6fadf24 100644 --- a/scripts/export_to_onnx_sat.py +++ b/scripts/export_to_onnx_sat.py @@ -6,7 +6,7 @@ import torch from adapters.models import MODEL_MIXIN_MAPPING # noqa from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, HfApi from onnxruntime.transformers.optimizer import optimize_model # noqa from transformers import AutoModelForTokenClassification, HfArgumentParser @@ -27,6 +27,7 @@ class Args: # otherwise, fetch from HF Hub: style_or_domain: str = "ud" language: str = "en" + upload_to_hub: bool = False if __name__ == "__main__": @@ -88,16 +89,16 @@ class Args: torch.onnx.export( model, { + "input_ids": torch.randint(0, model.config.vocab_size, (1, 1), dtype=torch.int64, device=args.device), "attention_mask": torch.randn((1, 1), dtype=torch.float16, device=args.device), - "input_ids": torch.randint(0, 250002, (1, 1), dtype=torch.int64, device=args.device), }, output_dir / "model.onnx", verbose=True, - input_names=["attention_mask", "input_ids"], + input_names=["input_ids", "attention_mask"], output_names=["logits"], dynamic_axes={ - "attention_mask": {0: "batch", 1: "sequence"}, "input_ids": {0: "batch", 1: "sequence"}, + "attention_mask": {0: "batch", 1: "sequence"}, "logits": {0: "batch", 1: "sequence"}, }, ) @@ -117,3 +118,17 @@ class Args: onnx_model = onnx.load(output_dir / "model.onnx") print(onnx.checker.check_model(onnx_model, full_check=True)) + + if args.upload_to_hub: + api = HfApi() + + api.upload_file( + path_or_fileobj=output_dir / "model_optimized.onnx", + path_in_repo="model_optimized.onnx", + repo_id=args.model_name_or_path, + ) + api.upload_file( + path_or_fileobj=output_dir / "model.onnx", + path_in_repo="model.onnx", + repo_id=args.model_name_or_path, + ) \ No newline at end of file diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index a3c93880..7ecd2c2f 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -45,8 +45,8 @@ def __call__(self, input_ids, attention_mask): logits = self.ort_session.run( ["logits"], { - self.ort_session.get_inputs()[0].name: input_ids.astype(np.int64), - self.ort_session.get_inputs()[1].name: attention_mask.astype(np.float16), + "attention_mask": attention_mask.astype(np.float16), + "input_ids": input_ids.astype(np.int64), }, )[0]