Skip to content

Commit

Permalink
[FLINK-35437] Rewrite BlockStatementGrouper so that it uses less memo…
Browse files Browse the repository at this point in the history
…ry (apache#24834)
  • Loading branch information
dawidwys authored May 24, 2024
1 parent 54f037f commit 3d40bd7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenStreamRewriter;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.apache.commons.lang3.tuple.Pair;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand Down Expand Up @@ -107,12 +106,14 @@
@Internal
public class BlockStatementGrouper {

private final String code;

private final long maxMethodLength;

private final String parameters;

private final TokenStreamRewriter rewriter;

private final StatementContext topStatement;

/**
* Initialize new BlockStatementGrouper.
*
Expand All @@ -121,9 +122,14 @@ public class BlockStatementGrouper {
* @param parameters parameters definition that should be used for extracted methods.
*/
public BlockStatementGrouper(String code, long maxMethodLength, String parameters) {
this.code = code;
this.maxMethodLength = maxMethodLength;
this.parameters = parameters;
CommonTokenStream tokenStream =
new CommonTokenStream(new JavaLexer(CharStreams.fromString(code)));
JavaParser javaParser = new JavaParser(tokenStream);
javaParser.getInterpreter().setPredictionMode(PredictionMode.SLL);
this.topStatement = javaParser.statement();
this.rewriter = new TokenStreamRewriter(tokenStream);
}

/**
Expand All @@ -135,38 +141,21 @@ public BlockStatementGrouper(String code, long maxMethodLength, String parameter
* groups with their names and content.
*/
public RewriteGroupedCode rewrite(String context) {

BlockStatementGrouperVisitor visitor =
new BlockStatementGrouperVisitor(maxMethodLength, parameters);
CommonTokenStream tokenStream =
new CommonTokenStream(new JavaLexer(CharStreams.fromString(code)));
JavaParser javaParser = new JavaParser(tokenStream);
javaParser.getInterpreter().setPredictionMode(PredictionMode.SLL);
TokenStreamRewriter rewriter = new TokenStreamRewriter(tokenStream);
visitor.visitStatement(javaParser.statement(), context, rewriter);

visitor.rewrite();
Map<String, Pair<TokenStreamRewriter, List<LocalGroupElement>>> groups = visitor.groups;

Map<String, List<String>> groupStrings =
CollectionUtil.newHashMapWithExpectedSize(groups.size());
for (Entry<String, Pair<TokenStreamRewriter, List<LocalGroupElement>>> group :
groups.entrySet()) {
List<String> collectedStringGroups =
group.getValue().getValue().stream()
.map(LocalGroupElement::getBody)
.collect(Collectors.toList());

groupStrings.put(group.getKey(), collectedStringGroups);
}

visitor.visitStatement(topStatement, context);
final Map<String, List<String>> groupStrings = visitor.rewrite(rewriter);

return new RewriteGroupedCode(rewriter.getText(), groupStrings);
}

private static class BlockStatementGrouperVisitor {

private final Map<String, Pair<TokenStreamRewriter, List<LocalGroupElement>>> groups =
new HashMap<>();
// Needs to be an ordered map, so that we later apply the innermost nested
// groups/transformations first and work on the results of such extractions with outer
// groups/transformations.
private final Map<String, List<LocalGroupElement>> groups = new LinkedHashMap<>();

private final long maxMethodLength;

Expand All @@ -179,8 +168,7 @@ private BlockStatementGrouperVisitor(long maxMethodLength, String parameters) {
this.parameters = parameters;
}

public void visitStatement(
StatementContext ctx, String context, TokenStreamRewriter rewriter) {
public void visitStatement(StatementContext ctx, String context) {

if (ctx.getChildCount() == 0) {
return;
Expand All @@ -193,21 +181,20 @@ public void visitStatement(
for (StatementContext statement : ctx.statement()) {
if (shouldExtract(statement)) {
String localContext = String.format("%s_%d", context, counter++);
groupBlock(statement, localContext, rewriter);
groupBlock(statement, localContext);
}
}
} else {
// The block did not start from IF/ELSE/WHILE statement
if (shouldExtract(ctx)) {
groupBlock(ctx, context, rewriter);
groupBlock(ctx, context);
}
}
}

// Group continuous block of statements together. If Statement is an IF/ELSE/WHILE,
// its body can be further grouped by recursive call to visitStatement method.
private void groupBlock(
StatementContext ctx, String context, TokenStreamRewriter rewriter) {
private void groupBlock(StatementContext ctx, String context) {
int localGroupCodeLength = 0;
List<LocalGroupElement> localGroup = new ArrayList<>();
for (BlockStatementContext bsc : ctx.block().blockStatement()) {
Expand All @@ -218,17 +205,9 @@ private void groupBlock(
|| statement.WHILE() != null) {
String localContext = context + "_rewriteGroup" + this.counter++;

CommonTokenStream tokenStream =
new CommonTokenStream(
new JavaLexer(
CharStreams.fromString(
CodeSplitUtil.getContextString(statement))));
TokenStreamRewriter localRewriter = new TokenStreamRewriter(tokenStream);
JavaParser javaParser = new JavaParser(tokenStream);
javaParser.getInterpreter().setPredictionMode(PredictionMode.SLL);
visitStatement(javaParser.statement(), localContext, localRewriter);
visitStatement(statement, localContext);

localGroup.add(new RewriteContextGroupElement(statement, localRewriter));
localGroup.add(new RewriteContextGroupElement(statement));

// new method call length to the localGroupCodeLength. The "3" contains two
// brackets for parameters and semicolon at the end of method call
Expand All @@ -239,7 +218,7 @@ private void groupBlock(
localGroup.add(new ContextGroupElement(bsc));
localGroupCodeLength += bsc.getText().length();
} else {
if (addLocalGroup(localGroup, context, rewriter)) {
if (addLocalGroup(localGroup, context)) {
localGroup = new ArrayList<>();
localGroupCodeLength = 0;
}
Expand All @@ -251,16 +230,15 @@ private void groupBlock(

// Groups that have only one statement that is "single line statement" such as
// "a[2] += b[2];" will not be extracted.
addLocalGroup(localGroup, context, rewriter);
addLocalGroup(localGroup, context);
}

private boolean addLocalGroup(
List<LocalGroupElement> localGroup, String context, TokenStreamRewriter rewriter) {
private boolean addLocalGroup(List<LocalGroupElement> localGroup, String context) {
if (localGroup.size() > 1
|| (localGroup.size() == 1
&& canGroupAsSingleStatement(localGroup.get(0).getContext()))) {
String localContext = context + "_rewriteGroup" + this.counter++;
groups.put(localContext, Pair.of(rewriter, localGroup));
groups.put(localContext, localGroup);
return true;
}

Expand Down Expand Up @@ -302,17 +280,25 @@ private int getNumOfReturnOrJumpStatements(ParserRuleContext ctx) {
return counter.getCounter();
}

private void rewrite() {
for (Entry<String, Pair<TokenStreamRewriter, List<LocalGroupElement>>> group :
groups.entrySet()) {
Pair<TokenStreamRewriter, List<LocalGroupElement>> pair = group.getValue();
TokenStreamRewriter rewriter = pair.getKey();
List<LocalGroupElement> value = pair.getValue();
private Map<String, List<String>> rewrite(TokenStreamRewriter rewriter) {
Map<String, List<String>> groupStrings =
CollectionUtil.newHashMapWithExpectedSize(groups.size());

for (Entry<String, List<LocalGroupElement>> group : groups.entrySet()) {
List<LocalGroupElement> groupElements = group.getValue();
List<String> collectedStringGroups =
groupElements.stream()
.map(el -> el.getBody(rewriter))
.collect(Collectors.toList());
rewriter.replace(
value.get(0).getStart(),
value.get(value.size() - 1).getStop(),
groupElements.get(0).getStart(),
groupElements.get(groupElements.size() - 1).getStop(),
group.getKey() + "(" + this.parameters + ");");

groupStrings.put(group.getKey(), collectedStringGroups);
}

return groupStrings;
}
}

Expand All @@ -330,7 +316,7 @@ private interface LocalGroupElement {
Token getStop();

/** @return String representation of this group element. */
String getBody();
String getBody(TokenStreamRewriter rewriter);

ParserRuleContext getContext();
}
Expand Down Expand Up @@ -359,8 +345,8 @@ public Token getStop() {
}

@Override
public String getBody() {
return CodeSplitUtil.getContextString(this.parserRuleContext);
public String getBody(TokenStreamRewriter rewriter) {
return rewriter.getText(this.parserRuleContext.getSourceInterval());
}

@Override
Expand Down Expand Up @@ -422,12 +408,8 @@ private static class RewriteContextGroupElement implements LocalGroupElement {

private final ParserRuleContext parserRuleContext;

private final TokenStreamRewriter rewriter;

private RewriteContextGroupElement(
ParserRuleContext parserRuleContext, TokenStreamRewriter rewriter) {
private RewriteContextGroupElement(ParserRuleContext parserRuleContext) {
this.parserRuleContext = parserRuleContext;
this.rewriter = rewriter;
}

@Override
Expand All @@ -441,8 +423,8 @@ public Token getStop() {
}

@Override
public String getBody() {
return this.rewriter.getText();
public String getBody(TokenStreamRewriter rewriter) {
return rewriter.getText(parserRuleContext.getSourceInterval());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.functions;

import org.apache.flink.table.api.DataTypes;

import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/** Tests for CASE WHEN expression. */
class CaseFunctionsITCase extends BuiltInFunctionTestBase {

@Override
Stream<TestSetSpec> getTestSetSpecs() {

String caseStatement =
"case "
+ IntStream.range(0, 250)
.mapToObj(idx -> String.format("when f0 = %d then %d", idx, idx))
.collect(Collectors.joining(" "))
+ "else 9999 end";

// Verify a long case when statement which produces deeply nested if else statements
// works correctly.
return Stream.of(
TestSetSpec.forExpression("CASE WHEN")
.onFieldsWithData(110)
.testSqlResult(caseStatement, 110, DataTypes.INT().notNull()),
TestSetSpec.forExpression("CASE WHEN ELSE")
.onFieldsWithData(450)
.testSqlResult(caseStatement, 9999, DataTypes.INT().notNull()));
}
}

0 comments on commit 3d40bd7

Please sign in to comment.