Skip to content

Commit

Permalink
update based on review
Browse files Browse the repository at this point in the history
  • Loading branch information
medsriha committed Sep 15, 2024
1 parent abbc2d6 commit 89ac0c9
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 94 deletions.
2 changes: 1 addition & 1 deletion integrations/snowflake/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
## [integrations/snowflake-v0.0.0] - 2024-09-06
## [integrations/snowflake-v0.0.1] - 2024-09-06
4 changes: 2 additions & 2 deletions integrations/snowflake/example/text2sql_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from haystack.components.generators import OpenAIGenerator
from haystack.utils import Secret

from haystack_integrations.components.retrievers.snowflake import SnowflakeRetriever
from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever

load_dotenv()

Expand Down Expand Up @@ -89,7 +89,7 @@
generation_kwargs={"temperature": 0.0, "max_tokens": 2000},
)

snowflake = SnowflakeRetriever(
snowflake = SnowflakeTableRetriever(
user="<ACCOUNT-USER>",
account="<ACCOUNT-IDENTIFIER>",
api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

from .snowflake_retriever import SnowflakeRetriever
from .snowflake_table_retriever import SnowflakeTableRetriever

__all__ = ["SnowflakeRetriever"]
__all__ = ["SnowflakeTableRetriever"]
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@


@component
class SnowflakeRetriever:
class SnowflakeTableRetriever:
"""
Connects to a Snowflake database to execute a SQL query.
For more information, see [Snowflake documentation](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector).
### Usage example:
```python
executor = SnowflakeRetriever(
executor = SnowflakeTableRetriever(
user="<ACCOUNT-USER>",
account="<ACCOUNT-IDENTIFIER>",
api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"),
Expand All @@ -51,12 +51,12 @@ class SnowflakeRetriever:
results = executor.run(query=query)
print(results["dataframe"].head(2))
print(results["dataframe"].head(2)) # Pandas dataframe
# Column 1 Column 2
# 0 Value1 Value2
# 1 Value1 Value2
print(results["table"])
print(results["table"]) # Markdown
# | Column 1 | Column 2 |
# |:----------|:----------|
# | Value1 | Value2 |
Expand Down Expand Up @@ -111,7 +111,7 @@ def to_dict(self) -> Dict[str, Any]:
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeRetriever":
def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeTableRetriever":
"""
Deserializes the component from a dictionary.
Expand Down Expand Up @@ -140,14 +140,13 @@ def _snowflake_connector(connect_params: Dict[str, Any]) -> Union[SnowflakeConne
@staticmethod
def _extract_table_names(query: str) -> list:
"""
Extract table names from a SQL query using regex.
Extract table names from an SQL query using regex.
The extracted table names will be checked for privilege.
:param query: SQL query to extract table names from.
"""

# Regular expressions to match table names in various clauses
suffix = "\\s+([a-zA-Z0-9_.]+)"
suffix = "\\s+([a-zA-Z0-9_.]+)" # Regular expressions to match table names in various clauses

patterns = [
"MERGE\\s+INTO",
Expand All @@ -168,7 +167,7 @@ def _extract_table_names(query: str) -> list:
# Find all matches in the query
matches = re.findall(pattern=combined_pattern, string=query, flags=re.IGNORECASE)

# Flatten list of tuples and remove duplication
# Flatten the list of tuples and remove duplication
matches = list(set(sum(matches, ())))

# Clean and return unique table names
Expand All @@ -177,26 +176,29 @@ def _extract_table_names(query: str) -> list:
@staticmethod
def _execute_sql_query(conn: SnowflakeConnection, query: str) -> pd.DataFrame:
"""
Execute a SQL query and fetch the results.
Execute an SQL query and fetch the results.
:param conn: An open connection to Snowflake.
:param query: The query to execute.
"""
cur = conn.cursor()
try:
cur.execute(query)
# set a limit to MAX_SYS_ROWS rows to avoid fetching too many rows
rows = cur.fetchmany(size=MAX_SYS_ROWS)
# Convert data to a dataframe
df = pd.DataFrame(rows, columns=[desc.name for desc in cur.description])
rows = cur.fetchmany(size=MAX_SYS_ROWS) # set a limit to avoid fetching too many rows

df = pd.DataFrame(rows, columns=[desc.name for desc in cur.description]) # Convert data to a dataframe
return df
except ProgrammingError as e:
logger.warning(
"{error_msg} Use the following ID to check the status of the query in Snowflake UI (ID: {sfqid})",
error_msg=e.msg,
sfqid=e.sfqid,
)
return pd.DataFrame()
except Exception as e:
if isinstance(e, ProgrammingError):
logger.warning(
"{error_msg} Use the following ID to check the status of the query in Snowflake UI (ID: {sfqid})",
error_msg=e.msg,
sfqid=e.sfqid,
)
else:
logger.warning("An unexpected error occurred: {error_msg}", error_msg=e)

return pd.DataFrame()

@staticmethod
def _has_select_privilege(privileges: list, table_name: str) -> bool:
Expand All @@ -213,7 +215,6 @@ def _has_select_privilege(privileges: list, table_name: str) -> bool:
string=privilege[1],
flags=re.IGNORECASE,
):
logger.error("User does not have `Select` privilege on the table.")
return False

return True
Expand Down Expand Up @@ -304,6 +305,8 @@ def _fetch_data(
user=self.user,
):
df = self._execute_sql_query(conn=conn, query=query)
else:
logger.error("User does not have `Select` privilege on the table.")

except Exception as e:
logger.error("An unexpected error has occurred: {error}", error=e)
Expand Down
Loading

0 comments on commit 89ac0c9

Please sign in to comment.