Skip to content

Commit

Permalink
JMockit to Mockito Recipe - Handle Typed Class Argument Matching and …
Browse files Browse the repository at this point in the history
…Collections (#420)

* JMockit to Mockito Recipe - handle typed any argument matcher

* Handle Mockito collection argument matchers

* polish

* polish

* polish

* format

* polish

* polish

* catch exceptions and return original LST

* polish

* error handling
  • Loading branch information
tinder-dthomson authored Nov 3, 2023
1 parent 71b6e6d commit 5c2367e
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package org.openrewrite.java.testing.jmockit;

import java.util.*;
import java.util.regex.Pattern;

import lombok.EqualsAndHashCode;
import lombok.Value;
Expand Down Expand Up @@ -69,6 +68,14 @@ private static class RewriteExpectationsVisitor extends JavaIsoVisitor<Execution
JMOCKIT_ARGUMENT_MATCHERS.add("anyShort");
JMOCKIT_ARGUMENT_MATCHERS.add("any");
}
private static final Map<String, String> MOCKITO_COLLECTION_MATCHERS = new HashMap<>();
static {
MOCKITO_COLLECTION_MATCHERS.put("java.util.List", "anyList");
MOCKITO_COLLECTION_MATCHERS.put("java.util.Set", "anySet");
MOCKITO_COLLECTION_MATCHERS.put("java.util.Collection", "anyCollection");
MOCKITO_COLLECTION_MATCHERS.put("java.util.Iterable", "anyIterable");
MOCKITO_COLLECTION_MATCHERS.put("java.util.Map", "anyMap");
}

@Override
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDeclaration, ExecutionContext ctx) {
Expand All @@ -81,77 +88,86 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl
J.Block newBody = md.getBody();
List<Statement> statements = md.getBody().getStatements();

// iterate over each statement in the method body, find Expectations blocks and rewrite them
for (int bodyStatementIndex = 0; bodyStatementIndex < statements.size(); bodyStatementIndex++) {
Statement s = statements.get(bodyStatementIndex);
if (!(s instanceof J.NewClass)) {
continue;
}
J.NewClass nc = (J.NewClass) s;
if (!(nc.getClazz() instanceof J.Identifier)) {
continue;
}
J.Identifier clazz = (J.Identifier) nc.getClazz();
if (!TypeUtils.isAssignableTo("mockit.Expectations", clazz.getType())) {
continue;
}
// empty Expectations block is considered invalid
assert nc.getBody() != null && !nc.getBody().getStatements().isEmpty() : "Expectations block is empty";
// Expectations block should be composed of a block within another block
assert nc.getBody().getStatements().size() == 1 : "Expectations block is malformed";

// we have a valid Expectations block, update imports and rewrite with Mockito statements
maybeRemoveImport("mockit.Expectations");

// the first coordinates are the coordinates of the Expectations block, replacing it
JavaCoordinates coordinates = nc.getCoordinates().replace();
J.Block expectationsBlock = (J.Block) nc.getBody().getStatements().get(0);
List<Object> templateParams = new ArrayList<>();

// iterate over the expectations statements and rebuild the method body
int mockitoStatementIndex = 0;
for (Statement expectationStatement : expectationsBlock.getStatements()) {
// TODO: handle additional jmockit expectations features

if (expectationStatement instanceof J.MethodInvocation) {
if (!templateParams.isEmpty()) {
// apply template to build new method body
newBody = applyTemplate(ctx, templateParams, cursorLocation, coordinates);

// next statement coordinates are immediately after the statement just added
int newStatementIndex = bodyStatementIndex + mockitoStatementIndex;
coordinates = newBody.getStatements().get(newStatementIndex).getCoordinates().after();

// cursor location is now the new body
cursorLocation = newBody;

// reset template params for next expectation
templateParams = new ArrayList<>();
mockitoStatementIndex += 1;
try {
// iterate over each statement in the method body, find Expectations blocks and rewrite them
for (int bodyStatementIndex = 0; bodyStatementIndex < statements.size(); bodyStatementIndex++) {
Statement s = statements.get(bodyStatementIndex);
if (!(s instanceof J.NewClass)) {
continue;
}
J.NewClass nc = (J.NewClass) s;
if (!(nc.getClazz() instanceof J.Identifier)) {
continue;
}
J.Identifier clazz = (J.Identifier) nc.getClazz();
if (!TypeUtils.isAssignableTo("mockit.Expectations", clazz.getType())) {
continue;
}
// empty Expectations block is considered invalid
assert nc.getBody() != null && !nc.getBody().getStatements().isEmpty() : "Expectations block is empty";
// Expectations block should be composed of a block within another block
assert nc.getBody().getStatements().size() == 1 : "Expectations block is malformed";

// we have a valid Expectations block, update imports and rewrite with Mockito statements
maybeRemoveImport("mockit.Expectations");

// the first coordinates are the coordinates of the Expectations block, replacing it
JavaCoordinates coordinates = nc.getCoordinates().replace();
J.Block expectationsBlock = (J.Block) nc.getBody().getStatements().get(0);
List<Object> templateParams = new ArrayList<>();

// iterate over the expectations statements and rebuild the method body
int mockitoStatementIndex = 0;
for (Statement expectationStatement : expectationsBlock.getStatements()) {
// TODO: handle additional jmockit expectations features

if (expectationStatement instanceof J.MethodInvocation) {
if (!templateParams.isEmpty()) {
// apply template to build new method body
newBody = rewriteMethodBody(ctx, templateParams, cursorLocation, coordinates);

// next statement coordinates are immediately after the statement just added
int newStatementIndex = bodyStatementIndex + mockitoStatementIndex;
coordinates = newBody.getStatements().get(newStatementIndex).getCoordinates().after();

// cursor location is now the new body
cursorLocation = newBody;

// reset template params for next expectation
templateParams = new ArrayList<>();
mockitoStatementIndex += 1;
}
templateParams.add(expectationStatement);
} else {
// assignment
templateParams.add(((J.Assignment) expectationStatement).getAssignment());
}
templateParams.add(expectationStatement);
} else {
// assignment
templateParams.add(((J.Assignment) expectationStatement).getAssignment());
}
}

// handle the last statement
if (!templateParams.isEmpty()) {
newBody = applyTemplate(ctx, templateParams, cursorLocation, coordinates);
// handle the last statement
if (!templateParams.isEmpty()) {
newBody = rewriteMethodBody(ctx, templateParams, cursorLocation, coordinates);
}
}
} catch (Exception e) {
// if anything goes wrong, just return the original method declaration
return md;
}

return md.withBody(newBody);
}

private J.Block applyTemplate(ExecutionContext ctx, List<Object> templateParams, Object cursorLocation,
JavaCoordinates coordinates) {
private J.Block rewriteMethodBody(ExecutionContext ctx, List<Object> templateParams, Object cursorLocation,
JavaCoordinates coordinates) {
Expression result = null;
String methodName = "doNothing";
if (templateParams.size() > 1) {
String methodName;
if (templateParams.size() == 1) {
methodName = "doNothing";
} else if (templateParams.size() == 2) {
methodName = "when";
result = (Expression) templateParams.get(1);
} else {
throw new IllegalStateException("Unexpected number of template params: " + templateParams.size());
}
maybeAddImport("org.mockito.Mockito", methodName);
rewriteArgumentMatchers(ctx, templateParams);
Expand All @@ -166,29 +182,78 @@ private J.Block applyTemplate(ExecutionContext ctx, List<Object> templateParams,
);
}

private void rewriteArgumentMatchers(ExecutionContext ctx, List<Object> templateParams) {
J.MethodInvocation invocation = (J.MethodInvocation) templateParams.get(0);
private void rewriteArgumentMatchers(ExecutionContext ctx, List<Object> bodyTemplateParams) {
J.MethodInvocation invocation = (J.MethodInvocation) bodyTemplateParams.get(0);
List<Expression> newArguments = new ArrayList<>(invocation.getArguments().size());
for (Expression methodArgument : invocation.getArguments()) {
if (!isArgumentMatcher(methodArgument)) {
newArguments.add(methodArgument);
continue;
}
String argumentMatcher = ((J.Identifier) methodArgument).getSimpleName();
maybeAddImport("org.mockito.Mockito", argumentMatcher);
newArguments.add(JavaTemplate.builder(argumentMatcher + "()")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-3.12"))
.staticImports("org.mockito.Mockito." + argumentMatcher)
.build()
.apply(
new Cursor(getCursor(), methodArgument),
methodArgument.getCoordinates().replace()
));
String argumentMatcher, template;
List<Object> argumentTemplateParams = new ArrayList<>();
if (!(methodArgument instanceof J.TypeCast)) {
argumentMatcher = ((J.Identifier) methodArgument).getSimpleName();
template = argumentMatcher + "()";
newArguments.add(rewriteMethodArgument(ctx, argumentMatcher, template, methodArgument,
argumentTemplateParams));
continue;
}
J.TypeCast tc = (J.TypeCast) methodArgument;
argumentMatcher = ((J.Identifier) tc.getExpression()).getSimpleName();
String className, fqn;
JavaType typeCastType = tc.getType();
if (typeCastType instanceof JavaType.Parameterized) {
// strip the raw type from the parameterized type
className = ((JavaType.Parameterized) typeCastType).getType().getClassName();
fqn = ((JavaType.Parameterized) typeCastType).getType().getFullyQualifiedName();
} else if (typeCastType instanceof JavaType.FullyQualified) {
className = ((JavaType.FullyQualified) typeCastType).getClassName();
fqn = ((JavaType.FullyQualified) typeCastType).getFullyQualifiedName();
} else {
throw new IllegalStateException("Unexpected J.TypeCast type: " + typeCastType);
}
if (MOCKITO_COLLECTION_MATCHERS.containsKey(fqn)) {
// mockito has specific argument matchers for collections
argumentMatcher = MOCKITO_COLLECTION_MATCHERS.get(fqn);
template = argumentMatcher + "()";
} else {
// rewrite parameter from ((<type>) any) to <type>.class
argumentTemplateParams.add(JavaTemplate.builder("#{}.class")
.javaParser(JavaParser.fromJavaVersion())
.imports(fqn)
.build()
.apply(
new Cursor(getCursor(), tc),
tc.getCoordinates().replace(),
className
));
template = argumentMatcher + "(#{any(java.lang.Class)})";
}
newArguments.add(rewriteMethodArgument(ctx, argumentMatcher, template, methodArgument,
argumentTemplateParams));
}
templateParams.set(0, invocation.withArguments(newArguments));
bodyTemplateParams.set(0, invocation.withArguments(newArguments));
}

private Expression rewriteMethodArgument(ExecutionContext ctx, String argumentMatcher, String template,
Expression methodArgument, List<Object> templateParams) {
maybeAddImport("org.mockito.Mockito", argumentMatcher);
return JavaTemplate.builder(template)
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-3.12"))
.staticImports("org.mockito.Mockito." + argumentMatcher)
.build()
.apply(
new Cursor(getCursor(), methodArgument),
methodArgument.getCoordinates().replace(),
templateParams.toArray()
);
}

private static boolean isArgumentMatcher(Expression expression) {
if (expression instanceof J.TypeCast) {
expression = ((J.TypeCast) expression).getExpression();
}
if (!(expression instanceof J.Identifier)) {
return false;
}
Expand All @@ -209,7 +274,7 @@ private static String getMockitoStatementTemplate(Expression result) {
? THROWABLE_RESULT_TEMPLATE
: OBJECT_RESULT_TEMPLATE;
} else {
throw new IllegalStateException("Unexpected value: " + result.getType());
throw new IllegalStateException("Unexpected expression type for template: " + result.getType());
}
return template;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,74 @@ void test() throws RuntimeException {
);
}

@Test
void jMockitExpectationsToMockitoWhenClassArgumentMatcher() {
//language=java
rewriteRun(
java(
"""
import java.util.List;
class MyObject {
public String getSomeField(List<String> input) {
return "X";
}
}
"""
),
java(
"""
import java.util.ArrayList;
import java.util.List;
import mockit.Expectations;
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
void test() {
new Expectations() {{
myObject.getSomeField((List<String>) any);
result = null;
}};
assertNotNull(myObject.getSomeField(new ArrayList<>()));
}
}
""",
"""
import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.anyList;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;
void test() {
when(myObject.getSomeField(anyList())).thenReturn(null);
assertNotNull(myObject.getSomeField(new ArrayList<>()));
}
}
"""
)
);
}

@Test
void jMockitExpectationsToMockitoWhenMultipleStatements() {
//language=java
Expand All @@ -480,18 +548,18 @@ public void doSomething() {}
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;
@Mocked
MyObject myOtherObject;
void test() {
new Expectations() {{
myObject.getSomeIntField();
Expand Down

0 comments on commit 5c2367e

Please sign in to comment.