Skip to content

Commit

Permalink
Move iterator code to a single module
Browse files Browse the repository at this point in the history
For the Django 4 upgrade fix these objects get more dependency,
so its better to have them in a separate module. This also keeps all
iterator magic in a single place.
  • Loading branch information
vdboor committed Jul 18, 2024
1 parent ae01f55 commit aad92f3
Show file tree
Hide file tree
Showing 9 changed files with 430 additions and 408 deletions.
169 changes: 3 additions & 166 deletions src/rest_framework_dso/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,20 @@
from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Iterator
from copy import copy
from dataclasses import dataclass
from functools import cached_property
from itertools import islice
from typing import TypeVar

from django.db import models
from django.db.models import ForeignObjectRel
from drf_spectacular.drainage import get_override
from lru import LRU
from rest_framework import serializers
from rest_framework.exceptions import ParseError

from rest_framework_dso.fields import AbstractEmbeddedField
from rest_framework_dso.serializer_helpers import ReturnGenerator, peek_iterable
from rest_framework_dso.iterators import ChunkedQuerySetIterator
from rest_framework_dso.serializer_helpers import ReturnGenerator
from rest_framework_dso.utils import (
DictOfDicts,
get_serializer_relation_lookups,
Expand All @@ -33,10 +30,6 @@

logger = logging.getLogger(__name__)

T = TypeVar("T")
M = TypeVar("M", bound=models.Model)

DEFAULT_SQL_CHUNK_SIZE = 2000 # allow unit tests to alter this.
MAX_EXPAND_ALL_DEPTH = 2


Expand Down Expand Up @@ -333,162 +326,6 @@ def get_embedded_field(
return _real_get_embedded_field(field_name, prefix=prefix)


class ChunkedQuerySetIterator(Iterable[M]):
"""An optimal strategy to perform ``prefetch_related()`` on large datasets.
It fetches data from the queryset in chunks,
and performs ``prefetch_related()`` behavior on each chunk.
Django's ``QuerySet.prefetch_related()`` works by loading the whole queryset into memory,
and performing an analysis of the related objects to fetch. When working on large datasets,
this is very inefficient as more memory is consumed. Instead, ``QuerySet.iterator()``
is preferred here as it returns instances while reading them. Nothing is stored in memory.
Hence, both approaches are fundamentally incompatible. This class performs a
mixed strategy: load a chunk, and perform prefetches for that particular batch.
As extra performance benefit, a local cache avoids prefetching the same records
again when the next chunk is analysed. It has a "least recently used" cache to avoid
flooding the caches when foreign keys constantly point to different unique objects.
"""

def __init__(self, queryset: models.QuerySet, chunk_size=None, sql_chunk_size=None):
"""
:param queryset: The queryset to iterate over, that has ``prefetch_related()`` data.
:param chunk_size: The size of each segment to analyse in-memory for related objects.
:param sql_chunk_size: The size of each segment to fetch from the database,
used when server-side cursors are not available. The default follows Django behavior.
"""
self.queryset = queryset
self.sql_chunk_size = sql_chunk_size or DEFAULT_SQL_CHUNK_SIZE
self.chunk_size = chunk_size or self.sql_chunk_size
self._fk_caches = defaultdict(lambda: LRU(self.chunk_size // 2))

def __iter__(self):
# Using iter() ensures the ModelIterable is resumed with the next chunk.
qs_iter = iter(self.queryset.iterator(chunk_size=self.sql_chunk_size))
chunk_id = 0

# Keep fetching chunks
while instances := list(islice(qs_iter, self.chunk_size)):
# Perform prefetches on this chunk:
if self.queryset._prefetch_related_lookups:
self._add_prefetches(instances, chunk_id)
chunk_id += 1

yield from instances

def _add_prefetches(self, instances: list[M], chunk_id):
"""Merge the prefetched objects for this batch with the model instances."""
if self._fk_caches:
# Make sure prefetch_related_objects() doesn't have to fetch items again
# that infrequently changes (e.g. a "wijk" or "stadsdeel").
all_restored = self._restore_caches(instances)
if all_restored:
logger.debug("[chunk %d] No additional prefetched needed.", chunk_id)
return

logger.debug("[chunk %d] Prefetching related objects...", chunk_id)

# Reuse the Django machinery for retrieving missing sub objects.
# and analyse the ForeignKey caches to allow faster prefetches next time
models.prefetch_related_objects(instances, *self.queryset._prefetch_related_lookups)
self._persist_prefetch_cache(instances)
logger.debug("[chunk %d] ...done prefetching related objects.", chunk_id)

def _persist_prefetch_cache(self, instances):
"""Store the prefetched data so it can be applied to the next batch"""
for instance in instances:
for lookup, obj in instance._state.fields_cache.items():
if obj is not None:
cache = self._fk_caches[lookup]
cache[obj.pk] = obj

def _restore_caches(self, instances) -> bool:
"""Restore prefetched data to the new set of instances.
This avoids unneeded prefetching of the same ForeignKey relation.
"""
if not instances:
return True
if not self._fk_caches:
return False

all_restored = True

for lookup, cache in self._fk_caches.items():
field = instances[0]._meta.get_field(lookup)
for instance in instances:
id_value = getattr(instance, field.attname)
if id_value is None:
continue

if (obj := cache.get(id_value, None)) is not None:
instance._state.fields_cache[lookup] = obj
else:
all_restored = False

if all_restored:
logger.debug("All prefetches restored from cache")

return all_restored


class ObservableIterator(Iterator[T]):
"""Observe the objects that are being returned.
Unlike itertools.tee(), retrieved objects are directly processed by other functions.
As built-in feature, the number of returned objects is also counted.
"""

def __init__(self, iterable: Iterable[T], observers=None):
self.number_returned = 0
self._iterable = iter(iterable)
self._item_callbacks = list(observers) if observers else []
self._has_items = None
self._is_iterated = False

def add_observer(self, callback: Callable[[T], None]):
"""Install an observer callback that is notified when items are iterated"""
self._item_callbacks.append(callback)

def clear_observers(self):
"""Remove all observers"""
self._item_callbacks = []

def __iter__(self) -> ObservableIterator[T]:
return self

def __next__(self) -> T:
"""Keep a count of the returned items, and allow to notify other generators"""
try:
value = next(self._iterable)
except StopIteration:
self._is_iterated = True
raise

self.number_returned += 1
self._has_items = True

# Notify observers
for notify_callback in self._item_callbacks:
notify_callback(value)

return value

def is_iterated(self):
"""Tell whether the iterator has finished."""
return self._is_iterated

def __bool__(self):
"""Tell whether the generator would contain items."""
if self._has_items is None:
# Perform an inspection of the generator:
first_item, items = peek_iterable(self._iterable)
self._iterable = items
self._has_items = first_item is not None

return self._has_items


class EmbeddedResultSet(ReturnGenerator):
"""A wrapper for the returned expanded fields.
This is used in combination with the ObservableIterator.
Expand Down
Loading

0 comments on commit aad92f3

Please sign in to comment.