diff --git a/edx/analytics/tasks/common/bigquery_load.py b/edx/analytics/tasks/common/bigquery_load.py index fd39ec80ec..6c36ab0850 100644 --- a/edx/analytics/tasks/common/bigquery_load.py +++ b/edx/analytics/tasks/common/bigquery_load.py @@ -53,9 +53,9 @@ def __init__(self, credentials_target, dataset_id, table, update_id): self.update_id = update_id with credentials_target.open('r') as credentials_file: json_creds = json.load(credentials_file) - self.project_id = json_creds['project_id'] - credentials = service_account.Credentials.from_service_account_info(json_creds) - self.client = bigquery.Client(credentials=credentials, project=self.project_id) + self.project_id = json_creds['project_id'] + credentials = service_account.Credentials.from_service_account_info(json_creds) + self.client = bigquery.Client(credentials=credentials, project=self.project_id) def touch(self): self.create_marker_table() diff --git a/edx/analytics/tasks/common/snowflake_load.py b/edx/analytics/tasks/common/snowflake_load.py index 1cf1c14d3a..40771df9de 100644 --- a/edx/analytics/tasks/common/snowflake_load.py +++ b/edx/analytics/tasks/common/snowflake_load.py @@ -11,6 +11,7 @@ from snowflake.connector import ProgrammingError from edx.analytics.tasks.util.overwrite import OverwriteOutputMixin +from edx.analytics.tasks.util.s3_util import canonicalize_s3_url from edx.analytics.tasks.util.url import ExternalURL log = logging.getLogger(__name__) @@ -90,9 +91,22 @@ def touch(self, connection): self.create_marker_table() connection.cursor().execute( - """INSERT INTO {database}.{schema}.{marker_table} (update_id, target_table) - VALUES (%s, %s)""".format(database=self.database, schema=self.schema, marker_table=self.marker_table), - (self.update_id, "{database}.{schema}.{table}".format(database=self.database, schema=self.schema, table=self.table)) + """ + INSERT INTO {database}.{schema}.{marker_table} (update_id, target_table) + VALUES (%s, %s) + """.format( + database=self.database, + schema=self.schema, + marker_table=self.marker_table, + ), + ( + self.update_id, + "{database}.{schema}.{table}".format( + database=self.database, + schema=self.schema, + table=self.table, + ), + ) ) # make sure update is properly marked @@ -112,8 +126,17 @@ def exists(self, connection=None): return False cursor = connection.cursor() - query = "SELECT 1 FROM {database}.{schema}.{marker_table} WHERE update_id='{update_id}' AND target_table='{database}.{schema}.{table}'".format( - database=self.database, schema=self.schema, marker_table=self.marker_table, update_id=self.update_id, table=self.table) + query = """ + SELECT 1 + FROM {database}.{schema}.{marker_table} + WHERE update_id='{update_id}' AND target_table='{database}.{schema}.{table}' + """.format( + database=self.database, + schema=self.schema, + marker_table=self.marker_table, + update_id=self.update_id, + table=self.table, + ) log.debug(query) cursor.execute(query) row = cursor.fetchone() @@ -135,8 +158,8 @@ def marker_table_exists(self, connection): schema=self.schema, )) row = cursor.fetchone() - except ProgrammingError as e: - if "does not exist" in e.msg: + except ProgrammingError as err: + if "does not exist" in err.msg: # If so then the query failed because the database or schema doesn't exist. row = None else: @@ -169,8 +192,14 @@ def clear_marker_table(self, connection): Delete all markers related to this table update. """ if self.marker_table_exists(connection): - query = "DELETE FROM {database}.{schema}.{marker_table} where target_table='{database}.{schema}.{table}'".format( - database=self.database, schema=self.schema, marker_table=self.marker_table, table=self.table, + query = """ + DELETE FROM {database}.{schema}.{marker_table} + WHERE target_table='{database}.{schema}.{table}' + """.format( + database=self.database, + schema=self.schema, + marker_table=self.marker_table, + table=self.table, ) connection.cursor().execute(query) @@ -233,20 +262,6 @@ def columns(self): def file_format_name(self): raise NotImplementedError - @property - def field_delimiter(self): - """ - The delimiter in the data to be copied. Default is tab (\t). - """ - return "\t" - - @property - def null_marker(self): - """ - The null sequence in the data to be copied. Default is Hive NULL (\\N). - """ - return r'\\N' - @property def pattern(self): """ @@ -260,7 +275,11 @@ def create_database(self, connection): def create_schema(self, connection): cursor = connection.cursor() - cursor.execute("CREATE SCHEMA IF NOT EXISTS {database}.{schema}".format(database=self.database, schema=self.schema)) + cursor.execute( + "CREATE SCHEMA IF NOT EXISTS {database}.{schema}".format( + database=self.database, schema=self.schema, + ) + ) def create_table(self, connection): coldefs = ','.join( @@ -273,29 +292,15 @@ def create_table(self, connection): def create_format(self, connection): """ - Creates a named file format used for bulk loading data into Snowflake tables. + Invoke Snowflake's CREATE FILE FORMAT statement to create the named file format which + configures the loading. + + The resulting file format name should be: {self.database}.{self.schema}.{self.file_format_name} """ - query = """ - CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name} - TYPE = 'CSV' COMPRESSION = 'AUTO' FIELD_DELIMITER = '{field_delimiter}' - FIELD_OPTIONALLY_ENCLOSED_BY = 'NONE' ERROR_ON_COLUMN_COUNT_MISMATCH = TRUE - EMPTY_FIELD_AS_NULL = FALSE ESCAPE_UNENCLOSED_FIELD = 'NONE' - NULL_IF = ('{null_marker}') - """.format( - database=self.database, - schema=self.schema, - file_format_name=self.file_format_name, - field_delimiter=self.field_delimiter, - null_marker=self.null_marker, - ) - log.debug(query) - connection.cursor().execute(query) + raise NotImplementedError def create_stage(self, connection): - """ - Creates a named external stage to use for loading data into Snowflake. - """ - stage_url = self.input()['insert_source_task'].path + stage_url = canonicalize_s3_url(self.input()['insert_source_task'].path) query = """ CREATE OR REPLACE STAGE {database}.{schema}.{table}_stage URL = '{stage_url}' @@ -377,3 +382,79 @@ def output(self): def update_id(self): return '{task_name}(date={key})'.format(task_name=self.task_family, key=self.date.isoformat()) + + +class SnowflakeLoadCSVTask(SnowflakeLoadTask): # pylint: disable=abstract-method + """ + Abstract Task for loading CSV data from s3 into a table in Snowflake. + + Implementations should define the following properties: + + - self.insert_source_task + - self.table + - self.columns + - self.file_format_name + """ + + @property + def field_delimiter(self): + """ + The delimiter in the data to be copied. Default is tab (\t). + """ + return "\t" + + @property + def null_marker(self): + """ + The null sequence in the data to be copied. Default is Hive NULL (\\N). + """ + return r'\\N' + + def create_format(self, connection): + query = """ + CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name} + TYPE = 'CSV' COMPRESSION = 'AUTO' FIELD_DELIMITER = '{field_delimiter}' + FIELD_OPTIONALLY_ENCLOSED_BY = 'NONE' ERROR_ON_COLUMN_COUNT_MISMATCH = TRUE + EMPTY_FIELD_AS_NULL = FALSE ESCAPE_UNENCLOSED_FIELD = 'NONE' + NULL_IF = ('{null_marker}') + """.format( + database=self.database, + schema=self.schema, + file_format_name=self.file_format_name, + field_delimiter=self.field_delimiter, + null_marker=self.null_marker, + ) + log.debug(query) + connection.cursor().execute(query) + + +class SnowflakeLoadJSONTask(SnowflakeLoadTask): # pylint: disable=abstract-method + """ + Abstract Task for loading JSON data from s3 into a table in Snowflake. The resulting table will + contain a single VARIANT column called raw_json. + + Implementations should define the following properties: + + - self.insert_source_task + - self.table + - self.file_format_name + """ + + @property + def columns(self): + return [ + ('raw_json', 'VARIANT'), + ] + + def create_format(self, connection): + query = """ + CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name} + TYPE = 'JSON' + COMPRESSION = 'AUTO' + """.format( + database=self.database, + schema=self.schema, + file_format_name=self.file_format_name, + ) + log.debug(query) + connection.cursor().execute(query) diff --git a/edx/analytics/tasks/util/s3_util.py b/edx/analytics/tasks/util/s3_util.py index bd99b88855..18305e859b 100644 --- a/edx/analytics/tasks/util/s3_util.py +++ b/edx/analytics/tasks/util/s3_util.py @@ -5,7 +5,7 @@ import os import time from fnmatch import fnmatch -from urlparse import urlparse +from urlparse import urlparse, urlunparse from luigi.contrib.hdfs.format import Plain from luigi.contrib.hdfs.target import HdfsTarget @@ -125,19 +125,17 @@ def func(name): class ScalableS3Client(S3Client): """ S3 client that adds support for defaulting host name. - """ - # TODO: Make this behavior configurable and submit this change upstream. - def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, **kwargs): + DEPRECATED: Just specify `host` in the `[s3]` configuration section, e.g.: - if not aws_access_key_id: - aws_access_key_id = self._get_s3_config('aws_access_key_id') - if not aws_secret_access_key: - aws_secret_access_key = self._get_s3_config('aws_secret_access_key') - if 'host' not in kwargs: - kwargs['host'] = self._get_s3_config('host') or 's3.amazonaws.com' + [s3] + host = s3.amazonaws.com - super(ScalableS3Client, self).__init__(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, **kwargs) + NOTE: In future versions of Luigi, we must NOT pass `host` to the s3 client + or else it will throw a KeyError. boto3 will already default to + s3.amazonaws.com. + """ + pass class S3HdfsTarget(HdfsTarget): @@ -162,3 +160,27 @@ def open(self, mode='r'): if not hasattr(self, 's3_client'): self.s3_client = ScalableS3Client() return AtomicS3File(safe_path, self.s3_client, policy=DEFAULT_KEY_ACCESS_POLICY) + + +def canonicalize_s3_url(url): + """ + Convert the given s3 URL into a form which is safe to use with external tools. + + Specifically, URL Schemes such as "s3+https" are urecognized by gsutil and Snowflake, and must + be converted to "s3". + + Args: + url (str): An s3 URL. + + Raises: + ValueError: if the scheme of the input url is unrecognized. + """ + parsed_url = urlparse(url) + if parsed_url.scheme == 's3': + canonical_url = url # Simple passthrough, no change needed. + if parsed_url.scheme == 's3+https': + new_url_parts = parsed_url._replace(scheme='s3') + canonical_url = urlunparse(new_url_parts) + else: + raise ValueError('The S3 URL scheme "{}" is unrecognized.'.format(parsed_url.scheme)) + return canonical_url diff --git a/edx/analytics/tasks/util/tests/test_s3_util.py b/edx/analytics/tasks/util/tests/test_s3_util.py index 186dc616ef..9b51d998a3 100644 --- a/edx/analytics/tasks/util/tests/test_s3_util.py +++ b/edx/analytics/tasks/util/tests/test_s3_util.py @@ -1,4 +1,7 @@ -"""Tests for S3-related utility functionality.""" +""" +Tests for S3-related utility functionality. +""" +from __future__ import print_function from unittest import TestCase @@ -25,7 +28,7 @@ def _make_s3_generator(self, bucket_name, root, path_info, patterns): target_list = [self._make_key("{root}/{path}".format(root=root, path=path), size) for path, size in path_info.iteritems()] s3_bucket.list = MagicMock(return_value=target_list) - print [(k.key, k.size) for k in target_list] + print([(k.key, k.size) for k in target_list]) s3_bucket.name = bucket_name source = "s3://{bucket}/{root}".format(bucket=bucket_name, root=root) diff --git a/edx/analytics/tasks/warehouse/load_warehouse_snowflake.py b/edx/analytics/tasks/warehouse/load_warehouse_snowflake.py index 8fb532ee30..eee5e5f516 100644 --- a/edx/analytics/tasks/warehouse/load_warehouse_snowflake.py +++ b/edx/analytics/tasks/warehouse/load_warehouse_snowflake.py @@ -4,7 +4,7 @@ import luigi from edx.analytics.tasks.common.pathutil import PathSetTask -from edx.analytics.tasks.common.snowflake_load import SnowflakeLoadDownstreamMixin, SnowflakeLoadTask +from edx.analytics.tasks.common.snowflake_load import SnowflakeLoadCSVTask, SnowflakeLoadDownstreamMixin from edx.analytics.tasks.insights.enrollments import EnrollmentSummaryRecord from edx.analytics.tasks.util.hive import HivePartition, WarehouseMixin from edx.analytics.tasks.util.url import ExternalURL, url_path_join @@ -14,7 +14,7 @@ from edx.analytics.tasks.warehouse.load_internal_reporting_course_structure import CourseBlockRecord -class LoadInternalReportingCertificatesToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadInternalReportingCertificatesToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): @property def table(self): @@ -42,7 +42,7 @@ def insert_source_task(self): return ExternalURL(url=self.hive_partition_path('internal_reporting_certificates', self.date)) -class LoadInternalReportingCountryToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadInternalReportingCountryToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): """ Loads the country table from Hive into the Vertica data warehouse. """ @@ -72,7 +72,7 @@ def insert_source_task(self): return ExternalURL(url=self.hive_partition_path('internal_reporting_d_country', self.date)) -class LoadInternalReportingCourseToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadInternalReportingCourseToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): @property def insert_source_task(self): @@ -92,7 +92,7 @@ def columns(self): return CourseRecord.get_sql_schema() -class LoadInternalReportingCourseSeatToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadInternalReportingCourseSeatToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): @property def insert_source_task(self): @@ -112,7 +112,7 @@ def columns(self): return CourseSeatRecord.get_sql_schema() -class LoadInternalReportingCourseSubjectToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadInternalReportingCourseSubjectToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): @property def insert_source_task(self): @@ -132,7 +132,7 @@ def columns(self): return CourseSubjectRecord.get_sql_schema() -class LoadInternalReportingProgramCourseToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadInternalReportingProgramCourseToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): @property def insert_source_task(self): @@ -177,7 +177,7 @@ def complete(self): return all(r.complete() for r in luigi.task.flatten(self.requires())) -class LoadInternalReportingCourseStructureToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadInternalReportingCourseStructureToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): def __init__(self, *args, **kwargs): super(LoadInternalReportingCourseStructureToSnowflake, self).__init__(*args, **kwargs) @@ -208,7 +208,7 @@ def columns(self): return CourseBlockRecord.get_sql_schema() -class LoadUserCourseSummaryToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadUserCourseSummaryToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): @property def insert_source_task(self): @@ -227,7 +227,7 @@ def columns(self): return EnrollmentSummaryRecord.get_sql_schema() -class LoadInternalReportingUserActivityToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadInternalReportingUserActivityToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): def __init__(self, *args, **kwargs): super(LoadInternalReportingUserActivityToSnowflake, self).__init__(*args, **kwargs) @@ -271,7 +271,7 @@ def columns(self): ] -class LoadInternalReportingUserToSnowflake(WarehouseMixin, SnowflakeLoadTask): +class LoadInternalReportingUserToSnowflake(WarehouseMixin, SnowflakeLoadCSVTask): @property def partition(self):