From 3d40bd7dd197b12b7b156bd758b4129148e885d1 Mon Sep 17 00:00:00 2001 From: Dawid Wysakowicz Date: Fri, 24 May 2024 10:25:35 +0200 Subject: [PATCH] [FLINK-35437] Rewrite BlockStatementGrouper so that it uses less memory (#24834) --- .../codesplit/BlockStatementGrouper.java | 118 ++++++++---------- .../functions/CaseFunctionsITCase.java | 50 ++++++++ 2 files changed, 100 insertions(+), 68 deletions(-) create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CaseFunctionsITCase.java diff --git a/flink-table/flink-table-code-splitter/src/main/java/org/apache/flink/table/codesplit/BlockStatementGrouper.java b/flink-table/flink-table-code-splitter/src/main/java/org/apache/flink/table/codesplit/BlockStatementGrouper.java index 0d6854ab7486e..e43a34c41b815 100644 --- a/flink-table/flink-table-code-splitter/src/main/java/org/apache/flink/table/codesplit/BlockStatementGrouper.java +++ b/flink-table/flink-table-code-splitter/src/main/java/org/apache/flink/table/codesplit/BlockStatementGrouper.java @@ -28,10 +28,9 @@ import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.TokenStreamRewriter; import org.antlr.v4.runtime.atn.PredictionMode; -import org.apache.commons.lang3.tuple.Pair; import java.util.ArrayList; -import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -107,12 +106,14 @@ @Internal public class BlockStatementGrouper { - private final String code; - private final long maxMethodLength; private final String parameters; + private final TokenStreamRewriter rewriter; + + private final StatementContext topStatement; + /** * Initialize new BlockStatementGrouper. * @@ -121,9 +122,14 @@ public class BlockStatementGrouper { * @param parameters parameters definition that should be used for extracted methods. */ public BlockStatementGrouper(String code, long maxMethodLength, String parameters) { - this.code = code; this.maxMethodLength = maxMethodLength; this.parameters = parameters; + CommonTokenStream tokenStream = + new CommonTokenStream(new JavaLexer(CharStreams.fromString(code))); + JavaParser javaParser = new JavaParser(tokenStream); + javaParser.getInterpreter().setPredictionMode(PredictionMode.SLL); + this.topStatement = javaParser.statement(); + this.rewriter = new TokenStreamRewriter(tokenStream); } /** @@ -135,38 +141,21 @@ public BlockStatementGrouper(String code, long maxMethodLength, String parameter * groups with their names and content. */ public RewriteGroupedCode rewrite(String context) { - BlockStatementGrouperVisitor visitor = new BlockStatementGrouperVisitor(maxMethodLength, parameters); - CommonTokenStream tokenStream = - new CommonTokenStream(new JavaLexer(CharStreams.fromString(code))); - JavaParser javaParser = new JavaParser(tokenStream); - javaParser.getInterpreter().setPredictionMode(PredictionMode.SLL); - TokenStreamRewriter rewriter = new TokenStreamRewriter(tokenStream); - visitor.visitStatement(javaParser.statement(), context, rewriter); - - visitor.rewrite(); - Map>> groups = visitor.groups; - - Map> groupStrings = - CollectionUtil.newHashMapWithExpectedSize(groups.size()); - for (Entry>> group : - groups.entrySet()) { - List collectedStringGroups = - group.getValue().getValue().stream() - .map(LocalGroupElement::getBody) - .collect(Collectors.toList()); - - groupStrings.put(group.getKey(), collectedStringGroups); - } + + visitor.visitStatement(topStatement, context); + final Map> groupStrings = visitor.rewrite(rewriter); return new RewriteGroupedCode(rewriter.getText(), groupStrings); } private static class BlockStatementGrouperVisitor { - private final Map>> groups = - new HashMap<>(); + // Needs to be an ordered map, so that we later apply the innermost nested + // groups/transformations first and work on the results of such extractions with outer + // groups/transformations. + private final Map> groups = new LinkedHashMap<>(); private final long maxMethodLength; @@ -179,8 +168,7 @@ private BlockStatementGrouperVisitor(long maxMethodLength, String parameters) { this.parameters = parameters; } - public void visitStatement( - StatementContext ctx, String context, TokenStreamRewriter rewriter) { + public void visitStatement(StatementContext ctx, String context) { if (ctx.getChildCount() == 0) { return; @@ -193,21 +181,20 @@ public void visitStatement( for (StatementContext statement : ctx.statement()) { if (shouldExtract(statement)) { String localContext = String.format("%s_%d", context, counter++); - groupBlock(statement, localContext, rewriter); + groupBlock(statement, localContext); } } } else { // The block did not start from IF/ELSE/WHILE statement if (shouldExtract(ctx)) { - groupBlock(ctx, context, rewriter); + groupBlock(ctx, context); } } } // Group continuous block of statements together. If Statement is an IF/ELSE/WHILE, // its body can be further grouped by recursive call to visitStatement method. - private void groupBlock( - StatementContext ctx, String context, TokenStreamRewriter rewriter) { + private void groupBlock(StatementContext ctx, String context) { int localGroupCodeLength = 0; List localGroup = new ArrayList<>(); for (BlockStatementContext bsc : ctx.block().blockStatement()) { @@ -218,17 +205,9 @@ private void groupBlock( || statement.WHILE() != null) { String localContext = context + "_rewriteGroup" + this.counter++; - CommonTokenStream tokenStream = - new CommonTokenStream( - new JavaLexer( - CharStreams.fromString( - CodeSplitUtil.getContextString(statement)))); - TokenStreamRewriter localRewriter = new TokenStreamRewriter(tokenStream); - JavaParser javaParser = new JavaParser(tokenStream); - javaParser.getInterpreter().setPredictionMode(PredictionMode.SLL); - visitStatement(javaParser.statement(), localContext, localRewriter); + visitStatement(statement, localContext); - localGroup.add(new RewriteContextGroupElement(statement, localRewriter)); + localGroup.add(new RewriteContextGroupElement(statement)); // new method call length to the localGroupCodeLength. The "3" contains two // brackets for parameters and semicolon at the end of method call @@ -239,7 +218,7 @@ private void groupBlock( localGroup.add(new ContextGroupElement(bsc)); localGroupCodeLength += bsc.getText().length(); } else { - if (addLocalGroup(localGroup, context, rewriter)) { + if (addLocalGroup(localGroup, context)) { localGroup = new ArrayList<>(); localGroupCodeLength = 0; } @@ -251,16 +230,15 @@ private void groupBlock( // Groups that have only one statement that is "single line statement" such as // "a[2] += b[2];" will not be extracted. - addLocalGroup(localGroup, context, rewriter); + addLocalGroup(localGroup, context); } - private boolean addLocalGroup( - List localGroup, String context, TokenStreamRewriter rewriter) { + private boolean addLocalGroup(List localGroup, String context) { if (localGroup.size() > 1 || (localGroup.size() == 1 && canGroupAsSingleStatement(localGroup.get(0).getContext()))) { String localContext = context + "_rewriteGroup" + this.counter++; - groups.put(localContext, Pair.of(rewriter, localGroup)); + groups.put(localContext, localGroup); return true; } @@ -302,17 +280,25 @@ private int getNumOfReturnOrJumpStatements(ParserRuleContext ctx) { return counter.getCounter(); } - private void rewrite() { - for (Entry>> group : - groups.entrySet()) { - Pair> pair = group.getValue(); - TokenStreamRewriter rewriter = pair.getKey(); - List value = pair.getValue(); + private Map> rewrite(TokenStreamRewriter rewriter) { + Map> groupStrings = + CollectionUtil.newHashMapWithExpectedSize(groups.size()); + + for (Entry> group : groups.entrySet()) { + List groupElements = group.getValue(); + List collectedStringGroups = + groupElements.stream() + .map(el -> el.getBody(rewriter)) + .collect(Collectors.toList()); rewriter.replace( - value.get(0).getStart(), - value.get(value.size() - 1).getStop(), + groupElements.get(0).getStart(), + groupElements.get(groupElements.size() - 1).getStop(), group.getKey() + "(" + this.parameters + ");"); + + groupStrings.put(group.getKey(), collectedStringGroups); } + + return groupStrings; } } @@ -330,7 +316,7 @@ private interface LocalGroupElement { Token getStop(); /** @return String representation of this group element. */ - String getBody(); + String getBody(TokenStreamRewriter rewriter); ParserRuleContext getContext(); } @@ -359,8 +345,8 @@ public Token getStop() { } @Override - public String getBody() { - return CodeSplitUtil.getContextString(this.parserRuleContext); + public String getBody(TokenStreamRewriter rewriter) { + return rewriter.getText(this.parserRuleContext.getSourceInterval()); } @Override @@ -422,12 +408,8 @@ private static class RewriteContextGroupElement implements LocalGroupElement { private final ParserRuleContext parserRuleContext; - private final TokenStreamRewriter rewriter; - - private RewriteContextGroupElement( - ParserRuleContext parserRuleContext, TokenStreamRewriter rewriter) { + private RewriteContextGroupElement(ParserRuleContext parserRuleContext) { this.parserRuleContext = parserRuleContext; - this.rewriter = rewriter; } @Override @@ -441,8 +423,8 @@ public Token getStop() { } @Override - public String getBody() { - return this.rewriter.getText(); + public String getBody(TokenStreamRewriter rewriter) { + return rewriter.getText(parserRuleContext.getSourceInterval()); } @Override diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CaseFunctionsITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CaseFunctionsITCase.java new file mode 100644 index 0000000000000..befd8f7f94c0a --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CaseFunctionsITCase.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions; + +import org.apache.flink.table.api.DataTypes; + +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +/** Tests for CASE WHEN expression. */ +class CaseFunctionsITCase extends BuiltInFunctionTestBase { + + @Override + Stream getTestSetSpecs() { + + String caseStatement = + "case " + + IntStream.range(0, 250) + .mapToObj(idx -> String.format("when f0 = %d then %d", idx, idx)) + .collect(Collectors.joining(" ")) + + "else 9999 end"; + + // Verify a long case when statement which produces deeply nested if else statements + // works correctly. + return Stream.of( + TestSetSpec.forExpression("CASE WHEN") + .onFieldsWithData(110) + .testSqlResult(caseStatement, 110, DataTypes.INT().notNull()), + TestSetSpec.forExpression("CASE WHEN ELSE") + .onFieldsWithData(450) + .testSqlResult(caseStatement, 9999, DataTypes.INT().notNull())); + } +}