Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize load to snowflake process #1

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 12 additions & 38 deletions target_snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,15 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT
# emit latest state
emit_state(copy.deepcopy(flushed_state))

# After all changes are flushed to S3 files we are trigering imports for each stream
for load_stream_name in stream_to_sync:
LOGGER.info("Start snowflake import from S3 for streams %s", load_stream_name)
tmp_db_sync = stream_to_sync[load_stream_name]
s3_key = f".*pipelinewise_{load_stream_name}_([0-9]{{2}})h([0-9]{{2}})m([0-9]{{2}})s.*[.]csv"
tmp_db_sync.load_file(config['s3_key_prefix'], s3_key, 0, 0)
## Replacing deleting of stage files withe S3 lifecycle
# tmp_db_sync.delete_from_stage(load_stream_name, config['s3_key_prefix'])

# After all changes are imported to snowflake emit state with new flushed_lsn values
if not flushed_state is None and 'bookmarks' in flushed_state:
lsn_list = [get_bookmark(flushed_state, s, 'lsn') for s in flushed_state["bookmarks"] if 'lsn' in flushed_state["bookmarks"][s]]
Expand Down Expand Up @@ -476,48 +485,13 @@ def flush_records(stream: str,

# Get file stats
row_count = len(records)
size_bytes = os.path.getsize(filepath)

# Upload to s3 and load into Snowflake
s3_key = db_sync.put_to_stage(filepath, stream, row_count, temp_dir=temp_dir)
db_sync.load_file(s3_key, row_count, size_bytes)
db_sync.put_to_stage(filepath, stream, row_count, temp_dir=temp_dir)

# Delete file from local disk
os.remove(filepath)

if archive_load_files:
stream_name_parts = stream_utils.stream_name_to_dict(stream)
if 'schema_name' not in stream_name_parts or 'table_name' not in stream_name_parts:
raise Exception(f"Failed to extract schema and table names from stream '{stream}'")

archive_schema = stream_name_parts['schema_name']
archive_table = stream_name_parts['table_name']
archive_tap = archive_load_files['tap']

archive_metadata = {
'tap': archive_tap,
'schema': archive_schema,
'table': archive_table,
'archived-by': 'pipelinewise_target_snowflake'
}

if 'column' in archive_load_files:
archive_metadata.update({
'incremental-key': archive_load_files['column'],
'incremental-key-min': str(archive_load_files['min']),
'incremental-key-max': str(archive_load_files['max'])
})

# Use same file name as in import
archive_file = os.path.basename(s3_key)
archive_key = f"{archive_tap}/{archive_table}/{archive_file}"

db_sync.copy_to_archive(s3_key, archive_key, archive_metadata)

# Delete file from S3
db_sync.delete_from_stage(stream, s3_key)


def main():
"""Main function"""
arg_parser = argparse.ArgumentParser()
Expand All @@ -527,8 +501,8 @@ def main():
if args.config:
with open(args.config, encoding="utf8") as config_input:
config = json.load(config_input)
date_is = datetime.now().strftime("%Y-%m-%d")
config["s3_key_prefix"] = f"{config['s3_key_prefix']}{date_is}_pid_{str(os.getpid())}/"
timestamp_is = datetime.now().strftime("%Y-%m-%d–%Hh-%Mm")
config["s3_key_prefix"] = f"{config['s3_key_prefix']}{timestamp_is}_pid_{str(os.getpid())}/"
else:
config = {}

Expand Down
14 changes: 9 additions & 5 deletions target_snowflake/db_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def get_stage_name(self, stream):
table_name = self.table_name(stream, False, without_schema=True)
return f"{self.schema_name}.%{table_name}"

def load_file(self, s3_key, count, size_bytes):
def load_file(self, s3_prefix, s3_key, count, size_bytes):
"""Load a supported file type from snowflake stage into target table"""
stream = self.stream_schema_message['stream']
self.logger.info("Loading %d rows into '%s'", count, self.table_name(stream, False))
Expand All @@ -459,6 +459,7 @@ def load_file(self, s3_key, count, size_bytes):
try:
inserts, updates = self._load_file_merge(
s3_key=s3_key,
s3_prefix=s3_prefix,
stream=stream,
columns_with_trans=columns_with_trans
)
Expand All @@ -474,6 +475,7 @@ def load_file(self, s3_key, count, size_bytes):
try:
inserts, updates = (
self._load_file_copy(
s3_prefix=s3_prefix,
s3_key=s3_key,
stream=stream,
columns_with_trans=columns_with_trans
Expand All @@ -493,7 +495,7 @@ def load_file(self, s3_key, count, size_bytes):
json.dumps({'inserts': inserts, 'updates': updates, 'size_bytes': size_bytes})
)

def _load_file_merge(self, s3_key, stream, columns_with_trans) -> Tuple[int, int]:
def _load_file_merge(self, s3_key, s3_prefix, stream, columns_with_trans) -> Tuple[int, int]:
# MERGE does insert and update
inserts = 0
updates = 0
Expand All @@ -502,12 +504,13 @@ def _load_file_merge(self, s3_key, stream, columns_with_trans) -> Tuple[int, int
merge_sql = self.file_format.formatter.create_merge_sql(
table_name=self.table_name(stream, False),
stage_name=self.get_stage_name(stream),
s3_prefix=s3_prefix,
s3_key=s3_key,
file_format_name=self.connection_config['file_format'],
columns=columns_with_trans,
pk_merge_condition=self.primary_key_merge_condition()
)
self.logger.debug('Running query: %s', merge_sql)
self.logger.info('Running query: %s', merge_sql)
cur.execute(merge_sql)
# Get number of inserted and updated records
results = cur.fetchall()
Expand All @@ -516,19 +519,20 @@ def _load_file_merge(self, s3_key, stream, columns_with_trans) -> Tuple[int, int
updates = results[0].get('number of rows updated', 0)
return inserts, updates

def _load_file_copy(self, s3_key, stream, columns_with_trans) -> int:
def _load_file_copy(self, s3_key, s3_prefix, stream, columns_with_trans) -> int:
# COPY does insert only
inserts = 0
with self.open_connection() as connection:
with connection.cursor(snowflake.connector.DictCursor) as cur:
copy_sql = self.file_format.formatter.create_copy_sql(
table_name=self.table_name(stream, False),
stage_name=self.get_stage_name(stream),
s3_prefix=s3_prefix,
s3_key=s3_key,
file_format_name=self.connection_config['file_format'],
columns=columns_with_trans
)
self.logger.debug('Running query: %s', copy_sql)
self.logger.info('Running query: %s', copy_sql)
cur.execute(copy_sql)
# Get number of inserted records - COPY does insert only
results = cur.fetchall()
Expand Down
9 changes: 6 additions & 3 deletions target_snowflake/file_formats/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,23 @@

def create_copy_sql(table_name: str,
stage_name: str,
s3_prefix: str,
s3_key: str,
file_format_name: str,
columns: List):
"""Generate a CSV compatible snowflake COPY INTO command"""
p_columns = ', '.join([c['name'] for c in columns])

return f"COPY INTO {table_name} ({p_columns}) " \
f"FROM '@{stage_name}/{s3_key}' " \
f"FROM '@{stage_name}/{s3_prefix}' " \
f"PATTERN = '{s3_key}' " \
f"ON_ERROR = CONTINUE " \
f"FILE_FORMAT = (format_name='{file_format_name}')"


def create_merge_sql(table_name: str,
stage_name: str,
s3_prefix: str,
s3_key: str,
file_format_name: str,
columns: List,
Expand All @@ -37,8 +40,8 @@ def create_merge_sql(table_name: str,

return f"MERGE INTO {table_name} t USING (" \
f"SELECT {p_source_columns} " \
f"FROM '@{stage_name}/{s3_key}' " \
f"(FILE_FORMAT => '{file_format_name}')) s " \
f"FROM '@{stage_name}/{s3_prefix}' " \
f"(FILE_FORMAT => '{file_format_name}'), PATTERN => '{s3_key}') s " \
f"ON {pk_merge_condition} " \
f"WHEN MATCHED THEN UPDATE SET {p_update} " \
"WHEN NOT MATCHED THEN " \
Expand Down
Loading