Skip to content

Commit

Permalink
Merge branch 'snowflake-imports-in-only-after-all-changes-are-read'
Browse files Browse the repository at this point in the history
  • Loading branch information
pslavov committed Aug 10, 2023
2 parents 0113bb5 + 32f32a4 commit ad3d068
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 46 deletions.
50 changes: 12 additions & 38 deletions target_snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,15 @@ def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatT
# emit latest state
emit_state(copy.deepcopy(flushed_state))

# After all changes are flushed to S3 files we are trigering imports for each stream
for load_stream_name in stream_to_sync:
LOGGER.info("Start snowflake import from S3 for streams %s", load_stream_name)
tmp_db_sync = stream_to_sync[load_stream_name]
s3_key = f".*pipelinewise_{load_stream_name}_([0-9]{{2}})h([0-9]{{2}})m([0-9]{{2}})s.*[.]csv"
tmp_db_sync.load_file(config['s3_key_prefix'], s3_key, 0, 0)
## Replacing deleting of stage files withe S3 lifecycle
# tmp_db_sync.delete_from_stage(load_stream_name, config['s3_key_prefix'])

# After all changes are imported to snowflake emit state with new flushed_lsn values
if not flushed_state is None and 'bookmarks' in flushed_state:
lsn_list = [get_bookmark(flushed_state, s, 'lsn') for s in flushed_state["bookmarks"] if 'lsn' in flushed_state["bookmarks"][s]]
Expand Down Expand Up @@ -476,48 +485,13 @@ def flush_records(stream: str,

# Get file stats
row_count = len(records)
size_bytes = os.path.getsize(filepath)

# Upload to s3 and load into Snowflake
s3_key = db_sync.put_to_stage(filepath, stream, row_count, temp_dir=temp_dir)
db_sync.load_file(s3_key, row_count, size_bytes)
db_sync.put_to_stage(filepath, stream, row_count, temp_dir=temp_dir)

# Delete file from local disk
os.remove(filepath)

if archive_load_files:
stream_name_parts = stream_utils.stream_name_to_dict(stream)
if 'schema_name' not in stream_name_parts or 'table_name' not in stream_name_parts:
raise Exception(f"Failed to extract schema and table names from stream '{stream}'")

archive_schema = stream_name_parts['schema_name']
archive_table = stream_name_parts['table_name']
archive_tap = archive_load_files['tap']

archive_metadata = {
'tap': archive_tap,
'schema': archive_schema,
'table': archive_table,
'archived-by': 'pipelinewise_target_snowflake'
}

if 'column' in archive_load_files:
archive_metadata.update({
'incremental-key': archive_load_files['column'],
'incremental-key-min': str(archive_load_files['min']),
'incremental-key-max': str(archive_load_files['max'])
})

# Use same file name as in import
archive_file = os.path.basename(s3_key)
archive_key = f"{archive_tap}/{archive_table}/{archive_file}"

db_sync.copy_to_archive(s3_key, archive_key, archive_metadata)

# Delete file from S3
db_sync.delete_from_stage(stream, s3_key)


def main():
"""Main function"""
arg_parser = argparse.ArgumentParser()
Expand All @@ -527,8 +501,8 @@ def main():
if args.config:
with open(args.config, encoding="utf8") as config_input:
config = json.load(config_input)
date_is = datetime.now().strftime("%Y-%m-%d")
config["s3_key_prefix"] = f"{config['s3_key_prefix']}{date_is}_pid_{str(os.getpid())}/"
timestamp_is = datetime.now().strftime("%Y-%m-%d–%Hh-%Mm")
config["s3_key_prefix"] = f"{config['s3_key_prefix']}{timestamp_is}_pid_{str(os.getpid())}/"
else:
config = {}

Expand Down
14 changes: 9 additions & 5 deletions target_snowflake/db_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def get_stage_name(self, stream):
table_name = self.table_name(stream, False, without_schema=True)
return f"{self.schema_name}.%{table_name}"

def load_file(self, s3_key, count, size_bytes):
def load_file(self, s3_prefix, s3_key, count, size_bytes):
"""Load a supported file type from snowflake stage into target table"""
stream = self.stream_schema_message['stream']
self.logger.info("Loading %d rows into '%s'", count, self.table_name(stream, False))
Expand All @@ -459,6 +459,7 @@ def load_file(self, s3_key, count, size_bytes):
try:
inserts, updates = self._load_file_merge(
s3_key=s3_key,
s3_prefix=s3_prefix,
stream=stream,
columns_with_trans=columns_with_trans
)
Expand All @@ -474,6 +475,7 @@ def load_file(self, s3_key, count, size_bytes):
try:
inserts, updates = (
self._load_file_copy(
s3_prefix=s3_prefix,
s3_key=s3_key,
stream=stream,
columns_with_trans=columns_with_trans
Expand All @@ -493,7 +495,7 @@ def load_file(self, s3_key, count, size_bytes):
json.dumps({'inserts': inserts, 'updates': updates, 'size_bytes': size_bytes})
)

def _load_file_merge(self, s3_key, stream, columns_with_trans) -> Tuple[int, int]:
def _load_file_merge(self, s3_key, s3_prefix, stream, columns_with_trans) -> Tuple[int, int]:
# MERGE does insert and update
inserts = 0
updates = 0
Expand All @@ -502,12 +504,13 @@ def _load_file_merge(self, s3_key, stream, columns_with_trans) -> Tuple[int, int
merge_sql = self.file_format.formatter.create_merge_sql(
table_name=self.table_name(stream, False),
stage_name=self.get_stage_name(stream),
s3_prefix=s3_prefix,
s3_key=s3_key,
file_format_name=self.connection_config['file_format'],
columns=columns_with_trans,
pk_merge_condition=self.primary_key_merge_condition()
)
self.logger.debug('Running query: %s', merge_sql)
self.logger.info('Running query: %s', merge_sql)
cur.execute(merge_sql)
# Get number of inserted and updated records
results = cur.fetchall()
Expand All @@ -516,19 +519,20 @@ def _load_file_merge(self, s3_key, stream, columns_with_trans) -> Tuple[int, int
updates = results[0].get('number of rows updated', 0)
return inserts, updates

def _load_file_copy(self, s3_key, stream, columns_with_trans) -> int:
def _load_file_copy(self, s3_key, s3_prefix, stream, columns_with_trans) -> int:
# COPY does insert only
inserts = 0
with self.open_connection() as connection:
with connection.cursor(snowflake.connector.DictCursor) as cur:
copy_sql = self.file_format.formatter.create_copy_sql(
table_name=self.table_name(stream, False),
stage_name=self.get_stage_name(stream),
s3_prefix=s3_prefix,
s3_key=s3_key,
file_format_name=self.connection_config['file_format'],
columns=columns_with_trans
)
self.logger.debug('Running query: %s', copy_sql)
self.logger.info('Running query: %s', copy_sql)
cur.execute(copy_sql)
# Get number of inserted records - COPY does insert only
results = cur.fetchall()
Expand Down
9 changes: 6 additions & 3 deletions target_snowflake/file_formats/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,23 @@

def create_copy_sql(table_name: str,
stage_name: str,
s3_prefix: str,
s3_key: str,
file_format_name: str,
columns: List):
"""Generate a CSV compatible snowflake COPY INTO command"""
p_columns = ', '.join([c['name'] for c in columns])

return f"COPY INTO {table_name} ({p_columns}) " \
f"FROM '@{stage_name}/{s3_key}' " \
f"FROM '@{stage_name}/{s3_prefix}' " \
f"PATTERN = '{s3_key}' " \
f"ON_ERROR = CONTINUE " \
f"FILE_FORMAT = (format_name='{file_format_name}')"


def create_merge_sql(table_name: str,
stage_name: str,
s3_prefix: str,
s3_key: str,
file_format_name: str,
columns: List,
Expand All @@ -37,8 +40,8 @@ def create_merge_sql(table_name: str,

return f"MERGE INTO {table_name} t USING (" \
f"SELECT {p_source_columns} " \
f"FROM '@{stage_name}/{s3_key}' " \
f"(FILE_FORMAT => '{file_format_name}')) s " \
f"FROM '@{stage_name}/{s3_prefix}' " \
f"(FILE_FORMAT => '{file_format_name}'), PATTERN => '{s3_key}') s " \
f"ON {pk_merge_condition} " \
f"WHEN MATCHED THEN UPDATE SET {p_update} " \
"WHEN NOT MATCHED THEN " \
Expand Down

0 comments on commit ad3d068

Please sign in to comment.