From 2a70058c3db8b1cfc2ec4bfe07ea774162306ac6 Mon Sep 17 00:00:00 2001 From: Ben Stafford Date: Wed, 20 Mar 2024 12:07:05 -0400 Subject: [PATCH] VNDLY-42402: Pull in changes from https://github.com/bernardopires/django-tenant-schemas/pull/567 --- tenant_schemas/__init__.py | 2 +- tenant_schemas/postgresql_backend/base.py | 23 +++++++++------- .../postgresql_backend/introspection.py | 26 +++++++++++++++++++ 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/tenant_schemas/__init__.py b/tenant_schemas/__init__.py index fefb37b8..c2fb8cd9 100644 --- a/tenant_schemas/__init__.py +++ b/tenant_schemas/__init__.py @@ -1,3 +1,3 @@ default_app_config = 'tenant_schemas.apps.TenantSchemaConfig' -__version__ = "v1.9.0-vndly-0.0.4" +__version__ = "v1.9.0-vndly-0.0.5" diff --git a/tenant_schemas/postgresql_backend/base.py b/tenant_schemas/postgresql_backend/base.py index 212d3eae..5a539d8d 100644 --- a/tenant_schemas/postgresql_backend/base.py +++ b/tenant_schemas/postgresql_backend/base.py @@ -1,6 +1,5 @@ import re import warnings -import psycopg2 from django.conf import settings from django.contrib.contenttypes.models import ContentType @@ -10,6 +9,15 @@ from tenant_schemas.utils import get_public_schema_name, get_limit_set_calls from tenant_schemas.postgresql_backend.introspection import DatabaseSchemaIntrospection +try: + from django.db.backends.postgresql.psycopg_any import is_psycopg3 +except ImportError: + is_psycopg3 = False + +if is_psycopg3: + import psycopg +else: + import psycopg2 as psycopg ORIGINAL_BACKEND = getattr(settings, 'ORIGINAL_BACKEND', 'django.db.backends.postgresql_psycopg2') # Django 1.9+ takes care to rename the default backend to 'django.db.backends.postgresql' @@ -142,12 +150,9 @@ def _cursor(self, name=None): search_paths.extend(EXTRA_SEARCH_PATHS) - if name: - # Named cursor can only be used once - cursor_for_search_path = self.connection.cursor() - else: - # Reuse - cursor_for_search_path = cursor + # Named cursor can only be used once, just like psycopg3 cursors. + needs_new_cursor = name or is_psycopg3 + cursor_for_search_path = self.connection.cursor() if needs_new_cursor else cursor # In the event that an error already happened in this transaction and we are going # to rollback we should just ignore database error when setting the search_path @@ -155,12 +160,12 @@ def _cursor(self, name=None): # we do not have to worry that it's not the good one try: cursor_for_search_path.execute('SET search_path = {0}'.format(','.join(search_paths))) - except (django.db.utils.DatabaseError, psycopg2.InternalError): + except (django.db.utils.DatabaseError, psycopg.InternalError): self.search_path_set = False else: self.search_path_set = True - if name: + if needs_new_cursor: cursor_for_search_path.close() return cursor diff --git a/tenant_schemas/postgresql_backend/introspection.py b/tenant_schemas/postgresql_backend/introspection.py index d6fb9105..e5fffee6 100644 --- a/tenant_schemas/postgresql_backend/introspection.py +++ b/tenant_schemas/postgresql_backend/introspection.py @@ -174,6 +174,21 @@ class DatabaseSchemaIntrospection(BaseDatabaseIntrospection): GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions; """ + _get_sequences_query = """ + SELECT s.relname as sequence_name, col.attname + FROM pg_class s + JOIN pg_namespace sn ON sn.oid = s.relnamespace + JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid='pg_class'::regclass + JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass + JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum + JOIN pg_class tbl ON tbl.oid = ad.adrelid + JOIN pg_namespace n ON n.oid = tbl.relnamespace + WHERE s.relkind = 'S' + AND d.deptype in ('a', 'n') + AND n.nspname = %(schema)s + AND tbl.relname = %(table)s + """ + def get_field_type(self, data_type, description): field_type = super(DatabaseSchemaIntrospection, self).get_field_type(data_type, description) if description.default and 'nextval' in description.default: @@ -315,3 +330,14 @@ def get_constraints(self, cursor, table_name): "options": options, } return constraints + + def get_sequences(self, cursor, table_name, table_fields=()): + sequences = [] + cursor.execute(self._get_sequences_query, { + 'schema': self.connection.schema_name, + 'table': table_name, + }) + + for row in cursor.fetchall(): + sequences.append({'name': row[0], 'table': table_name, 'column': row[1]}) + return sequences