diff --git a/Cargo.lock b/Cargo.lock index ba91db04..d8640240 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,21 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-compression" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fec134f64e2bc57411226dfc4e52dec859ddfc7e711fc5e07b612584f000e4aa" +dependencies = [ + "flate2", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "zstd", + "zstd-safe", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -187,6 +202,17 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +[[package]] +name = "cc" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aeb932158bd710538c73702db6945cb68a8fb08c519e6e12706b94263b36db8" +dependencies = [ + "jobserver", + "libc", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -680,6 +706,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -771,7 +806,7 @@ dependencies = [ [[package]] name = "mosec" -version = "0.8.9" +version = "0.9.0" dependencies = [ "async-channel", "async-stream", @@ -782,6 +817,8 @@ dependencies = [ "serde", "serde_json", "tokio", + "tower", + "tower-http", "tracing", "tracing-subscriber", "utoipa", @@ -881,6 +918,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + [[package]] name = "powerfmt" version = "0.2.0" @@ -1107,6 +1150,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -1316,6 +1365,26 @@ dependencies = [ "tower-service", ] +[[package]] +name = "tower-http" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8437150ab6bbc8c5f0f519e3d5ed4aa883a83dd4cdd3d1b21f9482936046cb97" +dependencies = [ + "async-compression", + "bitflags", + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "pin-project-lite", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -1736,3 +1805,31 @@ dependencies = [ "once_cell", "simd-adler32", ] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 2a41fcd3..4979a77d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mosec" -version = "0.8.9" +version = "0.9.0" authors = ["Keming ", "Zichen "] edition = "2021" license = "Apache-2.0" @@ -25,3 +25,5 @@ serde = "1.0" serde_json = "1.0" utoipa = "5" utoipa-swagger-ui = { version = "8", features = ["axum"] } +tower = "0.5.1" +tower-http = {version = "0.6.1", features = ["compression-zstd", "decompression-zstd", "compression-gzip", "decompression-gzip"]} diff --git a/README.md b/README.md index e5a8f124..3aac16ae 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ More ready-to-use examples can be found in the [Example](https://mosecorg.github - [Customized GPU allocation](https://mosecorg.github.io/mosec/examples/env.html): deploy multiple replicas, each using different GPUs. - [Customized metrics](https://mosecorg.github.io/mosec/examples/metric.html): record your own metrics for monitoring. - [Jax jitted inference](https://mosecorg.github.io/mosec/examples/jax.html): just-in-time compilation speeds up the inference. +- [Compression](https://mosecorg.github.io/mosec/examples/compression.html): enable request/response compression. - PyTorch deep learning models: - [sentiment analysis](https://mosecorg.github.io/mosec/examples/pytorch.html#natural-language-processing): infer the sentiment of a sentence. - [image recognition](https://mosecorg.github.io/mosec/examples/pytorch.html#computer-vision): categorize a given image. diff --git a/docs/source/examples/compression.md b/docs/source/examples/compression.md new file mode 100644 index 00000000..3bfa2085 --- /dev/null +++ b/docs/source/examples/compression.md @@ -0,0 +1,33 @@ +# Compression + +This example demonstrates how to use the `--compression` feature for segmentation tasks. We use the example from the [Segment Anything Model 2](https://github.com/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb). The request includes an image and its low resolution mask, the response is the final mask. Since there are lots of duplicate values in the mask, we can use `gzip` or `zstd` to compress it. + +## Server + +```shell +python examples/segment/server.py --compression +``` + +
+segment.py + +```{include} ../../../examples/segment/server.py +:code: python +``` + +
+ +## Client + +```shell +python examples/segment/client.py +``` + +
+segment.py + +```{include} ../../../examples/segment/client.py +:code: python +``` + +
diff --git a/docs/source/examples/index.md b/docs/source/examples/index.md index 26462559..77388969 100644 --- a/docs/source/examples/index.md +++ b/docs/source/examples/index.md @@ -16,6 +16,7 @@ pytorch rerank stable_diffusion validate +compression ``` We provide examples across different ML frameworks and for various tasks in this section. diff --git a/examples/segment/client.py b/examples/segment/client.py new file mode 100644 index 00000000..4cdef905 --- /dev/null +++ b/examples/segment/client.py @@ -0,0 +1,52 @@ +# Copyright 2023 MOSEC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +from http import HTTPStatus +from io import BytesIO + +import httpx +import msgpack # type: ignore +import numbin +import numpy as np +from PIL import Image # type: ignore + +truck_image = Image.open( + BytesIO( + httpx.get( + "https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/truck.jpg" + ).content + ) +) +array = np.array(truck_image.convert("RGB")) +# assume we have obtains the low resolution mask from the previous step +mask = np.zeros((256, 256)) + +resp = httpx.post( + "http://127.0.0.1:8000/inference", + content=gzip.compress( + msgpack.packb( # type: ignore + { + "image": numbin.dumps(array), + "mask": numbin.dumps(mask), + "labels": [1, 1], + "point_coords": [[500, 375], [1125, 625]], + } + ) + ), + headers={"Accept-Encoding": "gzip", "Content-Encoding": "gzip"}, +) +assert resp.status_code == HTTPStatus.OK, resp.status_code +res = numbin.loads(msgpack.loads(resp.content)) +assert res.shape == array.shape[:2], f"expect {array.shape[:2]}, got {res.shape}" diff --git a/examples/segment/server.py b/examples/segment/server.py new file mode 100644 index 00000000..7a9f431d --- /dev/null +++ b/examples/segment/server.py @@ -0,0 +1,66 @@ +# Copyright 2023 MOSEC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# refer to https://github.com/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb + +import numbin +import torch # type: ignore +from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore + +from mosec import Server, Worker, get_logger +from mosec.mixin import MsgpackMixin + +logger = get_logger() +MIN_TF32_MAJOR = 8 + + +class SegmentAnything(MsgpackMixin, Worker): + def __init__(self): + # select the device for computation + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + logger.info("using device: %s", device) + + self.predictor = SAM2ImagePredictor.from_pretrained( + "facebook/sam2-hiera-large", device=device + ) + + if device.type == "cuda": + # use bfloat16 + torch.autocast("cuda", dtype=torch.bfloat16).__enter__() + # turn on tf32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + if torch.cuda.get_device_properties(0).major >= MIN_TF32_MAJOR: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + def forward(self, data: dict) -> bytes: + with torch.inference_mode(): + self.predictor.set_image(numbin.loads(data["image"])) + masks, _, _ = self.predictor.predict( + point_coords=data["point_coords"], + point_labels=data["labels"], + mask_input=numbin.loads(data["mask"])[None, :, :], + multimask_output=False, + ) + return numbin.dumps(masks[0]) + + +if __name__ == "__main__": + server = Server() + server.append_worker(SegmentAnything, num=1, max_batch_size=1) + server.run() diff --git a/mosec/args.py b/mosec/args.py index 10229692..afef5e08 100644 --- a/mosec/args.py +++ b/mosec/args.py @@ -134,6 +134,13 @@ def build_arguments_parser() -> argparse.ArgumentParser: "This will omit the worker number for each stage.", action="store_true", ) + + parser.add_argument( + "--compression", + help="Enable `zstd` & `gzip` compression for the request & response", + action="store_true", + ) + return parser diff --git a/pyproject.toml b/pyproject.toml index 8ec652c5..a95071a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Rust", "Topic :: Scientific/Engineering :: Artificial Intelligence", diff --git a/requirements/dev.txt b/requirements/dev.txt index e7f153d2..0aea3765 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -7,3 +7,4 @@ ruff>=0.7 pre-commit>=2.15.0 httpx[http2]==0.27.2 httpx-sse==0.4.0 +zstandard~=0.23 diff --git a/src/config.rs b/src/config.rs index 1a1cc739..f905650a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -65,6 +65,8 @@ pub(crate) struct Config { pub namespace: String, // log level: (debug, info, warning, error) pub log_level: String, + // `zstd` & `gzip` compression + pub compression: bool, pub runtimes: Vec, pub routes: Vec, } @@ -79,6 +81,7 @@ impl Default for Config { port: 8000, namespace: String::from("mosec_service"), log_level: String::from("info"), + compression: false, runtimes: vec![Runtime { max_batch_size: 64, max_wait_time: 3000, diff --git a/src/main.rs b/src/main.rs index 4221155f..4bc89dc5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![forbid(unsafe_code)] + mod apidoc; mod config; mod errors; @@ -27,6 +29,9 @@ use std::net::SocketAddr; use axum::routing::{get, post}; use axum::Router; use tokio::signal::unix::{signal, SignalKind}; +use tower::ServiceBuilder; +use tower_http::compression::CompressionLayer; +use tower_http::decompression::RequestDecompressionLayer; use tracing::{debug, info}; use tracing_subscriber::fmt::time::UtcTime; use tracing_subscriber::prelude::*; @@ -90,12 +95,20 @@ async fn run(conf: &Config) { } } + if conf.compression { + router = router.layer( + ServiceBuilder::new() + .layer(RequestDecompressionLayer::new()) + .layer(CompressionLayer::new()), + ); + } + // wait until each stage has at least one worker alive barrier.wait().await; let addr: SocketAddr = format!("{}:{}", conf.address, conf.port).parse().unwrap(); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); info!(?addr, "http service is running"); - axum::serve(listener, router.into_make_service()) + axum::serve(listener, router) .with_graceful_shutdown(shutdown_signal()) .await .unwrap(); diff --git a/src/routes.rs b/src/routes.rs index f4ef6ba9..d7d012d4 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -60,7 +60,7 @@ fn build_response(status: StatusCode, content: Bytes) -> Response { ), ), )] -pub(crate) async fn index(_: Request) -> Response { +pub(crate) async fn index() -> Response { let task_manager = TaskManager::global(); if task_manager.is_shutdown() { build_response( @@ -79,7 +79,7 @@ pub(crate) async fn index(_: Request) -> Response { (status = StatusCode::OK, description = "Get metrics", body = String), ), )] -pub(crate) async fn metrics(_: Request) -> Response { +pub(crate) async fn metrics() -> Response { let mut encoded = String::new(); let registry = REGISTRY.get().unwrap(); encode(&mut encoded, registry).unwrap(); diff --git a/tests/test_service.py b/tests/test_service.py index 9368f38f..87facf4f 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -14,6 +14,8 @@ """End-to-end service tests.""" +import gzip +import json import random import re import shlex @@ -26,6 +28,7 @@ import msgpack # type: ignore import pytest from httpx_sse import connect_sse +from zstandard import ZstdCompressor from mosec.server import GUARD_CHECK_INTERVAL from tests.utils import wait_for_port_free, wait_for_port_open @@ -349,3 +352,41 @@ def test_multi_route_service(mosec_service, http_client): assert resp.status_code == HTTPStatus.OK, resp assert resp.headers["content-type"] == "application/msgpack" assert msgpack.unpackb(resp.content) == {"length": len(data)} + + +@pytest.mark.parametrize( + "mosec_service, http_client", + [ + pytest.param("square_service --compression --debug", "", id="compression"), + ], + indirect=["mosec_service", "http_client"], +) +def test_compression_service(mosec_service, http_client): + zstd_compressor = ZstdCompressor() + req = {"x": 2} + expect = {"x": 4} + + # test without compression + resp = http_client.post("/inference", json=req) + assert resp.status_code == HTTPStatus.OK, resp + assert resp.json() == expect, resp.content + + # test with gzip compression + binary = gzip.compress(json.dumps(req).encode()) + resp = http_client.post( + "/inference", + content=binary, + headers={"Accept-Encoding": "gzip", "Content-Encoding": "gzip"}, + ) + assert resp.status_code == HTTPStatus.OK, resp + assert resp.json() == expect, resp.content + + # test with zstd compression + binary = zstd_compressor.compress(json.dumps(req).encode()) + resp = http_client.post( + "/inference", + content=binary, + headers={"Accept-Encoding": "zstd", "Content-Encoding": "zstd"}, + ) + assert resp.status_code == HTTPStatus.OK, resp + assert resp.json() == expect, resp.content