Skip to content

Commit

Permalink
Fixed #31679 -- Delayed annotating aggregations.
Browse files Browse the repository at this point in the history
By avoiding to annotate aggregations meant to be possibly pushed to an
outer query until their references are resolved it is possible to
aggregate over a query with the same alias.

Even if #34176 is a convoluted case to support, this refactor seems
worth it given the reduction in complexity it brings with regards to
annotation removal when performing a subquery pushdown.
  • Loading branch information
charettes authored and felixxm committed Nov 23, 2022
1 parent d526d15 commit 1297c0d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 65 deletions.
22 changes: 2 additions & 20 deletions django/db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector
from django.db.models.expressions import Case, F, Ref, Value, When
from django.db.models.expressions import Case, F, Value, When
from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
Expand Down Expand Up @@ -589,24 +589,7 @@ def aggregate(self, *args, **kwargs):
raise TypeError("Complex aggregates require an alias")
kwargs[arg.default_alias] = arg

query = self.query.chain()
for (alias, aggregate_expr) in kwargs.items():
query.add_annotation(aggregate_expr, alias, is_summary=True)
annotation = query.annotations[alias]
if not annotation.contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias)
for expr in annotation.get_source_expressions():
if (
expr.contains_aggregate
and isinstance(expr, Ref)
and expr.refs in kwargs
):
name = expr.refs
raise exceptions.FieldError(
"Cannot compute %s('%s'): '%s' is an aggregate"
% (annotation.name, name, name)
)
return query.get_aggregation(self.db, kwargs)
return self.query.chain().get_aggregation(self.db, kwargs)

async def aaggregate(self, *args, **kwargs):
return await sync_to_async(self.aggregate)(*args, **kwargs)
Expand Down Expand Up @@ -1655,7 +1638,6 @@ def _annotate(self, args, kwargs, select=True):
clone.query.add_annotation(
annotation,
alias,
is_summary=False,
select=select,
)
for alias, annotation in clone.query.annotations.items():
Expand Down
75 changes: 34 additions & 41 deletions django/db/models/sql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,24 +381,28 @@ def _get_col(self, target, field, alias):
alias = None
return target.get_col(alias, field)

def get_aggregation(self, using, added_aggregate_names):
def get_aggregation(self, using, aggregate_exprs):
"""
Return the dictionary with the values of the existing aggregations.
"""
if not self.annotation_select:
if not aggregate_exprs:
return {}
existing_annotations = {
alias: annotation
for alias, annotation in self.annotations.items()
if alias not in added_aggregate_names
}
aggregates = {}
for alias, aggregate_expr in aggregate_exprs.items():
self.check_alias(alias)
aggregate = aggregate_expr.resolve_expression(
self, allow_joins=True, reuse=None, summarize=True
)
if not aggregate.contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias)
aggregates[alias] = aggregate
# Existing usage of aggregation can be determined by the presence of
# selected aggregates but also by filters against aliased aggregates.
_, having, qualify = self.where.split_having_qualify()
has_existing_aggregation = (
any(
getattr(annotation, "contains_aggregate", True)
for annotation in existing_annotations.values()
for annotation in self.annotations.values()
)
or having
)
Expand Down Expand Up @@ -449,25 +453,19 @@ def get_aggregation(self, using, added_aggregate_names):
# filtering against window functions is involved as it
# requires complex realising.
annotation_mask = set()
for name in added_aggregate_names:
annotation_mask.add(name)
annotation_mask |= inner_query.annotations[name].get_refs()
for aggregate in aggregates.values():
annotation_mask |= aggregate.get_refs()
inner_query.set_annotation_mask(annotation_mask)

# Remove any aggregates marked for reduction from the subquery and
# move them to the outer AggregateQuery. This requires making sure
# all columns referenced by the aggregates are selected in the
# subquery. It is achieved by retrieving all column references from
# the aggregates, explicitly selecting them if they are not
# already, and making sure the aggregates are repointed to
# referenced to them.
# Add aggregates to the outer AggregateQuery. This requires making
# sure all columns referenced by the aggregates are selected in the
# inner query. It is achieved by retrieving all column references
# by the aggregates, explicitly selecting them in the inner query,
# and making sure the aggregates are repointed to them.
col_refs = {}
for alias, expression in list(inner_query.annotation_select.items()):
if not expression.is_summary:
continue
annotation_select_mask = inner_query.annotation_select_mask
for alias, aggregate in aggregates.items():
replacements = {}
for col in self._gen_cols([expression], resolve_refs=False):
for col in self._gen_cols([aggregate], resolve_refs=False):
if not (col_ref := col_refs.get(col)):
index = len(col_refs) + 1
col_alias = f"__col{index}"
Expand All @@ -476,13 +474,9 @@ def get_aggregation(self, using, added_aggregate_names):
inner_query.annotations[col_alias] = col
inner_query.append_annotation_mask([col_alias])
replacements[col] = col_ref
outer_query.annotations[alias] = expression.replace_expressions(
outer_query.annotations[alias] = aggregate.replace_expressions(
replacements
)
del inner_query.annotations[alias]
annotation_select_mask.remove(alias)
# Make sure the annotation_select wont use cached results.
inner_query.set_annotation_mask(inner_query.annotation_select_mask)
if (
inner_query.select == ()
and not inner_query.default_cols
Expand All @@ -499,19 +493,21 @@ def get_aggregation(self, using, added_aggregate_names):
self.select = ()
self.default_cols = False
self.extra = {}
if existing_annotations:
if self.annotations:
# Inline reference to existing annotations and mask them as
# they are unnecessary given only the summarized aggregations
# are requested.
replacements = {
Ref(alias, annotation): annotation
for alias, annotation in existing_annotations.items()
for alias, annotation in self.annotations.items()
}
for name in added_aggregate_names:
self.annotations[name] = self.annotations[name].replace_expressions(
replacements
)
self.set_annotation_mask(added_aggregate_names)
self.annotations = {
alias: aggregate.replace_expressions(replacements)
for alias, aggregate in aggregates.items()
}
else:
self.annotations = aggregates
self.set_annotation_mask(aggregates)

empty_set_result = [
expression.empty_result_set_value
Expand All @@ -537,8 +533,7 @@ def get_count(self, using):
Perform a COUNT() query using the current filter constraints.
"""
obj = self.clone()
obj.add_annotation(Count("*"), alias="__count", is_summary=True)
return obj.get_aggregation(using, ["__count"])["__count"]
return obj.get_aggregation(using, {"__count": Count("*")})["__count"]

def has_filters(self):
return self.where
Expand Down Expand Up @@ -1085,12 +1080,10 @@ def check_alias(self, alias):
"semicolons, or SQL comments."
)

def add_annotation(self, annotation, alias, is_summary=False, select=True):
def add_annotation(self, annotation, alias, select=True):
"""Add a single annotation expression to the Query."""
self.check_alias(alias)
annotation = annotation.resolve_expression(
self, allow_joins=True, reuse=None, summarize=is_summary
)
annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
if select:
self.append_annotation_mask([alias])
else:
Expand Down
3 changes: 3 additions & 0 deletions docs/releases/4.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ Miscellaneous
* The undocumented ``negated`` parameter of the
:class:`~django.db.models.Exists` expression is removed.

* The ``is_summary`` argument of the undocumented ``Query.add_annotation()``
method is removed.

.. _deprecated-features-4.2:

Features deprecated in 4.2
Expand Down
16 changes: 12 additions & 4 deletions tests/aggregation/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,11 +1258,11 @@ def test_annotate_over_annotate(self):
self.assertEqual(author.sum_age, other_author.sum_age)

def test_aggregate_over_aggregate(self):
msg = "Cannot compute Avg('age'): 'age' is an aggregate"
msg = "Cannot resolve keyword 'age_agg' into field."
with self.assertRaisesMessage(FieldError, msg):
Author.objects.annotate(age_alias=F("age"),).aggregate(
age=Sum(F("age")),
avg_age=Avg(F("age")),
Author.objects.aggregate(
age_agg=Sum(F("age")),
avg_age=Avg(F("age_agg")),
)

def test_annotated_aggregate_over_annotated_aggregate(self):
Expand Down Expand Up @@ -2086,6 +2086,14 @@ def test_exists_extra_where_with_aggregate(self):
)
self.assertEqual(len(qs), 6)

def test_aggregation_over_annotation_shared_alias(self):
self.assertEqual(
Publisher.objects.annotate(agg=Count("book__authors"),).aggregate(
agg=Count("agg"),
),
{"agg": 5},
)


class AggregateAnnotationPruningTests(TestCase):
@classmethod
Expand Down

0 comments on commit 1297c0d

Please sign in to comment.