Skip to content

Commit

Permalink
Compute output of LookupJoinExec dynamically (#117763) (#117777)
Browse files Browse the repository at this point in the history
LookupJoinExec should not assume its output but instead compute it from
- Its input fields from the left - The fields added from the lookup
index

Currently, LookupJoinExec's output is determined when the logical plan
is mapped to a physical one, and thereafter the output cannot be changed
anymore. This makes it impossible to have late materialization of fields
from the left hand side via field extractions, because we are forced to
extract *all* fields before the LookupJoinExec, otherwise we do not
achieve the prescribed output.

Avoid that by tracking only which fields the LookupJoinExec will add
from the lookup index instead of tracking the whole output (that was
only correct for the logical plan).

**Note:** While this PR is a refactoring for the current functionality,
it should unblock @craigtaverner 's ongoing work related to field
extractions and getting multiple LOOKUP JOIN queries to work correctly
without adding hacks.

(cherry picked from commit 64107e0)

# Conflicts:
#	x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
  • Loading branch information
alex-spies authored Dec 2, 2024
1 parent 4b7c3b6 commit 2eb635b
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.INLINESTATS;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.INLINESTATS_V2;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_LOOKUP_V2;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_LOOKUP_V3;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_PLANNING_V1;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METADATA_FIELDS_REMOTE_TEST;
import static org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase.Mode.SYNC;
Expand Down Expand Up @@ -125,7 +125,7 @@ protected void shouldSkipTest(String testName) throws IOException {
assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains(INLINESTATS.capabilityName()));
assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains(INLINESTATS_V2.capabilityName()));
assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_PLANNING_V1.capabilityName()));
assumeFalse("LOOKUP JOIN not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_LOOKUP_V2.capabilityName()));
assumeFalse("LOOKUP JOIN not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_LOOKUP_V3.capabilityName()));
}

private TestFeatureService remoteFeaturesService() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

//TODO: this sometimes returns null instead of the looked up value (likely related to the execution order)
basicOnTheDataNode-Ignore
required_capability: join_lookup_v2
required_capability: join_lookup_v3

FROM employees
| EVAL language_code = languages
Expand All @@ -22,7 +22,7 @@ emp_no:integer | language_code:integer | language_name:keyword
;

basicRow-Ignore
required_capability: join_lookup
required_capability: join_lookup_v3

ROW language_code = 1
| LOOKUP JOIN languages_lookup ON language_code
Expand All @@ -33,7 +33,7 @@ language_code:keyword | language_name:keyword
;

basicOnTheCoordinator
required_capability: join_lookup_v2
required_capability: join_lookup_v3

FROM employees
| SORT emp_no
Expand All @@ -51,7 +51,7 @@ emp_no:integer | language_code:integer | language_name:keyword

//TODO: this sometimes returns null instead of the looked up value (likely related to the execution order)
subsequentEvalOnTheDataNode-Ignore
required_capability: join_lookup_v2
required_capability: join_lookup_v3

FROM employees
| EVAL language_code = languages
Expand All @@ -69,7 +69,7 @@ emp_no:integer | language_code:integer | language_name:keyword | language_code_x
;

subsequentEvalOnTheCoordinator
required_capability: join_lookup_v2
required_capability: join_lookup_v3

FROM employees
| SORT emp_no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ public enum Cap {
/**
* LOOKUP JOIN
*/
JOIN_LOOKUP_V2(Build.current().isSnapshot()),
JOIN_LOOKUP_V3(Build.current().isSnapshot()),

/**
* Fix for https://github.com/elastic/elasticsearch/issues/117054
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

package org.elasticsearch.xpack.esql.optimizer;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
Expand Down Expand Up @@ -62,34 +60,4 @@ public final PhysicalPlan apply(PhysicalPlan plan) {

protected abstract PhysicalPlan rule(SubPlan plan);
}

public abstract static class OptimizerExpressionRule<E extends Expression> extends Rule<PhysicalPlan, PhysicalPlan> {

private final TransformDirection direction;
// overriding type token which returns the correct class but does an uncheck cast to LogicalPlan due to its generic bound
// a proper solution is to wrap the Expression rule into a Plan rule but that would affect the rule declaration
// so instead this is hacked here
private final Class<E> expressionTypeToken = ReflectionUtils.detectSuperTypeForRuleLike(getClass());

public OptimizerExpressionRule(TransformDirection direction) {
this.direction = direction;
}

@Override
public final PhysicalPlan apply(PhysicalPlan plan) {
return direction == TransformDirection.DOWN
? plan.transformExpressionsDown(expressionTypeToken, this::rule)
: plan.transformExpressionsUp(expressionTypeToken, this::rule);
}

protected PhysicalPlan rule(PhysicalPlan plan) {
return plan;
}

protected abstract Expression rule(E e);

public Class<E> expressionToken() {
return expressionTypeToken;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ private static Set<Attribute> missingAttributes(PhysicalPlan p) {
var missing = new LinkedHashSet<Attribute>();
var inputSet = p.inputSet();

// FIXME: the extractors should work on the right side as well
// TODO: We need to extract whatever fields are missing from the left hand side.
// skip the lookup join since the right side is always materialized and a projection
if (p instanceof LookupJoinExec join) {
// collect fields used in the join condition
return Collections.emptySet();
}

var input = inputSet;
// collect field attributes used inside expressions
// TODO: Rather than going over all expressions manually, this should just call .references()
p.forEachExpression(TypedAttribute.class, f -> {
if (f instanceof FieldAttribute || f instanceof MetadataAttribute) {
if (input.contains(f) == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
Expand All @@ -23,12 +23,9 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
import static org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes.LEFT;
import static org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes.RIGHT;

public class Join extends BinaryPlan {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Join", Join::new);
Expand Down Expand Up @@ -100,6 +97,19 @@ public List<Attribute> output() {
return lazyOutput;
}

public List<Attribute> rightOutputFields() {
AttributeSet leftInputs = left().outputSet();

List<Attribute> rightOutputFields = new ArrayList<>();
for (Attribute attr : output()) {
if (leftInputs.contains(attr) == false) {
rightOutputFields.add(attr);
}
}

return rightOutputFields;
}

/**
* Combine the two lists of attributes into one.
* In case of (name) conflicts, specify which sides wins, that is overrides the other column - the left or the right.
Expand All @@ -108,18 +118,11 @@ public static List<Attribute> computeOutput(List<Attribute> leftOutput, List<Att
JoinType joinType = config.type();
List<Attribute> output;
// TODO: make the other side nullable
Set<String> matchFieldNames = config.matchFields().stream().map(NamedExpression::name).collect(Collectors.toSet());
if (LEFT.equals(joinType)) {
// right side becomes nullable and overrides left except for match fields, which we preserve from the left
List<Attribute> rightOutputWithoutMatchFields = rightOutput.stream()
.filter(attr -> matchFieldNames.contains(attr.name()) == false)
.toList();
// right side becomes nullable and overrides left except for join keys, which we preserve from the left
AttributeSet rightKeys = new AttributeSet(config.rightFields());
List<Attribute> rightOutputWithoutMatchFields = rightOutput.stream().filter(attr -> rightKeys.contains(attr) == false).toList();
output = mergeOutputAttributes(rightOutputWithoutMatchFields, leftOutput);
} else if (RIGHT.equals(joinType)) {
List<Attribute> leftOutputWithoutMatchFields = leftOutput.stream()
.filter(attr -> matchFieldNames.contains(attr.name()) == false)
.toList();
output = mergeOutputAttributes(leftOutputWithoutMatchFields, rightOutput);
} else {
throw new IllegalArgumentException(joinType.joinName() + " unsupported");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;

Expand All @@ -30,43 +29,43 @@ public class LookupJoinExec extends BinaryExec implements EstimatesRowSize {
LookupJoinExec::new
);

private final List<Attribute> matchFields;
private final List<Attribute> leftFields;
private final List<Attribute> rightFields;
private final List<Attribute> output;
private List<Attribute> lazyAddedFields;
/**
* These cannot be computed from the left + right outputs, because
* {@link org.elasticsearch.xpack.esql.optimizer.rules.physical.local.ReplaceSourceAttributes} will replace the {@link EsSourceExec} on
* the right hand side by a {@link EsQueryExec}, and thus lose the information of which fields we'll get from the lookup index.
*/
private final List<Attribute> addedFields;
private List<Attribute> lazyOutput;

public LookupJoinExec(
Source source,
PhysicalPlan left,
PhysicalPlan lookup,
List<Attribute> matchFields,
List<Attribute> leftFields,
List<Attribute> rightFields,
List<Attribute> output
List<Attribute> addedFields
) {
super(source, left, lookup);
this.matchFields = matchFields;
this.leftFields = leftFields;
this.rightFields = rightFields;
this.output = output;
this.addedFields = addedFields;
}

private LookupJoinExec(StreamInput in) throws IOException {
super(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteable(PhysicalPlan.class));
this.matchFields = in.readNamedWriteableCollectionAsList(Attribute.class);
this.leftFields = in.readNamedWriteableCollectionAsList(Attribute.class);
this.rightFields = in.readNamedWriteableCollectionAsList(Attribute.class);
this.output = in.readNamedWriteableCollectionAsList(Attribute.class);
this.addedFields = in.readNamedWriteableCollectionAsList(Attribute.class);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeNamedWriteableCollection(matchFields);
out.writeNamedWriteableCollection(leftFields);
out.writeNamedWriteableCollection(rightFields);
out.writeNamedWriteableCollection(output);
out.writeNamedWriteableCollection(addedFields);
}

@Override
Expand All @@ -78,10 +77,6 @@ public PhysicalPlan lookup() {
return right();
}

public List<Attribute> matchFields() {
return matchFields;
}

public List<Attribute> leftFields() {
return leftFields;
}
Expand All @@ -91,29 +86,26 @@ public List<Attribute> rightFields() {
}

public List<Attribute> addedFields() {
if (lazyAddedFields == null) {
AttributeSet set = outputSet();
set.removeAll(left().output());
for (Attribute m : matchFields) {
set.removeIf(a -> a.name().equals(m.name()));
return addedFields;
}

@Override
public List<Attribute> output() {
if (lazyOutput == null) {
lazyOutput = new ArrayList<>(left().output());
for (Attribute attr : addedFields) {
lazyOutput.add(attr);
}
lazyAddedFields = new ArrayList<>(set);
lazyAddedFields.sort(Comparator.comparing(Attribute::name));
}
return lazyAddedFields;
return lazyOutput;
}

@Override
public PhysicalPlan estimateRowSize(State state) {
state.add(false, output);
state.add(false, output());
return this;
}

@Override
public List<Attribute> output() {
return output;
}

@Override
public AttributeSet inputSet() {
// TODO: this is a hack since the right side is always materialized - instead this should
Expand All @@ -129,12 +121,12 @@ protected AttributeSet computeReferences() {

@Override
public LookupJoinExec replaceChildren(PhysicalPlan left, PhysicalPlan right) {
return new LookupJoinExec(source(), left, right, matchFields, leftFields, rightFields, output);
return new LookupJoinExec(source(), left, right, leftFields, rightFields, addedFields);
}

@Override
protected NodeInfo<? extends PhysicalPlan> info() {
return NodeInfo.create(this, LookupJoinExec::new, left(), right(), matchFields, leftFields, rightFields, output);
return NodeInfo.create(this, LookupJoinExec::new, left(), right(), leftFields, rightFields, addedFields);
}

@Override
Expand All @@ -148,15 +140,12 @@ public boolean equals(Object o) {
if (super.equals(o) == false) {
return false;
}
LookupJoinExec hash = (LookupJoinExec) o;
return matchFields.equals(hash.matchFields)
&& leftFields.equals(hash.leftFields)
&& rightFields.equals(hash.rightFields)
&& output.equals(hash.output);
LookupJoinExec other = (LookupJoinExec) o;
return leftFields.equals(other.leftFields) && rightFields.equals(other.rightFields) && addedFields.equals(other.addedFields);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), matchFields, leftFields, rightFields, output);
return Objects.hash(super.hashCode(), leftFields, rightFields, addedFields);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ private PhysicalOperation planLookupJoin(LookupJoinExec join, LocalExecutionPlan
if (localSourceExec.indexMode() != IndexMode.LOOKUP) {
throw new IllegalArgumentException("can't plan [" + join + "]");
}
List<Layout.ChannelAndType> matchFields = new ArrayList<>(join.matchFields().size());
for (Attribute m : join.matchFields()) {
List<Layout.ChannelAndType> matchFields = new ArrayList<>(join.leftFields().size());
for (Attribute m : join.leftFields()) {
Layout.ChannelAndType t = source.layout.get(m.id());
if (t == null) {
throw new IllegalArgumentException("can't plan [" + join + "][" + m + "]");
Expand All @@ -604,7 +604,7 @@ private PhysicalOperation planLookupJoin(LookupJoinExec join, LocalExecutionPlan
lookupFromIndexService,
matchFields.get(0).type(),
localSourceExec.index().name(),
join.matchFields().get(0).name(),
join.leftFields().get(0).name(),
join.addedFields().stream().map(f -> (NamedExpression) f).toList(),
join.source()
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,7 @@ private PhysicalPlan mapBinary(BinaryPlan binary) {
);
}
if (right instanceof EsSourceExec source && source.indexMode() == IndexMode.LOOKUP) {
return new LookupJoinExec(
join.source(),
left,
right,
config.matchFields(),
config.leftFields(),
config.rightFields(),
join.output()
);
return new LookupJoinExec(join.source(), left, right, config.leftFields(), config.rightFields(), join.rightOutputFields());
}
}

Expand Down
Loading

0 comments on commit 2eb635b

Please sign in to comment.