diff --git a/odd_collector/adapters/mysql/logger.py b/odd_collector/adapters/mysql/logger.py new file mode 100644 index 00000000..d12412d4 --- /dev/null +++ b/odd_collector/adapters/mysql/logger.py @@ -0,0 +1,3 @@ +from odd_collector_sdk.logger import logger + +logger = logger diff --git a/odd_collector/adapters/mysql/mappers/tables.py b/odd_collector/adapters/mysql/mappers/tables.py index b69a716f..6ac8799b 100644 --- a/odd_collector/adapters/mysql/mappers/tables.py +++ b/odd_collector/adapters/mysql/mappers/tables.py @@ -6,6 +6,7 @@ from odd_collector.models import Table +from ..logger import logger from .columns import map_column from .views import map_view @@ -34,8 +35,8 @@ def map_table(generator: MysqlGenerator, table: Table) -> DataEntity: def map_tables( generator: MysqlGenerator, tables: list[Table], -): - data_entities: list[DataEntity] = [] +) -> list[DataEntity]: + data_entities: dict[str, tuple[Table, DataEntity]] = {} for table in tables: if table.type == "VIEW": @@ -43,8 +44,16 @@ def map_tables( elif table.type == "BASE TABLE": data_entity = map_table(generator, table) else: + logger.warning(f"Can't parse {table.type=}. Available [VIEW, BASE_TABLE]") continue - data_entities.append(data_entity) + data_entities[table.uid] = (table, data_entity) - return data_entities + for table, data_entity in data_entities.values(): + for dependency in table.dependencies: + if dependency.uid in data_entities and data_entity.data_transformer: + data_entity.data_transformer.inputs.append( + data_entities[dependency.uid][1].oddrn + ) + + return [data_entity for _, data_entity in data_entities.values()] diff --git a/odd_collector/adapters/mysql/mappers/views.py b/odd_collector/adapters/mysql/mappers/views.py index 843248a2..f9b18735 100644 --- a/odd_collector/adapters/mysql/mappers/views.py +++ b/odd_collector/adapters/mysql/mappers/views.py @@ -1,12 +1,9 @@ from copy import deepcopy -from typing import Optional from odd_collector_sdk.utils.metadata import DefinitionType, extract_metadata from odd_models.models import DataEntity, DataEntityType, DataSet, DataTransformer -from odd_models.utils import SqlParser from oddrn_generator import MysqlGenerator -from odd_collector.logger import logger from odd_collector.models import Table from .columns import map_column @@ -30,43 +27,7 @@ def map_view(generator: MysqlGenerator, table: Table) -> DataEntity: map_column(generator, column, "views") for column in table.columns ], ), - data_transformer=extract_transformer_data( - sql=table.sql_definition, generator=generator + data_transformer=DataTransformer( + sql=table.sql_definition, inputs=[], outputs=[] ), ) - - -def extract_transformer_data( - generator: MysqlGenerator, sql: Optional[str] = None -) -> DataTransformer: - if not sql: - return DataTransformer(sql=sql, inputs=[], outputs=[]) - - if type(sql) == bytes: - sql = sql.decode("utf-8") - sql_parser = SqlParser(sql.replace("(", "").replace(")", "")) - - try: - inputs, outputs = sql_parser.get_response() - except Exception as e: - logger.warning(f"Couldn't parse inputs and outputs from {sql}") - return DataTransformer(sql=sql, inputs=[], outputs=[]) - - return DataTransformer( - inputs=get_oddrn_list(inputs, generator), - outputs=get_oddrn_list(outputs, generator), - sql=sql, - ) - - -def get_oddrn_list(tables, generator: MysqlGenerator) -> list[str]: - response = [] - generator = deepcopy(generator) - - for table in tables: - source = table.split(".") - table_name = source[1] if len(source) > 1 else source[0] - response.append( - generator.get_oddrn_by_path("tables", table_name.replace("`", "")) - ) - return response diff --git a/odd_collector/models/table.py b/odd_collector/models/table.py index 9f1caa7d..c44063cb 100644 --- a/odd_collector/models/table.py +++ b/odd_collector/models/table.py @@ -1,13 +1,26 @@ import dataclasses +import traceback from typing import Any, Optional from odd_collector_sdk.utils.metadata import HasMetadata +from sql_metadata import Parser from odd_collector.helpers.datetime import Datetime +from ..logger import logger from .column import Column +@dataclasses.dataclass +class Dependency: + name: str + schema: str + + @property + def uid(self) -> str: + return f"{self.schema}.{self.name}" + + @dataclasses.dataclass class Table(HasMetadata): catalog: str @@ -25,3 +38,39 @@ class Table(HasMetadata): @property def odd_metadata(self): return self.metadata + + @property + def uid(self) -> str: + return f"{self.schema}.{self.name}" + + @property + def dependencies(self) -> list[Dependency]: + try: + sql = self.sql_definition + + if not sql: + return [] + + if isinstance(sql, bytes): + sql = sql.decode("utf-8") + + parsed = Parser(sql.replace("(", "").replace(")", "")) + dependencies = [] + + for dependency in parsed.tables: + schema_name = dependency.split(".") + + if len(schema_name) != 2: + logger.warning( + f"Dependency must be in format .. got {dependency=}" + ) + continue + + schema, name = schema_name + dependencies.append(Dependency(name=name, schema=schema)) + return dependencies + except Exception as e: + logger.warning(f"Couldn't parse dependencies from {self.uid}. {e}") + logger.debug(self.sql_definition) + logger.debug(traceback.format_exc()) + return [] diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index ef1544f8..fcedc62c 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -1,4 +1,4 @@ -from funcy import lfilter +from funcy import filter, lfilter from odd_models import DataEntity from odd_models.models import DataEntityList, DataEntityType @@ -10,3 +10,12 @@ def find_by_type( return lfilter( lambda data_entity: data_entity.type == data_entity_type, data_entity_list.items ) + + +def find_by_name(data_entity_list: DataEntityList, name: str) -> DataEntity: + return next( + filter( + lambda data_entity: data_entity.name == name, + data_entity_list.items, + ) + ) diff --git a/tests/integration/test_mysql.py b/tests/integration/test_mysql.py index 8daae553..62c1e4a4 100644 --- a/tests/integration/test_mysql.py +++ b/tests/integration/test_mysql.py @@ -5,7 +5,7 @@ from pydantic import SecretStr from testcontainers.mysql import MySqlContainer -from tests.integration.helpers import find_by_type +from tests.integration.helpers import find_by_name, find_by_type create_tables = """ CREATE TABLE Persons ( @@ -23,6 +23,12 @@ WHERE City = 'Sandnes'; """ +create_view_from_view = """ +CREATE VIEW persons_last_names AS +SELECT LastName +FROM persons_names; +""" + from odd_collector.adapters.mysql.adapter import Adapter from odd_collector.domain.plugin import MySQLPlugin @@ -35,6 +41,7 @@ def test_mysql(): with engine.connect() as connection: connection.exec_driver_sql(create_tables) connection.exec_driver_sql(create_view) + connection.exec_driver_sql(create_view_from_view) config = MySQLPlugin( type="mysql", @@ -52,7 +59,7 @@ def test_mysql(): ) assert len(database_services) == 1 database_service = database_services[0] - assert len(database_service.data_entity_group.entities_list) == 2 + assert len(database_service.data_entity_group.entities_list) == 3 tables = find_by_type(data_entities, DataEntityType.TABLE) assert len(tables) == 1 @@ -60,10 +67,16 @@ def test_mysql(): assert len(table.dataset.field_list) == 5 views = find_by_type(data_entities, DataEntityType.VIEW) - assert len(views) == 1 - view = views[0] - assert len(view.dataset.field_list) == 2 - assert len(view.data_transformer.inputs) == 1 - assert view.data_transformer.inputs[0] == table.oddrn + assert len(views) == 2 + + persons_view = find_by_name(data_entities, "persons_names") + assert len(persons_view.dataset.field_list) == 2 + assert len(persons_view.data_transformer.inputs) == 1 + assert persons_view.data_transformer.inputs[0] == table.oddrn + + last_names_view = find_by_name(data_entities, "persons_last_names") + assert len(last_names_view.dataset.field_list) == 1 + assert len(last_names_view.data_transformer.inputs) == 1 + assert last_names_view.data_transformer.inputs[0] == persons_view.oddrn assert data_entities.json()