Skip to content

Commit

Permalink
fix: allow to set table location when output location is configured b…
Browse files Browse the repository at this point in the history
…ut not enforced (#223)

Co-authored-by: Jérémy Guiselin <[email protected]>
  • Loading branch information
juliansteger-sc and Jrmyy authored Apr 14, 2023
1 parent 209720e commit 0e4bd45
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
14 changes: 11 additions & 3 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def add_lf_tags(
logger.debug(self.parse_lf_response(response, database, table, columns, {tag_key: tag_value}))

@available
def get_work_group_output_location(self) -> Optional[str]:
def is_work_group_output_location_enforced(self) -> bool:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
Expand All @@ -154,13 +154,21 @@ def get_work_group_output_location(self) -> Optional[str]:

if creds.work_group:
work_group = athena_client.get_work_group(WorkGroup=creds.work_group)
return (
output_location = (
work_group.get("WorkGroup", {})
.get("Configuration", {})
.get("ResultConfiguration", {})
.get("OutputLocation")
.get("OutputLocation", None)
)

output_location_enforced = (
work_group.get("WorkGroup", {}).get("Configuration", {}).get("EnforceWorkGroupConfiguration", False)
)

return output_location is not None and output_location_enforced
else:
return False

@available
def s3_table_prefix(self, s3_data_dir: Optional[str]) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

{%- set location_property = 'external_location' -%}
{%- set partition_property = 'partitioned_by' -%}
{%- set work_group_output_location = adapter.get_work_group_output_location() -%}
{%- set work_group_output_location_enforced = adapter.is_work_group_output_location_enforced() -%}
{%- set location = adapter.s3_table_location(s3_data_dir, s3_data_naming, relation.schema, relation.identifier, external_location, temporary) -%}

{%- if materialized == 'table_hive_ha' -%}
Expand Down Expand Up @@ -48,7 +48,7 @@
with (
table_type='{{ table_type }}',
is_external={%- if table_type == 'iceberg' -%}false{%- else -%}true{%- endif %},
{%- if work_group_output_location is none -%}
{%- if not work_group_output_location_enforced -%}
{{ location_property }}='{{ location }}',
{%- endif %}
{%- if partitioned_by is not none %}
Expand Down
17 changes: 12 additions & 5 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,16 +723,23 @@ def test_upload_seed_to_s3_external_location(self, aws_credentials):
@mock_athena
def test_get_work_group_output_location(self, aws_credentials):
self.adapter.acquire_connection("dummy")
self.mock_aws_service.create_work_group_with_output_location(ATHENA_WORKGROUP)
work_group_location = self.adapter.get_work_group_output_location()
assert work_group_location is not None
self.mock_aws_service.create_work_group_with_output_location_enforced(ATHENA_WORKGROUP)
work_group_location_enforced = self.adapter.is_work_group_output_location_enforced()
assert work_group_location_enforced

@mock_athena
def test_get_work_group_output_location_no_location(self, aws_credentials):
self.adapter.acquire_connection("dummy")
self.mock_aws_service.create_work_group_no_output_location(ATHENA_WORKGROUP)
work_group_location = self.adapter.get_work_group_output_location()
assert work_group_location is None
work_group_location_enforced = self.adapter.is_work_group_output_location_enforced()
assert not work_group_location_enforced

@mock_athena
def test_get_work_group_output_location_not_enforced(self, aws_credentials):
self.adapter.acquire_connection("dummy")
self.mock_aws_service.create_work_group_with_output_location_not_enforced(ATHENA_WORKGROUP)
work_group_location_enforced = self.adapter.is_work_group_output_location_enforced()
assert not work_group_location_enforced

@mock_athena
@mock_glue
Expand Down
19 changes: 18 additions & 1 deletion tests/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def create_table_without_table_type(self, table_name: str):
},
)

def create_work_group_with_output_location(self, work_group_name: str):
def create_work_group_with_output_location_enforced(self, work_group_name: str):
athena = boto3.client("athena", region_name=AWS_REGION)
athena.create_work_group(
Name=work_group_name,
Expand All @@ -308,6 +308,23 @@ def create_work_group_with_output_location(self, work_group_name: str):
},
)

def create_work_group_with_output_location_not_enforced(self, work_group_name: str):
athena = boto3.client("athena", region_name=AWS_REGION)
athena.create_work_group(
Name=work_group_name,
Configuration={
"ResultConfiguration": {
"OutputLocation": "s3://pre-configured-output-location/",
},
"EnforceWorkGroupConfiguration": False,
"PublishCloudWatchMetricsEnabled": True,
"EngineVersion": {
"SelectedEngineVersion": "Athena engine version 2",
"EffectiveEngineVersion": "Athena engine version 2",
},
},
)

def create_work_group_no_output_location(self, work_group_name: str):
athena = boto3.client("athena", region_name=AWS_REGION)
athena.create_work_group(
Expand Down

0 comments on commit 0e4bd45

Please sign in to comment.