Skip to content

Commit

Permalink
[javac] AST conversion for guarded record patterns
Browse files Browse the repository at this point in the history
eg.

```java
public class A {
	static void myMethod() {
		MyRecord myRecordInstance = new MyRecord(new ChildRecord(1));

		switch (myRecordInstance) {
			case MyRecord(ChildRecord(int value)) when value /* <-- go to definition here */ == 1:
				System.out.println("asdf");
				break;
			default:
				break;
		}
	}

	static record ChildRecord(int a) {
	}

	static record MyRecord(ChildRecord b) {
	}
}
```

Signed-off-by: David Thompson <[email protected]>
  • Loading branch information
datho7561 authored and mickaelistria committed May 10, 2024
1 parent 2bb0a2c commit 47222c9
Showing 1 changed file with 63 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@
import com.sun.tools.javac.tree.JCTree.JCBlock;
import com.sun.tools.javac.tree.JCTree.JCBreak;
import com.sun.tools.javac.tree.JCTree.JCCase;
import com.sun.tools.javac.tree.JCTree.JCCaseLabel;
import com.sun.tools.javac.tree.JCTree.JCCatch;
import com.sun.tools.javac.tree.JCTree.JCClassDecl;
import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
import com.sun.tools.javac.tree.JCTree.JCConditional;
import com.sun.tools.javac.tree.JCTree.JCConstantCaseLabel;
import com.sun.tools.javac.tree.JCTree.JCContinue;
import com.sun.tools.javac.tree.JCTree.JCDirective;
import com.sun.tools.javac.tree.JCTree.JCDoWhileLoop;
Expand Down Expand Up @@ -91,8 +93,10 @@
import com.sun.tools.javac.tree.JCTree.JCPackageDecl;
import com.sun.tools.javac.tree.JCTree.JCParens;
import com.sun.tools.javac.tree.JCTree.JCPattern;
import com.sun.tools.javac.tree.JCTree.JCPatternCaseLabel;
import com.sun.tools.javac.tree.JCTree.JCPrimitiveTypeTree;
import com.sun.tools.javac.tree.JCTree.JCProvides;
import com.sun.tools.javac.tree.JCTree.JCRecordPattern;
import com.sun.tools.javac.tree.JCTree.JCRequires;
import com.sun.tools.javac.tree.JCTree.JCReturn;
import com.sun.tools.javac.tree.JCTree.JCSkip;
Expand Down Expand Up @@ -1263,14 +1267,7 @@ private Expression convertExpression(JCExpression javac) {
PatternInstanceofExpression res = this.ast.newPatternInstanceofExpression();
commonSettings(res, javac);
res.setLeftOperand(convertExpression(jcInstanceOf.getExpression()));
if (jcPattern instanceof JCBindingPattern jcBindingPattern) {
TypePattern jdtPattern = this.ast.newTypePattern();
commonSettings(jdtPattern, jcBindingPattern);
jdtPattern.setPatternVariable((SingleVariableDeclaration)convertVariableDeclaration(jcBindingPattern.var));
res.setPattern(jdtPattern);
} else {
throw new UnsupportedOperationException("Missing support to convert '" + jcPattern + "' of type " + javac.getClass().getSimpleName());
}
res.setPattern(convert(jcPattern));
return res;
}
if (javac instanceof JCArrayAccess jcArrayAccess) {
Expand Down Expand Up @@ -1325,7 +1322,7 @@ private Expression convertExpression(JCExpression javac) {
ASTNode body = jcLambda.getBody() instanceof JCExpression expr ? convertExpression(expr) :
jcLambda.getBody() instanceof JCStatement stmt ? convertStatement(stmt, res) :
null;
if( body != null )
if( body != null )
res.setBody(body);
// TODO set parenthesis looking at the next non-whitespace char after the last parameter
int endPos = jcLambda.getEndPosition(this.javacCompilationUnit.endPositions);
Expand Down Expand Up @@ -1377,6 +1374,24 @@ private Expression convertExpression(JCExpression javac) {
return substitute;
}

private Pattern convert(JCPattern jcPattern) {
if (jcPattern instanceof JCBindingPattern jcBindingPattern) {
TypePattern jdtPattern = this.ast.newTypePattern();
commonSettings(jdtPattern, jcBindingPattern);
jdtPattern.setPatternVariable((SingleVariableDeclaration)convertVariableDeclaration(jcBindingPattern.var));
return jdtPattern;
} else if (jcPattern instanceof JCRecordPattern jcRecordPattern) {
RecordPattern jdtPattern = this.ast.newRecordPattern();
commonSettings(jdtPattern, jcRecordPattern);
jdtPattern.setPatternType(convertToType(jcRecordPattern.deconstructor));
for (JCPattern nestedJcPattern : jcRecordPattern.nested) {
jdtPattern.patterns().add(convert(nestedJcPattern));
}
return jdtPattern;
}
throw new UnsupportedOperationException("Missing support to convert '" + jcPattern);
}

private ArrayInitializer createArrayInitializerFromJCNewArray(JCNewArray jcNewArray) {
ArrayInitializer initializer = this.ast.newArrayInitializer();
commonSettings(initializer, jcNewArray);
Expand Down Expand Up @@ -1635,7 +1650,7 @@ private Statement convertStatement(JCStatement javac, ASTNode parent) {
if (javac instanceof JCForLoop jcForLoop) {
ForStatement res = this.ast.newForStatement();
commonSettings(res, javac);
Statement stmt = convertStatement(jcForLoop.getStatement(), res);
Statement stmt = convertStatement(jcForLoop.getStatement(), res);
if( stmt != null )
res.setBody(stmt);
var initializerIt = jcForLoop.getInitializer().iterator();
Expand Down Expand Up @@ -1710,8 +1725,43 @@ private Statement convertStatement(JCStatement javac, ASTNode parent) {
SwitchCase res = this.ast.newSwitchCase();
commonSettings(res, javac);
if( this.ast.apiLevel >= AST.JLS14_INTERNAL) {
if (jcCase.getGuard() != null && (jcCase.getLabels().size() > 1 || jcCase.getLabels().get(0) instanceof JCPatternCaseLabel)) {
GuardedPattern guardedPattern = this.ast.newGuardedPattern();
guardedPattern.setExpression(convertExpression(jcCase.getGuard()));
if (jcCase.getLabels().length() > 1) {
int start = Integer.MAX_VALUE;
int end = Integer.MIN_VALUE;
EitherOrMultiPattern eitherOrMultiPattern = this.ast.newEitherOrMultiPattern();
for (JCCaseLabel label : jcCase.getLabels()) {
if (label.pos < start) {
start = label.pos;
}
if (end < label.getEndPosition(this.javacCompilationUnit.endPositions)) {
end = label.getEndPosition(this.javacCompilationUnit.endPositions);
}
if (label instanceof JCPatternCaseLabel jcPattern) {
eitherOrMultiPattern.patterns().add(convert(jcPattern.getPattern()));
}
// skip over any constants, they are not valid anyways
}
eitherOrMultiPattern.setSourceRange(start, end - start);
guardedPattern.setPattern(eitherOrMultiPattern);
} else if (jcCase.getLabels().length() == 1) {
if (jcCase.getLabels().get(0) instanceof JCPatternCaseLabel jcPattern) {
guardedPattern.setPattern(convert(jcPattern.getPattern()));
} else {
// see same above note regarding guarded case labels using constants
throw new UnsupportedOperationException("cannot convert case label: " + jcCase.getLabels().get(0));
}
}
int start = guardedPattern.getPattern().getStartPosition();
int end = guardedPattern.getExpression().getStartPosition() + guardedPattern.getExpression().getLength();
guardedPattern.setSourceRange(start, end - start);
res.expressions().add(guardedPattern);
} else {
jcCase.getExpressions().stream().map(this::convertExpression).forEach(res.expressions()::add);
}
res.setSwitchLabeledRule(jcCase.getCaseKind() == CaseKind.RULE);
jcCase.getExpressions().stream().map(this::convertExpression).forEach(res.expressions()::add);
} else {
List<JCExpression> l = jcCase.getExpressions();
if( l.size() == 1 ) {
Expand All @@ -1731,7 +1781,7 @@ private Statement convertStatement(JCStatement javac, ASTNode parent) {
expr = jcp.getExpression();
}
res.setExpression(convertExpression(expr));
Statement body = convertStatement(jcWhile.getStatement(), res);
Statement body = convertStatement(jcWhile.getStatement(), res);
if( body != null )
res.setBody(body);
return res;
Expand All @@ -1746,7 +1796,7 @@ private Statement convertStatement(JCStatement javac, ASTNode parent) {
Expression expr1 = convertExpression(expr);
if( expr != null )
res.setExpression(expr1);

Statement body = convertStatement(jcDoWhile.getStatement(), res);
if( body != null )
res.setBody(body);
Expand Down

0 comments on commit 47222c9

Please sign in to comment.