Skip to content
This repository has been archived by the owner on May 1, 2024. It is now read-only.

Commit

Permalink
Tasks for loading GA data into Snowflake (PART 1)
Browse files Browse the repository at this point in the history
This is part 1 of the GA loading pipeline which DOES NOT depend on a
Luigi upgrade.

DE-1374 (PART 1)
  • Loading branch information
pwnage101 committed Apr 19, 2019
1 parent 6db66ef commit 35301c3
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 71 deletions.
6 changes: 3 additions & 3 deletions edx/analytics/tasks/common/bigquery_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(self, credentials_target, dataset_id, table, update_id):
self.update_id = update_id
with credentials_target.open('r') as credentials_file:
json_creds = json.load(credentials_file)
self.project_id = json_creds['project_id']
credentials = service_account.Credentials.from_service_account_info(json_creds)
self.client = bigquery.Client(credentials=credentials, project=self.project_id)
self.project_id = json_creds['project_id']
credentials = service_account.Credentials.from_service_account_info(json_creds)
self.client = bigquery.Client(credentials=credentials, project=self.project_id)

def touch(self):
self.create_marker_table()
Expand Down
169 changes: 125 additions & 44 deletions edx/analytics/tasks/common/snowflake_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from snowflake.connector import ProgrammingError

from edx.analytics.tasks.util.overwrite import OverwriteOutputMixin
from edx.analytics.tasks.util.s3_util import canonicalize_s3_url
from edx.analytics.tasks.util.url import ExternalURL

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -90,9 +91,22 @@ def touch(self, connection):
self.create_marker_table()

connection.cursor().execute(
"""INSERT INTO {database}.{schema}.{marker_table} (update_id, target_table)
VALUES (%s, %s)""".format(database=self.database, schema=self.schema, marker_table=self.marker_table),
(self.update_id, "{database}.{schema}.{table}".format(database=self.database, schema=self.schema, table=self.table))
"""
INSERT INTO {database}.{schema}.{marker_table} (update_id, target_table)
VALUES (%s, %s)
""".format(
database=self.database,
schema=self.schema,
marker_table=self.marker_table,
),
(
self.update_id,
"{database}.{schema}.{table}".format(
database=self.database,
schema=self.schema,
table=self.table,
),
)
)

# make sure update is properly marked
Expand All @@ -112,8 +126,17 @@ def exists(self, connection=None):
return False

cursor = connection.cursor()
query = "SELECT 1 FROM {database}.{schema}.{marker_table} WHERE update_id='{update_id}' AND target_table='{database}.{schema}.{table}'".format(
database=self.database, schema=self.schema, marker_table=self.marker_table, update_id=self.update_id, table=self.table)
query = """
SELECT 1
FROM {database}.{schema}.{marker_table}
WHERE update_id='{update_id}' AND target_table='{database}.{schema}.{table}'
""".format(
database=self.database,
schema=self.schema,
marker_table=self.marker_table,
update_id=self.update_id,
table=self.table,
)
log.debug(query)
cursor.execute(query)
row = cursor.fetchone()
Expand All @@ -135,8 +158,8 @@ def marker_table_exists(self, connection):
schema=self.schema,
))
row = cursor.fetchone()
except ProgrammingError as e:
if "does not exist" in e.msg:
except ProgrammingError as err:
if "does not exist" in err.msg:
# If so then the query failed because the database or schema doesn't exist.
row = None
else:
Expand Down Expand Up @@ -169,8 +192,14 @@ def clear_marker_table(self, connection):
Delete all markers related to this table update.
"""
if self.marker_table_exists(connection):
query = "DELETE FROM {database}.{schema}.{marker_table} where target_table='{database}.{schema}.{table}'".format(
database=self.database, schema=self.schema, marker_table=self.marker_table, table=self.table,
query = """
DELETE FROM {database}.{schema}.{marker_table}
WHERE target_table='{database}.{schema}.{table}'
""".format(
database=self.database,
schema=self.schema,
marker_table=self.marker_table,
table=self.table,
)
connection.cursor().execute(query)

Expand Down Expand Up @@ -233,20 +262,6 @@ def columns(self):
def file_format_name(self):
raise NotImplementedError

@property
def field_delimiter(self):
"""
The delimiter in the data to be copied. Default is tab (\t).
"""
return "\t"

@property
def null_marker(self):
"""
The null sequence in the data to be copied. Default is Hive NULL (\\N).
"""
return r'\\N'

@property
def pattern(self):
"""
Expand All @@ -260,7 +275,11 @@ def create_database(self, connection):

def create_schema(self, connection):
cursor = connection.cursor()
cursor.execute("CREATE SCHEMA IF NOT EXISTS {database}.{schema}".format(database=self.database, schema=self.schema))
cursor.execute(
"CREATE SCHEMA IF NOT EXISTS {database}.{schema}".format(
database=self.database, schema=self.schema,
)
)

def create_table(self, connection):
coldefs = ','.join(
Expand All @@ -273,29 +292,15 @@ def create_table(self, connection):

def create_format(self, connection):
"""
Creates a named file format used for bulk loading data into Snowflake tables.
Invoke Snowflake's CREATE FILE FORMAT statement to create the named file format which
configures the loading.
The resulting file format name should be: {self.database}.{self.schema}.{self.file_format_name}
"""
query = """
CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'CSV' COMPRESSION = 'AUTO' FIELD_DELIMITER = '{field_delimiter}'
FIELD_OPTIONALLY_ENCLOSED_BY = 'NONE' ERROR_ON_COLUMN_COUNT_MISMATCH = TRUE
EMPTY_FIELD_AS_NULL = FALSE ESCAPE_UNENCLOSED_FIELD = 'NONE'
NULL_IF = ('{null_marker}')
""".format(
database=self.database,
schema=self.schema,
file_format_name=self.file_format_name,
field_delimiter=self.field_delimiter,
null_marker=self.null_marker,
)
log.debug(query)
connection.cursor().execute(query)
raise NotImplementedError

def create_stage(self, connection):
"""
Creates a named external stage to use for loading data into Snowflake.
"""
stage_url = self.input()['insert_source_task'].path
stage_url = canonicalize_s3_url(self.input()['insert_source_task'].path)
query = """
CREATE OR REPLACE STAGE {database}.{schema}.{table}_stage
URL = '{stage_url}'
Expand Down Expand Up @@ -377,3 +382,79 @@ def output(self):

def update_id(self):
return '{task_name}(date={key})'.format(task_name=self.task_family, key=self.date.isoformat())


class SnowflakeLoadCSVTask(SnowflakeLoadTask): # pylint: disable=abstract-method
"""
Abstract Task for loading CSV data from s3 into a table in Snowflake.
Implementations should define the following properties:
- self.insert_source_task
- self.table
- self.columns
- self.file_format_name
"""

@property
def field_delimiter(self):
"""
The delimiter in the data to be copied. Default is tab (\t).
"""
return "\t"

@property
def null_marker(self):
"""
The null sequence in the data to be copied. Default is Hive NULL (\\N).
"""
return r'\\N'

def create_format(self, connection):
query = """
CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'CSV' COMPRESSION = 'AUTO' FIELD_DELIMITER = '{field_delimiter}'
FIELD_OPTIONALLY_ENCLOSED_BY = 'NONE' ERROR_ON_COLUMN_COUNT_MISMATCH = TRUE
EMPTY_FIELD_AS_NULL = FALSE ESCAPE_UNENCLOSED_FIELD = 'NONE'
NULL_IF = ('{null_marker}')
""".format(
database=self.database,
schema=self.schema,
file_format_name=self.file_format_name,
field_delimiter=self.field_delimiter,
null_marker=self.null_marker,
)
log.debug(query)
connection.cursor().execute(query)


class SnowflakeLoadJSONTask(SnowflakeLoadTask): # pylint: disable=abstract-method
"""
Abstract Task for loading JSON data from s3 into a table in Snowflake. The resulting table will
contain a single VARIANT column called raw_json.
Implementations should define the following properties:
- self.insert_source_task
- self.table
- self.file_format_name
"""

@property
def columns(self):
return [
('raw_json', 'VARIANT'),
]

def create_format(self, connection):
query = """
CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'JSON'
COMPRESSION = 'AUTO'
""".format(
database=self.database,
schema=self.schema,
file_format_name=self.file_format_name,
)
log.debug(query)
connection.cursor().execute(query)
44 changes: 33 additions & 11 deletions edx/analytics/tasks/util/s3_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import time
from fnmatch import fnmatch
from urlparse import urlparse
from urlparse import urlparse, urlunparse

from luigi.contrib.hdfs.format import Plain
from luigi.contrib.hdfs.target import HdfsTarget
Expand Down Expand Up @@ -125,19 +125,17 @@ def func(name):
class ScalableS3Client(S3Client):
"""
S3 client that adds support for defaulting host name.
"""
# TODO: Make this behavior configurable and submit this change upstream.
def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, **kwargs):
DEPRECATED: Just specify `host` in the `[s3]` configuration section, e.g.:
if not aws_access_key_id:
aws_access_key_id = self._get_s3_config('aws_access_key_id')
if not aws_secret_access_key:
aws_secret_access_key = self._get_s3_config('aws_secret_access_key')
if 'host' not in kwargs:
kwargs['host'] = self._get_s3_config('host') or 's3.amazonaws.com'
[s3]
host = s3.amazonaws.com
super(ScalableS3Client, self).__init__(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, **kwargs)
NOTE: In future versions of Luigi, we must NOT pass `host` to the s3 client
or else it will throw a KeyError. boto3 will already default to
s3.amazonaws.com.
"""
pass


class S3HdfsTarget(HdfsTarget):
Expand All @@ -162,3 +160,27 @@ def open(self, mode='r'):
if not hasattr(self, 's3_client'):
self.s3_client = ScalableS3Client()
return AtomicS3File(safe_path, self.s3_client, policy=DEFAULT_KEY_ACCESS_POLICY)


def canonicalize_s3_url(url):
"""
Convert the given s3 URL into a form which is safe to use with external tools.
Specifically, URL Schemes such as "s3+https" are urecognized by gsutil and Snowflake, and must
be converted to "s3".
Args:
url (str): An s3 URL.
Raises:
ValueError: if the scheme of the input url is unrecognized.
"""
parsed_url = urlparse(url)
if parsed_url.scheme == 's3':
canonical_url = url # Simple passthrough, no change needed.
if parsed_url.scheme == 's3+https':
new_url_parts = parsed_url._replace(scheme='s3')
canonical_url = urlunparse(new_url_parts)
else:
raise ValueError('The S3 URL scheme "{}" is unrecognized.'.format(parsed_url.scheme))
return canonical_url
7 changes: 5 additions & 2 deletions edx/analytics/tasks/util/tests/test_s3_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Tests for S3-related utility functionality."""
"""
Tests for S3-related utility functionality.
"""
from __future__ import print_function

from unittest import TestCase

Expand All @@ -25,7 +28,7 @@ def _make_s3_generator(self, bucket_name, root, path_info, patterns):
target_list = [self._make_key("{root}/{path}".format(root=root, path=path), size)
for path, size in path_info.iteritems()]
s3_bucket.list = MagicMock(return_value=target_list)
print [(k.key, k.size) for k in target_list]
print([(k.key, k.size) for k in target_list])

s3_bucket.name = bucket_name
source = "s3://{bucket}/{root}".format(bucket=bucket_name, root=root)
Expand Down
Loading

0 comments on commit 35301c3

Please sign in to comment.