Skip to content

Commit

Permalink
Refs #33374 -- Adjusted full match condition handling.
Browse files Browse the repository at this point in the history
Adjusting WhereNode.as_sql() to raise an exception when encoutering a
full match just like with empty matches ensures that all case are
explicitly handled.
  • Loading branch information
charettes authored and felixxm committed Nov 7, 2022
1 parent 4b702c8 commit 76e3751
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 61 deletions.
6 changes: 6 additions & 0 deletions django/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ class EmptyResultSet(Exception):
pass


class FullResultSet(Exception):
"""A database query predicate is matches everything."""

pass


class SynchronousOnlyOperation(Exception):
"""The user tried to call a sync-only function from an async context."""

Expand Down
14 changes: 9 additions & 5 deletions django/db/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from django.core.exceptions import FieldError
from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Col
from django.db.models.sql import compiler

Expand Down Expand Up @@ -40,12 +40,16 @@ def as_sql(self):
"DELETE %s FROM"
% self.quote_name_unless_alias(self.query.get_initial_alias())
]
from_sql, from_params = self.get_from_clause()
from_sql, params = self.get_from_clause()
result.extend(from_sql)
where_sql, where_params = self.compile(where)
if where_sql:
try:
where_sql, where_params = self.compile(where)
except FullResultSet:
pass
else:
result.append("WHERE %s" % where_sql)
return " ".join(result), tuple(from_params) + tuple(where_params)
params.extend(where_params)
return " ".join(result), tuple(params)


class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
Expand Down
9 changes: 6 additions & 3 deletions django/db/models/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Classes to represent the definitions of aggregate functions.
"""
from django.core.exceptions import FieldError
from django.core.exceptions import FieldError, FullResultSet
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
Expand Down Expand Up @@ -104,8 +104,11 @@ def as_sql(self, compiler, connection, **extra_context):
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
if self.filter:
if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
if filter_sql:
try:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
except FullResultSet:
pass
else:
template = self.filter_template % extra_context.get(
"template", self.template
)
Expand Down
17 changes: 7 additions & 10 deletions django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from decimal import Decimal
from uuid import UUID

from django.core.exceptions import EmptyResultSet, FieldError
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP
Expand Down Expand Up @@ -955,6 +955,8 @@ def as_sql(
if empty_result_set_value is NotImplemented:
raise
arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
except FullResultSet:
arg_sql, arg_params = compiler.compile(Value(True))
sql_parts.append(arg_sql)
params.extend(arg_params)
data = {**self.extra, **extra_context}
Expand Down Expand Up @@ -1367,14 +1369,6 @@ def as_sql(self, compiler, connection, template=None, **extra_context):
template_params = extra_context
sql_params = []
condition_sql, condition_params = compiler.compile(self.condition)
# Filters that match everything are handled as empty strings in the
# WHERE clause, but in a CASE WHEN expression they must use a predicate
# that's always True.
if condition_sql == "":
if connection.features.supports_boolean_expr_in_select_clause:
condition_sql, condition_params = compiler.compile(Value(True))
else:
condition_sql, condition_params = "1=1", ()
template_params["condition"] = condition_sql
result_sql, result_params = compiler.compile(self.result)
template_params["result"] = result_sql
Expand Down Expand Up @@ -1461,14 +1455,17 @@ def as_sql(
template_params = {**self.extra, **extra_context}
case_parts = []
sql_params = []
default_sql, default_params = compiler.compile(self.default)
for case in self.cases:
try:
case_sql, case_params = compiler.compile(case)
except EmptyResultSet:
continue
except FullResultSet:
default_sql, default_params = compiler.compile(case.result)
break
case_parts.append(case_sql)
sql_params.extend(case_params)
default_sql, default_params = compiler.compile(self.default)
if not case_parts:
return default_sql, default_params
case_joiner = case_joiner or self.case_joiner
Expand Down
9 changes: 0 additions & 9 deletions django/db/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,15 +1103,6 @@ def formfield(self, **kwargs):
defaults = {"form_class": form_class, "required": False}
return super().formfield(**{**defaults, **kwargs})

def select_format(self, compiler, sql, params):
sql, params = super().select_format(compiler, sql, params)
# Filters that match everything are handled as empty strings in the
# WHERE clause, but in SELECT or GROUP BY list they must use a
# predicate that's always True.
if sql == "":
sql = "1"
return sql, params


class CharField(Field):
description = _("String (up to %(max_length)s)")
Expand Down
37 changes: 25 additions & 12 deletions django/db/models/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from itertools import chain

from django.core.exceptions import EmptyResultSet, FieldError
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import DatabaseError, NotSupportedError
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_group_by(self, select, order_by):
expr = Ref(alias, expr)
try:
sql, params = self.compile(expr)
except EmptyResultSet:
except (EmptyResultSet, FullResultSet):
continue
sql, params = expr.select_format(self, sql, params)
params_hash = make_hashable(params)
Expand Down Expand Up @@ -287,6 +287,8 @@ def get_select_from_parent(klass_info):
sql, params = "0", ()
else:
sql, params = self.compile(Value(empty_result_set_value))
except FullResultSet:
sql, params = self.compile(Value(True))
else:
sql, params = col.select_format(self, sql, params)
if alias is None and with_col_aliases:
Expand Down Expand Up @@ -721,9 +723,16 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
raise
# Use a predicate that's always False.
where, w_params = "0 = 1", []
having, h_params = (
self.compile(self.having) if self.having is not None else ("", [])
)
except FullResultSet:
where, w_params = "", []
try:
having, h_params = (
self.compile(self.having)
if self.having is not None
else ("", [])
)
except FullResultSet:
having, h_params = "", []
result = ["SELECT"]
params = []

Expand Down Expand Up @@ -1817,11 +1826,12 @@ def contains_self_reference_subquery(self):
)

def _as_sql(self, query):
result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)]
where, params = self.compile(query.where)
if where:
result.append("WHERE %s" % where)
return " ".join(result), tuple(params)
delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)
try:
where, params = self.compile(query.where)
except FullResultSet:
return delete, ()
return f"{delete} WHERE {where}", tuple(params)

def as_sql(self):
"""
Expand Down Expand Up @@ -1906,8 +1916,11 @@ def as_sql(self):
"UPDATE %s SET" % qn(table),
", ".join(values),
]
where, params = self.compile(self.query.where)
if where:
try:
where, params = self.compile(self.query.where)
except FullResultSet:
params = []
else:
result.append("WHERE %s" % where)
return " ".join(result), tuple(update_params + params)

Expand Down
8 changes: 6 additions & 2 deletions django/db/models/sql/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Useful auxiliary data structures for query construction. Not useful outside
the SQL domain.
"""
from django.core.exceptions import FullResultSet
from django.db.models.sql.constants import INNER, LOUTER


Expand Down Expand Up @@ -100,8 +101,11 @@ def as_sql(self, compiler, connection):
join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if self.filtered_relation:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
if extra_sql:
try:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
except FullResultSet:
pass
else:
join_conditions.append("(%s)" % extra_sql)
params.extend(extra_params)
if not join_conditions:
Expand Down
25 changes: 14 additions & 11 deletions django/db/models/sql/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import operator
from functools import reduce

from django.core.exceptions import EmptyResultSet
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db.models.expressions import Case, When
from django.db.models.lookups import Exact
from django.utils import tree
Expand Down Expand Up @@ -145,6 +145,8 @@ def as_sql(self, compiler, connection):
sql, params = compiler.compile(child)
except EmptyResultSet:
empty_needed -= 1
except FullResultSet:
full_needed -= 1
else:
if sql:
result.append(sql)
Expand All @@ -158,24 +160,25 @@ def as_sql(self, compiler, connection):
# counts.
if empty_needed == 0:
if self.negated:
return "", []
raise FullResultSet
else:
raise EmptyResultSet
if full_needed == 0:
if self.negated:
raise EmptyResultSet
else:
return "", []
raise FullResultSet
conn = " %s " % self.connector
sql_string = conn.join(result)
if sql_string:
if self.negated:
# Some backends (Oracle at least) need parentheses
# around the inner SQL in the negated case, even if the
# inner SQL contains just a single expression.
sql_string = "NOT (%s)" % sql_string
elif len(result) > 1 or self.resolved:
sql_string = "(%s)" % sql_string
if not sql_string:
raise FullResultSet
if self.negated:
# Some backends (Oracle at least) need parentheses around the inner
# SQL in the negated case, even if the inner SQL contains just a
# single expression.
sql_string = "NOT (%s)" % sql_string
elif len(result) > 1 or self.resolved:
sql_string = "(%s)" % sql_string
return sql_string, result_params

def get_group_by_cols(self):
Expand Down
11 changes: 11 additions & 0 deletions docs/ref/exceptions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ Django core exception classes are defined in ``django.core.exceptions``.
return any results. Most Django projects won't encounter this exception,
but it might be useful for implementing custom lookups and expressions.

``FullResultSet``
-----------------

.. exception:: FullResultSet

.. versionadded:: 4.2

``FullResultSet`` may be raised during query generation if a query will
match everything. Most Django projects won't encounter this exception, but
it might be useful for implementing custom lookups and expressions.

``FieldDoesNotExist``
---------------------

Expand Down
19 changes: 17 additions & 2 deletions tests/annotations/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
When,
)
from django.db.models.expressions import RawSQL
from django.db.models.functions import Coalesce, ExtractYear, Floor, Length, Lower, Trim
from django.db.models.functions import (
Cast,
Coalesce,
ExtractYear,
Floor,
Length,
Lower,
Trim,
)
from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import register_lookup

Expand Down Expand Up @@ -282,6 +290,13 @@ def test_full_expression_annotation(self):
self.assertEqual(len(books), Book.objects.count())
self.assertTrue(all(book.selected for book in books))

def test_full_expression_wrapped_annotation(self):
books = Book.objects.annotate(
selected=Coalesce(~Q(pk__in=[]), True),
)
self.assertEqual(len(books), Book.objects.count())
self.assertTrue(all(book.selected for book in books))

def test_full_expression_annotation_with_aggregation(self):
qs = Book.objects.filter(isbn="159059725").annotate(
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
Expand All @@ -292,7 +307,7 @@ def test_full_expression_annotation_with_aggregation(self):
def test_aggregate_over_full_expression_annotation(self):
qs = Book.objects.annotate(
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
).aggregate(Sum("selected"))
).aggregate(selected__sum=Sum(Cast("selected", IntegerField())))
self.assertEqual(qs["selected__sum"], Book.objects.count())

def test_empty_queryset_annotation(self):
Expand Down
Loading

0 comments on commit 76e3751

Please sign in to comment.