Skip to content

Commit

Permalink
fix onnx input ordering, add onnxruntime-gpu timings
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Sep 24, 2024
1 parent b9198d0 commit b549ab6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 14 deletions.
18 changes: 10 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,31 +50,33 @@ sat_adapted.split("This is a test This is another test.")
🚀 You can now enable even faster ONNX inference for `sat` and `sat-sm` models! 🚀

```python
sat = SaT("sat-3l-sm", onnx_providers=["CUDAExecutionProvider"])
sat = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
```

```python
>>> from wtpsplit import SaT
>>> texts = ["This is a sentence. This is another sentence."] * 1000

# PyTorch GPU
>>> model = SaT("sat-3l-sm")
>>> model.half().to("cuda")
>>> list(model.split(texts))
>>> model_pytorch = SaT("sat-3l-sm")
>>> model_pytorch.half().to("cuda");
>>> %timeit list(model_pytorch.split(texts))
# 144 ms ± 252 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# quite fast already, but...

# onnxruntime GPU
>>> model = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider"])
>>> %timeit list(model.split(texts))
# ...this should be ~50% faster!
>>> model_ort = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
>>> %timeit list(model_ort.split(texts))
# 94.9 ms ± 165 μs per loop (mean ± std. dev. of 7 runs, 10 loops each
# ...this should be ~50% faster! (tested on RTX 3090)
```

If you wish to use LoRA in combination with an ONNX model:
- Run `scripts/export_to_onnx_sat.py` with `use_lora: True` and an appropriate `output_dir: <OUTPUT_DIR>`.
- If you have a local LoRA module, use `lora_path`.
- If you wish to load a LoRA module from the HuggingFace hub, use `style_or_domain` and `language`.
- Load the ONNX model with merged LoRA weights:
`sat = SaT(<OUTPUT_DIR>, onnx_providers=["CUDAExecutionProvider"])`
`sat = SaT(<OUTPUT_DIR>, onnx_providers=["CUDAExecutionProvider", "CPUExecutionProvider"])`


## Available Models
Expand Down
23 changes: 23 additions & 0 deletions scripts/export_all_to_onnx.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# all models in manually defined array of models
models=(
"sat-1l-sm"
"sat-3l-sm"
"sat-6l-sm"
"sat-9l-sm"
"sat-12l-sm"
"sat-1l"
"sat-3l"
"sat-6l"
"sat-9l"
"sat-12l"
"sat-1l-no-limited-lookahead"
"sat-3l-no-limited-lookahead"
"sat-6l-no-limited-lookahead"
"sat-9l-no-limited-lookahead"
"sat-12l-no-limited-lookahead"
)

for model in "${models[@]}"
do
python scripts/export_to_onnx_sat.py --model_name_or_path=segment-any-text/$model --output_dir=output_onnx_exports/$model --upload_to_hub=True
done
23 changes: 19 additions & 4 deletions scripts/export_to_onnx_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from adapters.models import MODEL_MIXIN_MAPPING # noqa
from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, HfApi
from onnxruntime.transformers.optimizer import optimize_model # noqa
from transformers import AutoModelForTokenClassification, HfArgumentParser

Expand All @@ -27,6 +27,7 @@ class Args:
# otherwise, fetch from HF Hub:
style_or_domain: str = "ud"
language: str = "en"
upload_to_hub: bool = False


if __name__ == "__main__":
Expand Down Expand Up @@ -88,16 +89,16 @@ class Args:
torch.onnx.export(
model,
{
"input_ids": torch.randint(0, model.config.vocab_size, (1, 1), dtype=torch.int64, device=args.device),
"attention_mask": torch.randn((1, 1), dtype=torch.float16, device=args.device),
"input_ids": torch.randint(0, 250002, (1, 1), dtype=torch.int64, device=args.device),
},
output_dir / "model.onnx",
verbose=True,
input_names=["attention_mask", "input_ids"],
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes={
"attention_mask": {0: "batch", 1: "sequence"},
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"logits": {0: "batch", 1: "sequence"},
},
)
Expand All @@ -117,3 +118,17 @@ class Args:

onnx_model = onnx.load(output_dir / "model.onnx")
print(onnx.checker.check_model(onnx_model, full_check=True))

if args.upload_to_hub:
api = HfApi()

api.upload_file(
path_or_fileobj=output_dir / "model_optimized.onnx",
path_in_repo="model_optimized.onnx",
repo_id=args.model_name_or_path,
)
api.upload_file(
path_or_fileobj=output_dir / "model.onnx",
path_in_repo="model.onnx",
repo_id=args.model_name_or_path,
)
4 changes: 2 additions & 2 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __call__(self, input_ids, attention_mask):
logits = self.ort_session.run(
["logits"],
{
self.ort_session.get_inputs()[0].name: input_ids.astype(np.int64),
self.ort_session.get_inputs()[1].name: attention_mask.astype(np.float16),
"attention_mask": attention_mask.astype(np.float16),
"input_ids": input_ids.astype(np.int64),
},
)[0]

Expand Down

0 comments on commit b549ab6

Please sign in to comment.