Skip to content

Commit

Permalink
Fix #80 -- Address a crash on unrestricted select_related() usage.
Browse files Browse the repository at this point in the history
Thanks @JiriKr for the report.
  • Loading branch information
charettes committed Sep 4, 2024
1 parent 9797c00 commit 3806fae
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 9 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
1.6.3
=====
:release-date: unreleased

- Address a crash on unrestricted ``select_related()`` usage (#81)

1.6.2
=====
:release-date: 2024-08-12
Expand Down
40 changes: 34 additions & 6 deletions seal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,36 @@
from operator import attrgetter

from django.db import models
from django.db.models.query_utils import select_related_descend

cached_value_getter = attrgetter("get_cached_value")


def get_select_related_getters(lookups, opts):
def get_restricted_select_related_getters(lookups, opts):
"""Turn a select_related dict structure into a tree of attribute getters"""
for lookup, nested_lookups in lookups.items():
field = opts.get_field(lookup)
lookup_opts = field.related_model._meta
yield (
cached_value_getter(field),
tuple(get_select_related_getters(nested_lookups, lookup_opts)),
tuple(get_restricted_select_related_getters(nested_lookups, lookup_opts)),
)


def get_unrestricted_select_related_getters(opts, max_depth, cur_depth=1):
if cur_depth > max_depth:
return
for field in opts.fields:
if not select_related_descend(field, False, None, {}):
continue
related_model_meta = field.related_model._meta
yield (
cached_value_getter(field),
tuple(
get_unrestricted_select_related_getters(
related_model_meta, max_depth=max_depth, cur_depth=cur_depth + 1
)
),
)


Expand Down Expand Up @@ -44,12 +62,22 @@ def _sealed_related_iterator(self, related_walker):
yield obj

def __iter__(self):
select_related = self.queryset.query.select_related
query = self.queryset.query
select_related = query.select_related
if select_related:
opts = self.queryset.model._meta
select_related_getters = tuple(
get_select_related_getters(self.queryset.query.select_related, opts)
)
if isinstance(select_related, dict):
select_related_getters = tuple(
get_restricted_select_related_getters(
self.queryset.query.select_related, opts
)
)
else:
select_related_getters = tuple(
get_unrestricted_select_related_getters(
opts, max_depth=query.max_depth
)
)
related_walker = partial(
walk_select_relateds, getters=select_related_getters
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Location(SealableModel):
related_locations = models.ManyToManyField("self")


class Island(models.Model):
class Island(SealableModel):
# Explicitly avoid setting a related_name.
location = models.ForeignKey(Location, on_delete=models.CASCADE)

Expand Down
17 changes: 15 additions & 2 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def setUpTestData(cls):
cls.nickname = Nickname.objects.create(
name="Jonathan Livingston", content_object=cls.gull
)
cls.island = Island.objects.create(location=cls.location)
tests_models = tuple(apps.get_app_config("tests").get_models())
ContentType.objects.get_for_models(*tests_models, for_concrete_models=True)

Expand Down Expand Up @@ -182,6 +183,19 @@ def test_sealed_select_related_reverse_one_to_one(self):
with self.assertRaises(SeaLion.gull.RelatedObjectDoesNotExist):
instance.gull

def test_sealed_select_related_unrestricted(self):
instance = Island.objects.select_related().seal().get()
self.assertEqual(instance.location, self.location)
instance = SeaLion.objects.select_related().seal().get()
message = (
'Attempt to fetch related field "location" on sealed <SeaLion instance>'
)
with self.assertWarnsMessage(UnsealedAttributeAccess, message) as ctx:
# null=True relationships are not followed when using
# an unrestricted select_related()
instance.location
self.assertEqual(ctx.filename, __file__)

def test_sealed_prefetch_related_reverse_one_to_one(self):
instance = SeaLion.objects.prefetch_related("gull").seal().get()
self.assertEqual(instance.gull, self.gull)
Expand Down Expand Up @@ -439,9 +453,8 @@ def test_sealed_prefetched_select_related_many_to_many(self):
self.assertSequenceEqual(instance.location.climates.all(), [self.climate])

def test_prefetch_without_related_name(self):
island = Island.objects.create(location=self.location)
location = Location.objects.prefetch_related("island_set").seal().get()
self.assertSequenceEqual(location.island_set.all(), [island])
self.assertSequenceEqual(location.island_set.all(), [self.island])

def test_prefetch_combine(self):
with self.assertNumQueries(6):
Expand Down

0 comments on commit 3806fae

Please sign in to comment.