diff --git a/client/src/main/java/org/apache/rocketmq/client/impl/producer/DefaultMQProducerImpl.java b/client/src/main/java/org/apache/rocketmq/client/impl/producer/DefaultMQProducerImpl.java index bbbb17b07a5..2d6b83ac2c6 100644 --- a/client/src/main/java/org/apache/rocketmq/client/impl/producer/DefaultMQProducerImpl.java +++ b/client/src/main/java/org/apache/rocketmq/client/impl/producer/DefaultMQProducerImpl.java @@ -547,6 +547,8 @@ public void send(Message msg, @Deprecated public void send(final Message msg, final SendCallback sendCallback, final long timeout) throws MQClientException, RemotingException, InterruptedException { + BackpressureSendCallBack newCallBack = new BackpressureSendCallBack(sendCallback); + final long beginStartTime = System.currentTimeMillis(); Runnable runnable = new Runnable() { @Override @@ -554,20 +556,53 @@ public void run() { long costTime = System.currentTimeMillis() - beginStartTime; if (timeout > costTime) { try { - sendDefaultImpl(msg, CommunicationMode.ASYNC, sendCallback, timeout - costTime); + sendDefaultImpl(msg, CommunicationMode.ASYNC, newCallBack, timeout - costTime); } catch (Exception e) { - sendCallback.onException(e); + newCallBack.onException(e); } } else { - sendCallback.onException( + newCallBack.onException( new RemotingTooMuchRequestException("DEFAULT ASYNC send call timeout")); } } }; - executeAsyncMessageSend(runnable, msg, sendCallback, timeout, beginStartTime); + executeAsyncMessageSend(runnable, msg, newCallBack, timeout, beginStartTime); } - public void executeAsyncMessageSend(Runnable runnable, final Message msg, final SendCallback sendCallback, + class BackpressureSendCallBack implements SendCallback { + public boolean isSemaphoreAsyncSizeAquired = false; + public boolean isSemaphoreAsyncNumAquired = false; + public int msgLen; + private final SendCallback sendCallback; + + public BackpressureSendCallBack(final SendCallback sendCallback) { + this.sendCallback = sendCallback; + } + + @Override + public void onSuccess(SendResult sendResult) { + if (isSemaphoreAsyncSizeAquired) { + semaphoreAsyncSendSize.release(msgLen); + } + if (isSemaphoreAsyncNumAquired) { + semaphoreAsyncSendNum.release(); + } + sendCallback.onSuccess(sendResult); + } + + @Override + public void onException(Throwable e) { + if (isSemaphoreAsyncSizeAquired) { + semaphoreAsyncSendSize.release(msgLen); + } + if (isSemaphoreAsyncNumAquired) { + semaphoreAsyncSendNum.release(); + } + sendCallback.onException(e); + } + } + + public void executeAsyncMessageSend(Runnable runnable, final Message msg, final BackpressureSendCallBack sendCallback, final long timeout, final long beginStartTime) throws MQClientException, InterruptedException { ExecutorService executor = this.getAsyncSenderExecutor(); @@ -595,7 +630,9 @@ public void executeAsyncMessageSend(Runnable runnable, final Message msg, final return; } } - + sendCallback.isSemaphoreAsyncSizeAquired = isSemaphoreAsyncSizeAquired; + sendCallback.isSemaphoreAsyncNumAquired = isSemaphoreAsyncNumAquired; + sendCallback.msgLen = msgLen; executor.submit(runnable); } catch (RejectedExecutionException e) { if (isEnableBackpressureForAsyncMode) { @@ -603,15 +640,7 @@ public void executeAsyncMessageSend(Runnable runnable, final Message msg, final } else { throw new MQClientException("executor rejected ", e); } - } finally { - if (isSemaphoreAsyncSizeAquired) { - semaphoreAsyncSendSize.release(msgLen); - } - if (isSemaphoreAsyncNumAquired) { - semaphoreAsyncSendNum.release(); - } } - } public MessageQueue invokeMessageQueueSelector(Message msg, MessageQueueSelector selector, Object arg, @@ -1188,7 +1217,7 @@ public void send(Message msg, MessageQueue mq, SendCallback sendCallback) @Deprecated public void send(final Message msg, final MessageQueue mq, final SendCallback sendCallback, final long timeout) throws MQClientException, RemotingException, InterruptedException { - + BackpressureSendCallBack newCallBack = new BackpressureSendCallBack(sendCallback); final long beginStartTime = System.currentTimeMillis(); Runnable runnable = new Runnable() { @Override @@ -1203,22 +1232,22 @@ public void run() { long costTime = System.currentTimeMillis() - beginStartTime; if (timeout > costTime) { try { - sendKernelImpl(msg, mq, CommunicationMode.ASYNC, sendCallback, null, + sendKernelImpl(msg, mq, CommunicationMode.ASYNC, newCallBack, null, timeout - costTime); } catch (MQBrokerException e) { throw new MQClientException("unknown exception", e); } } else { - sendCallback.onException(new RemotingTooMuchRequestException("call timeout")); + newCallBack.onException(new RemotingTooMuchRequestException("call timeout")); } } catch (Exception e) { - sendCallback.onException(e); + newCallBack.onException(e); } } }; - executeAsyncMessageSend(runnable, msg, sendCallback, timeout, beginStartTime); + executeAsyncMessageSend(runnable, msg, newCallBack, timeout, beginStartTime); } /** @@ -1315,7 +1344,7 @@ public void send(Message msg, MessageQueueSelector selector, Object arg, SendCal public void send(final Message msg, final MessageQueueSelector selector, final Object arg, final SendCallback sendCallback, final long timeout) throws MQClientException, RemotingException, InterruptedException { - + BackpressureSendCallBack newCallBack = new BackpressureSendCallBack(sendCallback); final long beginStartTime = System.currentTimeMillis(); Runnable runnable = new Runnable() { @Override @@ -1324,21 +1353,21 @@ public void run() { if (timeout > costTime) { try { try { - sendSelectImpl(msg, selector, arg, CommunicationMode.ASYNC, sendCallback, + sendSelectImpl(msg, selector, arg, CommunicationMode.ASYNC, newCallBack, timeout - costTime); } catch (MQBrokerException e) { throw new MQClientException("unknown exception", e); } } catch (Exception e) { - sendCallback.onException(e); + newCallBack.onException(e); } } else { - sendCallback.onException(new RemotingTooMuchRequestException("call timeout")); + newCallBack.onException(new RemotingTooMuchRequestException("call timeout")); } } }; - executeAsyncMessageSend(runnable, msg, sendCallback, timeout, beginStartTime); + executeAsyncMessageSend(runnable, msg, newCallBack, timeout, beginStartTime); } /**