diff --git a/java/src/org/openqa/selenium/netty/server/SeleniumHandler.java b/java/src/org/openqa/selenium/netty/server/SeleniumHandler.java index a13ddc33ae68d..73dbd2a45aadf 100644 --- a/java/src/org/openqa/selenium/netty/server/SeleniumHandler.java +++ b/java/src/org/openqa/selenium/netty/server/SeleniumHandler.java @@ -19,8 +19,10 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import org.openqa.selenium.internal.Require; import org.openqa.selenium.remote.ErrorFilter; import org.openqa.selenium.remote.http.HttpHandler; @@ -31,18 +33,27 @@ class SeleniumHandler extends SimpleChannelInboundHandler { private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(); private final HttpHandler seleniumHandler; + private Future lastOne; public SeleniumHandler(HttpHandler seleniumHandler) { super(HttpRequest.class); this.seleniumHandler = Require.nonNull("HTTP handler", seleniumHandler).with(new ErrorFilter()); + this.lastOne = CompletableFuture.completedFuture(null); } @Override protected void channelRead0(ChannelHandlerContext ctx, HttpRequest msg) { - EXECUTOR.submit( - () -> { - HttpResponse res = seleniumHandler.execute(msg); - ctx.writeAndFlush(res); - }); + lastOne = + EXECUTOR.submit( + () -> { + HttpResponse res = seleniumHandler.execute(msg); + ctx.writeAndFlush(res); + }); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + lastOne.cancel(true); + super.channelInactive(ctx); } } diff --git a/java/test/org/openqa/selenium/netty/server/NettyServerTest.java b/java/test/org/openqa/selenium/netty/server/NettyServerTest.java index 04ba014a5be17..fd110c233e80e 100644 --- a/java/test/org/openqa/selenium/netty/server/NettyServerTest.java +++ b/java/test/org/openqa/selenium/netty/server/NettyServerTest.java @@ -28,14 +28,20 @@ import com.google.common.collect.ImmutableMap; import java.net.URL; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.openqa.selenium.TimeoutException; import org.openqa.selenium.grid.config.CompoundConfig; import org.openqa.selenium.grid.config.Config; import org.openqa.selenium.grid.config.MapConfig; import org.openqa.selenium.grid.server.BaseServerOptions; import org.openqa.selenium.grid.server.Server; import org.openqa.selenium.net.PortProber; +import org.openqa.selenium.remote.http.ClientConfig; import org.openqa.selenium.remote.http.HttpClient; import org.openqa.selenium.remote.http.HttpRequest; import org.openqa.selenium.remote.http.HttpResponse; @@ -141,6 +147,43 @@ void shouldNotBindToHost() { assertEquals("anyRandomHost", server.getUrl().getHost()); } + @Test + void doesInterruptPending() throws Exception { + CountDownLatch interrupted = new CountDownLatch(1); + Config cfg = new MapConfig(ImmutableMap.of()); + BaseServerOptions options = new BaseServerOptions(cfg); + + Server server = + new NettyServer( + options, + req -> { + try { + Thread.sleep(800); + } catch (InterruptedException ex) { + interrupted.countDown(); + } + return new HttpResponse(); + }) + .start(); + ClientConfig config = + ClientConfig.defaultConfig() + .readTimeout(Duration.ofMillis(400)) + .baseUri(server.getUrl().toURI()); + + // provoke a client timeout + Assertions.assertThrows( + TimeoutException.class, + () -> { + try (HttpClient client = HttpClient.Factory.createDefault().createClient(config)) { + HttpRequest request = new HttpRequest(DELETE, "/session"); + request.setHeader("Accept", "*/*"); + client.execute(request); + } + }); + + assertTrue(interrupted.await(1000, TimeUnit.MILLISECONDS), "The handling was interrupted"); + } + private void outputHeaders(HttpResponse res) { res.forEachHeader((name, value) -> System.out.printf("%s -> %s\n", name, value)); }