diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyClientConfig.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyClientConfig.java index b2e7df75491c..c28288786a3b 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyClientConfig.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyClientConfig.java @@ -53,6 +53,12 @@ public class NettyClientConfig { private boolean disableCallbackExecutor = false; private boolean disableNettyWorkerGroup = false; + private long maxReconnectIntervalTimeSeconds = 60; + + private boolean enableReconnectForGoAway = true; + + private boolean enableTransparentRetry = true; + public boolean isClientCloseSocketIfTimeout() { return clientCloseSocketIfTimeout; } @@ -181,6 +187,30 @@ public void setDisableNettyWorkerGroup(boolean disableNettyWorkerGroup) { this.disableNettyWorkerGroup = disableNettyWorkerGroup; } + public long getMaxReconnectIntervalTimeSeconds() { + return maxReconnectIntervalTimeSeconds; + } + + public void setMaxReconnectIntervalTimeSeconds(long maxReconnectIntervalTimeSeconds) { + this.maxReconnectIntervalTimeSeconds = maxReconnectIntervalTimeSeconds; + } + + public boolean isEnableReconnectForGoAway() { + return enableReconnectForGoAway; + } + + public void setEnableReconnectForGoAway(boolean enableReconnectForGoAway) { + this.enableReconnectForGoAway = enableReconnectForGoAway; + } + + public boolean isEnableTransparentRetry() { + return enableTransparentRetry; + } + + public void setEnableTransparentRetry(boolean enableTransparentRetry) { + this.enableTransparentRetry = enableTransparentRetry; + } + public String getSocksProxyConfig() { return socksProxyConfig; } diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java index 12e66f913cf3..07ace28ea543 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingAbstract.java @@ -40,9 +40,11 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import javax.annotation.Nullable; import org.apache.rocketmq.common.AbortProcessException; +import org.apache.rocketmq.common.MQVersion; import org.apache.rocketmq.common.Pair; import org.apache.rocketmq.common.ServiceThread; import org.apache.rocketmq.common.UtilAll; @@ -60,6 +62,7 @@ import org.apache.rocketmq.remoting.metrics.RemotingMetricsManager; import org.apache.rocketmq.remoting.protocol.RemotingCommand; import org.apache.rocketmq.remoting.protocol.RemotingSysResponseCode; +import org.apache.rocketmq.remoting.protocol.ResponseCode; import static org.apache.rocketmq.remoting.metrics.RemotingMetricsConstant.LABEL_IS_LONG_POLLING; import static org.apache.rocketmq.remoting.metrics.RemotingMetricsConstant.LABEL_REQUEST_CODE; @@ -120,6 +123,8 @@ public abstract class NettyRemotingAbstract { */ protected List rpcHooks = new ArrayList<>(); + protected AtomicBoolean isShuttingDown = new AtomicBoolean(false); + static { NettyLogger.initNettyLogger(); } @@ -264,6 +269,16 @@ public void processRequestCommand(final ChannelHandlerContext ctx, final Remotin Runnable run = buildProcessRequestHandler(ctx, cmd, pair, opaque); + if (isShuttingDown.get()) { + if (cmd.getVersion() > MQVersion.Version.V5_1_4.ordinal()) { + final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.GO_AWAY, + "please go away"); + response.setOpaque(opaque); + writeResponse(ctx.channel(), cmd, response); + return; + } + } + if (pair.getObject1().rejectRequest()) { final RemotingCommand response = RemotingCommand.createResponseCommand(RemotingSysResponseCode.SYSTEM_BUSY, "[REJECTREQUEST]system busy, start flow control for a while"); diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingClient.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingClient.java index 8631d0447d80..4bc51bd833ab 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingClient.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingClient.java @@ -18,6 +18,7 @@ import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.TypeReference; +import com.google.common.base.Stopwatch; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; @@ -48,6 +49,7 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.cert.CertificateException; +import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -57,6 +59,7 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; @@ -66,6 +69,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; import org.apache.commons.lang3.StringUtils; import org.apache.rocketmq.common.Pair; import org.apache.rocketmq.common.ThreadFactoryImpl; @@ -82,6 +86,7 @@ import org.apache.rocketmq.remoting.exception.RemotingTimeoutException; import org.apache.rocketmq.remoting.exception.RemotingTooMuchRequestException; import org.apache.rocketmq.remoting.protocol.RemotingCommand; +import org.apache.rocketmq.remoting.protocol.ResponseCode; import org.apache.rocketmq.remoting.proxy.SocksProxyConfig; public class NettyRemotingClient extends NettyRemotingAbstract implements RemotingClient { @@ -97,6 +102,7 @@ public class NettyRemotingClient extends NettyRemotingAbstract implements Remoti private final Map proxyMap = new HashMap<>(); private final ConcurrentHashMap bootstrapMap = new ConcurrentHashMap<>(); private final ConcurrentMap channelTables = new ConcurrentHashMap<>(); + private final ConcurrentMap channelWrapperTables = new ConcurrentHashMap<>(); private final HashedWheelTimer timer = new HashedWheelTimer(r -> new Thread(r, "ClientHouseKeepingService")); @@ -356,9 +362,10 @@ public void shutdown() { this.timer.stop(); for (String addr : this.channelTables.keySet()) { - this.closeChannel(addr, this.channelTables.get(addr).getChannel()); + this.channelTables.get(addr).close(); } + this.channelWrapperTables.clear(); this.channelTables.clear(); this.eventLoopGroupWorker.shutdownGracefully(); @@ -416,7 +423,10 @@ public void closeChannel(final String addr, final Channel channel) { } if (removeItemFromTable) { - this.channelTables.remove(addrRemote); + ChannelWrapper channelWrapper = this.channelWrapperTables.remove(channel); + if (channelWrapper != null && channelWrapper.tryClose(channel)) { + this.channelTables.remove(addrRemote); + } LOGGER.info("closeChannel: the channel[{}] was removed from channel table", addrRemote); } @@ -463,7 +473,10 @@ public void closeChannel(final Channel channel) { } if (removeItemFromTable) { - this.channelTables.remove(addrRemote); + ChannelWrapper channelWrapper = this.channelWrapperTables.remove(channel); + if (channelWrapper != null && channelWrapper.tryClose(channel)) { + this.channelTables.remove(addrRemote); + } LOGGER.info("closeChannel: the channel[{}] was removed from channel table", addrRemote); RemotingHelper.closeChannel(channel); } @@ -511,7 +524,7 @@ public void updateNameServerAddressList(List addrs) { if (addr.contains(namesrvAddr)) { ChannelWrapper channelWrapper = this.channelTables.get(addr); if (channelWrapper != null) { - closeChannel(channelWrapper.getChannel()); + channelWrapper.close(); } } } @@ -689,8 +702,9 @@ private Channel createChannel(final String addr) throws InterruptedException { ChannelFuture channelFuture = fetchBootstrap(addr) .connect(hostAndPort[0], Integer.parseInt(hostAndPort[1])); LOGGER.info("createChannel: begin to connect remote host[{}] asynchronously", addr); - cw = new ChannelWrapper(channelFuture); + cw = new ChannelWrapper(addr, channelFuture); this.channelTables.put(addr, cw); + this.channelWrapperTables.put(channelFuture.channel(), cw); } } catch (Exception e) { LOGGER.error("createChannel: create channel exception", e); @@ -758,6 +772,64 @@ public void invokeOneway(String addr, RemotingCommand request, long timeoutMilli } } + @Override + public CompletableFuture invoke(String addr, RemotingCommand request, + long timeoutMillis) { + CompletableFuture future = new CompletableFuture<>(); + try { + final Channel channel = this.getAndCreateChannel(addr); + if (channel != null && channel.isActive()) { + return invokeImpl(channel, request, timeoutMillis).whenComplete((v, t) -> { + if (t == null) { + updateChannelLastResponseTime(addr); + } + }).thenApply(ResponseFuture::getResponseCommand); + } else { + this.closeChannel(addr, channel); + future.completeExceptionally(new RemotingConnectException(addr)); + } + } catch (Throwable t) { + future.completeExceptionally(t); + } + return future; + } + + @Override + public CompletableFuture invokeImpl(final Channel channel, final RemotingCommand request, + final long timeoutMillis) { + Stopwatch stopwatch = Stopwatch.createStarted(); + return super.invokeImpl(channel, request, timeoutMillis).thenCompose(responseFuture -> { + RemotingCommand response = responseFuture.getResponseCommand(); + if (response.getCode() == ResponseCode.GO_AWAY) { + if (nettyClientConfig.isEnableReconnectForGoAway()) { + ChannelWrapper channelWrapper = channelWrapperTables.computeIfPresent(channel, (channel0, channelWrapper0) -> { + try { + if (channelWrapper0.reconnect()) { + LOGGER.info("Receive go away from channel {}, recreate the channel", channel0); + channelWrapperTables.put(channelWrapper0.getChannel(), channelWrapper0); + } + } catch (Throwable t) { + LOGGER.error("Channel {} reconnect error", channelWrapper0, t); + } + return channelWrapper0; + }); + if (channelWrapper != null) { + if (nettyClientConfig.isEnableTransparentRetry()) { + long duration = stopwatch.elapsed(TimeUnit.MILLISECONDS); + stopwatch.stop(); + RemotingCommand retryRequest = RemotingCommand.createRequestCommand(request.getCode(), request.readCustomHeader()); + Channel retryChannel = channelWrapper.getChannel(); + if (channel != retryChannel) { + return super.invokeImpl(retryChannel, retryRequest, timeoutMillis - duration); + } + } + } + } + } + return CompletableFuture.completedFuture(responseFuture); + }); + } + @Override public void registerProcessor(int requestCode, NettyRequestProcessor processor, ExecutorService executor) { ExecutorService executorThis = executor; @@ -877,30 +949,41 @@ public void run() { } } - static class ChannelWrapper { - private final ChannelFuture channelFuture; + class ChannelWrapper { + private final ReentrantReadWriteLock lock; + private ChannelFuture channelFuture; // only affected by sync or async request, oneway is not included. + private ChannelFuture channelToClose; private long lastResponseTime; + private volatile long lastReconnectTimestamp = 0L; + private final String channelAddress; - public ChannelWrapper(ChannelFuture channelFuture) { + public ChannelWrapper(String address, ChannelFuture channelFuture) { + this.lock = new ReentrantReadWriteLock(); this.channelFuture = channelFuture; this.lastResponseTime = System.currentTimeMillis(); + this.channelAddress = address; } public boolean isOK() { - return this.channelFuture.channel() != null && this.channelFuture.channel().isActive(); + return getChannel() != null && getChannel().isActive(); } public boolean isWritable() { - return this.channelFuture.channel().isWritable(); + return getChannel().isWritable(); } private Channel getChannel() { - return this.channelFuture.channel(); + return getChannelFuture().channel(); } public ChannelFuture getChannelFuture() { - return channelFuture; + lock.readLock().lock(); + try { + return this.channelFuture; + } finally { + lock.readLock().unlock(); + } } public long getLastResponseTime() { @@ -910,6 +993,52 @@ public long getLastResponseTime() { public void updateLastResponseTime() { this.lastResponseTime = System.currentTimeMillis(); } + + public boolean reconnect() { + if (lock.writeLock().tryLock()) { + try { + if (lastReconnectTimestamp == 0L || System.currentTimeMillis() - lastReconnectTimestamp > Duration.ofSeconds(nettyClientConfig.getMaxReconnectIntervalTimeSeconds()).toMillis()) { + channelToClose = channelFuture; + String[] hostAndPort = getHostAndPort(channelAddress); + channelFuture = fetchBootstrap(channelAddress) + .connect(hostAndPort[0], Integer.parseInt(hostAndPort[1])); + lastReconnectTimestamp = System.currentTimeMillis(); + return true; + } + } finally { + lock.writeLock().unlock(); + } + } + return false; + } + + public boolean tryClose(Channel channel) { + try { + lock.readLock().lock(); + if (channelFuture != null) { + if (channelFuture.channel().equals(channel)) { + return true; + } + } + } finally { + lock.readLock().unlock(); + } + return false; + } + + public void close() { + try { + lock.writeLock().lock(); + if (channelFuture != null) { + closeChannel(channelFuture.channel()); + } + if (channelToClose != null) { + closeChannel(channelToClose.channel()); + } + } finally { + lock.writeLock().unlock(); + } + } } class InvokeCallbackWrapper implements InvokeCallback { diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java index aa0d46542bee..735d36168f4d 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java @@ -53,6 +53,19 @@ import io.netty.util.Timeout; import io.netty.util.TimerTask; import io.netty.util.concurrent.DefaultEventExecutorGroup; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.time.Duration; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.rocketmq.common.Pair; @@ -74,19 +87,6 @@ import org.apache.rocketmq.remoting.exception.RemotingTooMuchRequestException; import org.apache.rocketmq.remoting.protocol.RemotingCommand; -import java.io.IOException; -import java.net.InetSocketAddress; -import java.security.cert.CertificateException; -import java.util.List; -import java.util.NoSuchElementException; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; - @SuppressWarnings("NullableProblems") public class NettyRemotingServer extends NettyRemotingAbstract implements RemotingServer { private static final Logger log = LoggerFactory.getLogger(LoggerName.ROCKETMQ_REMOTING_NAME); @@ -305,6 +305,10 @@ private void addCustomConfig(ServerBootstrap childHandler) { @Override public void shutdown() { try { + if (nettyServerConfig.isEnableShutdownGracefully() && isShuttingDown.compareAndSet(false, true)) { + Thread.sleep(Duration.ofSeconds(nettyServerConfig.getShutdownWaitTimeSeconds()).toMillis()); + } + this.timer.stop(); this.eventLoopGroupBoss.shutdownGracefully(); @@ -736,6 +740,7 @@ public void start() { @Override public void shutdown() { + isShuttingDown.set(true); if (this.serverChannel != null) { try { this.serverChannel.close().await(5, TimeUnit.SECONDS); diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyServerConfig.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyServerConfig.java index 59ef2c84f159..756661f623f4 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyServerConfig.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyServerConfig.java @@ -38,6 +38,9 @@ public class NettyServerConfig implements Cloneable { private int serverSocketBacklog = NettySystemConfig.socketBacklog; private boolean serverPooledByteBufAllocatorEnable = true; + private boolean enableShutdownGracefully = false; + private int shutdownWaitTimeSeconds = 30; + /** * make install * @@ -171,4 +174,20 @@ public int getWriteBufferHighWaterMark() { public void setWriteBufferHighWaterMark(int writeBufferHighWaterMark) { this.writeBufferHighWaterMark = writeBufferHighWaterMark; } + + public boolean isEnableShutdownGracefully() { + return enableShutdownGracefully; + } + + public void setEnableShutdownGracefully(boolean enableShutdownGracefully) { + this.enableShutdownGracefully = enableShutdownGracefully; + } + + public int getShutdownWaitTimeSeconds() { + return shutdownWaitTimeSeconds; + } + + public void setShutdownWaitTimeSeconds(int shutdownWaitTimeSeconds) { + this.shutdownWaitTimeSeconds = shutdownWaitTimeSeconds; + } } diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/protocol/ResponseCode.java b/remoting/src/main/java/org/apache/rocketmq/remoting/protocol/ResponseCode.java index e81dadf2e12e..be945c48fda4 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/protocol/ResponseCode.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/protocol/ResponseCode.java @@ -99,6 +99,8 @@ public class ResponseCode extends RemotingSysResponseCode { public static final int RPC_SEND_TO_CHANNEL_FAILED = -1004; public static final int RPC_TIME_OUT = -1006; + public static final int GO_AWAY = 1500; + /** * Controller response code */ diff --git a/remoting/src/test/java/org/apache/rocketmq/remoting/netty/NettyRemotingClientTest.java b/remoting/src/test/java/org/apache/rocketmq/remoting/netty/NettyRemotingClientTest.java index e72e7bd53ebd..1cc6b4f46876 100644 --- a/remoting/src/test/java/org/apache/rocketmq/remoting/netty/NettyRemotingClientTest.java +++ b/remoting/src/test/java/org/apache/rocketmq/remoting/netty/NettyRemotingClientTest.java @@ -47,7 +47,6 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -74,13 +73,11 @@ public void testInvokeResponse() throws Exception { RemotingCommand response = RemotingCommand.createResponseCommand(null); response.setCode(ResponseCode.SUCCESS); - doAnswer(invocation -> { - InvokeCallback callback = invocation.getArgument(3); - ResponseFuture responseFuture = new ResponseFuture(null, request.getOpaque(), 3 * 1000, null, null); - responseFuture.setResponseCommand(response); - callback.operationSucceed(responseFuture.getResponseCommand()); - return null; - }).when(remotingClient).invokeAsync(anyString(), any(RemotingCommand.class), anyLong(), any(InvokeCallback.class)); + ResponseFuture responseFuture = new ResponseFuture(null, request.getOpaque(), 3 * 1000, null, null); + responseFuture.setResponseCommand(response); + CompletableFuture future0 = new CompletableFuture<>(); + future0.complete(responseFuture.getResponseCommand()); + doReturn(future0).when(remotingClient).invoke(anyString(), any(RemotingCommand.class), anyLong()); CompletableFuture future = remotingClient.invoke("0.0.0.0", request, 1000); RemotingCommand actual = future.get(); @@ -93,11 +90,9 @@ public void testRemotingSendRequestException() throws Exception { RemotingCommand response = RemotingCommand.createResponseCommand(null); response.setCode(ResponseCode.SUCCESS); - doAnswer(invocation -> { - InvokeCallback callback = invocation.getArgument(3); - callback.operationFail(new RemotingSendRequestException(null)); - return null; - }).when(remotingClient).invokeAsync(anyString(), any(RemotingCommand.class), anyLong(), any(InvokeCallback.class)); + CompletableFuture future0 = new CompletableFuture<>(); + future0.completeExceptionally(new RemotingSendRequestException(null)); + doReturn(future0).when(remotingClient).invoke(anyString(), any(RemotingCommand.class), anyLong()); CompletableFuture future = remotingClient.invoke("0.0.0.0", request, 1000); Throwable thrown = catchThrowable(future::get); @@ -110,11 +105,9 @@ public void testRemotingTimeoutException() throws Exception { RemotingCommand response = RemotingCommand.createResponseCommand(null); response.setCode(ResponseCode.SUCCESS); - doAnswer(invocation -> { - InvokeCallback callback = invocation.getArgument(3); - callback.operationFail(new RemotingTimeoutException("")); - return null; - }).when(remotingClient).invokeAsync(anyString(), any(RemotingCommand.class), anyLong(), any(InvokeCallback.class)); + CompletableFuture future0 = new CompletableFuture<>(); + future0.completeExceptionally(new RemotingTimeoutException("")); + doReturn(future0).when(remotingClient).invoke(anyString(), any(RemotingCommand.class), anyLong()); CompletableFuture future = remotingClient.invoke("0.0.0.0", request, 1000); Throwable thrown = catchThrowable(future::get); @@ -125,13 +118,9 @@ public void testRemotingTimeoutException() throws Exception { public void testRemotingException() throws Exception { RemotingCommand request = RemotingCommand.createRequestCommand(RequestCode.PULL_MESSAGE, null); - RemotingCommand response = RemotingCommand.createResponseCommand(null); - response.setCode(ResponseCode.SUCCESS); - doAnswer(invocation -> { - InvokeCallback callback = invocation.getArgument(3); - callback.operationFail(new RemotingException(null)); - return null; - }).when(remotingClient).invokeAsync(anyString(), any(RemotingCommand.class), anyLong(), any(InvokeCallback.class)); + CompletableFuture future0 = new CompletableFuture<>(); + future0.completeExceptionally(new RemotingException("")); + doReturn(future0).when(remotingClient).invoke(anyString(), any(RemotingCommand.class), anyLong()); CompletableFuture future = remotingClient.invoke("0.0.0.0", request, 1000); Throwable thrown = catchThrowable(future::get);