diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index dd29068495ba..407681a41830 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -1,5 +1,6 @@ import datetime import decimal +import json from importlib import import_module import sqlparse @@ -575,6 +576,9 @@ def adapt_ipaddressfield_value(self, value): """ return value or None + def adapt_json_value(self, value, encoder): + return json.dumps(value, cls=encoder) + def year_lookup_bounds_for_date_field(self, value, iso_year=False): """ Return a two-elements list with the lower and upper bound to be used diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 2303703ebcc5..62273fc43c0b 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -1,4 +1,8 @@ +import json +from functools import lru_cache, partial + from psycopg2.extras import Inet +from psycopg2.extras import Json as Jsonb from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations @@ -6,6 +10,13 @@ from django.db.models.constants import OnConflict +@lru_cache +def get_json_dumps(encoder): + if encoder is None: + return json.dumps + return partial(json.dumps, cls=encoder) + + class DatabaseOperations(BaseDatabaseOperations): cast_char_field_without_max_length = "varchar" explain_prefix = "EXPLAIN" @@ -308,6 +319,9 @@ def adapt_ipaddressfield_value(self, value): return Inet(value) return None + def adapt_json_value(self, value, encoder): + return Jsonb(value, dumps=get_json_dumps(encoder)) + def subtract_temporals(self, internal_type, lhs, rhs): if internal_type == "DateField": lhs_sql, lhs_params = lhs diff --git a/django/db/models/fields/json.py b/django/db/models/fields/json.py index 22c7e2ad005f..c0242bd7bee6 100644 --- a/django/db/models/fields/json.py +++ b/django/db/models/fields/json.py @@ -6,7 +6,11 @@ from django.db.models import lookups from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import TextField -from django.db.models.lookups import PostgresOperatorLookup, Transform +from django.db.models.lookups import ( + FieldGetDbPrepValueMixin, + PostgresOperatorLookup, + Transform, +) from django.utils.translation import gettext_lazy as _ from . import Field @@ -92,10 +96,15 @@ def from_db_value(self, value, expression, connection): def get_internal_type(self): return "JSONField" - def get_prep_value(self, value): + def get_db_prep_value(self, value, connection, prepared=False): + if hasattr(value, "as_sql"): + return value + return connection.ops.adapt_json_value(value, self.encoder) + + def get_db_prep_save(self, value, connection): if value is None: return value - return json.dumps(value, cls=self.encoder) + return self.get_db_prep_value(value, connection) def get_transform(self, name): transform = super().get_transform(name) @@ -141,7 +150,7 @@ def compile_json_path(key_transforms, include_root=True): return "".join(path) -class DataContains(PostgresOperatorLookup): +class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup): lookup_name = "contains" postgres_operator = "@>" @@ -156,7 +165,7 @@ def as_sql(self, compiler, connection): return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params -class ContainedBy(PostgresOperatorLookup): +class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup): lookup_name = "contained_by" postgres_operator = "<@"