diff --git a/engine_test.go b/engine_test.go index 1ba6ef32a..536e02074 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1570,6 +1570,20 @@ var queries = []struct { `SELECT (SELECT i FROM mytable ORDER BY i ASC LIMIT 1) AS x`, []sql.Row{{int64(1)}}, }, + { + `SELECT DISTINCT n FROM bigtable ORDER BY t`, + []sql.Row{ + {int64(1)}, + {int64(9)}, + {int64(7)}, + {int64(3)}, + {int64(2)}, + {int64(8)}, + {int64(6)}, + {int64(5)}, + {int64(4)}, + }, + }, } func TestQueries(t *testing.T) { diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index 0037d9b0d..88283cf9f 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -34,17 +34,19 @@ func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, e a.Log("optimize distinct, node of type: %T", node) if n, ok := node.(*plan.Distinct); ok { - var isSorted bool + var sortField *expression.GetField plan.Inspect(n, func(node sql.Node) bool { a.Log("checking for optimization in node of type: %T", node) - if _, ok := node.(*plan.Sort); ok { - isSorted = true + if sort, ok := node.(*plan.Sort); ok && sortField == nil { + if col, ok := sort.SortFields[0].Column.(*expression.GetField); ok { + sortField = col + } return false } return true }) - if isSorted { + if sortField != nil && n.Schema().Contains(sortField.Name(), sortField.Table()) { a.Log("distinct optimized for ordered output") return plan.NewOrderedDistinct(n.Child), nil } diff --git a/sql/analyzer/optimization_rules_test.go b/sql/analyzer/optimization_rules_test.go index 285117251..60dee3b0f 100644 --- a/sql/analyzer/optimization_rules_test.go +++ b/sql/analyzer/optimization_rules_test.go @@ -186,24 +186,54 @@ func TestEraseProjection(t *testing.T) { } func TestOptimizeDistinct(t *testing.T) { - require := require.New(t) - - t1 := memory.NewTable("foo", nil) - t2 := memory.NewTable("foo", nil) + t1 := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + }) - notSorted := plan.NewDistinct(plan.NewResolvedTable(t1)) - sorted := plan.NewDistinct(plan.NewSort(nil, plan.NewResolvedTable(t2))) + testCases := []struct { + name string + child sql.Node + optimized bool + }{ + { + "without sort", + plan.NewResolvedTable(t1), + false, + }, + { + "sort but column not projected", + plan.NewSort( + []plan.SortField{ + {Column: gf(0, "foo", "c")}, + }, + plan.NewResolvedTable(t1), + ), + false, + }, + { + "sort and column projected", + plan.NewSort( + []plan.SortField{ + {Column: gf(0, "foo", "a")}, + }, + plan.NewResolvedTable(t1), + ), + true, + }, + } rule := getRule("optimize_distinct") - analyzedNotSorted, err := rule.Apply(sql.NewEmptyContext(), nil, notSorted) - require.NoError(err) - - analyzedSorted, err := rule.Apply(sql.NewEmptyContext(), nil, sorted) - require.NoError(err) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + node, err := rule.Apply(sql.NewEmptyContext(), nil, plan.NewDistinct(tt.child)) + require.NoError(t, err) - require.Equal(notSorted, analyzedNotSorted) - require.Equal(plan.NewOrderedDistinct(sorted.Child), analyzedSorted) + _, ok := node.(*plan.OrderedDistinct) + require.Equal(t, tt.optimized, ok) + }) + } } func TestMoveJoinConditionsToFilter(t *testing.T) {