From a9d2d8d1c36a4338758a792c475965180629a59f Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Wed, 9 Nov 2022 21:55:47 -0500 Subject: [PATCH] Refs #28477 -- Reduced complexity of aggregation over qualify queries. --- django/db/models/sql/query.py | 34 ++++++++++++++++++------------- tests/expressions_window/tests.py | 22 ++++++++++++-------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c9e296001231..775e2668b05b 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -447,13 +447,15 @@ def get_aggregation(self, using, added_aggregate_names): if alias not in added_aggregate_names } # Existing usage of aggregation can be determined by the presence of - # selected aggregate and window annotations but also by filters against - # aliased aggregate and windows via HAVING / QUALIFY. - has_existing_aggregation = any( - getattr(annotation, "contains_aggregate", True) - or getattr(annotation, "contains_over_clause", True) - for annotation in existing_annotations.values() - ) or any(self.where.split_having_qualify()[1:]) + # selected aggregates but also by filters against aliased aggregates. + _, having, qualify = self.where.split_having_qualify() + has_existing_aggregation = ( + any( + getattr(annotation, "contains_aggregate", True) + for annotation in existing_annotations.values() + ) + or having + ) # Decide if we need to use a subquery. # # Existing aggregations would cause incorrect results as @@ -468,6 +470,7 @@ def get_aggregation(self, using, added_aggregate_names): isinstance(self.group_by, tuple) or self.is_sliced or has_existing_aggregation + or qualify or self.distinct or self.combinator ): @@ -494,13 +497,16 @@ def get_aggregation(self, using, added_aggregate_names): self.model._meta.pk.get_col(inner_query.get_initial_alias()), ) inner_query.default_cols = False - # Mask existing annotations that are not referenced by - # aggregates to be pushed to the outer query. - annotation_mask = set() - for name in added_aggregate_names: - annotation_mask.add(name) - annotation_mask |= inner_query.annotations[name].get_refs() - inner_query.set_annotation_mask(annotation_mask) + if not qualify: + # Mask existing annotations that are not referenced by + # aggregates to be pushed to the outer query unless + # filtering against window functions is involved as it + # requires complex realising. + annotation_mask = set() + for name in added_aggregate_names: + annotation_mask.add(name) + annotation_mask |= inner_query.annotations[name].get_refs() + inner_query.set_annotation_mask(annotation_mask) relabels = {t: "subquery" for t in inner_query.alias_map} relabels[None] = "subquery" diff --git a/tests/expressions_window/tests.py b/tests/expressions_window/tests.py index cac611490494..027fc9c25c90 100644 --- a/tests/expressions_window/tests.py +++ b/tests/expressions_window/tests.py @@ -42,6 +42,7 @@ ) from django.db.models.lookups import Exact from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature +from django.test.utils import CaptureQueriesContext from .models import Classification, Detail, Employee, PastEmployeeDepartment @@ -1157,16 +1158,21 @@ def test_limited_filter(self): ) def test_filter_count(self): - self.assertEqual( - Employee.objects.annotate( - department_salary_rank=Window( - Rank(), partition_by="department", order_by="-salary" + with CaptureQueriesContext(connection) as ctx: + self.assertEqual( + Employee.objects.annotate( + department_salary_rank=Window( + Rank(), partition_by="department", order_by="-salary" + ) ) + .filter(department_salary_rank=1) + .count(), + 5, ) - .filter(department_salary_rank=1) - .count(), - 5, - ) + self.assertEqual(len(ctx.captured_queries), 1) + sql = ctx.captured_queries[0]["sql"].lower() + self.assertEqual(sql.count("select"), 3) + self.assertNotIn("group by", sql) @skipUnlessDBFeature("supports_frame_range_fixed_distance") def test_range_n_preceding_and_following(self):