Skip to content

Commit

Permalink
refactor to share column selections
Browse files Browse the repository at this point in the history
  • Loading branch information
pnadolny13 committed Sep 12, 2023
1 parent 288d6cd commit b9d91ca
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions target_snowflake/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,19 @@ def _get_put_statement(self, sync_id: str, file_uri: str) -> Tuple[text, dict]:
"""Get Snowflake PUT statement."""
return (text(f"put :file_uri '@~/target-snowflake/{sync_id}'"), {})

@staticmethod
def _format_column_selections(column_selections: dict, format: str) -> str:
if format == "json_casting":
return ', '.join(
[
f"$1:{col['clean_property_name']}::{col['sql_type']} as {col['clean_alias']}" for col in column_selections
]
)
elif format == "col_alias":
return f"({', '.join([col['clean_alias'] for col in column_selections])})"
else:
raise NotImplementedError(f"Column format not implemented: {format}")

def _get_column_selections(self, schema: dict, formatter: SnowflakeIdentifierPreparer) -> list:
column_selections = []
for property_name, property_def in schema["properties"].items():
Expand All @@ -306,7 +319,11 @@ def _get_column_selections(self, schema: dict, formatter: SnowflakeIdentifierPre
if '"' in clean_property_name:
clean_alias = clean_property_name.upper()
column_selections.append(
f"$1:{clean_property_name}::{self.to_sql_type(property_def)} as {clean_alias}"
{
"clean_property_name": clean_property_name,
"sql_type": self.to_sql_type(property_def),
"clean_alias": clean_alias,
}
)
return column_selections

Expand All @@ -317,6 +334,7 @@ def _get_merge_from_stage_statement(

formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
column_selections = self._get_column_selections(schema, formatter)
json_casting_selects = self._format_column_selections(column_selections, "json_casting")

# use UPPER from here onwards
formatted_properties = [formatter.format_collation(col) for col in schema["properties"].keys()]
Expand All @@ -336,7 +354,7 @@ def _get_merge_from_stage_statement(
return (
text(
f"merge into {full_table_name} d using "
+ f"(select {', '.join(column_selections)} from '@~/target-snowflake/{sync_id}'"
+ f"(select {json_casting_selects} from '@~/target-snowflake/{sync_id}'"
+ f"(file_format => {file_format}) {dedup}) s "
+ f"on {join_expr} "
+ f"when matched then update set {matched_clause} "
Expand All @@ -349,12 +367,13 @@ def _get_merge_from_stage_statement(
def _get_copy_statement(self, full_table_name, schema, sync_id, file_format):
"""Get Snowflake COPY statement."""
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
formatted_properties = ", ".join([formatter.format_collation(col) for col in schema["properties"].keys()])
column_selections = self._get_column_selections(schema, formatter)
json_casting_selects = self._format_column_selections(column_selections, "json_casting")
col_alias_selects = self._format_column_selections(column_selections, "col_alias")
return (
text(
f"copy into {full_table_name} ({formatted_properties}) from "
+ f"(select {', '.join(column_selections)} from "
f"copy into {full_table_name} {col_alias_selects} from "
+ f"(select {json_casting_selects} from "
+ f"'@~/target-snowflake/{sync_id}')"
+ f"file_format = (format_name='{file_format}')"
),
Expand Down

0 comments on commit b9d91ca

Please sign in to comment.