Skip to content

Commit

Permalink
Support ''Specs', Contains', 'In', and 'NotIn' queries on PostgreSQL …
Browse files Browse the repository at this point in the history
…backend (#622)

* Base support of "Contains" operator

* In ⎆ & NotIn 🚷 support (similar to eq, contains)

* In works, NotIn does not

* Trouble with test comparing the wrong list

* Weird logical differentes between pgsql and sqlte❔

* Primitive specs support, missing test data

* specs postgres

* Some absolutley mental stuff going on with SQLite 🧠

* Specs 🕶️ in SQLite

* Clean up 🧹 development artifacts
  • Loading branch information
Kezzsim authored Jan 2, 2024
1 parent 3adcc9c commit e7c370d
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 70 deletions.
88 changes: 34 additions & 54 deletions tiled/_tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
keys = list(string.ascii_lowercase)
mapping = {
letter: ArrayAdapter.from_array(
number * numpy.ones(10), metadata={"letter": letter, "number": number}
number * numpy.ones(10),
metadata={"letter": letter, "number": number},
specs=[letter],
)
for letter, number in zip(keys, range(26))
}
Expand Down Expand Up @@ -75,7 +77,9 @@ async def client(request, tmpdir_module):
with Context.from_app(app) as context:
client = from_context(context)
for k, v in mapping.items():
client.write_array(v.read(), key=k, metadata=dict(v.metadata()))
client.write_array(
v.read(), key=k, metadata=dict(v.metadata()), specs=v.specs
)
yield client
elif request.param == "postgresql":
if not TILED_TEST_POSTGRESQL_URI:
Expand Down Expand Up @@ -104,7 +108,12 @@ async def client(request, tmpdir_module):
client = from_context(context)
# Write data into catalog.
for k, v in mapping.items():
client.write_array(v.read(), key=k, metadata=dict(v.metadata()))
client.write_array(
v.read(),
key=k,
metadata=dict(v.metadata()),
specs=v.specs,
)
yield client
else:
assert False
Expand Down Expand Up @@ -146,15 +155,7 @@ def test_comparison(client):


def test_contains(client):
if client.metadata["backend"] == "postgresql":

def cm():
return fail_with_status_code(400)

else:
cm = nullcontext
with cm():
assert list(client.search(Contains("letters", "z"))) == ["does_contain_z"]
assert list(client.search(Contains("letters", "z"))) == ["does_contain_z"]


def test_full_text(client):
Expand Down Expand Up @@ -215,19 +216,11 @@ def test_not_and_and_or(client):
],
)
def test_in(client, query_values):
if client.metadata["backend"] == "postgresql":

def cm():
return fail_with_status_code(400)

else:
cm = nullcontext
with cm():
assert sorted(list(client.search(In("letter", query_values)))) == [
"a",
"k",
"z",
]
assert sorted(list(client.search(In("letter", query_values)))) == [
"a",
"k",
"z",
]


@pytest.mark.parametrize(
Expand All @@ -240,17 +233,20 @@ def cm():
],
)
def test_notin(client, query_values):
if client.metadata["backend"] == "postgresql":

def cm():
return fail_with_status_code(400)

else:
cm = nullcontext
with cm():
assert sorted(list(client.search(NotIn("letter", query_values)))) == sorted(
list(set(keys) - set(["a", "k", "z"]))
# TODO: Postgres and SQlite ACTUALLY treat this query differently in external testing.
# SQLite WILL NOT include fields that do not have the key, which is correct.
# Postgres WILL include fields that do not have the key,
# because by extension they do not have the value. Also correct. Why?
assert sorted(list(client.search(NotIn("letter", query_values)))) == sorted(
list(
set(
list(mapping.keys())
if client.metadata["backend"] == "postgresql"
else keys
)
- set(["a", "k", "z"])
)
)


@pytest.mark.parametrize(
Expand All @@ -263,25 +259,9 @@ def cm():
],
)
def test_specs(client, include_values, exclude_values):
if client.metadata["backend"] in {"postgresql", "sqlite"}:

def cm():
return fail_with_status_code(400)

else:
cm = nullcontext
with pytest.raises(TypeError):
SpecsQuery("foo")

with cm():
assert sorted(
list(client.search(SpecsQuery(include=include_values)))
) == sorted(["specs_foo_bar", "specs_foo_bar_baz"])

with cm():
assert list(
client.search(SpecsQuery(include=include_values, exclude=exclude_values))
) == ["specs_foo_bar"]
assert list(
client.search(SpecsQuery(include=include_values, exclude=exclude_values))
) == ["specs_foo_bar"]


def test_structure_families(client):
Expand Down
83 changes: 67 additions & 16 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import anyio
import httpx
from fastapi import HTTPException
from sqlalchemy import delete, event, func, select, text, type_coerce, update
from sqlalchemy import delete, event, func, not_, or_, select, text, type_coerce, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import create_async_engine

Expand All @@ -26,6 +26,7 @@
NotEq,
NotIn,
Operator,
SpecsQuery,
StructureFamilyQuery,
)

Expand Down Expand Up @@ -966,30 +967,58 @@ def comparison(query, tree):


def contains(query, tree):
dialect_name = tree.engine.url.get_dialect().name
attr = orm.Node.metadata_[query.key.split(".")]
if dialect_name == "sqlite":
condition = _get_value(attr, type(query.value)).contains(query.value)
else:
raise UnsupportedQueryType("Contains")
condition = _get_value(attr, type(query.value)).contains(query.value)
return tree.new_variation(conditions=tree.conditions + [condition])


def specs(query, tree):
raise UnsupportedQueryType("Specs")
# conditions = []
# for spec in query.include:
# conditions.append(func.json_contains(orm.Node.specs, spec))
# for spec in query.exclude:
# conditions.append(not_(func.json_contains(orm.Node.specs.contains, spec)))
# return tree.new_variation(conditions=tree.conditions + conditions)
dialect_name = tree.engine.url.get_dialect().name
conditions = []
attr = orm.Node.specs
if dialect_name == "sqlite":
# Construct the conditions for includes
for i, name in enumerate(query.include):
conditions.append(attr.like(f'%{{"name":"{name}",%'))
# Construct the conditions for excludes
for i, name in enumerate(query.exclude):
conditions.append(not_(attr.like(f'%{{"name":"{name}",%')))
elif dialect_name == "postgresql":
if query.include:
conditions.append(attr.op("@>")(specs_array_to_json(query.include)))
if query.exclude:
conditions.append(not_(attr.op("@>")(specs_array_to_json(query.exclude))))
else:
raise UnsupportedQueryType("specs")
return tree.new_variation(conditions=tree.conditions + conditions)


def in_or_not_in(query, tree, method):
dialect_name = tree.engine.url.get_dialect().name
attr = orm.Node.metadata_[query.key.split(".")]
keys = query.key.split(".")
attr = orm.Node.metadata_[keys]
if dialect_name == "sqlite":
condition = getattr(_get_value(attr, type(query.value[0])), method)(query.value)
elif dialect_name == "postgresql":
# Engage btree_gin index with @> operator
if method == "in_":
condition = or_(
*(
orm.Node.metadata_.op("@>")(key_array_to_json(keys, item))
for item in query.value
)
)
elif method == "not_in":
condition = not_(
or_(
*(
orm.Node.metadata_.op("@>")(key_array_to_json(keys, item))
for item in query.value
)
)
)
else:
raise UnsupportedQueryType("NotIn")
else:
raise UnsupportedQueryType("In/NotIn")
return tree.new_variation(conditions=tree.conditions + [condition])
Expand All @@ -1013,8 +1042,8 @@ def structure_family(query, tree):
CatalogNodeAdapter.register_query(NotIn, partial(in_or_not_in, method="not_in"))
CatalogNodeAdapter.register_query(KeysFilter, keys_filter)
CatalogNodeAdapter.register_query(StructureFamilyQuery, structure_family)
# CatalogNodeAdapter.register_query(Specs, specs)
# TODO: FullText, Regex, Specs
CatalogNodeAdapter.register_query(SpecsQuery, specs)
# TODO: FullText, Regex


def in_memory(
Expand Down Expand Up @@ -1129,6 +1158,28 @@ def key_array_to_json(keys, value):
return {keys[0]: reduce(lambda x, y: {y: x}, keys[1:][::-1], value)}


def specs_array_to_json(specs):
"""Take array of Specs strings and convert them to a `penguin` @> friendly array
Assume constructed array will feature keys called "name"
Parameters
----------
specs : iterable
An array of specs strings to be searched for.
Returns
-------
json
JSON object for use in postgresql queries.
Examples
--------
>>> specs_array_to_json(['foo','bar'])
[{"name":"foo"},{"name":"bar"}]
"""
return [{"name": spec} for spec in specs]


STRUCTURES = {
StructureFamily.container: CatalogContainerAdapter,
StructureFamily.array: CatalogArrayAdapter,
Expand Down

0 comments on commit e7c370d

Please sign in to comment.