From 0e8059dc2e79d5c20b77821e1f6ae7fbf143d3fb Mon Sep 17 00:00:00 2001 From: Puja Jagani Date: Wed, 13 Nov 2024 15:25:56 +0530 Subject: [PATCH] [bidi][java] Add network request handler APIs (#14424) --- .../openqa/selenium/bidi/network/Header.java | 2 +- .../selenium/bidi/network/ResponseData.java | 10 +- .../org/openqa/selenium/remote/Network.java | 8 + .../openqa/selenium/remote/RemoteNetwork.java | 114 +++++++ .../selenium/remote/http/HttpMethod.java | 14 +- .../org/openqa/selenium/WebNetworkTest.java | 278 ++++++++++++++++++ .../bidi/network/NetworkEventsTest.java | 6 +- 7 files changed, 422 insertions(+), 10 deletions(-) diff --git a/java/src/org/openqa/selenium/bidi/network/Header.java b/java/src/org/openqa/selenium/bidi/network/Header.java index b98d3b26587ab..b57a85900d4d9 100644 --- a/java/src/org/openqa/selenium/bidi/network/Header.java +++ b/java/src/org/openqa/selenium/bidi/network/Header.java @@ -25,7 +25,7 @@ public class Header { private final String name; private final BytesValue value; - private Header(String name, BytesValue value) { + public Header(String name, BytesValue value) { this.name = name; this.value = value; } diff --git a/java/src/org/openqa/selenium/bidi/network/ResponseData.java b/java/src/org/openqa/selenium/bidi/network/ResponseData.java index efeedb0b57761..a6c16abb7ca2e 100644 --- a/java/src/org/openqa/selenium/bidi/network/ResponseData.java +++ b/java/src/org/openqa/selenium/bidi/network/ResponseData.java @@ -28,7 +28,7 @@ public class ResponseData { private final String url; private final String protocol; - private final long status; + private final int status; private final String statusText; private final boolean fromCache; private final List
headers; @@ -42,7 +42,7 @@ public class ResponseData { private ResponseData( String url, String protocol, - long status, + int status, String statusText, boolean fromCache, List
headers, @@ -69,7 +69,7 @@ private ResponseData( public static ResponseData fromJson(JsonInput input) { String url = null; String protocol = null; - long status = 0; + int status = 0; String statusText = null; boolean fromCache = false; List
headers = new ArrayList<>(); @@ -89,7 +89,7 @@ public static ResponseData fromJson(JsonInput input) { protocol = input.read(String.class); break; case "status": - status = input.read(Long.class); + status = input.read(Integer.class); break; case "statusText": statusText = input.read(String.class); @@ -150,7 +150,7 @@ public String getProtocol() { return protocol; } - public long getStatus() { + public int getStatus() { return status; } diff --git a/java/src/org/openqa/selenium/remote/Network.java b/java/src/org/openqa/selenium/remote/Network.java index b64060b980237..6407d3de3745b 100644 --- a/java/src/org/openqa/selenium/remote/Network.java +++ b/java/src/org/openqa/selenium/remote/Network.java @@ -19,8 +19,10 @@ import java.net.URI; import java.util.function.Predicate; +import java.util.function.UnaryOperator; import org.openqa.selenium.Beta; import org.openqa.selenium.UsernameAndPassword; +import org.openqa.selenium.remote.http.HttpRequest; @Beta public interface Network { @@ -32,4 +34,10 @@ public interface Network { void removeAuthenticationHandler(long id); void clearAuthenticationHandlers(); + + long addRequestHandler(Predicate filter, UnaryOperator handler); + + void removeRequestHandler(long id); + + void clearRequestHandlers(); } diff --git a/java/src/org/openqa/selenium/remote/RemoteNetwork.java b/java/src/org/openqa/selenium/remote/RemoteNetwork.java index 4a1f921921d66..2848895667f1f 100644 --- a/java/src/org/openqa/selenium/remote/RemoteNetwork.java +++ b/java/src/org/openqa/selenium/remote/RemoteNetwork.java @@ -18,18 +18,28 @@ package org.openqa.selenium.remote; import java.net.URI; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Predicate; +import java.util.function.UnaryOperator; import org.openqa.selenium.Beta; import org.openqa.selenium.UsernameAndPassword; import org.openqa.selenium.WebDriver; import org.openqa.selenium.bidi.BiDi; import org.openqa.selenium.bidi.HasBiDi; import org.openqa.selenium.bidi.network.AddInterceptParameters; +import org.openqa.selenium.bidi.network.BytesValue; +import org.openqa.selenium.bidi.network.ContinueRequestParameters; +import org.openqa.selenium.bidi.network.Header; import org.openqa.selenium.bidi.network.InterceptPhase; +import org.openqa.selenium.bidi.network.RequestData; +import org.openqa.selenium.remote.http.Contents; +import org.openqa.selenium.remote.http.HttpMethod; +import org.openqa.selenium.remote.http.HttpRequest; @Beta class RemoteNetwork implements Network { @@ -39,6 +49,8 @@ class RemoteNetwork implements Network { private final Map authHandlers = new ConcurrentHashMap<>(); + private final Map requestHandlers = new ConcurrentHashMap<>(); + private final AtomicLong callBackId = new AtomicLong(1); public RemoteNetwork(WebDriver driver) { @@ -46,6 +58,7 @@ public RemoteNetwork(WebDriver driver) { this.network = new org.openqa.selenium.bidi.module.Network(driver); interceptAuthTraffic(); + interceptRequest(); } private void interceptAuthTraffic() { @@ -73,6 +86,71 @@ private Optional getAuthCredentials(URI uri) { .findFirst(); } + private void interceptRequest() { + this.network.addIntercept(new AddInterceptParameters(InterceptPhase.BEFORE_REQUEST_SENT)); + + this.network.onBeforeRequestSent( + beforeRequestSent -> { + String requestId = beforeRequestSent.getRequest().getRequestId(); + URI uri = URI.create(beforeRequestSent.getRequest().getUrl()); + + ContinueRequestParameters continueRequestParameters = + new ContinueRequestParameters(requestId); + + Optional> requestHandler = getRequestHandler(uri); + + if (requestHandler.isPresent()) { + RequestData interceptedRequest = beforeRequestSent.getRequest(); + + // Build the originalRequest object from the intercepted request details. + HttpRequest originalRequest = + new HttpRequest( + HttpMethod.getHttpMethod(interceptedRequest.getMethod()), + interceptedRequest.getUrl()); + + // Populate the headers of the original request. + interceptedRequest + .getHeaders() + .forEach( + header -> + originalRequest.addHeader(header.getName(), header.getValue().getValue())); + + HttpRequest modifiedRequest = requestHandler.get().apply(originalRequest); + + continueRequestParameters.method(modifiedRequest.getMethod()); + + if (!uri.toString().equals(modifiedRequest.getUri())) { + continueRequestParameters.url(modifiedRequest.getUri()); + } + + List
headerList = new ArrayList<>(); + modifiedRequest.forEachHeader( + (name, value) -> + headerList.add( + new Header(name, new BytesValue(BytesValue.Type.STRING, value)))); + + if (!headerList.isEmpty()) { + continueRequestParameters.headers(headerList); + } + + Contents.Supplier content = modifiedRequest.getContent(); + + if (content.length() > 0) { + continueRequestParameters.body( + new BytesValue(BytesValue.Type.STRING, Contents.utf8String(content))); + } + } + network.continueRequest(continueRequestParameters); + }); + } + + private Optional> getRequestHandler(URI uri) { + return requestHandlers.values().stream() + .filter(requestDetails -> requestDetails.getFilter().test(uri)) + .map(RequestDetails::getHandler) + .findFirst(); + } + @Override public long addAuthenticationHandler(UsernameAndPassword usernameAndPassword) { return addAuthenticationHandler(url -> true, usernameAndPassword); @@ -97,6 +175,24 @@ public void clearAuthenticationHandlers() { authHandlers.clear(); } + @Override + public long addRequestHandler(Predicate filter, UnaryOperator handler) { + long id = this.callBackId.incrementAndGet(); + + requestHandlers.put(id, new RequestDetails(filter, handler)); + return id; + } + + @Override + public void removeRequestHandler(long id) { + requestHandlers.remove(id); + } + + @Override + public void clearRequestHandlers() { + requestHandlers.clear(); + } + private class AuthDetails { private final Predicate filter; private final UsernameAndPassword usernameAndPassword; @@ -114,4 +210,22 @@ public UsernameAndPassword getUsernameAndPassword() { return usernameAndPassword; } } + + private class RequestDetails { + private final Predicate filter; + private final UnaryOperator handler; + + public RequestDetails(Predicate filter, UnaryOperator handler) { + this.filter = filter; + this.handler = handler; + } + + public Predicate getFilter() { + return this.filter; + } + + public UnaryOperator getHandler() { + return this.handler; + } + } } diff --git a/java/src/org/openqa/selenium/remote/http/HttpMethod.java b/java/src/org/openqa/selenium/remote/http/HttpMethod.java index b10cd75a01b8a..73ed0f52aa4b8 100644 --- a/java/src/org/openqa/selenium/remote/http/HttpMethod.java +++ b/java/src/org/openqa/selenium/remote/http/HttpMethod.java @@ -26,5 +26,17 @@ public enum HttpMethod { PATCH, HEAD, CONNECT, - TRACE, + TRACE; + + public static HttpMethod getHttpMethod(String method) { + if (method == null) { + throw new IllegalArgumentException("Method cannot be null"); + } + + try { + return HttpMethod.valueOf(method.toUpperCase()); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("No enum constant for method: " + method); + } + } } diff --git a/java/test/org/openqa/selenium/WebNetworkTest.java b/java/test/org/openqa/selenium/WebNetworkTest.java index 3ad732b8ff6b2..fb7f3023784c6 100644 --- a/java/test/org/openqa/selenium/WebNetworkTest.java +++ b/java/test/org/openqa/selenium/WebNetworkTest.java @@ -18,12 +18,24 @@ package org.openqa.selenium; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.openqa.selenium.remote.http.Contents.utf8String; import java.net.URI; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.function.Predicate; import org.junit.jupiter.api.Test; +import org.openqa.selenium.bidi.module.Network; +import org.openqa.selenium.bidi.network.Header; +import org.openqa.selenium.environment.webserver.NettyAppServer; import org.openqa.selenium.remote.RemoteWebDriver; +import org.openqa.selenium.remote.http.HttpMethod; +import org.openqa.selenium.remote.http.HttpRequest; +import org.openqa.selenium.remote.http.HttpResponse; +import org.openqa.selenium.remote.http.Route; import org.openqa.selenium.testing.Ignore; import org.openqa.selenium.testing.JupiterTestBase; import org.openqa.selenium.testing.NeedsFreshDriver; @@ -163,4 +175,270 @@ void canClearAuthenticationHandlers() { assertThatExceptionOfType(UnhandledAlertException.class) .isThrownBy(() -> driver.findElement(By.tagName("h1"))); } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddRequestHandler() { + Predicate filter = uri -> uri.getPath().contains("logEntry"); + + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + ((RemoteWebDriver) driver).network().addRequestHandler(filter, httpRequest -> httpRequest); + + driver.get(page); + + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddRequestHandlerToModifyMethod() { + Predicate filter = uri -> uri.getPath().contains("logEntry"); + + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + ((RemoteWebDriver) driver) + .network() + .addRequestHandler(filter, httpRequest -> new HttpRequest(HttpMethod.HEAD, page)); + + driver.get(page); + + assertThatThrownBy(() -> driver.findElement(By.tagName("h1"))) + .isInstanceOf(NoSuchElementException.class); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddRequestHandlerToModifyHeaders() throws InterruptedException { + Route route = + Route.matching(req -> req.getUri().contains("network")) + .to( + () -> + req -> { + HttpResponse response = new HttpResponse(); + + req.getHeaderNames() + .forEach( + header -> { + String value = req.getHeader(header); + response.addHeader(header, value); + }); + return response.setContent(utf8String("Received response for network")); + }); + + appServer = new NettyAppServer(route); + appServer.start(); + + Predicate filter = uri -> uri.getPath().contains("network"); + + CountDownLatch latch = new CountDownLatch(1); + + page = appServer.whereIs("network.html"); + + ((RemoteWebDriver) driver) + .network() + .addRequestHandler( + filter, + httpRequest -> + new HttpRequest(HttpMethod.HEAD, page).addHeader("test", "network-intercept")); + + Network network = new Network(driver); + network.onResponseCompleted( + responseDetails -> { + List
headers = responseDetails.getResponseData().getHeaders(); + headers.forEach( + header -> { + if (header.getName().equals("test")) { + assertThat(header.getValue().getValue()).isEqualTo("network-intercept"); + latch.countDown(); + } + }); + }); + + driver.get(page); + + latch.await(5, TimeUnit.SECONDS); + + assertThat(latch.getCount()).isEqualTo(0); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddRequestHandlerToModifyBody() throws InterruptedException { + Route route = + Route.matching(req -> req.getUri().contains("network")) + .to( + () -> + req -> { + HttpResponse response = new HttpResponse(); + return response.setContent(req.getContent()); + }); + + appServer = new NettyAppServer(route); + appServer.start(); + + Predicate filter = uri -> uri.getPath().contains("network"); + + page = appServer.whereIs("network.html"); + + ((RemoteWebDriver) driver) + .network() + .addRequestHandler( + filter, + httpRequest -> + new HttpRequest(HttpMethod.POST, page) + .setContent(utf8String("Received response for the request"))); + + driver.get(page); + + assertThat(driver.getPageSource().contains("Received response for the request")).isTrue(); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddMultipleRequestHandlers() { + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + ((RemoteWebDriver) driver) + .network() + .addRequestHandler(uri -> uri.getPath().contains("logEntry"), httpRequest -> httpRequest); + + ((RemoteWebDriver) driver) + .network() + .addRequestHandler( + uri -> uri.getPath().contains("hello"), + httpRequest -> new HttpRequest(HttpMethod.HEAD, page)); + + driver.get(page); + + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canAddMultipleRequestHandlersWithTheSameFilter() { + ((RemoteWebDriver) driver) + .network() + .addRequestHandler(uri -> uri.getPath().contains("logEntry"), httpRequest -> httpRequest); + + ((RemoteWebDriver) driver) + .network() + .addRequestHandler(uri -> uri.getPath().contains("logEntry"), httpRequest -> httpRequest); + + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + driver.get(page); + + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canRemoveRequestHandler() throws InterruptedException { + Route route = + Route.matching(req -> req.getUri().contains("network")) + .to( + () -> + req -> { + HttpResponse response = new HttpResponse(); + + req.getHeaderNames() + .forEach( + header -> { + String value = req.getHeader(header); + response.addHeader(header, value); + }); + return response.setContent(utf8String("Received response for network")); + }); + + appServer = new NettyAppServer(route); + appServer.start(); + + Predicate filter = uri -> uri.getPath().contains("network"); + + CountDownLatch latch = new CountDownLatch(1); + + page = appServer.whereIs("network.html"); + + long id = + ((RemoteWebDriver) driver) + .network() + .addRequestHandler( + filter, + httpRequest -> + new HttpRequest(HttpMethod.HEAD, page).addHeader("test", "network-intercept")); + + ((RemoteWebDriver) driver).network().removeRequestHandler(id); + + Network network = new Network(driver); + network.onResponseCompleted( + responseDetails -> { + List
headers = responseDetails.getResponseData().getHeaders(); + headers.forEach( + header -> { + if (header.getName().equals("test")) { + assertThat(header.getValue().getValue()).isEqualTo("network-intercept"); + latch.countDown(); + } + }); + }); + + driver.get(page); + + latch.await(5, TimeUnit.SECONDS); + + assertThat(latch.getCount()).isEqualTo(1); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canRemoveRequestHandlerThatDoesNotExist() { + ((RemoteWebDriver) driver).network().removeAuthenticationHandler(5); + page = appServer.whereIs("/bidi/logEntryAdded.html"); + driver.get(page); + + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + } + + @Test + @NeedsFreshDriver + @Ignore(Browser.CHROME) + @Ignore(Browser.EDGE) + void canClearRequestHandlers() { + page = appServer.whereIs("/bidi/logEntryAdded.html"); + + ((RemoteWebDriver) driver) + .network() + .addRequestHandler( + uri -> uri.getPath().contains("logEntryAdded"), + httpRequest -> new HttpRequest(HttpMethod.DELETE, page)); + + ((RemoteWebDriver) driver) + .network() + .addRequestHandler( + uri -> uri.getPath().contains("hello"), + httpRequest -> new HttpRequest(HttpMethod.HEAD, page)); + + ((RemoteWebDriver) driver).network().clearRequestHandlers(); + + driver.get(page); + + assertThat(driver.findElement(By.tagName("h1")).getText()).isEqualTo("Long entry added events"); + } } diff --git a/java/test/org/openqa/selenium/bidi/network/NetworkEventsTest.java b/java/test/org/openqa/selenium/bidi/network/NetworkEventsTest.java index 7f221f6fc20bc..f4a0b10f5bd03 100644 --- a/java/test/org/openqa/selenium/bidi/network/NetworkEventsTest.java +++ b/java/test/org/openqa/selenium/bidi/network/NetworkEventsTest.java @@ -77,7 +77,7 @@ void canListenToResponseStartedEvent() assertThat(response.getRequest().getUrl()).isNotNull(); assertThat(response.getResponseData().getHeaders().size()).isGreaterThanOrEqualTo(1); assertThat(response.getResponseData().getUrl()).contains("/bidi/logEntryAdded.html"); - assertThat(response.getResponseData().getStatus()).isEqualTo(200L); + assertThat(response.getResponseData().getStatus()).isEqualTo(200); } } @@ -100,7 +100,7 @@ void canListenToResponseCompletedEvent() assertThat(response.getRequest().getUrl()).isNotNull(); assertThat(response.getResponseData().getHeaders().size()).isGreaterThanOrEqualTo(1); assertThat(response.getResponseData().getUrl()).contains("/bidi/logEntryAdded.html"); - assertThat(response.getResponseData().getStatus()).isEqualTo(200L); + assertThat(response.getResponseData().getStatus()).isEqualTo(200); } } @@ -147,7 +147,7 @@ void canListenToOnAuthRequiredEvent() assertThat(response.getRequest().getUrl()).isNotNull(); assertThat(response.getResponseData().getHeaders().size()).isGreaterThanOrEqualTo(1); assertThat(response.getResponseData().getUrl()).contains("basicAuth"); - assertThat(response.getResponseData().getStatus()).isEqualTo(401L); + assertThat(response.getResponseData().getStatus()).isEqualTo(401); } }