Skip to content

Commit

Permalink
Support intersection type casts (#3652)
Browse files Browse the repository at this point in the history
* Support intersection type casts

Fixes: #3651

* Add missing `ReloadableJava21JavadocVisitor`

* Complete type attribution for `var` variables

* Completed Java 8, 11, and 21

* Let `J.IntersectionType` implement `Expression`

* Extend `JavaTypeVisitor` for `Intersection`

* Implement `TypeUtils#isAssignableTo()` for `Intersection`
  • Loading branch information
knutwannheden authored Nov 21, 2023
1 parent 9965bbf commit 552a4a8
Show file tree
Hide file tree
Showing 23 changed files with 415 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,14 @@ public J visitInstanceOf(InstanceOfTree node, Space fmt) {
typeMapping.type(node));
}

@Override
public J visitIntersectionType(IntersectionTypeTree node, Space fmt) {
JContainer<TypeTree> bounds = node.getBounds().isEmpty() ? null :
JContainer.build(EMPTY,
convertAll(node.getBounds(), t -> sourceBefore("&"), noDelim), Markers.EMPTY);
return new J.IntersectionType(randomId(), fmt, Markers.EMPTY, bounds);
}

@Override
public J visitLabeledStatement(LabeledStatementTree node, Space fmt) {
skip(node.getLabel().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.openrewrite.java.tree.TypeUtils;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
Expand Down Expand Up @@ -59,7 +60,9 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
return existing;
}

if (type instanceof Type.ClassType) {
if (type instanceof Type.IntersectionClassType) {
return intersectionType((Type.IntersectionClassType) type, signature);
} else if (type instanceof Type.ClassType) {
return classType((Type.ClassType) type, signature);
} else if (type instanceof Type.TypeVar) {
return generic((Type.TypeVar) type, signature);
Expand All @@ -78,6 +81,19 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
throw new UnsupportedOperationException("Unknown type " + type.getClass().getName());
}

private JavaType intersectionType(Type.IntersectionClassType type, String signature) {
JavaType.Intersection intersection = new JavaType.Intersection(null);
typeCache.put(signature, intersection);
JavaType[] types = new JavaType[type.getBounds().size()];
List<? extends TypeMirror> bounds = type.getBounds();
for (int i = 0; i < bounds.size(); i++) {
TypeMirror bound = bounds.get(i);
types[i] = type((Type) bound);
}
intersection.unsafeSet(types);
return intersection;
}

private JavaType array(Type type, String signature) {
JavaType.Array arr = new JavaType.Array(null, null);
typeCache.put(signature, arr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.openrewrite.java.tree.JavaType;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.util.HashSet;
import java.util.Set;
import java.util.StringJoiner;
Expand All @@ -39,6 +40,13 @@ public String signature(@Nullable Object t) {
private String signature(@Nullable Type type) {
if (type == null || type instanceof Type.UnknownType || type instanceof NullType) {
return "{undefined}";
} else if (type instanceof Type.IntersectionClassType) {
Type.IntersectionClassType intersectionClassType = (Type.IntersectionClassType) type;
StringJoiner joiner = new StringJoiner(" & ");
for (TypeMirror typeArg : intersectionClassType.getBounds()) {
joiner.add(signature(typeArg));
}
return joiner.toString();
} else if (type instanceof Type.ClassType) {
try {
return ((Type.ClassType) type).typarams_field != null && ((Type.ClassType) type).typarams_field.length() > 0 ? parameterizedSignature(type) : classSignature(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,14 @@ public J visitInstanceOf(InstanceOfTree node, Space fmt) {
type);
}

@Override
public J visitIntersectionType(IntersectionTypeTree node, Space fmt) {
JContainer<TypeTree> bounds = node.getBounds().isEmpty() ? null :
JContainer.build(EMPTY,
convertAll(node.getBounds(), t -> sourceBefore("&"), noDelim), Markers.EMPTY);
return new J.IntersectionType(randomId(), fmt, Markers.EMPTY, bounds);
}

@Override
public J visitLabeledStatement(LabeledStatementTree node, Space fmt) {
skip(node.getLabel().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.openrewrite.java.tree.TypeUtils;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
Expand Down Expand Up @@ -56,7 +57,9 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
return existing;
}

if (type instanceof Type.ClassType) {
if (type instanceof Type.IntersectionClassType) {
return intersectionType((Type.IntersectionClassType) type, signature);
} else if (type instanceof Type.ClassType) {
return classType((Type.ClassType) type, signature);
} else if (type instanceof Type.TypeVar) {
return generic((Type.TypeVar) type, signature);
Expand All @@ -75,6 +78,19 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
throw new UnsupportedOperationException("Unknown type " + type.getClass().getName());
}

private JavaType intersectionType(Type.IntersectionClassType type, String signature) {
JavaType.Intersection intersection = new JavaType.Intersection(null);
typeCache.put(signature, intersection);
JavaType[] types = new JavaType[type.getBounds().size()];
List<? extends TypeMirror> bounds = type.getBounds();
for (int i = 0; i < bounds.size(); i++) {
TypeMirror bound = bounds.get(i);
types[i] = type((Type) bound);
}
intersection.unsafeSet(types);
return intersection;
}

private JavaType array(Type type, String signature) {
JavaType.Array arr = new JavaType.Array(null, null);
typeCache.put(signature, arr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.openrewrite.java.tree.JavaType;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.util.HashSet;
import java.util.Set;
import java.util.StringJoiner;
Expand All @@ -39,6 +40,13 @@ public String signature(@Nullable Object t) {
private String signature(@Nullable Type type) {
if (type == null || type instanceof Type.UnknownType || type instanceof NullType) {
return "{undefined}";
} else if (type instanceof Type.IntersectionClassType) {
Type.IntersectionClassType intersectionClassType = (Type.IntersectionClassType) type;
StringJoiner joiner = new StringJoiner(" & ");
for (TypeMirror typeArg : intersectionClassType.getBounds()) {
joiner.add(signature(typeArg));
}
return joiner.toString();
} else if (type instanceof Type.ClassType) {
try {
return ((Type.ClassType) type).typarams_field != null && ((Type.ClassType) type).typarams_field.length() > 0 ? parameterizedSignature(type) : classSignature(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,14 @@ public J visitInstanceOf(InstanceOfTree node, Space fmt) {
type);
}

@Override
public J visitIntersectionType(IntersectionTypeTree node, Space fmt) {
JContainer<TypeTree> bounds = node.getBounds().isEmpty() ? null :
JContainer.build(EMPTY,
convertAll(node.getBounds(), t -> sourceBefore("&"), noDelim), Markers.EMPTY);
return new J.IntersectionType(randomId(), fmt, Markers.EMPTY, bounds);
}

@Override
public J visitLabeledStatement(LabeledStatementTree node, Space fmt) {
skip(node.getLabel().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.openrewrite.java.tree.TypeUtils;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
Expand Down Expand Up @@ -56,7 +57,9 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
return existing;
}

if (type instanceof Type.ClassType) {
if (type instanceof Type.IntersectionClassType) {
return intersectionType((Type.IntersectionClassType) type, signature);
} else if (type instanceof Type.ClassType) {
return classType((Type.ClassType) type, signature);
} else if (type instanceof Type.TypeVar) {
return generic((Type.TypeVar) type, signature);
Expand All @@ -75,6 +78,19 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
throw new UnsupportedOperationException("Unknown type " + type.getClass().getName());
}

private JavaType intersectionType(Type.IntersectionClassType type, String signature) {
JavaType.Intersection intersection = new JavaType.Intersection(null);
typeCache.put(signature, intersection);
JavaType[] types = new JavaType[type.getBounds().size()];
List<? extends TypeMirror> bounds = type.getBounds();
for (int i = 0; i < bounds.size(); i++) {
TypeMirror bound = bounds.get(i);
types[i] = type((Type) bound);
}
intersection.unsafeSet(types);
return intersection;
}

private JavaType array(Type type, String signature) {
JavaType.Array arr = new JavaType.Array(null, null);
typeCache.put(signature, arr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.openrewrite.java.tree.JavaType;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.util.HashSet;
import java.util.Set;
import java.util.StringJoiner;
Expand All @@ -39,6 +40,13 @@ public String signature(@Nullable Object t) {
private String signature(@Nullable Type type) {
if (type == null || type instanceof Type.UnknownType || type instanceof NullType) {
return "{undefined}";
} else if (type instanceof Type.IntersectionClassType) {
Type.IntersectionClassType intersectionClassType = (Type.IntersectionClassType) type;
StringJoiner joiner = new StringJoiner(" & ");
for (TypeMirror typeArg : intersectionClassType.getBounds()) {
joiner.add(signature(typeArg));
}
return joiner.toString();
} else if (type instanceof Type.ClassType) {
try {
return ((Type.ClassType) type).typarams_field != null && ((Type.ClassType) type).typarams_field.length() > 0 ? parameterizedSignature(type) : classSignature(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,14 @@ public J visitInstanceOf(InstanceOfTree node, Space fmt) {
typeMapping.type(node));
}

@Override
public J visitIntersectionType(IntersectionTypeTree node, Space fmt) {
JContainer<TypeTree> bounds = node.getBounds().isEmpty() ? null :
JContainer.build(EMPTY,
convertAll(node.getBounds(), t -> sourceBefore("&"), noDelim), Markers.EMPTY);
return new J.IntersectionType(randomId(), fmt, Markers.EMPTY, bounds);
}

@Override
public J visitLabeledStatement(LabeledStatementTree node, Space fmt) {
skip(node.getLabel().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.openrewrite.java.tree.TypeUtils;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
Expand Down Expand Up @@ -58,7 +59,9 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
return existing;
}

if (type instanceof Type.ClassType) {
if (type instanceof Type.IntersectionClassType) {
return intersectionType((Type.IntersectionClassType) type, signature);
} else if (type instanceof Type.ClassType) {
return classType((Type.ClassType) type, signature);
} else if (type instanceof Type.TypeVar) {
return generic((Type.TypeVar) type, signature);
Expand All @@ -79,6 +82,19 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
throw new UnsupportedOperationException("Unknown type " + type.getClass().getName());
}

private JavaType intersectionType(Type.IntersectionClassType type, String signature) {
JavaType.Intersection intersection = new JavaType.Intersection(null);
typeCache.put(signature, intersection);
JavaType[] types = new JavaType[type.getBounds().size()];
List<? extends TypeMirror> bounds = type.getBounds();
for (int i = 0; i < bounds.size(); i++) {
TypeMirror bound = bounds.get(i);
types[i] = type((Type) bound);
}
intersection.unsafeSet(types);
return intersection;
}

private JavaType array(Type type, String signature) {
JavaType.Array arr = new JavaType.Array(null, null);
typeCache.put(signature, arr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.openrewrite.java.tree.JavaType;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.util.HashSet;
import java.util.Set;
import java.util.StringJoiner;
Expand All @@ -38,6 +39,13 @@ public String signature(@Nullable Object t) {
private String signature(@Nullable Type type) {
if (type == null || type instanceof Type.UnknownType || type instanceof NullType) {
return "{undefined}";
} else if (type instanceof Type.IntersectionClassType) {
Type.IntersectionClassType intersectionClassType = (Type.IntersectionClassType) type;
StringJoiner joiner = new StringJoiner(" & ");
for (TypeMirror typeArg : intersectionClassType.getBounds()) {
joiner.add(signature(typeArg));
}
return joiner.toString();
} else if (type instanceof Type.ClassType) {
try {
return ((Type.ClassType) type).typarams_field != null && ((Type.ClassType) type).typarams_field.length() > 0 ? parameterizedSignature(type) : classSignature(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
package org.openrewrite.java.tree;

import org.junit.jupiter.api.Test;
import org.openrewrite.java.MinimumJava11;
import org.openrewrite.test.RewriteTest;

import static org.assertj.core.api.Assertions.assertThat;
import static org.openrewrite.java.Assertions.java;

class TypeCastTest implements RewriteTest {
Expand All @@ -35,4 +37,59 @@ class Test {
)
);
}

@Test
void intersectionCast() {
rewriteRun(
java(
"""
import java.io.Serializable;
import java.util.function.BiFunction;
class Test {
Serializable s = (Serializable & BiFunction<Integer, Integer, Integer>) Integer::sum;
}
"""
)
);
}

@MinimumJava11
@Test
void intersectionCastAssignedToVar() {
rewriteRun(
java(
"""
import java.io.Serializable;
import java.util.function.BiFunction;
class Test {
void m() {
var s = (Serializable & BiFunction<Integer, Integer, Integer>) Integer::sum;
}
}
""",
spec -> spec.afterRecipe(cu -> {
J.MethodDeclaration m = (J.MethodDeclaration) cu.getClasses().get(0).getBody().getStatements().get(0);
J.VariableDeclarations s = (J.VariableDeclarations) m.getBody().getStatements().get(0);
assertThat(s.getType()).isInstanceOf(JavaType.Intersection.class);
JavaType.Intersection intersection = (JavaType.Intersection) s.getType();
assertThat(intersection.getBounds()).satisfiesExactly(
b1 -> assertThat(b1).satisfies(
t -> assertThat(t).isInstanceOf(JavaType.Class.class),
t -> assertThat(((JavaType.Class) t).getFullyQualifiedName()).isEqualTo("java.io.Serializable")
),
b2 -> assertThat(b2).satisfies(
t -> assertThat(t).isInstanceOf(JavaType.Parameterized.class),
t -> assertThat(((JavaType.Parameterized) t).getFullyQualifiedName()).isEqualTo("java.util.function.BiFunction"),
t -> assertThat(((JavaType.Parameterized) t).getTypeParameters()).hasSize(3),
t -> assertThat(((JavaType.Parameterized) t).getTypeParameters()).allSatisfy(
p -> assertThat(((JavaType.Class) p).getFullyQualifiedName()).isEqualTo("java.lang.Integer")
)
)
);
})
)
);
}
}
Loading

0 comments on commit 552a4a8

Please sign in to comment.