From 0e4bd45aaac5c2da5aebb35576ba30d6935c9913 Mon Sep 17 00:00:00 2001 From: Julian Steger <108534789+juliansteger-sc@users.noreply.github.com> Date: Fri, 14 Apr 2023 10:46:18 +0200 Subject: [PATCH] fix: allow to set table location when output location is configured but not enforced (#223) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémy Guiselin <9251353+Jrmyy@users.noreply.github.com> --- dbt/adapters/athena/impl.py | 14 +++++++++++--- .../models/table/create_table_as.sql | 4 ++-- tests/unit/test_adapter.py | 17 ++++++++++++----- tests/unit/utils.py | 19 ++++++++++++++++++- 4 files changed, 43 insertions(+), 11 deletions(-) diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 95410a42..dd51a3ed 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -144,7 +144,7 @@ def add_lf_tags( logger.debug(self.parse_lf_response(response, database, table, columns, {tag_key: tag_value})) @available - def get_work_group_output_location(self) -> Optional[str]: + def is_work_group_output_location_enforced(self) -> bool: conn = self.connections.get_thread_connection() creds = conn.credentials client = conn.handle @@ -154,13 +154,21 @@ def get_work_group_output_location(self) -> Optional[str]: if creds.work_group: work_group = athena_client.get_work_group(WorkGroup=creds.work_group) - return ( + output_location = ( work_group.get("WorkGroup", {}) .get("Configuration", {}) .get("ResultConfiguration", {}) - .get("OutputLocation") + .get("OutputLocation", None) ) + output_location_enforced = ( + work_group.get("WorkGroup", {}).get("Configuration", {}).get("EnforceWorkGroupConfiguration", False) + ) + + return output_location is not None and output_location_enforced + else: + return False + @available def s3_table_prefix(self, s3_data_dir: Optional[str]) -> str: """ diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql index 49b3ce4e..ccfcffb1 100644 --- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql +++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -14,7 +14,7 @@ {%- set location_property = 'external_location' -%} {%- set partition_property = 'partitioned_by' -%} - {%- set work_group_output_location = adapter.get_work_group_output_location() -%} + {%- set work_group_output_location_enforced = adapter.is_work_group_output_location_enforced() -%} {%- set location = adapter.s3_table_location(s3_data_dir, s3_data_naming, relation.schema, relation.identifier, external_location, temporary) -%} {%- if materialized == 'table_hive_ha' -%} @@ -48,7 +48,7 @@ with ( table_type='{{ table_type }}', is_external={%- if table_type == 'iceberg' -%}false{%- else -%}true{%- endif %}, - {%- if work_group_output_location is none -%} + {%- if not work_group_output_location_enforced -%} {{ location_property }}='{{ location }}', {%- endif %} {%- if partitioned_by is not none %} diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index d69bdc69..68c592ea 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -723,16 +723,23 @@ def test_upload_seed_to_s3_external_location(self, aws_credentials): @mock_athena def test_get_work_group_output_location(self, aws_credentials): self.adapter.acquire_connection("dummy") - self.mock_aws_service.create_work_group_with_output_location(ATHENA_WORKGROUP) - work_group_location = self.adapter.get_work_group_output_location() - assert work_group_location is not None + self.mock_aws_service.create_work_group_with_output_location_enforced(ATHENA_WORKGROUP) + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert work_group_location_enforced @mock_athena def test_get_work_group_output_location_no_location(self, aws_credentials): self.adapter.acquire_connection("dummy") self.mock_aws_service.create_work_group_no_output_location(ATHENA_WORKGROUP) - work_group_location = self.adapter.get_work_group_output_location() - assert work_group_location is None + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert not work_group_location_enforced + + @mock_athena + def test_get_work_group_output_location_not_enforced(self, aws_credentials): + self.adapter.acquire_connection("dummy") + self.mock_aws_service.create_work_group_with_output_location_not_enforced(ATHENA_WORKGROUP) + work_group_location_enforced = self.adapter.is_work_group_output_location_enforced() + assert not work_group_location_enforced @mock_athena @mock_glue diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 70fa3aac..9607b432 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -291,7 +291,7 @@ def create_table_without_table_type(self, table_name: str): }, ) - def create_work_group_with_output_location(self, work_group_name: str): + def create_work_group_with_output_location_enforced(self, work_group_name: str): athena = boto3.client("athena", region_name=AWS_REGION) athena.create_work_group( Name=work_group_name, @@ -308,6 +308,23 @@ def create_work_group_with_output_location(self, work_group_name: str): }, ) + def create_work_group_with_output_location_not_enforced(self, work_group_name: str): + athena = boto3.client("athena", region_name=AWS_REGION) + athena.create_work_group( + Name=work_group_name, + Configuration={ + "ResultConfiguration": { + "OutputLocation": "s3://pre-configured-output-location/", + }, + "EnforceWorkGroupConfiguration": False, + "PublishCloudWatchMetricsEnabled": True, + "EngineVersion": { + "SelectedEngineVersion": "Athena engine version 2", + "EffectiveEngineVersion": "Athena engine version 2", + }, + }, + ) + def create_work_group_no_output_location(self, work_group_name: str): athena = boto3.client("athena", region_name=AWS_REGION) athena.create_work_group(