Skip to content

Commit

Permalink
Feat: Add Cortex-compatible Snowflake SQL processor for storing vecto…
Browse files Browse the repository at this point in the history
…r data; Enable native merge upsert for Snowflake caches (#203)

Co-authored-by: Aaron ("AJ") Steers <[email protected]>
  • Loading branch information
bindipankhudi and aaronsteers authored May 10, 2024
1 parent cd1327a commit 51026ee
Show file tree
Hide file tree
Showing 10 changed files with 836 additions and 276 deletions.
16 changes: 16 additions & 0 deletions airbyte/_processors/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,18 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
"""SQL processors."""

from __future__ import annotations

from airbyte._processors.sql.snowflakecortex import (
SnowflakeCortexSqlProcessor,
SnowflakeCortexTypeConverter,
)


__all__ = [
# Classes
"SnowflakeCortexSqlProcessor",
"SnowflakeCortexTypeConverter",
# modules
"snowflakecortex",
]
1 change: 1 addition & 0 deletions airbyte/_processors/sql/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class SnowflakeSqlProcessor(SqlProcessorBase):

file_writer_class = JsonlWriter
type_converter_class = SnowflakeTypeConverter
supports_merge_insert = True

@overrides
def _write_files_to_new_table(
Expand Down
237 changes: 237 additions & 0 deletions airbyte/_processors/sql/snowflakecortex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
"""A Snowflake vector store implementation of the SQL processor."""

from __future__ import annotations

from textwrap import dedent, indent
from typing import TYPE_CHECKING

import sqlalchemy
from overrides import overrides
from sqlalchemy import text

from airbyte import exceptions as exc
from airbyte._processors.base import RecordProcessor
from airbyte._processors.sql.snowflake import SnowflakeSqlProcessor, SnowflakeTypeConverter
from airbyte.caches._catalog_manager import CatalogManager


if TYPE_CHECKING:
from pathlib import Path

from sqlalchemy.engine import Connection, Engine

from airbyte_cdk.models import ConfiguredAirbyteCatalog

from airbyte._processors.file.base import FileWriterBase
from airbyte.caches.base import CacheBase


class SnowflakeCortexTypeConverter(SnowflakeTypeConverter):
"""A class to convert array type into vector."""

def __init__(
self,
conversion_map: dict | None = None,
*,
vector_length: int,
) -> None:
self.vector_length = vector_length
super().__init__(conversion_map)

@overrides
def to_sql_type(
self,
json_schema_property_def: dict[str, str | dict | list],
) -> sqlalchemy.types.TypeEngine:
"""Convert a value to a SQL type."""
sql_type = super().to_sql_type(json_schema_property_def)
if isinstance(sql_type, sqlalchemy.types.ARRAY):
# SQLAlchemy doesn't yet support the `VECTOR` data type.
# We may want to remove this or update once this resolves:
# https://github.com/snowflakedb/snowflake-sqlalchemy/issues/499
return f"VECTOR(FLOAT, {self.vector_length})"

return sql_type


class SnowflakeCortexSqlProcessor(SnowflakeSqlProcessor):
"""A Snowflake implementation for use with Cortex functions."""

supports_merge_insert = True

def __init__(
self,
cache: CacheBase,
catalog: ConfiguredAirbyteCatalog,
vector_length: int,
source_name: str,
stream_names: set[str],
*,
file_writer: FileWriterBase | None = None,
) -> None:
"""Custom initialization: Initialize type_converter with vector_length."""
self._catalog = catalog
# to-do: see if we can get rid of the following assignment
self.source_catalog = catalog
self._vector_length = vector_length
self._engine: Engine | None = None
self._connection_to_reuse: Connection | None = None

# call base class to do necessary initialization
RecordProcessor.__init__(self, cache=cache, catalog_manager=None)
self._ensure_schema_exists()
self._catalog_manager = CatalogManager(
engine=self.get_sql_engine(),
table_name_resolver=lambda stream_name: self.get_sql_table_name(stream_name),
)

# TODO: read streams and source from catalog if not provided

# initialize catalog manager by registering source
self.register_source(
source_name=source_name,
incoming_source_catalog=self._catalog,
stream_names=stream_names,
)
self.file_writer = file_writer or self.file_writer_class(cache)
self.type_converter = SnowflakeCortexTypeConverter(vector_length=vector_length)
self._cached_table_definitions: dict[str, sqlalchemy.Table] = {}

def _get_column_list_from_table(
self,
table_name: str,
) -> list[str]:
"""Get column names for passed stream.
This is overridden due to lack of SQLAlchemy compatibility for the
`VECTOR` data type.
"""
conn: Connection = self.cache.get_vendor_client()
cursor = conn.cursor()
cursor.execute(f"DESCRIBE TABLE {table_name};")
results = cursor.fetchall()
column_names = [row[0].lower() for row in results]
cursor.close()
conn.close()
return column_names

@overrides
def _ensure_compatible_table_schema(
self,
stream_name: str,
*,
raise_on_error: bool = True,
) -> bool:
"""Read the exsting table schema using Snowflake python connector"""
json_schema = self.get_stream_json_schema(stream_name)
stream_column_names: list[str] = json_schema["properties"].keys()
table_column_names: list[str] = self._get_column_list_from_table(stream_name)

lower_case_table_column_names = self.normalizer.normalize_set(table_column_names)
missing_columns = [
stream_col
for stream_col in stream_column_names
if self.normalizer.normalize(stream_col) not in lower_case_table_column_names
]
# TODO: shouldn't we just return false here, so missing tables can be created ?
if missing_columns:
if raise_on_error:
raise exc.PyAirbyteCacheTableValidationError(
violation="Cache table is missing expected columns.",
context={
"stream_column_names": stream_column_names,
"table_column_names": table_column_names,
"missing_columns": missing_columns,
},
)
return False # Some columns are missing.

return True # All columns exist.

@overrides
def _write_files_to_new_table(
self,
files: list[Path],
stream_name: str,
batch_id: str,
) -> str:
"""Write files to a new table."""
temp_table_name = self._create_table_for_loading(
stream_name=stream_name,
batch_id=batch_id,
)
internal_sf_stage_name = f"@%{temp_table_name}"

def path_str(path: Path) -> str:
return str(path.absolute()).replace("\\", "\\\\")

put_files_statements = "\n".join(
[f"PUT 'file://{path_str(file_path)}' {internal_sf_stage_name};" for file_path in files]
)
self._execute_sql(put_files_statements)
columns_list = [
self._quote_identifier(c)
for c in list(self._get_sql_column_definitions(stream_name).keys())
]
files_list = ", ".join([f"'{f.name}'" for f in files])
columns_list_str: str = indent("\n, ".join(columns_list), " " * 12)

# following two lines are different from SnowflakeSqlProcessor
vector_suffix = f"::Vector(Float, {self._vector_length})"
variant_cols_str: str = ("\n" + " " * 21 + ", ").join(
[
f"$1:{self.normalizer.normalize(col)}{vector_suffix if 'embedding' in col else ''}"
for col in columns_list
]
)

copy_statement = dedent(
f"""
COPY INTO {temp_table_name}
(
{columns_list_str}
)
FROM (
SELECT {variant_cols_str}
FROM {internal_sf_stage_name}
)
FILES = ( {files_list} )
FILE_FORMAT = ( TYPE = JSON )
;
"""
)
self._execute_sql(copy_statement)
return temp_table_name

@overrides
def _add_missing_columns_to_table(
self,
stream_name: str,
table_name: str,
) -> None:
"""Use Snowflake Python connector to add new columns to the table"""
columns = self._get_sql_column_definitions(stream_name)
existing_columns = self._get_column_list_from_table(table_name)
for column_name, column_type in columns.items():
if column_name not in existing_columns:
self._add_new_column_to_table(table_name, column_name, column_type)
self._invalidate_table_cache(table_name)
pass

def _add_new_column_to_table(
self,
table_name: str,
column_name: str,
column_type: sqlalchemy.types.TypeEngine,
) -> None:
conn: Connection = self.cache.get_vendor_client()
cursor = conn.cursor()
cursor.execute(
text(
f"ALTER TABLE {self._fully_qualified(table_name)} "
f"ADD COLUMN {column_name} {column_type}"
),
)
cursor.close()
conn.close()
5 changes: 5 additions & 0 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def get_database_name(self) -> str:
"""Return the name of the database."""
...

def get_vendor_client(self) -> object:
"""Alternate (non-SQLAlchemy) way of getting database connection"""
msg = "This method needs to be implemented for specific databases"
raise NotImplementedError(msg)

@final
@property
def streams(
Expand Down
13 changes: 13 additions & 0 deletions airbyte/caches/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

from overrides import overrides
from snowflake import connector
from snowflake.sqlalchemy import URL

from airbyte._processors.sql.base import RecordDedupeMode
Expand Down Expand Up @@ -62,6 +63,18 @@ def get_sql_alchemy_url(self) -> SecretString:
)
)

def get_vendor_client(self) -> object:
"""Return the Snowflake connection object."""
return connector.connect(
user=self.username,
password=self.password,
account=self.account,
warehouse=self.warehouse,
database=self.database,
schema=self.schema_name,
role=self.role,
)

@overrides
def get_database_name(self) -> str:
"""Return the name of the database."""
Expand Down
9 changes: 8 additions & 1 deletion airbyte/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# We include them here for completeness.
"object": sqlalchemy.types.JSON,
"array": sqlalchemy.types.JSON,
"vector_array": sqlalchemy.types.ARRAY,
}


Expand Down Expand Up @@ -82,6 +83,9 @@ def _get_airbyte_type( # noqa: PLR0911 # Too many return statements

return "array", None

if json_schema_type == "vector_array":
return "vector_array", "Float"

err_msg = f"Could not determine airbyte type from JSON schema type: {json_schema_property_def}"
raise SQLTypeConversionError(err_msg)

Expand Down Expand Up @@ -110,13 +114,16 @@ def get_json_type(cls) -> sqlalchemy.types.TypeEngine:
"""Get the type to use for nested JSON data."""
return sqlalchemy.types.JSON()

def to_sql_type(
def to_sql_type( # noqa: PLR0911 # Too many return statements
self,
json_schema_property_def: dict[str, str | dict | list],
) -> sqlalchemy.types.TypeEngine:
"""Convert a value to a SQL type."""
try:
airbyte_type, _ = _get_airbyte_type(json_schema_property_def)
# to-do - is there a better way to check the following
if airbyte_type == "vector_array":
return sqlalchemy.types.ARRAY(sqlalchemy.types.Float())
sql_type = self.conversion_map[airbyte_type]
except SQLTypeConversionError:
print(f"Could not determine airbyte type from JSON schema: {json_schema_property_def}")
Expand Down
Loading

0 comments on commit 51026ee

Please sign in to comment.