From 9b0c8c9d030bec8a84625ca85ab2869b451ea6c7 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 31 Oct 2023 14:23:21 -0700 Subject: [PATCH] [api] Refactor PublisherBytesSupplier.java (#2831) --- .../streaming/PublisherBytesSupplier.java | 96 ++++++------------- .../streaming/PublisherBytesSupplierTest.java | 36 ++++--- 2 files changed, 48 insertions(+), 84 deletions(-) diff --git a/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java b/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java index d83c4678f33..d5fdfda878b 100644 --- a/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java +++ b/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java @@ -14,13 +14,10 @@ import ai.djl.ndarray.BytesSupplier; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; /** @@ -29,16 +26,14 @@ */ public class PublisherBytesSupplier implements BytesSupplier { - private final List allData; - private final AtomicBoolean completed; private Consumer subscriber; - private final AtomicInteger dataPushed; + private CountDownLatch latch; + private CompletableFuture future; /** Constructs a {@link PublisherBytesSupplier}. */ public PublisherBytesSupplier() { - allData = new ArrayList<>(); - completed = new AtomicBoolean(); - dataPushed = new AtomicInteger(); + latch = new CountDownLatch(1); + future = new CompletableFuture<>(); } /** @@ -48,13 +43,24 @@ public PublisherBytesSupplier() { * @param lastChunk true if this is the last chunk */ public void appendContent(byte[] data, boolean lastChunk) { - synchronized (allData) { - allData.add(data); + if (subscriber == null) { + try { + if (!latch.await(2, TimeUnit.MINUTES)) { + throw new IllegalStateException("Wait for subscriber timeout."); + } + if (subscriber == null) { + // workaround Spotbugs + throw new IllegalStateException("subscriber is not set."); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Append content interrupted.", e); + } } + subscriber.accept(data); if (lastChunk) { - completed.set(true); + subscriber.accept(null); + future.complete(null); } - pushData(); } /** @@ -62,69 +68,21 @@ public void appendContent(byte[] data, boolean lastChunk) { * * @param subscriber a consumer function that will receive bytes when new daata is added and * null when completed + * @return a {@code CompletableFuture} object */ - public void subscribe(Consumer subscriber) { + public CompletableFuture subscribe(Consumer subscriber) { if (this.subscriber != null) { throw new IllegalStateException( "The PublisherBytesSupplier only allows a single Subscriber"); } this.subscriber = subscriber; - pushData(); - } - - private void pushData() { - if (subscriber == null) { - return; - } - - int dataAvailable; - synchronized (allData) { - dataAvailable = allData.size(); - } - - int sent = dataPushed.getAndSet(dataAvailable); - if (sent < dataAvailable) { - synchronized (this) { - for (; sent < dataAvailable; sent++) { - subscriber.accept(allData.get(sent)); - } - if (completed.get()) { - subscriber.accept(null); - } - } - } - } - - /** Waits until completed before passing thread (BLOCKS THREAD!). */ - @SuppressWarnings("PMD.EmptyControlStatement") - public void waitToRead() { - // Block until complete!!! - while (!completed.get()) { - // Do nothing - } - } - - /** {@inheritDoc} */ - @Override - public byte[] getAsBytes() { - if (!completed.get()) { - throw new IllegalStateException( - "PublisherByteSupplier must be completely filled before reading."); - } - - try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { - for (byte[] data : allData) { - bos.write(data); - } - return bos.toByteArray(); - } catch (IOException e) { - throw new AssertionError("Failed to read BytesSupplier", e); - } + latch.countDown(); + return future; } /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(getAsBytes()); + throw new UnsupportedOperationException("Not supported."); } } diff --git a/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java index 8c140688124..a8b2bdfab62 100644 --- a/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java +++ b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java @@ -15,32 +15,38 @@ import org.testng.Assert; import org.testng.annotations.Test; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; public class PublisherBytesSupplierTest { @Test - public void test() { + public void test() throws ExecutionException, InterruptedException { AtomicInteger contentCount = new AtomicInteger(); PublisherBytesSupplier supplier = new PublisherBytesSupplier(); - // Add to supplier without subscriber - supplier.appendContent(new byte[] {1}, false); - Assert.assertEquals(contentCount.get(), 0); + new Thread( + () -> { + // Add to supplier without subscriber + supplier.appendContent(new byte[] {1}, false); + // Add to supplier with subscriber + supplier.appendContent(new byte[] {1}, true); + }) + .start(); // Subscribing with data should trigger subscriptions - supplier.subscribe( - d -> { - if (d == null) { - // Do nothing on completion - return; - } - contentCount.getAndIncrement(); - }); - Assert.assertEquals(contentCount.get(), 1); + CompletableFuture future = + supplier.subscribe( + d -> { + if (d == null) { + // Do nothing on completion + return; + } + contentCount.getAndIncrement(); + }); - // Add to supplier with subscriber - supplier.appendContent(new byte[] {1}, true); + future.get(); Assert.assertEquals(contentCount.get(), 2); } }