From c5946327c8324e28467f71653a4e2f50adfa356d Mon Sep 17 00:00:00 2001 From: qasimgulzar Date: Mon, 18 Nov 2024 14:34:04 +0500 Subject: [PATCH 1/2] fix: make database function compatible with postgresql and mysql both --- openedx_tagging/core/tagging/models/utils.py | 42 ++++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/openedx_tagging/core/tagging/models/utils.py b/openedx_tagging/core/tagging/models/utils.py index 86a5f128..e653df2b 100644 --- a/openedx_tagging/core/tagging/models/utils.py +++ b/openedx_tagging/core/tagging/models/utils.py @@ -3,6 +3,8 @@ """ from django.db.models import Aggregate, CharField from django.db.models.expressions import Func +from django.db import connection + RESERVED_TAG_CHARS = [ '\t', # Used in the database to separate tag levels in the "lineage" field @@ -34,21 +36,47 @@ def as_sqlite(self, compiler, connection, **extra_context): ) -class StringAgg(Aggregate): # pylint: disable=abstract-method +class StringAgg(Aggregate): """ Aggregate function that collects the values of some column across all rows, - and creates a string by concatenating those values, with "," as a separator. + and creates a string by concatenating those values, with a specified separator. - This is the same as Django's django.contrib.postgres.aggregates.StringAgg, - but this version works with MySQL and SQLite. + This version supports PostgreSQL (STRING_AGG), MySQL (GROUP_CONCAT), and SQLite. """ + # Default function is for MySQL (GROUP_CONCAT) function = 'GROUP_CONCAT' - template = '%(function)s(%(distinct)s%(expressions)s)' + template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)' + + def __init__(self, expression, distinct=False, delimiter=',', **extra): + + self.delimiter=delimiter + # Handle the distinct option and output type + distinct_str = 'DISTINCT ' if distinct else '' - def __init__(self, expression, distinct=False, **extra): + # Check the database backend (PostgreSQL, MySQL, or SQLite) + if 'postgresql' in connection.vendor.lower(): + self.function = 'STRING_AGG' + self.template = '%(function)s(%(distinct)s%(expressions)s, %(delimiter)s)' + elif 'mysql' in connection.vendor.lower() or 'sqlite' in connection.vendor.lower(): + self.function = 'GROUP_CONCAT' + self.template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)' + + # Initialize the parent class with the necessary parameters super().__init__( expression, - distinct='DISTINCT ' if distinct else '', + distinct=distinct_str, + delimiter=delimiter, output_field=CharField(), **extra, ) + + def as_sql(self, compiler, connection, **extra_context): + # If PostgreSQL, we use STRING_AGG with a separator + if 'postgresql' in connection.vendor.lower(): + # Ensure that expressions are cast to TEXT for PostgreSQL + expressions_sql, params = compiler.compile(self.source_expressions[0]) + expressions_sql = f"({expressions_sql})::TEXT" # Cast to TEXT for PostgreSQL + return f"{self.function}({expressions_sql}, {self.delimiter!r})", params + else: + # MySQL/SQLite handles GROUP_CONCAT with SEPARATOR + return super().as_sql(compiler, connection, **extra_context) From bc7a493f7ad50a14875648c42f23427580e783a5 Mon Sep 17 00:00:00 2001 From: qasimgulzar Date: Thu, 12 Dec 2024 16:09:27 +0500 Subject: [PATCH 2/2] fix: tests and quality issues --- openedx_tagging/core/tagging/models/utils.py | 37 +++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/openedx_tagging/core/tagging/models/utils.py b/openedx_tagging/core/tagging/models/utils.py index e653df2b..c9539299 100644 --- a/openedx_tagging/core/tagging/models/utils.py +++ b/openedx_tagging/core/tagging/models/utils.py @@ -3,8 +3,7 @@ """ from django.db.models import Aggregate, CharField from django.db.models.expressions import Func -from django.db import connection - +from django.db import connection as db_connection RESERVED_TAG_CHARS = [ '\t', # Used in the database to separate tag levels in the "lineage" field @@ -36,7 +35,10 @@ def as_sqlite(self, compiler, connection, **extra_context): ) -class StringAgg(Aggregate): +from django.db.models import Aggregate, CharField +from django.db.models.expressions import Combinable + +class StringAgg(Aggregate, Combinable): """ Aggregate function that collects the values of some column across all rows, and creates a string by concatenating those values, with a specified separator. @@ -45,28 +47,27 @@ class StringAgg(Aggregate): """ # Default function is for MySQL (GROUP_CONCAT) function = 'GROUP_CONCAT' - template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)' + template = '%(function)s(%(distinct)s%(expressions)s)' def __init__(self, expression, distinct=False, delimiter=',', **extra): - - self.delimiter=delimiter + self.delimiter = delimiter # Handle the distinct option and output type distinct_str = 'DISTINCT ' if distinct else '' + extra.update(dict( + distinct=distinct_str, + output_field=CharField() + )) + # Check the database backend (PostgreSQL, MySQL, or SQLite) - if 'postgresql' in connection.vendor.lower(): + if 'postgresql' in db_connection.vendor.lower(): self.function = 'STRING_AGG' self.template = '%(function)s(%(distinct)s%(expressions)s, %(delimiter)s)' - elif 'mysql' in connection.vendor.lower() or 'sqlite' in connection.vendor.lower(): - self.function = 'GROUP_CONCAT' - self.template = '%(function)s(%(distinct)s%(expressions)s SEPARATOR %(delimiter)s)' + extra.update({"delimiter": delimiter}) # Initialize the parent class with the necessary parameters super().__init__( expression, - distinct=distinct_str, - delimiter=delimiter, - output_field=CharField(), **extra, ) @@ -80,3 +81,13 @@ def as_sql(self, compiler, connection, **extra_context): else: # MySQL/SQLite handles GROUP_CONCAT with SEPARATOR return super().as_sql(compiler, connection, **extra_context) + + # Implementing abstract methods from Combinable + def __rand__(self, other): + return self._combine(other, 'AND', is_combinable=True) + + def __ror__(self, other): + return self._combine(other, 'OR', is_combinable=True) + + def __rxor__(self, other): + return self._combine(other, 'XOR', is_combinable=True)