Skip to content

Commit

Permalink
feat(mode/ingest): Add support for missing Mode datasets in lineage (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-salvi-apptware authored Sep 10, 2024
1 parent f082287 commit 852a23b
Show file tree
Hide file tree
Showing 9 changed files with 615 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class BIAssetSubTypes(StrEnum):

# Mode
MODE_REPORT = "Report"
MODE_DATASET = "Dataset"
MODE_QUERY = "Query"
MODE_CHART = "Chart"

Expand Down
128 changes: 100 additions & 28 deletions metadata-ingestion/src/datahub/ingestion/source/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
infer_output_schema,
)
from datahub.utilities import config_clean
from datahub.utilities.lossy_collections import LossyDict, LossyList
from datahub.utilities.lossy_collections import LossyList

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -199,10 +199,6 @@ class ModeSourceReport(StaleEntityRemovalSourceReport):
num_query_template_render_failures: int = 0
num_query_template_render_success: int = 0

dropped_imported_datasets: LossyDict[str, LossyList[str]] = dataclasses.field(
default_factory=LossyDict
)

def report_dropped_space(self, ent_name: str) -> None:
self.filtered_spaces.append(ent_name)

Expand Down Expand Up @@ -429,10 +425,25 @@ def construct_dashboard(
# Last refreshed ts.
last_refreshed_ts = self._parse_last_run_at(report_info)

# Datasets
datasets = []
for imported_dataset_name in report_info.get("imported_datasets", {}):
mode_dataset = self._get_request_json(
f"{self.workspace_uri}/reports/{imported_dataset_name.get('token')}"
)
dataset_urn = builder.make_dataset_urn_with_platform_instance(
self.platform,
str(mode_dataset.get("id")),
platform_instance=None,
env=self.config.env,
)
datasets.append(dataset_urn)

dashboard_info_class = DashboardInfoClass(
description=description if description else "",
title=title if title else "",
charts=self._get_chart_urns(report_token),
datasets=datasets if datasets else None,
lastModified=last_modified,
lastRefreshed=last_refreshed_ts,
dashboardUrl=f"{self.config.connect_uri}/{self.config.workspace}/reports/{report_token}",
Expand Down Expand Up @@ -725,6 +736,10 @@ def _get_platform_and_dbname(
data_source.get("adapter", ""), data_source.get("name", "")
)
database = data_source.get("database", "")
# This is hacky but on bigquery we want to change the database if its default
# For lineage we need project_id.db.table
if platform == "bigquery" and database == "default":
database = data_source.get("host", "")
return platform, database
else:
self.report.report_warning(
Expand Down Expand Up @@ -900,24 +915,36 @@ def normalize_mode_query(self, query: str) -> str:

return rendered_query

def construct_query_from_api_data(
def construct_query_or_dataset(
self,
report_token: str,
query_data: dict,
space_token: str,
report_info: dict,
is_mode_dataset: bool,
) -> Iterable[MetadataWorkUnit]:
query_urn = self.get_dataset_urn_from_query(query_data)
query_urn = (
self.get_dataset_urn_from_query(query_data)
if not is_mode_dataset
else self.get_dataset_urn_from_query(report_info)
)

query_token = query_data.get("token")

externalUrl = (
f"{self.config.connect_uri}/{self.config.workspace}/datasets/{report_token}"
if is_mode_dataset
else f"{self.config.connect_uri}/{self.config.workspace}/reports/{report_token}/details/queries/{query_token}"
)

dataset_props = DatasetPropertiesClass(
name=query_data.get("name"),
name=report_info.get("name") if is_mode_dataset else query_data.get("name"),
description=f"""### Source Code
``` sql
{query_data.get("raw_query")}
```
""",
externalUrl=f"{self.config.connect_uri}/{self.config.workspace}/reports/{report_token}/details/queries/{query_token}",
externalUrl=externalUrl,
customProperties=self.get_custom_props_from_dict(
query_data,
[
Expand All @@ -939,7 +966,22 @@ def construct_query_from_api_data(
).as_workunit()
)

subtypes = SubTypesClass(typeNames=([BIAssetSubTypes.MODE_QUERY]))
if is_mode_dataset:
space_container_key = self.gen_space_key(space_token)
yield from add_dataset_to_container(
container_key=space_container_key,
dataset_urn=query_urn,
)

subtypes = SubTypesClass(
typeNames=(
[
BIAssetSubTypes.MODE_DATASET
if is_mode_dataset
else BIAssetSubTypes.MODE_QUERY
]
)
)
yield (
MetadataChangeProposalWrapper(
entityUrn=query_urn,
Expand All @@ -950,15 +992,16 @@ def construct_query_from_api_data(
yield MetadataChangeProposalWrapper(
entityUrn=query_urn,
aspect=BrowsePathsV2Class(
path=self._browse_path_query(space_token, report_info)
path=self._browse_path_dashboard(space_token)
if is_mode_dataset
else self._browse_path_query(space_token, report_info)
),
).as_workunit()

(
upstream_warehouse_platform,
upstream_warehouse_db_name,
) = self._get_platform_and_dbname(query_data.get("data_source_id"))

if upstream_warehouse_platform is None:
# this means we can't infer the platform
return
Expand Down Expand Up @@ -1022,7 +1065,7 @@ def construct_query_from_api_data(
schema_fields = infer_output_schema(parsed_query_object)
if schema_fields:
schema_metadata = SchemaMetadataClass(
schemaName="mode_query",
schemaName="mode_dataset" if is_mode_dataset else "mode_query",
platform=f"urn:li:dataPlatform:{self.platform}",
version=0,
fields=schema_fields,
Expand All @@ -1040,7 +1083,7 @@ def construct_query_from_api_data(
)

yield from self.get_upstream_lineage_for_parsed_sql(
query_data, parsed_query_object
query_urn, query_data, parsed_query_object
)

operation = OperationClass(
Expand Down Expand Up @@ -1089,10 +1132,9 @@ def construct_query_from_api_data(
).as_workunit()

def get_upstream_lineage_for_parsed_sql(
self, query_data: dict, parsed_query_object: SqlParsingResult
self, query_urn: str, query_data: dict, parsed_query_object: SqlParsingResult
) -> List[MetadataWorkUnit]:
wu = []
query_urn = self.get_dataset_urn_from_query(query_data)

if parsed_query_object is None:
logger.info(
Expand Down Expand Up @@ -1350,6 +1392,24 @@ def _get_reports(self, space_token: str) -> List[dict]:
)
return reports

@lru_cache(maxsize=None)
def _get_datasets(self, space_token: str) -> List[dict]:
"""
Retrieves datasets for a given space token.
"""
datasets = []
try:
url = f"{self.workspace_uri}/spaces/{space_token}/datasets"
datasets_json = self._get_request_json(url)
datasets = datasets_json.get("_embedded", {}).get("reports", [])
except HTTPError as http_error:
self.report.report_failure(
title="Failed to Retrieve Datasets for Space",
message=f"Unable to retrieve datasets for space token {space_token}.",
context=f"Error: {str(http_error)}",
)
return datasets

@lru_cache(maxsize=None)
def _get_queries(self, report_token: str) -> list:
queries = []
Expand Down Expand Up @@ -1523,24 +1583,14 @@ def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
for report in reports:
report_token = report.get("token", "")

if report.get("imported_datasets"):
# The connector doesn't support imported datasets yet.
# For now, we just keep this in the report to track what we're missing.
imported_datasets = [
imported_dataset.get("name") or str(imported_dataset)
for imported_dataset in report["imported_datasets"]
]
self.report.dropped_imported_datasets.setdefault(
report_token, LossyList()
).extend(imported_datasets)

queries = self._get_queries(report_token)
for query in queries:
query_mcps = self.construct_query_from_api_data(
query_mcps = self.construct_query_or_dataset(
report_token,
query,
space_token=space_token,
report_info=report,
is_mode_dataset=False,
)
chart_fields: Dict[str, SchemaFieldClass] = {}
for wu in query_mcps:
Expand All @@ -1566,6 +1616,27 @@ def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
query_name=query["name"],
)

def emit_dataset_mces(self):
"""
Emits MetadataChangeEvents (MCEs) for datasets within each space.
"""
for space_token, _ in self.space_tokens.items():
datasets = self._get_datasets(space_token)

for report in datasets:
report_token = report.get("token", "")
queries = self._get_queries(report_token)
for query in queries:
query_mcps = self.construct_query_or_dataset(
report_token,
query,
space_token=space_token,
report_info=report,
is_mode_dataset=True,
)
for wu in query_mcps:
yield wu

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "ModeSource":
config: ModeConfig = ModeConfig.parse_obj(config_dict)
Expand All @@ -1581,6 +1652,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
yield from self.emit_dashboard_mces()
yield from self.emit_dataset_mces()
yield from self.emit_chart_mces()

def get_report(self) -> SourceReport:
Expand Down
Loading

0 comments on commit 852a23b

Please sign in to comment.