Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.x] Compute output of LookupJoinExec dynamically (#117763) #117777

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.INLINESTATS;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.INLINESTATS_V2;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_LOOKUP_V2;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_LOOKUP_V3;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_PLANNING_V1;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METADATA_FIELDS_REMOTE_TEST;
import static org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase.Mode.SYNC;
Expand Down Expand Up @@ -125,7 +125,7 @@ protected void shouldSkipTest(String testName) throws IOException {
assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains(INLINESTATS.capabilityName()));
assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains(INLINESTATS_V2.capabilityName()));
assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_PLANNING_V1.capabilityName()));
assumeFalse("LOOKUP JOIN not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_LOOKUP_V2.capabilityName()));
assumeFalse("LOOKUP JOIN not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_LOOKUP_V3.capabilityName()));
}

private TestFeatureService remoteFeaturesService() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

//TODO: this sometimes returns null instead of the looked up value (likely related to the execution order)
basicOnTheDataNode-Ignore
required_capability: join_lookup_v2
required_capability: join_lookup_v3

FROM employees
| EVAL language_code = languages
Expand All @@ -22,7 +22,7 @@ emp_no:integer | language_code:integer | language_name:keyword
;

basicRow-Ignore
required_capability: join_lookup
required_capability: join_lookup_v3

ROW language_code = 1
| LOOKUP JOIN languages_lookup ON language_code
Expand All @@ -33,7 +33,7 @@ language_code:keyword | language_name:keyword
;

basicOnTheCoordinator
required_capability: join_lookup_v2
required_capability: join_lookup_v3

FROM employees
| SORT emp_no
Expand All @@ -51,7 +51,7 @@ emp_no:integer | language_code:integer | language_name:keyword

//TODO: this sometimes returns null instead of the looked up value (likely related to the execution order)
subsequentEvalOnTheDataNode-Ignore
required_capability: join_lookup_v2
required_capability: join_lookup_v3

FROM employees
| EVAL language_code = languages
Expand All @@ -69,7 +69,7 @@ emp_no:integer | language_code:integer | language_name:keyword | language_code_x
;

subsequentEvalOnTheCoordinator
required_capability: join_lookup_v2
required_capability: join_lookup_v3

FROM employees
| SORT emp_no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ public enum Cap {
/**
* LOOKUP JOIN
*/
JOIN_LOOKUP_V2(Build.current().isSnapshot()),
JOIN_LOOKUP_V3(Build.current().isSnapshot()),

/**
* Fix for https://github.com/elastic/elasticsearch/issues/117054
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

package org.elasticsearch.xpack.esql.optimizer;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
Expand Down Expand Up @@ -62,34 +60,4 @@ public final PhysicalPlan apply(PhysicalPlan plan) {

protected abstract PhysicalPlan rule(SubPlan plan);
}

public abstract static class OptimizerExpressionRule<E extends Expression> extends Rule<PhysicalPlan, PhysicalPlan> {

private final TransformDirection direction;
// overriding type token which returns the correct class but does an uncheck cast to LogicalPlan due to its generic bound
// a proper solution is to wrap the Expression rule into a Plan rule but that would affect the rule declaration
// so instead this is hacked here
private final Class<E> expressionTypeToken = ReflectionUtils.detectSuperTypeForRuleLike(getClass());

public OptimizerExpressionRule(TransformDirection direction) {
this.direction = direction;
}

@Override
public final PhysicalPlan apply(PhysicalPlan plan) {
return direction == TransformDirection.DOWN
? plan.transformExpressionsDown(expressionTypeToken, this::rule)
: plan.transformExpressionsUp(expressionTypeToken, this::rule);
}

protected PhysicalPlan rule(PhysicalPlan plan) {
return plan;
}

protected abstract Expression rule(E e);

public Class<E> expressionToken() {
return expressionTypeToken;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ private static Set<Attribute> missingAttributes(PhysicalPlan p) {
var missing = new LinkedHashSet<Attribute>();
var inputSet = p.inputSet();

// FIXME: the extractors should work on the right side as well
// TODO: We need to extract whatever fields are missing from the left hand side.
// skip the lookup join since the right side is always materialized and a projection
if (p instanceof LookupJoinExec join) {
// collect fields used in the join condition
return Collections.emptySet();
}

var input = inputSet;
// collect field attributes used inside expressions
// TODO: Rather than going over all expressions manually, this should just call .references()
p.forEachExpression(TypedAttribute.class, f -> {
if (f instanceof FieldAttribute || f instanceof MetadataAttribute) {
if (input.contains(f) == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
Expand All @@ -23,12 +23,9 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
import static org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes.LEFT;
import static org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes.RIGHT;

public class Join extends BinaryPlan {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Join", Join::new);
Expand Down Expand Up @@ -100,6 +97,19 @@ public List<Attribute> output() {
return lazyOutput;
}

public List<Attribute> rightOutputFields() {
AttributeSet leftInputs = left().outputSet();

List<Attribute> rightOutputFields = new ArrayList<>();
for (Attribute attr : output()) {
if (leftInputs.contains(attr) == false) {
rightOutputFields.add(attr);
}
}

return rightOutputFields;
}

/**
* Combine the two lists of attributes into one.
* In case of (name) conflicts, specify which sides wins, that is overrides the other column - the left or the right.
Expand All @@ -108,18 +118,11 @@ public static List<Attribute> computeOutput(List<Attribute> leftOutput, List<Att
JoinType joinType = config.type();
List<Attribute> output;
// TODO: make the other side nullable
Set<String> matchFieldNames = config.matchFields().stream().map(NamedExpression::name).collect(Collectors.toSet());
if (LEFT.equals(joinType)) {
// right side becomes nullable and overrides left except for match fields, which we preserve from the left
List<Attribute> rightOutputWithoutMatchFields = rightOutput.stream()
.filter(attr -> matchFieldNames.contains(attr.name()) == false)
.toList();
// right side becomes nullable and overrides left except for join keys, which we preserve from the left
AttributeSet rightKeys = new AttributeSet(config.rightFields());
List<Attribute> rightOutputWithoutMatchFields = rightOutput.stream().filter(attr -> rightKeys.contains(attr) == false).toList();
output = mergeOutputAttributes(rightOutputWithoutMatchFields, leftOutput);
} else if (RIGHT.equals(joinType)) {
List<Attribute> leftOutputWithoutMatchFields = leftOutput.stream()
.filter(attr -> matchFieldNames.contains(attr.name()) == false)
.toList();
output = mergeOutputAttributes(leftOutputWithoutMatchFields, rightOutput);
} else {
throw new IllegalArgumentException(joinType.joinName() + " unsupported");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;

Expand All @@ -30,43 +29,43 @@ public class LookupJoinExec extends BinaryExec implements EstimatesRowSize {
LookupJoinExec::new
);

private final List<Attribute> matchFields;
private final List<Attribute> leftFields;
private final List<Attribute> rightFields;
private final List<Attribute> output;
private List<Attribute> lazyAddedFields;
/**
* These cannot be computed from the left + right outputs, because
* {@link org.elasticsearch.xpack.esql.optimizer.rules.physical.local.ReplaceSourceAttributes} will replace the {@link EsSourceExec} on
* the right hand side by a {@link EsQueryExec}, and thus lose the information of which fields we'll get from the lookup index.
*/
private final List<Attribute> addedFields;
private List<Attribute> lazyOutput;

public LookupJoinExec(
Source source,
PhysicalPlan left,
PhysicalPlan lookup,
List<Attribute> matchFields,
List<Attribute> leftFields,
List<Attribute> rightFields,
List<Attribute> output
List<Attribute> addedFields
) {
super(source, left, lookup);
this.matchFields = matchFields;
this.leftFields = leftFields;
this.rightFields = rightFields;
this.output = output;
this.addedFields = addedFields;
}

private LookupJoinExec(StreamInput in) throws IOException {
super(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteable(PhysicalPlan.class));
this.matchFields = in.readNamedWriteableCollectionAsList(Attribute.class);
this.leftFields = in.readNamedWriteableCollectionAsList(Attribute.class);
this.rightFields = in.readNamedWriteableCollectionAsList(Attribute.class);
this.output = in.readNamedWriteableCollectionAsList(Attribute.class);
this.addedFields = in.readNamedWriteableCollectionAsList(Attribute.class);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeNamedWriteableCollection(matchFields);
out.writeNamedWriteableCollection(leftFields);
out.writeNamedWriteableCollection(rightFields);
out.writeNamedWriteableCollection(output);
out.writeNamedWriteableCollection(addedFields);
}

@Override
Expand All @@ -78,10 +77,6 @@ public PhysicalPlan lookup() {
return right();
}

public List<Attribute> matchFields() {
return matchFields;
}

public List<Attribute> leftFields() {
return leftFields;
}
Expand All @@ -91,29 +86,26 @@ public List<Attribute> rightFields() {
}

public List<Attribute> addedFields() {
if (lazyAddedFields == null) {
AttributeSet set = outputSet();
set.removeAll(left().output());
for (Attribute m : matchFields) {
set.removeIf(a -> a.name().equals(m.name()));
return addedFields;
}

@Override
public List<Attribute> output() {
if (lazyOutput == null) {
lazyOutput = new ArrayList<>(left().output());
for (Attribute attr : addedFields) {
lazyOutput.add(attr);
}
lazyAddedFields = new ArrayList<>(set);
lazyAddedFields.sort(Comparator.comparing(Attribute::name));
}
return lazyAddedFields;
return lazyOutput;
}

@Override
public PhysicalPlan estimateRowSize(State state) {
state.add(false, output);
state.add(false, output());
return this;
}

@Override
public List<Attribute> output() {
return output;
}

@Override
public AttributeSet inputSet() {
// TODO: this is a hack since the right side is always materialized - instead this should
Expand All @@ -129,12 +121,12 @@ protected AttributeSet computeReferences() {

@Override
public LookupJoinExec replaceChildren(PhysicalPlan left, PhysicalPlan right) {
return new LookupJoinExec(source(), left, right, matchFields, leftFields, rightFields, output);
return new LookupJoinExec(source(), left, right, leftFields, rightFields, addedFields);
}

@Override
protected NodeInfo<? extends PhysicalPlan> info() {
return NodeInfo.create(this, LookupJoinExec::new, left(), right(), matchFields, leftFields, rightFields, output);
return NodeInfo.create(this, LookupJoinExec::new, left(), right(), leftFields, rightFields, addedFields);
}

@Override
Expand All @@ -148,15 +140,12 @@ public boolean equals(Object o) {
if (super.equals(o) == false) {
return false;
}
LookupJoinExec hash = (LookupJoinExec) o;
return matchFields.equals(hash.matchFields)
&& leftFields.equals(hash.leftFields)
&& rightFields.equals(hash.rightFields)
&& output.equals(hash.output);
LookupJoinExec other = (LookupJoinExec) o;
return leftFields.equals(other.leftFields) && rightFields.equals(other.rightFields) && addedFields.equals(other.addedFields);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), matchFields, leftFields, rightFields, output);
return Objects.hash(super.hashCode(), leftFields, rightFields, addedFields);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ private PhysicalOperation planLookupJoin(LookupJoinExec join, LocalExecutionPlan
if (localSourceExec.indexMode() != IndexMode.LOOKUP) {
throw new IllegalArgumentException("can't plan [" + join + "]");
}
List<Layout.ChannelAndType> matchFields = new ArrayList<>(join.matchFields().size());
for (Attribute m : join.matchFields()) {
List<Layout.ChannelAndType> matchFields = new ArrayList<>(join.leftFields().size());
for (Attribute m : join.leftFields()) {
Layout.ChannelAndType t = source.layout.get(m.id());
if (t == null) {
throw new IllegalArgumentException("can't plan [" + join + "][" + m + "]");
Expand All @@ -604,7 +604,7 @@ private PhysicalOperation planLookupJoin(LookupJoinExec join, LocalExecutionPlan
lookupFromIndexService,
matchFields.get(0).type(),
localSourceExec.index().name(),
join.matchFields().get(0).name(),
join.leftFields().get(0).name(),
join.addedFields().stream().map(f -> (NamedExpression) f).toList(),
join.source()
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,7 @@ private PhysicalPlan mapBinary(BinaryPlan binary) {
);
}
if (right instanceof EsSourceExec source && source.indexMode() == IndexMode.LOOKUP) {
return new LookupJoinExec(
join.source(),
left,
right,
config.matchFields(),
config.leftFields(),
config.rightFields(),
join.output()
);
return new LookupJoinExec(join.source(), left, right, config.leftFields(), config.rightFields(), join.rightOutputFields());
}
}

Expand Down
Loading