Skip to content

Commit

Permalink
Optimize load to snowflake process
Browse files Browse the repository at this point in the history
While reading replication data from postgres just create s3 files and
do not load them in snowflake. When the replication is finished load all
files for each table at once  - this way will limit import calls to
snowflake abd import bigger batches - which is alwais better in
snowflake.
Also files will not be deleted when they are loaded, but we will keep
them in s3 for a week - cleaned later using lifecycle

Also update the S3 prefix to include hour and minutes - to avoid having
data imported mutiuple times when we have same pid in the same day.
  • Loading branch information
pslavov committed Aug 9, 2023
1 parent 0113bb5 commit df8504b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 42 deletions.
50 changes: 12 additions & 38 deletions target_snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,15 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT
# emit latest state
emit_state(copy.deepcopy(flushed_state))

# After all changes are flushed to S3 files we are trigering imports for each stream
for load_stream_name in stream_to_sync:
LOGGER.info("Start snowflake import from S3 for streams %s", load_stream_name)
tmp_db_sync = stream_to_sync[load_stream_name]
s3_key = f".*pipelinewise_{load_stream_name}_([0-9]{{2}})h([0-9]{{2}})m([0-9]{{2}})s.*[.]csv"
tmp_db_sync.load_file(config['s3_key_prefix'], s3_key, 0, 0)
## Replacing deleting of stage files withe S3 lifecycle
# tmp_db_sync.delete_from_stage(load_stream_name, config['s3_key_prefix'])

# After all changes are imported to snowflake emit state with new flushed_lsn values
if not flushed_state is None and 'bookmarks' in flushed_state:
lsn_list = [get_bookmark(flushed_state, s, 'lsn') for s in flushed_state["bookmarks"] if 'lsn' in flushed_state["bookmarks"][s]]
Expand Down Expand Up @@ -476,48 +485,13 @@ def flush_records(stream: str,

# Get file stats
row_count = len(records)
size_bytes = os.path.getsize(filepath)

# Upload to s3 and load into Snowflake
s3_key = db_sync.put_to_stage(filepath, stream, row_count, temp_dir=temp_dir)
db_sync.load_file(s3_key, row_count, size_bytes)
db_sync.put_to_stage(filepath, stream, row_count, temp_dir=temp_dir)

# Delete file from local disk
os.remove(filepath)

if archive_load_files:
stream_name_parts = stream_utils.stream_name_to_dict(stream)
if 'schema_name' not in stream_name_parts or 'table_name' not in stream_name_parts:
raise Exception(f"Failed to extract schema and table names from stream '{stream}'")

archive_schema = stream_name_parts['schema_name']
archive_table = stream_name_parts['table_name']
archive_tap = archive_load_files['tap']

archive_metadata = {
'tap': archive_tap,
'schema': archive_schema,
'table': archive_table,
'archived-by': 'pipelinewise_target_snowflake'
}

if 'column' in archive_load_files:
archive_metadata.update({
'incremental-key': archive_load_files['column'],
'incremental-key-min': str(archive_load_files['min']),
'incremental-key-max': str(archive_load_files['max'])
})

# Use same file name as in import
archive_file = os.path.basename(s3_key)
archive_key = f"{archive_tap}/{archive_table}/{archive_file}"

db_sync.copy_to_archive(s3_key, archive_key, archive_metadata)

# Delete file from S3
db_sync.delete_from_stage(stream, s3_key)


def main():
"""Main function"""
arg_parser = argparse.ArgumentParser()
Expand All @@ -527,8 +501,8 @@ def main():
if args.config:
with open(args.config, encoding="utf8") as config_input:
config = json.load(config_input)
date_is = datetime.now().strftime("%Y-%m-%d")
config["s3_key_prefix"] = f"{config['s3_key_prefix']}{date_is}_pid_{str(os.getpid())}/"
timestamp_is = datetime.now().strftime("%Y-%m-%d–%Hh-%Mm")
config["s3_key_prefix"] = f"{config['s3_key_prefix']}{timestamp_is}_pid_{str(os.getpid())}/"
else:
config = {}

Expand Down
8 changes: 5 additions & 3 deletions target_snowflake/db_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def get_stage_name(self, stream):
table_name = self.table_name(stream, False, without_schema=True)
return f"{self.schema_name}.%{table_name}"

def load_file(self, s3_key, count, size_bytes):
def load_file(self, s3_prefix, s3_key, count, size_bytes):
"""Load a supported file type from snowflake stage into target table"""
stream = self.stream_schema_message['stream']
self.logger.info("Loading %d rows into '%s'", count, self.table_name(stream, False))
Expand Down Expand Up @@ -474,6 +474,7 @@ def load_file(self, s3_key, count, size_bytes):
try:
inserts, updates = (
self._load_file_copy(
s3_prefix=s3_prefix,
s3_key=s3_key,
stream=stream,
columns_with_trans=columns_with_trans
Expand Down Expand Up @@ -516,19 +517,20 @@ def _load_file_merge(self, s3_key, stream, columns_with_trans) -> Tuple[int, int
updates = results[0].get('number of rows updated', 0)
return inserts, updates

def _load_file_copy(self, s3_key, stream, columns_with_trans) -> int:
def _load_file_copy(self, s3_key, s3_prefix, stream, columns_with_trans) -> int:
# COPY does insert only
inserts = 0
with self.open_connection() as connection:
with connection.cursor(snowflake.connector.DictCursor) as cur:
copy_sql = self.file_format.formatter.create_copy_sql(
table_name=self.table_name(stream, False),
stage_name=self.get_stage_name(stream),
s3_prefix=s3_prefix,
s3_key=s3_key,
file_format_name=self.connection_config['file_format'],
columns=columns_with_trans
)
self.logger.debug('Running query: %s', copy_sql)
self.logger.info('Running query: %s', copy_sql)
cur.execute(copy_sql)
# Get number of inserted records - COPY does insert only
results = cur.fetchall()
Expand Down
4 changes: 3 additions & 1 deletion target_snowflake/file_formats/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@

def create_copy_sql(table_name: str,
stage_name: str,
s3_prefix: str,
s3_key: str,
file_format_name: str,
columns: List):
"""Generate a CSV compatible snowflake COPY INTO command"""
p_columns = ', '.join([c['name'] for c in columns])

return f"COPY INTO {table_name} ({p_columns}) " \
f"FROM '@{stage_name}/{s3_key}' " \
f"FROM '@{stage_name}/{s3_prefix}' " \
f"PATTERN = '{s3_key}' " \
f"ON_ERROR = CONTINUE " \
f"FILE_FORMAT = (format_name='{file_format_name}')"

Expand Down

0 comments on commit df8504b

Please sign in to comment.