Skip to content

Commit

Permalink
fix: Use except distinct and intersect distinct (#1094)
Browse files Browse the repository at this point in the history
Co-authored-by: Lingqing Gan <[email protected]>
  • Loading branch information
aholyoke and Linchin authored Jul 18, 2024
1 parent 9e0b117 commit 80781ef
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sqlalchemy_bigquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class BigQueryCompiler(_struct.SQLCompiler, vendored_postgresql.PGCompiler):
compound_keywords = SQLCompiler.compound_keywords.copy()
compound_keywords[selectable.CompoundSelect.UNION] = "UNION DISTINCT"
compound_keywords[selectable.CompoundSelect.UNION_ALL] = "UNION ALL"
compound_keywords[selectable.CompoundSelect.EXCEPT] = "EXCEPT DISTINCT"
compound_keywords[selectable.CompoundSelect.INTERSECT] = "INTERSECT DISTINCT"

def __init__(self, dialect, statement, *args, **kwargs):
if isinstance(statement, Column):
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,94 @@ def test_typed_parameters(faux_conn, type_, val, btype, vrep):
)


def test_except(faux_conn):
table = setup_table(
faux_conn,
"table",
sqlalchemy.Column("id", sqlalchemy.Integer),
sqlalchemy.Column("foo", sqlalchemy.Integer),
)

s1 = sqlalchemy.select(table.c.foo).where(table.c.id >= 2)
s2 = sqlalchemy.select(table.c.foo).where(table.c.id >= 4)

s3 = s1.except_(s2)

result = s3.compile(faux_conn).string

expected = (
"SELECT `table`.`foo` \n"
"FROM `table` \n"
"WHERE `table`.`id` >= %(id_1:INT64)s EXCEPT DISTINCT SELECT `table`.`foo` \n"
"FROM `table` \n"
"WHERE `table`.`id` >= %(id_2:INT64)s"
)
assert result == expected


def test_intersect(faux_conn):
table = setup_table(
faux_conn,
"table",
sqlalchemy.Column("id", sqlalchemy.Integer),
sqlalchemy.Column("foo", sqlalchemy.Integer),
)

s1 = sqlalchemy.select(table.c.foo).where(table.c.id >= 2)
s2 = sqlalchemy.select(table.c.foo).where(table.c.id >= 4)

s3 = s1.intersect(s2)

result = s3.compile(faux_conn).string

expected = (
"SELECT `table`.`foo` \n"
"FROM `table` \n"
"WHERE `table`.`id` >= %(id_1:INT64)s INTERSECT DISTINCT SELECT `table`.`foo` \n"
"FROM `table` \n"
"WHERE `table`.`id` >= %(id_2:INT64)s"
)
assert result == expected


def test_union(faux_conn):
table = setup_table(
faux_conn,
"table",
sqlalchemy.Column("id", sqlalchemy.Integer),
sqlalchemy.Column("foo", sqlalchemy.Integer),
)

s1 = sqlalchemy.select(table.c.foo).where(table.c.id >= 2)
s2 = sqlalchemy.select(table.c.foo).where(table.c.id >= 4)

s3 = s1.union(s2)

result = s3.compile(faux_conn).string

expected = (
"SELECT `table`.`foo` \n"
"FROM `table` \n"
"WHERE `table`.`id` >= %(id_1:INT64)s UNION DISTINCT SELECT `table`.`foo` \n"
"FROM `table` \n"
"WHERE `table`.`id` >= %(id_2:INT64)s"
)
assert result == expected

s4 = s1.union_all(s2)

result = s4.compile(faux_conn).string

expected = (
"SELECT `table`.`foo` \n"
"FROM `table` \n"
"WHERE `table`.`id` >= %(id_1:INT64)s UNION ALL SELECT `table`.`foo` \n"
"FROM `table` \n"
"WHERE `table`.`id` >= %(id_2:INT64)s"
)
assert result == expected


def test_select_struct(faux_conn, metadata):
from sqlalchemy_bigquery import STRUCT

Expand Down

0 comments on commit 80781ef

Please sign in to comment.