diff --git a/src/sodaspark/scan.py b/src/sodaspark/scan.py index 79ce6a6..d813d69 100644 --- a/src/sodaspark/scan.py +++ b/src/sodaspark/scan.py @@ -9,6 +9,7 @@ from pyspark.sql import types as T # noqa: N812 from sodasql.common.yaml_helper import YamlHelper from sodasql.dialects.spark_dialect import SparkDialect +from sodasql.scan.failed_rows_processor import FailedRowsProcessor from sodasql.scan.file_system import FileSystemSingleton from sodasql.scan.measurement import Measurement from sodasql.scan.scan import Scan @@ -255,6 +256,7 @@ def create_scan( warehouse_name: str = "sodaspark", soda_server_client: SodaServerClient | None = None, time: str | None = None, + failed_rows_processor: FailedRowsProcessor | None = None, ) -> Scan: """ Create a scan object. @@ -263,11 +265,16 @@ def create_scan( ---------- scan_yml : ScanYml The scan yml. - variables: variables to be substituted in scan yml + variables: Optional[dict] (default: None) + variables to be substituted in scan yml + warehouse_name: Optional[str] (default: sodapsark) + The name of the warehouse soda_server_client : Optional[SodaServerClient] (default : None) A soda server client. time: Optional[str] (default: None) Timestamp date in ISO8601 format. If None, use datatime.now() in ISO8601 format. + failed_rows_processor: Optional[FailedRowsProcessor] (default: None) + A FailedRowsProcessor implementation Returns ------- @@ -285,6 +292,7 @@ def create_scan( soda_server_client=soda_server_client, variables=variables, time=time, + failed_rows_processor=failed_rows_processor, ) return scan @@ -430,6 +438,7 @@ def execute( soda_server_client: SodaServerClient | None = None, as_frames: bool | None = False, time: str | None = None, + failed_rows_processor: FailedRowsProcessor | None = None, ) -> ScanResult: """ Execute a scan on a data frame. @@ -442,12 +451,16 @@ def execute( The data frame to be scanned. variables: Optional[dict] (default : None) Variables to be substituted in scan yml + warehouse_name: Optional[str] (default: sodapsark) + The name of the warehouse soda_server_client : Optional[SodaServerClient] (default : None) A soda server client. as_frames : bool (default : False) Flag to return results in Dataframe time: str (default : None) Timestamp date in ISO8601 format at the start of a scan + failed_rows_processor: Optional[FailedRowsProcessor] (default: None) + A FailedRowsProcessor implementation Returns ------- @@ -463,6 +476,7 @@ def execute( soda_server_client=soda_server_client, time=time, warehouse_name=warehouse_name, + failed_rows_processor=failed_rows_processor, ) scan.execute() diff --git a/tests/test_scan.py b/tests/test_scan.py index abb17b8..9fd6702 100644 --- a/tests/test_scan.py +++ b/tests/test_scan.py @@ -1,15 +1,19 @@ from __future__ import annotations +import ast import datetime as dt import json from dataclasses import dataclass from typing import BinaryIO import pytest +from _pytest.capture import CaptureFixture from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql import functions as F # noqa: N812 from pyspark.sql import types as T # noqa: N812 +from pyspark.sql.types import IntegerType, StringType, StructField, StructType from sodasql.dialects.spark_dialect import SparkDialect +from sodasql.scan.failed_rows_processor import FailedRowsProcessor from sodasql.scan.group_value import GroupValue from sodasql.scan.measurement import Measurement from sodasql.scan.scan_error import TestExecutionScanError @@ -183,6 +187,19 @@ def df(spark_session: SparkSession) -> DataFrame: return df +class PrintFailedRowProcessor(FailedRowsProcessor): + def process(self, context: dict) -> dict: + + print(context) + + return {"message": "All failed rows were printed in your terminal"} + + +@pytest.fixture +def failed_rows_processor() -> FailedRowsProcessor: + return PrintFailedRowProcessor() + + def test_create_scan_yml_table_name_is_demodata( scan_definition: str, ) -> None: @@ -507,3 +524,59 @@ def test_scan_execute_return_as_data_frame( (scan_result[1].count(), len(scan_result[1].columns)), (scan_result[2].count(), len(scan_result[2].columns)), ) + + +def test_failed_rows_processor_return_correct_values( + spark_session: SparkSession, + failed_rows_processor: FailedRowsProcessor, + capsys: CaptureFixture, +) -> None: + """We expect the failed rows to show up in the system output.""" + + expected_output = { + "sample_name": "missing", + "column_name": "number", + "test_ids": ['{"column":"number","expression":"missing_count == 0"}'], + "sample_columns": [ + {"name": "id", "type": "string"}, + {"name": "number", "type": "int"}, + ], + "sample_rows": [["3", None]], + "sample_description": "my_table.number.missing", + "total_row_count": 1, + } + + data = [("1", 100), ("2", 200), ("3", None), ("4", 400)] + + schema = StructType( + [ + StructField("id", StringType(), True), + StructField("number", IntegerType(), True), + ] + ) + + df = spark_session.createDataFrame(data=data, schema=schema) + + scan_definition = """ + table_name: my_table + metric_groups: + - all + samples: + failed_limit: 5 + tests: + - row_count > 0 + columns: + number: + tests: + - duplicate_count == 0 + - missing_count == 0 + """ + + scan.execute( + scan_definition=scan_definition, + df=df, + failed_rows_processor=failed_rows_processor, + ) + + out, err = capsys.readouterr() + assert expected_output == ast.literal_eval(out)