diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertArrayEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertArrayEqualsToAssertThat.java index d28542df6..ecdde2479 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertArrayEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertArrayEqualsToAssertThat.java @@ -73,7 +73,8 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu Expression expected = args.get(0); Expression actual = args.get(1); - maybeAddImport("org.assertj.core.api.Assertions", "assertThat"); + // Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME); if (args.size() == 2) { @@ -93,7 +94,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu .build() .apply(getCursor(), method.getCoordinates().replace(), actual, message, expected); } else if (args.size() == 3) { - maybeAddImport("org.assertj.core.api.Assertions", "within"); + maybeAddImport("org.assertj.core.api.Assertions", "within", false); // assert is using floating points with a delta and no message. return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()}, within(#{any()}));") .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") @@ -104,7 +105,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu // The assertEquals is using a floating point with a delta argument and a message. Expression message = args.get(3); - maybeAddImport("org.assertj.core.api.Assertions", "within"); + maybeAddImport("org.assertj.core.api.Assertions", "within", false); JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()}, within(#{any()}));") : diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java index de96ac7fd..276c3870d 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java @@ -49,6 +49,16 @@ public TreeVisitor getVisitor() { } public static class AssertEqualsToAssertThatVisitor extends JavaIsoVisitor { + private JavaParser.Builder assertionsParser; + + private JavaParser.Builder assertionsParser(ExecutionContext ctx) { + if (assertionsParser == null) { + assertionsParser = JavaParser.fromJavaVersion() + .classpathFromResources(ctx, "assertj-core-3.24"); + } + return assertionsParser; + } + private static final MethodMatcher JUNIT_ASSERT_EQUALS = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertEquals(..)"); @Override @@ -63,13 +73,14 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu //always add the import (even if not referenced) maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + + // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. maybeRemoveImport("org.junit.jupiter.api.Assertions"); if (args.size() == 2) { return JavaTemplate.builder("assertThat(#{any()}).isEqualTo(#{any()});") - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .javaParser(assertionsParser(ctx)) .build() .apply(getCursor(), method.getCoordinates().replace(), actual, expected); } else if (args.size() == 3 && !isFloatingPointType(args.get(2))) { @@ -78,10 +89,9 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isEqualTo(#{any()});") : JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isEqualTo(#{any()});"); return template - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .imports("java.util.function.Supplier") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .javaParser(assertionsParser(ctx)) .build() .apply( getCursor(), @@ -91,11 +101,11 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu expected ); } else if (args.size() == 3) { - maybeAddImport("org.assertj.core.api.Assertions", "within"); + //always add the import (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "within", false); return JavaTemplate.builder("assertThat(#{any()}).isCloseTo(#{any()}, within(#{any()}));") - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .javaParser(assertionsParser(ctx)) .build() .apply(getCursor(), method.getCoordinates().replace(), actual, expected, args.get(2)); @@ -104,15 +114,15 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu // The assertEquals is using a floating point with a delta argument and a message. Expression message = args.get(3); - maybeAddImport("org.assertj.core.api.Assertions", "within"); + //always add the import (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "within", false); JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ? JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isCloseTo(#{any()}, within(#{any()}));") : JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isCloseTo(#{any()}, within(#{any()}));"); return template - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") .imports("java.util.function.Supplier") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) + .javaParser(assertionsParser(ctx)) .build() .apply( getCursor(), diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertFalseToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertFalseToAssertThat.java index 4f879c413..149e5b3df 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertFalseToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertFalseToAssertThat.java @@ -97,7 +97,10 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu ); } - maybeAddImport("org.assertj.core.api.Assertions", "assertThat"); + //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + + // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. maybeRemoveImport("org.junit.jupiter.api.Assertions"); return method; diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java index 20f83f869..d78506810 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java @@ -74,7 +74,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 2) { method = JavaTemplate.builder("assertThat(#{any()}).isNotEqualTo(#{any()});") - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -93,7 +92,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu method = template - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -106,7 +104,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu ); } else if (args.size() == 3) { method = JavaTemplate.builder("assertThat(#{any()}).isNotCloseTo(#{any()}, within(#{any()}));") - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") .javaParser(assertionsParser(ctx)) .build() @@ -117,7 +114,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu expected, args.get(2) ); - maybeAddImport("org.assertj.core.api.Assertions", "within"); + maybeAddImport("org.assertj.core.api.Assertions", "within", false); } else { Expression message = args.get(3); @@ -126,7 +123,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNotCloseTo(#{any()}, within(#{any()}));"); method = template - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") .javaParser(assertionsParser(ctx)) .build() @@ -139,12 +135,13 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu args.get(2) ); - maybeAddImport("org.assertj.core.api.Assertions", "within"); + maybeAddImport("org.assertj.core.api.Assertions", "within", false); } - //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" - maybeAddImport("org.assertj.core.api.Assertions", "assertThat"); - //And if there are no longer references to the JUnit assertions class, we can remove the import. + //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + + // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. maybeRemoveImport("org.junit.jupiter.api.Assertions"); return method; diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java index 2fde4a8b4..dea4c6ffe 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java @@ -71,7 +71,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 1) { method = JavaTemplate.builder("assertThat(#{any()}).isNotNull();") - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -89,7 +88,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNotNull();"); method = template - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -101,8 +99,11 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu ); } + //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + + //And if there are no longer references to the JUnit assertions class, we can remove the import. maybeRemoveImport("org.junit.jupiter.api.Assertions"); - maybeAddImport("org.assertj.core.api.Assertions", "assertThat"); return method; } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java index e94af26e1..c2c916f20 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java @@ -71,7 +71,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 1) { method = JavaTemplate.builder("assertThat(#{any()}).isNull();") - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -88,7 +87,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNull();"); method = template - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -100,12 +98,12 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu ); } + // Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. maybeRemoveImport("org.junit.jupiter.api.Assertions"); - // Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat". - maybeAddImport("org.assertj.core.api.Assertions", "assertThat"); - return method; } } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java index 67b4e4d83..1241d70e0 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java @@ -72,7 +72,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 2) { method = JavaTemplate.builder("assertThat(#{any()}).isSameAs(#{any()});") - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -90,7 +89,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isSameAs(#{any()});"); method = template - .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -103,8 +101,11 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu ); } + // Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + + // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. maybeRemoveImport("org.junit.jupiter.api.Assertions"); - maybeAddImport("org.assertj.core.api.Assertions", "assertThat"); return method; } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java index 4cf686c56..07f55f5b1 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertThrowsToAssertExceptionType.java @@ -85,7 +85,7 @@ && getCursor().getParentTreeCursor().getValue() instanceof J.Block) { mi.getCoordinates().replace(), mi.getArguments().get(0), executable ); - maybeAddImport("org.assertj.core.api.AssertionsForClassTypes", "assertThatExceptionOfType"); + maybeAddImport("org.assertj.core.api.AssertionsForClassTypes", "assertThatExceptionOfType", false); maybeRemoveImport("org.junit.jupiter.api.Assertions.assertThrows"); maybeRemoveImport("org.junit.jupiter.api.Assertions"); } diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertTrueToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertTrueToAssertThat.java index f41b4c736..b478716dc 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertTrueToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertTrueToAssertThat.java @@ -97,7 +97,10 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu ); } - maybeAddImport("org.assertj.core.api.Assertions", "assertThat"); + //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); + + // Remove import for "org.junit.jupiter.api.Assertions" if no longer used. maybeRemoveImport("org.junit.jupiter.api.Assertions"); return method; diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitFailToAssertJFail.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitFailToAssertJFail.java index c0407af8f..d2f1e7118 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitFailToAssertJFail.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitFailToAssertJFail.java @@ -147,7 +147,8 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu method.getCoordinates().replace(), arguments.toArray() ); - maybeAddImport("org.assertj.core.api.Assertions", "fail"); + //Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "fail", false); return super.visitMethodInvocation(method, ctx); } } diff --git a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThatTest.java b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThatTest.java index 8a4bcb5ec..14dc3b999 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThatTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThatTest.java @@ -246,8 +246,9 @@ private String notification() { @Test @Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/491") void importAddedForCustomArguments() { + //language=java rewriteRun( - //language=java + spec -> spec.typeValidationOptions(TypeValidation.none()), java( """ import org.junit.jupiter.api.Test;