From 1243079ba501e1dba0a6ebb02289eb7e3531e9e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Sautter?= Date: Sun, 24 Nov 2024 17:57:53 +0100 Subject: [PATCH] [java] reuse the classes created by the WebDriverDecorator #14789 --- .../decorators/WebDriverDecorator.java | 175 +++++++++++++----- .../decorators/DecoratedWebDriverTest.java | 50 +++++ 2 files changed, 181 insertions(+), 44 deletions(-) diff --git a/java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java b/java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java index 0ced6710e2f99..f500774138277 100644 --- a/java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java +++ b/java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java @@ -20,14 +20,20 @@ import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Function; import java.util.stream.Collectors; import net.bytebuddy.ByteBuddy; +import net.bytebuddy.description.modifier.Visibility; import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.FieldAccessor; import net.bytebuddy.implementation.InvocationHandlerAdapter; import net.bytebuddy.matcher.ElementMatchers; import org.openqa.selenium.Alert; @@ -183,6 +189,65 @@ @Beta public class WebDriverDecorator { + protected static class Definition { + private final Class decoratedClass; + private final Class originalClass; + + public Definition(Decorated decorated) { + this.decoratedClass = decorated.getClass(); + this.originalClass = decorated.getOriginal().getClass(); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Definition definition = (Definition) o; + // intentionally an identity check, to ensure we get no false positive lookup due to an + // unknown implementation of decoratedClass.equals or originalClass.equals + return (decoratedClass == definition.decoratedClass) + && (originalClass == definition.originalClass); + } + + @Override + public int hashCode() { + return Arrays.hashCode( + new int[] { + System.identityHashCode(decoratedClass), System.identityHashCode(originalClass) + }); + } + } + + public interface HasTarget { + Decorated getTarget(); + + void setTarget(Decorated target); + } + + protected static class ProxyFactory { + private final Class clazz; + + private ProxyFactory(Class clazz) { + this.clazz = clazz; + } + + public T newInstance(Decorated target) { + T instance; + try { + instance = (T) clazz.newInstance(); + } catch (ReflectiveOperationException e) { + throw new AssertionError("Unable to create new proxy", e); + } + + // ensure we can later find the target to call + //noinspection unchecked + ((HasTarget) instance).setTarget(target); + + return instance; + } + } + + private final ConcurrentMap> cache; + private final Class targetWebDriverClass; private Decorated decorated; @@ -194,6 +259,7 @@ public WebDriverDecorator() { public WebDriverDecorator(Class targetClass) { this.targetWebDriverClass = targetClass; + this.cache = new ConcurrentHashMap<>(); } public final T decorate(T original) { @@ -295,18 +361,36 @@ private Object decorateResult(Object toDecorate) { return toDecorate; } - protected final Z createProxy(final Decorated decorated, Class clazz) { - Set> decoratedInterfaces = extractInterfaces(decorated); - Set> originalInterfaces = extractInterfaces(decorated.getOriginal()); - Map, InvocationHandler> derivedInterfaces = - deriveAdditionalInterfaces(decorated.getOriginal()); + protected final Z createProxy(final Decorated decorated, Class clazz) { + @SuppressWarnings("unchecked") + ProxyFactory factory = + (ProxyFactory) + cache.computeIfAbsent( + new Definition(decorated), (key) -> createProxyFactory(key, decorated, clazz)); + + return factory.newInstance(decorated); + } + + protected final ProxyFactory createProxyFactory( + Definition definition, final Decorated sample, Class clazz) { + Set> decoratedInterfaces = extractInterfaces(definition.decoratedClass); + Set> originalInterfaces = extractInterfaces(definition.originalClass); + // all samples with the same definition should have the same derivedInterfaces + Map, Function> derivedInterfaces = + deriveAdditionalInterfaces(sample.getOriginal()); final InvocationHandler handler = (proxy, method, args) -> { + // Lookup the instance to call, to reuse the clazz and handler. + @SuppressWarnings("unchecked") + Decorated instance = ((HasTarget) proxy).getTarget(); + if (instance == null) { + throw new AssertionError("Failed to get instance to call"); + } try { if (method.getDeclaringClass().equals(Object.class) || decoratedInterfaces.contains(method.getDeclaringClass())) { - return method.invoke(decorated, args); + return method.invoke(instance, args); } // Check if the class in which the method resides, implements any one of the // interfaces that we extracted from the decorated class. @@ -317,9 +401,9 @@ protected final Z createProxy(final Decorated decorated, Class clazz) eachInterface.isAssignableFrom(method.getDeclaringClass())); if (isCompatible) { - decorated.beforeCall(method, args); - Object result = decorated.call(method, args); - decorated.afterCall(method, result, args); + instance.beforeCall(method, args); + Object result = instance.call(method, args); + instance.afterCall(method, result, args); return result; } @@ -333,12 +417,15 @@ protected final Z createProxy(final Decorated decorated, Class clazz) eachInterface.isAssignableFrom(method.getDeclaringClass())); if (isCompatible) { - return derivedInterfaces.get(method.getDeclaringClass()).invoke(proxy, method, args); + return derivedInterfaces + .get(method.getDeclaringClass()) + .apply(instance.getOriginal()) + .invoke(proxy, method, args); } - return method.invoke(decorated.getOriginal(), args); + return method.invoke(instance.getOriginal(), args); } catch (InvocationTargetException e) { - return decorated.onError(method, e, args); + return instance.onError(method, e, args); } }; @@ -346,6 +433,8 @@ protected final Z createProxy(final Decorated decorated, Class clazz) allInterfaces.addAll(decoratedInterfaces); allInterfaces.addAll(originalInterfaces); allInterfaces.addAll(derivedInterfaces.keySet()); + // ensure a decorated driver can get decorated again + allInterfaces.remove(HasTarget.class); Class[] allInterfacesArray = allInterfaces.toArray(new Class[0]); Class proxy = @@ -354,20 +443,15 @@ protected final Z createProxy(final Decorated decorated, Class clazz) .implement(allInterfacesArray) .method(ElementMatchers.any()) .intercept(InvocationHandlerAdapter.of(handler)) + .defineField("target", Decorated.class, Visibility.PRIVATE) + .implement(HasTarget.class) + .intercept(FieldAccessor.ofField("target")) .make() .load(clazz.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) .getLoaded() .asSubclass(clazz); - try { - return proxy.newInstance(); - } catch (ReflectiveOperationException e) { - throw new IllegalStateException("Unable to create new proxy", e); - } - } - - static Set> extractInterfaces(final Object object) { - return extractInterfaces(object.getClass()); + return new ProxyFactory(proxy); } private static Set> extractInterfaces(final Class clazz) { @@ -393,43 +477,46 @@ private static void extractInterfaces(final Set> collector, final Class extractInterfaces(collector, clazz.getSuperclass()); } - private Map, InvocationHandler> deriveAdditionalInterfaces(Object object) { - Map, InvocationHandler> handlers = new HashMap<>(); + private Map, Function> deriveAdditionalInterfaces(Z sample) { + Map, Function> handlers = new HashMap<>(); - if (object instanceof WebDriver && !(object instanceof WrapsDriver)) { + if (sample instanceof WebDriver && !(sample instanceof WrapsDriver)) { handlers.put( WrapsDriver.class, - (proxy, method, args) -> { - if ("getWrappedDriver".equals(method.getName())) { - return object; - } - throw new UnsupportedOperationException(method.getName()); - }); + (instance) -> + (proxy, method, args) -> { + if ("getWrappedDriver".equals(method.getName())) { + return instance; + } + throw new UnsupportedOperationException(method.getName()); + }); } - if (object instanceof WebElement && !(object instanceof WrapsElement)) { + if (sample instanceof WebElement && !(sample instanceof WrapsElement)) { handlers.put( WrapsElement.class, - (proxy, method, args) -> { - if ("getWrappedElement".equals(method.getName())) { - return object; - } - throw new UnsupportedOperationException(method.getName()); - }); + (instance) -> + (proxy, method, args) -> { + if ("getWrappedElement".equals(method.getName())) { + return instance; + } + throw new UnsupportedOperationException(method.getName()); + }); } try { - Method toJson = object.getClass().getDeclaredMethod("toJson"); + Method toJson = sample.getClass().getDeclaredMethod("toJson"); toJson.setAccessible(true); handlers.put( JsonSerializer.class, - ((proxy, method, args) -> { - if ("toJson".equals(method.getName())) { - return toJson.invoke(object); - } - throw new UnsupportedOperationException(method.getName()); - })); + (instance) -> + ((proxy, method, args) -> { + if ("toJson".equals(method.getName())) { + return toJson.invoke(instance); + } + throw new UnsupportedOperationException(method.getName()); + })); } catch (NoSuchMethodException e) { // Fine. Just fall through } diff --git a/java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java b/java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java index ac61d70a166ef..82355b11ba2c6 100644 --- a/java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java +++ b/java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java @@ -31,6 +31,7 @@ import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.UUID; import java.util.function.Consumer; import java.util.function.Function; import org.junit.jupiter.api.Tag; @@ -163,6 +164,55 @@ void findElement() { verifyDecoratingFunction($ -> $.findElement(By.id("test")), found, WebElement::click); } + @Test + void doesNotCreateTooManyClasses() { + final WebElement found0 = mock(WebElement.class); + final WebElement found1 = mock(WebElement.class); + final WebElement found2 = mock(WebElement.class); + Function f = $ -> $.findElement(By.id("test")); + Function> f2 = $ -> $.findElements(By.id("test")); + Fixture fixture = new Fixture(); + when(f.apply(fixture.original)).thenReturn(found0); + when(f2.apply(fixture.original)).thenReturn(List.of(found0, found1, found2)); + + WebElement proxy0 = f.apply(fixture.decorated); + WebElement proxy1 = f.apply(fixture.decorated); + WebElement proxy2 = f.apply(fixture.decorated); + + assertThat(proxy0.getClass()).isSameAs(proxy1.getClass()); + assertThat(proxy1.getClass()).isSameAs(proxy2.getClass()); + + List proxies = f2.apply(fixture.decorated); + + assertThat(proxy0.getClass()).isSameAs(proxies.get(0).getClass()); + assertThat(proxy0.getClass()).isSameAs(proxies.get(1).getClass()); + assertThat(proxy0.getClass()).isSameAs(proxies.get(2).getClass()); + } + + @Test + void doesHitTheCorrectInstance() { + String uuid0 = UUID.randomUUID().toString(); + String uuid1 = UUID.randomUUID().toString(); + String uuid2 = UUID.randomUUID().toString(); + final WebElement found0 = mock(WebElement.class); + final WebElement found1 = mock(WebElement.class); + final WebElement found2 = mock(WebElement.class); + when(found0.getTagName()).thenReturn(uuid0); + when(found1.getTagName()).thenReturn(uuid1); + when(found2.getTagName()).thenReturn(uuid2); + + Fixture fixture = new Fixture(); + Function> f = $ -> $.findElements(By.id("test")); + + when(f.apply(fixture.original)).thenReturn(List.of(found0, found1, found2)); + + List proxies = f.apply(fixture.decorated); + + assertThat(proxies.get(0).getTagName()).isEqualTo(uuid0); + assertThat(proxies.get(1).getTagName()).isEqualTo(uuid1); + assertThat(proxies.get(2).getTagName()).isEqualTo(uuid2); + } + @Test void findElementNotFound() { Fixture fixture = new Fixture();