Skip to content

Commit

Permalink
[bidi][java] Add network request handler APIs (#14424)
Browse files Browse the repository at this point in the history
  • Loading branch information
pujagani authored Nov 13, 2024
1 parent 25551ad commit 0e8059d
Show file tree
Hide file tree
Showing 7 changed files with 422 additions and 10 deletions.
2 changes: 1 addition & 1 deletion java/src/org/openqa/selenium/bidi/network/Header.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class Header {
private final String name;
private final BytesValue value;

private Header(String name, BytesValue value) {
public Header(String name, BytesValue value) {
this.name = name;
this.value = value;
}
Expand Down
10 changes: 5 additions & 5 deletions java/src/org/openqa/selenium/bidi/network/ResponseData.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class ResponseData {
private final String url;

private final String protocol;
private final long status;
private final int status;
private final String statusText;
private final boolean fromCache;
private final List<Header> headers;
Expand All @@ -42,7 +42,7 @@ public class ResponseData {
private ResponseData(
String url,
String protocol,
long status,
int status,
String statusText,
boolean fromCache,
List<Header> headers,
Expand All @@ -69,7 +69,7 @@ private ResponseData(
public static ResponseData fromJson(JsonInput input) {
String url = null;
String protocol = null;
long status = 0;
int status = 0;
String statusText = null;
boolean fromCache = false;
List<Header> headers = new ArrayList<>();
Expand All @@ -89,7 +89,7 @@ public static ResponseData fromJson(JsonInput input) {
protocol = input.read(String.class);
break;
case "status":
status = input.read(Long.class);
status = input.read(Integer.class);
break;
case "statusText":
statusText = input.read(String.class);
Expand Down Expand Up @@ -150,7 +150,7 @@ public String getProtocol() {
return protocol;
}

public long getStatus() {
public int getStatus() {
return status;
}

Expand Down
8 changes: 8 additions & 0 deletions java/src/org/openqa/selenium/remote/Network.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import java.net.URI;
import java.util.function.Predicate;
import java.util.function.UnaryOperator;
import org.openqa.selenium.Beta;
import org.openqa.selenium.UsernameAndPassword;
import org.openqa.selenium.remote.http.HttpRequest;

@Beta
public interface Network {
Expand All @@ -32,4 +34,10 @@ public interface Network {
void removeAuthenticationHandler(long id);

void clearAuthenticationHandlers();

long addRequestHandler(Predicate<URI> filter, UnaryOperator<HttpRequest> handler);

void removeRequestHandler(long id);

void clearRequestHandlers();
}
114 changes: 114 additions & 0 deletions java/src/org/openqa/selenium/remote/RemoteNetwork.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,28 @@
package org.openqa.selenium.remote;

import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Predicate;
import java.util.function.UnaryOperator;
import org.openqa.selenium.Beta;
import org.openqa.selenium.UsernameAndPassword;
import org.openqa.selenium.WebDriver;
import org.openqa.selenium.bidi.BiDi;
import org.openqa.selenium.bidi.HasBiDi;
import org.openqa.selenium.bidi.network.AddInterceptParameters;
import org.openqa.selenium.bidi.network.BytesValue;
import org.openqa.selenium.bidi.network.ContinueRequestParameters;
import org.openqa.selenium.bidi.network.Header;
import org.openqa.selenium.bidi.network.InterceptPhase;
import org.openqa.selenium.bidi.network.RequestData;
import org.openqa.selenium.remote.http.Contents;
import org.openqa.selenium.remote.http.HttpMethod;
import org.openqa.selenium.remote.http.HttpRequest;

@Beta
class RemoteNetwork implements Network {
Expand All @@ -39,13 +49,16 @@ class RemoteNetwork implements Network {

private final Map<Long, AuthDetails> authHandlers = new ConcurrentHashMap<>();

private final Map<Long, RequestDetails> requestHandlers = new ConcurrentHashMap<>();

private final AtomicLong callBackId = new AtomicLong(1);

public RemoteNetwork(WebDriver driver) {
this.biDi = ((HasBiDi) driver).getBiDi();
this.network = new org.openqa.selenium.bidi.module.Network(driver);

interceptAuthTraffic();
interceptRequest();
}

private void interceptAuthTraffic() {
Expand Down Expand Up @@ -73,6 +86,71 @@ private Optional<UsernameAndPassword> getAuthCredentials(URI uri) {
.findFirst();
}

private void interceptRequest() {
this.network.addIntercept(new AddInterceptParameters(InterceptPhase.BEFORE_REQUEST_SENT));

this.network.onBeforeRequestSent(
beforeRequestSent -> {
String requestId = beforeRequestSent.getRequest().getRequestId();
URI uri = URI.create(beforeRequestSent.getRequest().getUrl());

ContinueRequestParameters continueRequestParameters =
new ContinueRequestParameters(requestId);

Optional<UnaryOperator<HttpRequest>> requestHandler = getRequestHandler(uri);

if (requestHandler.isPresent()) {
RequestData interceptedRequest = beforeRequestSent.getRequest();

// Build the originalRequest object from the intercepted request details.
HttpRequest originalRequest =
new HttpRequest(
HttpMethod.getHttpMethod(interceptedRequest.getMethod()),
interceptedRequest.getUrl());

// Populate the headers of the original request.
interceptedRequest
.getHeaders()
.forEach(
header ->
originalRequest.addHeader(header.getName(), header.getValue().getValue()));

HttpRequest modifiedRequest = requestHandler.get().apply(originalRequest);

continueRequestParameters.method(modifiedRequest.getMethod());

if (!uri.toString().equals(modifiedRequest.getUri())) {
continueRequestParameters.url(modifiedRequest.getUri());
}

List<Header> headerList = new ArrayList<>();
modifiedRequest.forEachHeader(
(name, value) ->
headerList.add(
new Header(name, new BytesValue(BytesValue.Type.STRING, value))));

if (!headerList.isEmpty()) {
continueRequestParameters.headers(headerList);
}

Contents.Supplier content = modifiedRequest.getContent();

if (content.length() > 0) {
continueRequestParameters.body(
new BytesValue(BytesValue.Type.STRING, Contents.utf8String(content)));
}
}
network.continueRequest(continueRequestParameters);
});
}

private Optional<UnaryOperator<HttpRequest>> getRequestHandler(URI uri) {
return requestHandlers.values().stream()
.filter(requestDetails -> requestDetails.getFilter().test(uri))
.map(RequestDetails::getHandler)
.findFirst();
}

@Override
public long addAuthenticationHandler(UsernameAndPassword usernameAndPassword) {
return addAuthenticationHandler(url -> true, usernameAndPassword);
Expand All @@ -97,6 +175,24 @@ public void clearAuthenticationHandlers() {
authHandlers.clear();
}

@Override
public long addRequestHandler(Predicate<URI> filter, UnaryOperator<HttpRequest> handler) {
long id = this.callBackId.incrementAndGet();

requestHandlers.put(id, new RequestDetails(filter, handler));
return id;
}

@Override
public void removeRequestHandler(long id) {
requestHandlers.remove(id);
}

@Override
public void clearRequestHandlers() {
requestHandlers.clear();
}

private class AuthDetails {
private final Predicate<URI> filter;
private final UsernameAndPassword usernameAndPassword;
Expand All @@ -114,4 +210,22 @@ public UsernameAndPassword getUsernameAndPassword() {
return usernameAndPassword;
}
}

private class RequestDetails {
private final Predicate<URI> filter;
private final UnaryOperator<HttpRequest> handler;

public RequestDetails(Predicate<URI> filter, UnaryOperator<HttpRequest> handler) {
this.filter = filter;
this.handler = handler;
}

public Predicate<URI> getFilter() {
return this.filter;
}

public UnaryOperator<HttpRequest> getHandler() {
return this.handler;
}
}
}
14 changes: 13 additions & 1 deletion java/src/org/openqa/selenium/remote/http/HttpMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,17 @@ public enum HttpMethod {
PATCH,
HEAD,
CONNECT,
TRACE,
TRACE;

public static HttpMethod getHttpMethod(String method) {
if (method == null) {
throw new IllegalArgumentException("Method cannot be null");
}

try {
return HttpMethod.valueOf(method.toUpperCase());
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException("No enum constant for method: " + method);
}
}
}
Loading

0 comments on commit 0e8059d

Please sign in to comment.