Skip to content

Commit

Permalink
Scientific Metadata Search Engine (Fulltext) implementation for Postg…
Browse files Browse the repository at this point in the history
…reSQL 🔎 (#640)

* Merge and reset with head upstream 👷‍♀️ Conform to pep∞

* 🗂️ Implement `jsonb_to_tsvector` which actually engages the index for real this time.

* 📜 Implement fulltext query in adapter.py

* ⬆️ Alembic migration creates text search index

* Upstream merge removed paren 🔣

* Add new migration to list in `core.py` ➕

* Black formatting for precommit 🧹

* Verbose index naming convention 🪪

* Change from a list of conditions to a single condition

* Skip unsupported `ts_vector` index creation for sqlite

* Preformat with black ⬛️ and enable tests ✅

* change op from sqlalchemy match to `to_tsquery` 🪄

* Fix oddity with plainto_tsquery vs to_tsquery

* Isolate migration from orm; orm may change!

* Provide a more useful high-level comment.

* Put index creation in its own function.

* Remove all allusions to case sensitivity from fulltext 🔠🧽

* Taking a heavier hand to some light-touch changes

---------

Co-authored-by: Dan Allan <[email protected]>
  • Loading branch information
Kezzsim and danielballan authored Jan 29, 2024
1 parent 1ff0885 commit eb79bdb
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 40 deletions.
8 changes: 7 additions & 1 deletion tiled/_tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
mapping["does_not_contain_z"] = ArrayAdapter.from_array(
numpy.ones(10), metadata={"letters": list(string.ascii_lowercase[:-1])}
)
mapping["full_text_test_case"] = ArrayAdapter.from_array(
numpy.ones(10), metadata={"color": "purple"}
)

mapping["specs_foo_bar"] = ArrayAdapter.from_array(numpy.ones(10), specs=["foo", "bar"])
mapping["specs_foo_bar_baz"] = ArrayAdapter.from_array(
Expand Down Expand Up @@ -159,7 +162,7 @@ def test_contains(client):


def test_full_text(client):
if client.metadata["backend"] in {"postgresql", "sqlite"}:
if client.metadata["backend"] in {"sqlite"}:

def cm():
return fail_with_status_code(400)
Expand All @@ -168,6 +171,9 @@ def cm():
cm = nullcontext
with cm():
assert list(client.search(FullText("z"))) == ["z", "does_contain_z"]
# plainto_tsquery fails to find certain words, weirdly, so it is a useful
# test that we are using tsquery
assert list(client.search(FullText("purple"))) == ["full_text_test_case"]


def test_regex(client):
Expand Down
14 changes: 3 additions & 11 deletions tiled/adapters/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def __init__(
specs : List[str], optional
access_policy : AccessPolicy, optional
entries_stale_after: timedelta
This server uses this to communite to the client how long
This server uses this to communicate to the client how long
it should rely on a local cache before checking back for changes.
metadata_stale_after: timedelta
This server uses this to communite to the client how long
This server uses this to communicate to the client how long
it should rely on a local cache before checking back for changes.
must_revalidate : bool
Whether the client should strictly refresh stale cache items.
Expand Down Expand Up @@ -336,20 +336,12 @@ def iter_child_metadata(query_key, tree):
def full_text_search(query, tree):
matches = {}
text = query.text
if query.case_sensitive:

def maybe_lower(s):
# no-op
return s

else:
maybe_lower = str.lower
query_words = set(text.split())
for key, value in tree.items():
words = set(
word
for s in walk_string_values(value.metadata())
for word in maybe_lower(s).split()
for word in s.lower().split()
)
# Note that `not set.isdisjoint` is faster than `set.intersection`. At
# the C level, `isdisjoint` loops over the set until it finds one match,
Expand Down
20 changes: 19 additions & 1 deletion tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
import anyio
from fastapi import HTTPException
from sqlalchemy import delete, event, func, not_, or_, select, text, type_coerce, update
from sqlalchemy.dialects.postgresql import JSONB, REGCONFIG
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql.expression import cast

from tiled.queries import (
Comparison,
Contains,
Eq,
FullText,
In,
KeysFilter,
NotEq,
Expand Down Expand Up @@ -997,6 +1000,20 @@ def contains(query, tree):
return tree.new_variation(conditions=tree.conditions + [condition])


def full_text(query, tree):
dialect_name = tree.engine.url.get_dialect().name
if dialect_name == "sqlite":
raise UnsupportedQueryType("full_text")
elif dialect_name == "postgresql":
tsvector = func.jsonb_to_tsvector(
cast("simple", REGCONFIG), orm.Node.metadata_, cast(["string"], JSONB)
)
condition = tsvector.op("@@")(func.to_tsquery("simple", query.text))
else:
raise UnsupportedQueryType("full_text")
return tree.new_variation(conditions=tree.conditions + [condition])


def specs(query, tree):
dialect_name = tree.engine.url.get_dialect().name
conditions = []
Expand Down Expand Up @@ -1068,7 +1085,8 @@ def structure_family(query, tree):
CatalogNodeAdapter.register_query(KeysFilter, keys_filter)
CatalogNodeAdapter.register_query(StructureFamilyQuery, structure_family)
CatalogNodeAdapter.register_query(SpecsQuery, specs)
# TODO: FullText, Regex
CatalogNodeAdapter.register_query(FullText, full_text)
# TODO: Regex


def in_memory(
Expand Down
1 change: 1 addition & 0 deletions tiled/catalog/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# This is list of all valid revisions (from current to oldest).
ALL_REVISIONS = [
"1cd99c02d0c7",
"a66028395cab",
"3db11ff95b6c",
"0b033e7fbe30",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Create index for fulltext search
Revision ID: 1cd99c02d0c7
Revises: a66028395cab
Create Date: 2024-01-24 15:53:12.348880
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import JSONB

# revision identifiers, used by Alembic.
revision = "1cd99c02d0c7"
down_revision = "a66028395cab"
branch_labels = None
depends_on = None

# Make JSONB available in column
JSONVariant = sa.JSON().with_variant(JSONB(), "postgresql")


def upgrade():
connection = op.get_bind()
if connection.engine.dialect.name == "postgresql":
with op.get_context().autocommit_block():
# There is no sane way to perform this using op.create_index()
op.execute(
"""
CREATE INDEX metadata_tsvector_search
ON nodes
USING gin (jsonb_to_tsvector('simple', metadata, '["string"]'))
"""
)


def downgrade():
# This _could_ be implemented but we will wait for a need since we are
# still in alpha releases.
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
import sqlalchemy as sa
from alembic import op

from tiled.catalog.orm import (
DataSourceAssetAssociation,
JSONVariant,
unique_parameter_num_null_check,
)
from tiled.catalog.orm import JSONVariant

# revision identifiers, used by Alembic.
revision = "a66028395cab"
Expand All @@ -32,7 +28,7 @@ def upgrade():
sa.Column("structure", JSONVariant),
)
data_source_asset_association = sa.Table(
DataSourceAssetAssociation.__tablename__,
"data_source_asset_association",
sa.MetaData(),
sa.Column("asset_id", sa.Integer),
sa.Column("data_source_id", sa.Integer),
Expand Down Expand Up @@ -67,11 +63,11 @@ def upgrade():

# Add columns 'parameter' and 'num' to association table.
op.add_column(
DataSourceAssetAssociation.__tablename__,
"data_source_asset_association",
sa.Column("parameter", sa.Unicode(255), nullable=True),
)
op.add_column(
DataSourceAssetAssociation.__tablename__,
"data_source_asset_association",
sa.Column("num", sa.Integer, nullable=True),
)

Expand Down Expand Up @@ -162,7 +158,7 @@ def upgrade():
if connection.engine.dialect.name == "sqlite":
# SQLite does not supported adding constraints to an existing table.
# We invoke its 'copy and move' functionality.
with op.batch_alter_table(DataSourceAssetAssociation.__tablename__) as batch_op:
with op.batch_alter_table("data_source_asset_association") as batch_op:
# Gotcha: This does not take table_name because it is bound into batch_op.
batch_op.create_unique_constraint(
"parameter_num_unique_constraint",
Expand All @@ -172,11 +168,15 @@ def upgrade():
"num",
],
)
# This creates a pair of triggers on the data_source_asset_association
# table. Each pair include one trigger that runs when NEW.num IS NULL and
# one trigger than runs when NEW.num IS NOT NULL. Thus, for a given insert,
# only one of these triggers is run.
with op.get_context().autocommit_block():
connection.execute(
sa.text(
"""
CREATE TRIGGER cannot_insert_num_null_if_num_int_exists
CREATE TRIGGER cannot_insert_num_null_if_num_exists
BEFORE INSERT ON data_source_asset_association
WHEN NEW.num IS NULL
BEGIN
Expand Down Expand Up @@ -214,14 +214,72 @@ def upgrade():
# PostgreSQL
op.create_unique_constraint(
"parameter_num_unique_constraint",
DataSourceAssetAssociation.__tablename__,
"data_source_asset_association",
[
"data_source_id",
"parameter",
"num",
],
)
unique_parameter_num_null_check(data_source_asset_association, connection)
connection.execute(
sa.text(
"""
CREATE OR REPLACE FUNCTION raise_if_parameter_exists()
RETURNS TRIGGER AS $$
BEGIN
IF EXISTS (
SELECT 1
FROM data_source_asset_association
WHERE parameter = NEW.parameter
AND data_source_id = NEW.data_source_id
) THEN
RAISE EXCEPTION 'Can only insert num=NULL if no other row exists for the same parameter';
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;"""
)
)
connection.execute(
sa.text(
"""
CREATE TRIGGER cannot_insert_num_null_if_num_exists
BEFORE INSERT ON data_source_asset_association
FOR EACH ROW
WHEN (NEW.num IS NULL)
EXECUTE FUNCTION raise_if_parameter_exists();"""
)
)
connection.execute(
sa.text(
"""
CREATE OR REPLACE FUNCTION raise_if_null_parameter_exists()
RETURNS TRIGGER AS $$
BEGIN
IF EXISTS (
SELECT 1
FROM data_source_asset_association
WHERE parameter = NEW.parameter
AND data_source_id = NEW.data_source_id
AND num IS NULL
) THEN
RAISE EXCEPTION 'Can only insert INTEGER num if no NULL row exists for the same parameter';
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;"""
)
)
connection.execute(
sa.text(
"""
CREATE TRIGGER cannot_insert_num_int_if_num_null_exists
BEFORE INSERT ON data_source_asset_association
FOR EACH ROW
WHEN (NEW.num IS NOT NULL)
EXECUTE FUNCTION raise_if_null_parameter_exists();"""
)
)


def downgrade():
Expand Down
27 changes: 23 additions & 4 deletions tiled/catalog/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Node(Timestamped, Base):
"id",
"metadata",
postgresql_using="gin",
),
)
# This is used by ORDER BY with the default sorting.
# Index("ancestors_time_created", "ancestors", "time_created"),
)
Expand Down Expand Up @@ -149,9 +149,12 @@ class DataSourceAssetAssociation(Base):

@event.listens_for(DataSourceAssetAssociation.__table__, "after_create")
def unique_parameter_num_null_check(target, connection, **kw):
# Ensure that we cannot mix NULL and INTEGER values of num for
# a given data_source_id and parameter, and that there cannot be multiple
# instances of NULL.
# This creates a pair of triggers on the data_source_asset_association
# table. (There are a total of four defined below, two for the SQLite
# branch and two for the PostgreSQL branch.) Each pair include one trigger
# that runs when NEW.num IS NULL and one trigger than runs when
# NEW.num IS NOT NULL. Thus, for a given insert, only one of these
# triggers is run.
if connection.engine.dialect.name == "sqlite":
connection.execute(
text(
Expand Down Expand Up @@ -252,6 +255,22 @@ def unique_parameter_num_null_check(target, connection, **kw):
)


@event.listens_for(DataSourceAssetAssociation.__table__, "after_create")
def create_index_metadata_tsvector_search(target, connection, **kw):
# This creates a ts_vector based metadata search index for fulltext.
# Postgres only feature
if connection.engine.dialect.name == "postgresql":
connection.execute(
text(
"""
CREATE INDEX metadata_tsvector_search
ON nodes
USING gin (jsonb_to_tsvector('simple', metadata, '["string"]'))
"""
)
)


class DataSource(Timestamped, Base):
"""
The describes how to open one or more file/blobs to extract data for a Node.
Expand Down
13 changes: 3 additions & 10 deletions tiled/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,16 @@ class FullText(NoBool):
Parameters
----------
text : str
case_sensitive : bool, optional
Default False (case-insensitive).
"""

text: str
case_sensitive: bool = False

def encode(self):
return {"text": self.text, "case_sensitive": json.dumps(self.case_sensitive)}
return {"text": self.text}

@classmethod
def decode(cls, *, text, case_sensitive=False):
# Note: FastAPI decodes case_sensitive into a boolean for us.
return cls(
text=text,
case_sensitive=case_sensitive,
)
def decode(cls, *, text):
return cls(text=text)


@register(name="lookup")
Expand Down
1 change: 0 additions & 1 deletion web-frontend/src/openapi_schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ export interface operations {
sort?: string;
omit_links?: boolean;
"filter[fulltext][condition][text]"?: string[];
"filter[fulltext][condition][case_sensitive]"?: boolean[];
"filter[lookup][condition][key]"?: string[];
};
};
Expand Down

0 comments on commit eb79bdb

Please sign in to comment.