diff --git a/CHANGELOG.md b/CHANGELOG.md index cf530c1..6f9c235 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## v0.8.1 (2023-10-30) + +### Fix + +- remove local refs + ## v0.8.0 (2023-10-26) ### Feat diff --git a/README.md b/README.md index a5c7e52..54d3ae4 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,9 @@ def foo(df: pb.DataFrame[InputDFSchema]) -> pb.DataFrame[OutputDFSchema]: return df ``` -Now, whenever `foo` is called, you can be sure that the data follows your predefined schemas at input and return. If it does not, an exception will be raised. +Now, **whenever `foo` is called**, validation triggers and you can be sure that the data follows your predefined schemas at input and return. If it does not, an exception will be raised. + +*This package is heavily inspired by the [`pandera`](https://github.com/unionai-oss/pandera) Python package. Pandera is a fantastic Python library for statistical data testing, that offers a lot more functionality than `pandabear`. Consider this a lighter, `pandas`-only version of `pandera`. If you're looking for a more comprehensive solution that supports other backends than just `pandas` (like `spark`, `polars`, etc.), we highly recommend you check it out.* **See package level [README.md](src/pandabear/README.md) for documentation and usage examples** diff --git a/pyproject.toml b/pyproject.toml index 5f78b66..2f94438 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ fail_under = 0 [tool.commitizen] name = "cz_conventional_commits" -version = "0.8.0" +version = "0.8.1" version_files = [ "src/pandabear/__init__.py:__version__" ] diff --git a/src/pandabear/__init__.py b/src/pandabear/__init__.py index ddc0fd3..bdcfa12 100644 --- a/src/pandabear/__init__.py +++ b/src/pandabear/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.8.0" +__version__ = "0.8.1" # Set default logging handler to avoid "No handler found" warnings. diff --git a/src/pandabear/exceptions.py b/src/pandabear/exceptions.py index b32315b..bc66c3c 100644 --- a/src/pandabear/exceptions.py +++ b/src/pandabear/exceptions.py @@ -1,5 +1,5 @@ import re -from typing import Any +from typing import Any, Type import pandas as pd @@ -42,6 +42,16 @@ def __init__(self, message): super().__init__(message) +class UnsupportedTypeError(Exception): + """Raise when a field is defined with a unsupported type. + + This may happen when the user defines types that are not supported. + """ + + def __init__(self, message): + super().__init__(message) + + class SchemaValidationError(Exception): """Raise when `df` does not match `schema`. @@ -104,3 +114,31 @@ def _get_message(self) -> str: ) fails_msg = fail_series.head(MAX_FAILURE_ROWS).to_string() return f"{text_msg}\n{fails_msg}" + + +class IndexCheckError(Exception): + """Raise when an index check fails checks defined in `Field` variable. + + Report the percentage of rows that failed the check, and display the first + few rows that failed the check. + """ + + def __init__(self, check_name: str, check_value: Any, index: Type[pd.Index], result: pd.Series): + self.check_name = check_name + self.check_value = check_value + self.series = index.to_series() + self.result = result + super().__init__(self._get_message()) + + def _get_message(self) -> str: + fail_series = self.series[~self.result] + total = len(self.series) + fails = len(fail_series) + fail_pc = int(round(100 * fails / total)) + check_name = self.check_name.replace("series_", "") + text_msg = ( + f"Column '{self.series.name}' failed check {check_name}({self.check_value}): " + f"{fails} of {total} ({fail_pc} %)" + ) + fails_msg = fail_series.head(MAX_FAILURE_ROWS).to_string() + return f"{text_msg}\n{fails_msg}" diff --git a/src/pandabear/model.py b/src/pandabear/model.py index 959eab2..91715a7 100644 --- a/src/pandabear/model.py +++ b/src/pandabear/model.py @@ -1,6 +1,6 @@ import re -from types import NoneType -from typing import Any, Union +from types import NoneType, UnionType +from typing import Any, Type, Union import numpy as np import pandas as pd @@ -9,23 +9,21 @@ from pandabear.exceptions import ( CoersionError, ColumnCheckError, + IndexCheckError, MissingColumnsError, MissingIndexError, SchemaDefinitionError, SchemaValidationError, + UnsupportedTypeError, ) from pandabear.model_components import ( BaseConfig, Field, FieldInfo, - Index, get_index_type, is_type_index, ) - -TYPE_DTYPE_MAP = { - str: np.dtype("O"), -} +from pandabear.type_checking import is_of_type # @dataclasses.dataclass @@ -51,37 +49,49 @@ def _get_config(cls): return BaseConfig._override(cls.Config) @classmethod - def _validate_series(cls, se: pd.Series, field: Field, typ: Any, coerce: bool) -> pd.Series: + def _validate_series_or_index( + cls, se_or_idx: pd.Series | Type[pd.Index], field: Field, typ: Any, coerce: bool + ) -> pd.Series: """Validate a series against a field and type. Args: - se (pd.Series): The series to validate. - field (Field): The field to validate against. - typ (Any): The type to validate against. - coerce (bool): Whether to coerce the series to the type of the + se_or_idx: The series or Index (or Index substype) to validate. + field: The field to validate against. + typ: The type to validate against. + coerce: Whether to coerce the series to the type of the field. Returns: - pd.Series: The validated series. + se_or_idx: pd.Series | Type[pd.Index]: The validated series or Index. """ - dtype = TYPE_DTYPE_MAP.get(typ, typ) - if se.dtype != dtype: + is_index = ~isinstance(se_or_idx, pd.Series) + + if not is_of_type(se_or_idx, typ): if coerce: try: - se = se.astype(typ) + se_or_idx = se_or_idx.astype(typ) except ValueError: - raise CoersionError(f"Could not coerce `{se.name}` with dtype {se.dtype} to {dtype}") + raise CoersionError(f"Could not coerce `{se_or_idx.name}` with dtype {se_or_idx.dtype} to {typ}") else: - raise SchemaValidationError(f"Expected `{se.name}` with dtype {dtype} but found {se.dtype}") + raise SchemaValidationError( + f"Expected {f'`{se_or_idx.name}`' if se_or_idx.name else 'index'} with dtype {typ} but found dtype `{se_or_idx.dtype}`" + ) for check_name, check_func in CHECK_NAME_FUNCTION_MAP.items(): check_value = getattr(field, check_name) if check_value is not None: - result = check_func(series=se, value=check_value) + result = check_func(series=se_or_idx if is_index else se_or_idx.to_series(), value=check_value) if not result.all(): - raise ColumnCheckError(check_name=check_name, check_value=check_value, series=se, result=result) - return se + if is_index: + raise ColumnCheckError( + check_name=check_name, check_value=check_value, series=se_or_idx, result=result + ) + else: + raise IndexCheckError( + check_name=check_name, check_value=check_value, index=se_or_idx, result=result + ) + return se_or_idx class DataFrameModel(BaseModel): @@ -111,6 +121,7 @@ def _get_schema_map(cls) -> dict[str, FieldInfo]: """ schema_map = {} for name, typ in cls.__annotations__.items(): + cls._check_type_is_valid(typ) typ, optional = cls._check_optional_type(typ) is_index = is_type_index(typ, name, cls.__name__) if is_index: @@ -131,31 +142,40 @@ def _check_optional_type(typ: type) -> tuple[type, bool]: return typ, optional @staticmethod - def override_level( - index: pd.MultiIndex | pd.Index, index_level: str, series: pd.Series + def _override_level( + df_index: Type[pd.Index], index_level: str, new_index_values: Type[pd.Index] ) -> pd.MultiIndex | pd.Index: - """Override a level in a MultiIndex or Index with a new series.""" - if isinstance(index, pd.MultiIndex): - df_reset = index.to_frame(index=False) - if index_level not in df_reset.columns: + """Override a level in a MultiIndex or Index with a new index.""" + if isinstance(df_index, pd.MultiIndex): + df_tmp = df_index.to_frame(index=False) + if index_level not in df_tmp.columns: raise ValueError(f"Index level '{index_level}' not found in MultiIndex.") - df_reset[index_level] = series.values - new_index = pd.MultiIndex.from_frame(df_reset) + df_tmp[index_level] = new_index_values + return pd.MultiIndex.from_frame(df_tmp) else: - if index.name != index_level: - raise ValueError(f"Index name '{index.name}' does not match given index_level '{index_level}'.") - new_index = pd.Index(series.values, name=index_level) - return new_index + if df_index.name != index_level: + raise ValueError(f"Index name '{df_index.name}' does not match given index_level '{index_level}'.") + index_type_map = { + pd.Index: pd.Index, + pd.DatetimeIndex: pd.DatetimeIndex, + pd.PeriodIndex: pd.PeriodIndex, + pd.TimedeltaIndex: pd.TimedeltaIndex, + pd.CategoricalIndex: pd.CategoricalIndex, + pd.RangeIndex: pd.RangeIndex, + pd.IntervalIndex: pd.IntervalIndex, + } + index_type = index_type_map.get(type(df_index)) + return index_type(new_index_values, name=index_level) @staticmethod - def _select_index_series(df: pd.DataFrame, level: str, optional: bool = True) -> list[pd.Series]: + def _select_index_series(df: pd.DataFrame, level: str, optional: bool = True) -> list[Type[pd.Index]]: """Select a series from a dataframe by column name. Return a list containing maximally 1 series. Reason for this is that series are validated in a loop, so returning a list is convenient. """ try: - return [df.index.get_level_values(level).to_series()] + return [df.index.get_level_values(level)] except KeyError: # When this happens we can deduce that the corresponding column is # optional (otherwise an error would have been raised in @@ -184,19 +204,31 @@ def _select_series(df: pd.DataFrame, column_name: str, optional: bool = True) -> return [] @staticmethod - def _select_index_series_by_regex(df: pd.DataFrame, alias: str) -> list[pd.Series]: + def _select_index_series_by_regex(df: pd.DataFrame, alias: str) -> list[Type[pd.Index]]: """Select a series from a dataframe by regex.""" - return [ - df.index.get_level_values(level).to_series() - for level in df.index.names - if re.match(alias, level) is not None - ] + return [df.index.get_level_values(level) for level in df.index.names if re.match(alias, level) is not None] @staticmethod def _select_series_by_regex(df: pd.DataFrame, alias: str) -> list[pd.Series]: """Select a series from a dataframe by regex.""" return [df[col] for col in df.filter(regex=alias, axis=1).columns] + @classmethod + def _check_type_is_valid(cls, typ: Any) -> bool: + """Recursively check that `typ` is a valid type annotation.""" + if typ in [int, float, str, bytes, bool, type(None)]: + return True + if isinstance(typ, type): + return True + if hasattr(typ, "__origin__") and hasattr(typ, "__args__"): + origin = typ.__origin__ + args = typ.__args__ + if origin in {list, dict, Union}: + return all(cls._check_type_is_valid(arg) for arg in args) + if isinstance(typ, UnionType): + return all(cls._check_type_is_valid(arg) for arg in typ.__args__) + raise UnsupportedTypeError(f"Type `{typ}` is not supported") + @classmethod def _validate_custom_checks(cls, df: pd.DataFrame): """Validate custom checks defined on the schema. @@ -266,9 +298,7 @@ def _validate_schema(cls, schema_map: dict[str, FieldInfo]): ) @classmethod - def _validate_multiindex( - cls, df: pd.DataFrame, schema_map: dict[str, FieldInfo], Config: BaseConfig - ) -> pd.DataFrame: + def _validate_multiindex(cls, df: pd.DataFrame) -> pd.DataFrame: """Validate index levels in `df` against the schema. Raise approproate errors if index levels are missing or if there is an @@ -278,11 +308,9 @@ def _validate_multiindex( pass `df` through with coerced types, filtered index levels, ordered index levels or as-is. """ - matching_index_names_in_df, schema_map = cls._select_matching_names( - list(df.index.names), schema_map, match_index=True - ) + matching_index_names_in_df = cls._select_matching_names(list(df.index.names), match_index=True) - if Config.filter: + if cls.Config.filter: # Make sure that only the matching index levels are kept if df.index.names != [None] and len(matching_index_names_in_df) < len(df.index.names): if len(matching_index_names_in_df) == 0: @@ -290,25 +318,25 @@ def _validate_multiindex( else: df = df.droplevel([ind for ind in df.index.names if ind not in matching_index_names_in_df]) - if Config.multiindex_strict: + if cls.Config.multiindex_strict: if unexpected_indices := set(df.index.names) - set(matching_index_names_in_df) - set([None]): raise SchemaValidationError( f"MultiIndex names {unexpected_indices} are present in `df` but not defined in schema. Use `multiindex_strict=False` to supress this error." ) - if Config.multiindex_ordered: + if cls.Config.multiindex_ordered: if df.index.names != [None] and matching_index_names_in_df != list(df.index.names): raise SchemaValidationError( "MultiIndex names in `df` are not ordered as in schema. Use `multiindex_ordered=False` to supress this error." ) - if Config.multiindex_sorted: + if cls.Config.multiindex_sorted: if not (df.index.is_monotonic_increasing or df.index.is_monotonic_decreasing): raise SchemaValidationError( "MultiIndex is not sorted. Use `multiindex_sorted=False` to supress this error." ) - if Config.multiindex_unique: + if cls.Config.multiindex_unique: if not df.index.is_unique: raise SchemaValidationError( "MultiIndex is not unique. Use `multiindex_unique=False` to supress this error." @@ -317,7 +345,7 @@ def _validate_multiindex( return df @classmethod - def _validate_columns(cls, df: pd.DataFrame, schema_map: dict[str, FieldInfo], Config: BaseConfig) -> pd.DataFrame: + def _validate_columns(cls, df: pd.DataFrame) -> pd.DataFrame: """Validate column names in `df` against the schema. Raise approproate errors if columns are missing or if there is an @@ -327,15 +355,15 @@ def _validate_columns(cls, df: pd.DataFrame, schema_map: dict[str, FieldInfo], C pass `df` through with coerced types, filtered columns, ordered columns or as-is. """ - matching_columns_in_df, _ = cls._select_matching_names(list(df.columns), schema_map) + matching_columns_in_df = cls._select_matching_names(list(df.columns)) # Drop columns in `df` that do not match the schema - if Config.filter: + if cls.Config.filter: ordered_columns_in_df = [col for col in df.columns if col in matching_columns_in_df] df = df[ordered_columns_in_df].copy() # Complain about columns in `df` that are not defined in the schema - elif Config.strict: + elif cls.Config.strict: if unexpected_columns := set(df.columns) - set(matching_columns_in_df): raise SchemaValidationError( f"Columns {unexpected_columns} are present in `df` but not in schema. Use `strict=False` or `filter=True` to supress this error." @@ -343,7 +371,7 @@ def _validate_columns(cls, df: pd.DataFrame, schema_map: dict[str, FieldInfo], C # Complain if the order of columns in `df` does not match the order in # which they are defined in the schema - if Config.ordered: + if cls.Config.ordered: if matching_columns_in_df != list(df.columns): raise SchemaValidationError( "Columns in `df` are not ordered as in schema. Use `ordered=False` to supress this error." @@ -353,7 +381,7 @@ def _validate_columns(cls, df: pd.DataFrame, schema_map: dict[str, FieldInfo], C @classmethod def _select_matching_names( - cls, names: list[str], schema_map: dict[str, FieldInfo], match_index: bool = False + cls, names: list[str], match_index: bool = False ) -> tuple[list[str], dict[str, FieldInfo]]: """Select columns or index levels in `names` that match the schema. @@ -362,6 +390,11 @@ def _select_matching_names( errors when columns in `df` seem to be *missing* when compared to the schema. + NOTE: This method may modify the schema map. This happens when an Index + field is defined with `check_index_name = False`. In that case this + method will replace the index name in the schema map with the name of + the index in `df`. + Raises: SchemaDefinitionError: If a column or alias is not found in `df`, is already matched by another field. @@ -373,7 +406,7 @@ def _select_matching_names( MissingNameError = MissingIndexError if match_index else MissingColumnsError series_type = "index level" if match_index else "column" matching_names = [] - for series_name, (_, optional, is_index, field) in schema_map.items(): + for series_name, (_, optional, is_index, field) in cls.schema_map.items(): if is_index and not match_index: continue elif not is_index and match_index: @@ -410,11 +443,11 @@ def _select_matching_names( ) elif field.check_index_name is False and match_index: assert len(names) == 1, "This should not happen. Looks like columns were not properly filtered." - schema_map[names[0]] = schema_map.pop(series_name) - return names, schema_map + cls.schema_map[names[0]] = cls.schema_map.pop(series_name) + return names else: matching_names.append(series_name) - return matching_names, schema_map + return matching_names @classmethod def validate(cls, df: pd.DataFrame) -> pd.DataFrame: @@ -444,50 +477,52 @@ def validate(cls, df: pd.DataFrame) -> pd.DataFrame: """ df = df.copy() - schema_map = cls._get_schema_map() - Config = cls._get_config() + cls.schema_map = cls._get_schema_map() + cls.Config = cls._get_config() # Validate schema definition. Catch errors like, e.g., missing aliases # when regex=True, number checks on non-numeric columns, etc. - cls._validate_schema(schema_map) + cls._validate_schema(cls.schema_map) # Check that indices and columns in `df` match schema. The only errors # that should be thrown here relate to schema errors or missing columns # in `df`. Furthermore, this method may filter, coerce or order `df` # depending on user-provided specifications in `Config`. - df = cls._validate_multiindex(df, schema_map, Config) - df = cls._validate_columns(df, schema_map, Config) + df = cls._validate_multiindex(df) + df = cls._validate_columns(df) # Validate `df` against schema. The only errors that should be raised # in this step are from dtype checks and `Field` checks. - for name, (typ, optional, is_index, field) in schema_map.items(): + for name, (typ, optional, is_index, field) in cls.schema_map.items(): # Select the column (or columns) in `df` that match the field. # ... when index column if is_index: if field.regex and field.alias is not None: - matched_series = cls._select_index_series_by_regex(df, field.alias) + matched_series_or_index = cls._select_index_series_by_regex(df, field.alias) else: - matched_series = cls._select_index_series(df, field.alias or name, optional) + matched_series_or_index = cls._select_index_series(df, field.alias or name, optional) # ... when column has aliased name elif field.alias is not None: if field.regex: - matched_series = cls._select_series_by_regex(df, field.alias) + matched_series_or_index = cls._select_series_by_regex(df, field.alias) else: - matched_series = cls._select_series(df, field.alias, optional) + matched_series_or_index = cls._select_series(df, field.alias, optional) # ... when column name is attribute name (not alias) else: - matched_series = cls._select_series(df, name, optional) + matched_series_or_index = cls._select_series(df, name, optional) # Validate the selected column(s) against the field and type. - for series in matched_series: - series = cls._validate_series(series, field, typ, Config.coerce or field.coerce) - if Config.coerce or field.coerce: + for series_or_index in matched_series_or_index: + series_or_index = cls._validate_series_or_index( + series_or_index, field, typ, cls.Config.coerce or field.coerce + ) + if cls.Config.coerce or field.coerce: if is_index: - df.index = cls.override_level(df.index, series.name, series) + df.index = cls._override_level(df.index, series_or_index.name, series_or_index.values) else: - df[series.name] = series + df[series_or_index.name] = series_or_index cls._validate_custom_checks(df) @@ -517,5 +552,5 @@ def validate(cls, series: pd.Series): _, value_type = cls._get_value_name_and_type() field = cls._get_field() Config = cls._get_config() - series = cls._validate_series(series, field, value_type, Config.coerce) + series = cls._validate_series_or_index(series, field, value_type, Config.coerce) return series diff --git a/src/pandabear/model_components.py b/src/pandabear/model_components.py index 0bbc9b8..bd6849e 100644 --- a/src/pandabear/model_components.py +++ b/src/pandabear/model_components.py @@ -1,9 +1,9 @@ import dataclasses -from typing import Any, NamedTuple, Type +from typing import Any, NamedTuple, Type, Union import pandas as pd -from pandabear.exceptions import SchemaDefinitionError +from pandabear.exceptions import SchemaDefinitionError, UnsupportedTypeError PANDAS_INDEX_TYPES = [ # pd.DatetimeIndex @@ -115,19 +115,28 @@ class FieldInfo(NamedTuple): def is_type_index_wrapped(typ): + """Check whether type annotation is like `Index[]`.""" return hasattr(typ, "__args__") and typ.__args__[0] is Index def is_type_index(typ, name, class_name): + # Field is like `index: Index` (not allowed) if typ is Index: raise SchemaDefinitionError( f"Index column `{name}` in schema `{class_name}` must be defined as `Index[]`" ) - return is_type_index_wrapped(typ) or typ in PANDAS_INDEX_TYPES + # Field is like `index: Index[int]` + if is_type_index_wrapped(typ): + return True + # Field is like `index: pd.DatetimeIndex` or other subclass of `pd.Index` + if issubclass(typ, pd.Index): + return True + # The user has provided something that is not meant as an index + return False -def get_index_type(typ): +def get_index_type(typ): if is_type_index_wrapped(typ): return typ.__args__[1] return typ diff --git a/src/pandabear/type_checking.py b/src/pandabear/type_checking.py index bd9a6f8..44e7168 100644 --- a/src/pandabear/type_checking.py +++ b/src/pandabear/type_checking.py @@ -12,6 +12,10 @@ def check_isinstance(series_or_index, typ): return isinstance(series_or_index, typ) +def check_type_is(series_or_index, typ): + return type(series_or_index) is typ + + def check_str_object(series_or_index, typ): return check_dtype_equality(series_or_index, np.dtype("O")) @@ -37,8 +41,10 @@ def check_bare_categorical_dtype(series_or_index, typ): str: check_str_object, np.datetime64: check_datetime64, datetime.datetime: check_datetime64, + pd.CategoricalIndex: check_isinstance, pd.CategoricalDtype: check_bare_categorical_dtype, - pd.DatetimeIndex: check_isinstance, + pd.DatetimeIndex: check_type_is, + pd.Index: check_type_is, } diff --git a/static/images/coverage-badge.svg b/static/images/coverage-badge.svg index 3438732..59d64b3 100644 --- a/static/images/coverage-badge.svg +++ b/static/images/coverage-badge.svg @@ -9,13 +9,13 @@ - + coverage coverage - 97% - 97% + 93% + 93% diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py new file mode 100644 index 0000000..2518778 --- /dev/null +++ b/tests/test_dtypes.py @@ -0,0 +1,57 @@ +import datetime + +import numpy as np +import pandas as pd + +from pandabear import DataFrameModel, Field, Index + + +def test_datetime(): + class MySchema(DataFrameModel): + date: Index[pd.DatetimeIndex] + column_a: int + + df = pd.DataFrame(data=[200], columns=["column_a"], index=pd.DatetimeIndex(["2020-01-01"], name="date")) + MySchema.validate(df) + + class MySchema(DataFrameModel): + date: pd.DatetimeIndex + column_a: int + + df = pd.DataFrame(data=[200], columns=["column_a"], index=pd.DatetimeIndex(["2020-01-01"], name="date")) + MySchema.validate(df) + + class MySchema(DataFrameModel): + index_date_datetime64: Index[np.datetime64] + index_date_datetime: Index[datetime.datetime] + date_datetime64: np.datetime64 + date_datetime: datetime.datetime + + df = pd.DataFrame( + dict( + date_datetime64=pd.to_datetime(["2020-01-01"]), + date_datetime=pd.to_datetime(["2020-01-01"]), + ), + index=pd.MultiIndex.from_arrays( + [ + pd.to_datetime(["2020-01-01"]), + pd.to_datetime(["2020-01-01"]), + ], + names=["index_date_datetime64", "index_date_datetime"], + ), + ) + MySchema.validate(df) + + +def test_categorical_index(): + class GenericCategorySchema(DataFrameModel): + category: pd.CategoricalIndex = Field() + column_a: int + + df = pd.DataFrame( + data=[200, 300, 400, 500], + columns=["column_a"], + index=pd.CategoricalIndex(list("aabb"), categories=["a", "b"], ordered=True, name="category"), + ) + + GenericCategorySchema.validate(df) diff --git a/tests/test_unsupported_types.py b/tests/test_unsupported_types.py new file mode 100644 index 0000000..7bdab2f --- /dev/null +++ b/tests/test_unsupported_types.py @@ -0,0 +1,43 @@ +import numpy as np +import pandas as pd +import pytest + +from pandabear import DataFrameModel, Field, Index +from pandabear.exceptions import UnsupportedTypeError + + +def test_datetime(): + class MySchema(DataFrameModel): + date_dtype: np.dtype("datetime64[ns]") + + df = pd.DataFrame( + dict( + date_dtype=pd.to_datetime(["2020-01-01"]), + ) + ) + + with pytest.raises(UnsupportedTypeError): + MySchema.validate(df) + + +def test_categorical(): + CategoryABType = pd.CategoricalDtype(["a", "b"], ordered=True) + + class SpecificCategorySchema(DataFrameModel): + category: CategoryABType + + df = pd.DataFrame({"category": ["a", "a", "b", "b"]}) + + with pytest.raises(UnsupportedTypeError): + SpecificCategorySchema.validate(df) + + +def test_obvious_mistakes(): + class MySchema(DataFrameModel): + index: [0, 1, 2] + column_a: "lol" + + df = pd.DataFrame({"index": [0, 1, 2, 2, 1, 2], "column_a": ["lol", "lol", "lol", "lol", "lol", "lol"]}) + + with pytest.raises(UnsupportedTypeError): + MySchema.validate(df) diff --git a/tests/test_validate_columns.py b/tests/test_validate_columns.py index 75d6cdc..7a46f06 100644 --- a/tests/test_validate_columns.py +++ b/tests/test_validate_columns.py @@ -17,14 +17,16 @@ class FilterConfig: filter: bool = True MySchema.Config = FilterConfig + MySchema.Config = MySchema._get_config() + MySchema.schema_map = MySchema._get_schema_map() df = pd.DataFrame(dict(a=[1], b=[1.0], c=["a"], d=[1])) - dfval = MySchema._validate_columns(df, MySchema._get_schema_map(), MySchema._get_config()) + dfval = MySchema._validate_columns(df) assert dfval.shape == (1, 3) assert dfval.columns.tolist() == ["a", "b", "c"] # 2. column order is maintained df = pd.DataFrame(dict(b=[1.0], a=[1], d=[1], c=["a"])) - dfval = MySchema._validate_columns(df, MySchema._get_schema_map(), MySchema._get_config()) + dfval = MySchema._validate_columns(df) assert dfval.shape == (1, 3) assert dfval.columns.tolist() == ["b", "a", "c"] @@ -34,8 +36,9 @@ class FilterConfig: strict: bool = True MySchema.Config = FilterConfig + MySchema.Config = MySchema._get_config() df = pd.DataFrame(dict(b=[1.0], a=[1], d=[1], c=["a"])) - dfval = MySchema._validate_columns(df, MySchema._get_schema_map(), MySchema._get_config()) + dfval = MySchema._validate_columns(df) assert dfval.columns.tolist() == ["b", "a", "c"] # 4. Changing filter to false, strict will now take effect, and fail: @@ -44,9 +47,10 @@ class FilterConfig: strict: bool = True MySchema.Config = FilterConfig + MySchema.Config = MySchema._get_config() df = pd.DataFrame(dict(b=[1.0], a=[1], d=[1], c=["a"])) with pytest.raises(SchemaValidationError): - dfval = MySchema._validate_columns(df, MySchema._get_schema_map(), MySchema._get_config()) + dfval = MySchema._validate_columns(df) # 5. changing to ordered == true, will now fail class FilterConfig: @@ -55,9 +59,10 @@ class FilterConfig: ordered: bool = True MySchema.Config = FilterConfig + MySchema.Config = MySchema._get_config() df = pd.DataFrame(dict(b=[1.0], a=[1], c=["a"])) with pytest.raises(SchemaValidationError): - dfval = MySchema._validate_columns(df, MySchema._get_schema_map(), MySchema._get_config()) + dfval = MySchema._validate_columns(df) # 6. changing to filter == true, will still fail class FilterConfig: @@ -66,11 +71,12 @@ class FilterConfig: ordered: bool = True MySchema.Config = FilterConfig + MySchema.Config = MySchema._get_config() df = pd.DataFrame(dict(b=[1.0], a=[1], c=["a"])) with pytest.raises(SchemaValidationError): - dfval = MySchema._validate_columns(df, MySchema._get_schema_map(), MySchema._get_config()) + dfval = MySchema._validate_columns(df) # 7. missing column will fail on filter df = pd.DataFrame(dict(b=[1.0], a=[1])) with pytest.raises(MissingColumnsError): - dfval = MySchema._validate_columns(df, MySchema._get_schema_map(), MySchema._get_config()) + dfval = MySchema._validate_columns(df) diff --git a/tests/test_validate_multiindex.py b/tests/test_validate_multiindex.py index a238b0d..a07ce2d 100644 --- a/tests/test_validate_multiindex.py +++ b/tests/test_validate_multiindex.py @@ -27,63 +27,78 @@ class NoIndexSchema(DataFrameModel): b: float = Field() +NoIndexSchema.schema_map = NoIndexSchema._get_schema_map() +NoIndexSchema.Config = NoIndexSchema._get_config() + + class IndexSchema(DataFrameModel): index: Index[str] a: int = Field() b: float = Field() +IndexSchema.schema_map = IndexSchema._get_schema_map() + + def test_no_index_schema__passing(): df = pd.DataFrame(dict(a=[1], b=[1.0])) # 1. passes - NoIndexSchema._validate_multiindex(df, NoIndexSchema._get_schema_map(), NoIndexSchema._get_config()) + NoIndexSchema._validate_multiindex(df) # 2. Strict NoIndexSchema.Config = IndexStrictConfig - NoIndexSchema._validate_multiindex(df, NoIndexSchema._get_schema_map(), NoIndexSchema._get_config()) + NoIndexSchema.Config = NoIndexSchema._get_config() + NoIndexSchema._validate_multiindex(df) # 3. Ordered NoIndexSchema.Config = IndexOrderedConfig - NoIndexSchema._validate_multiindex(df, NoIndexSchema._get_schema_map(), NoIndexSchema._get_config()) + NoIndexSchema.Config = NoIndexSchema._get_config() + NoIndexSchema._validate_multiindex(df) # 4. Sorted NoIndexSchema.Config = IndexSortedConfig - NoIndexSchema._validate_multiindex(df, NoIndexSchema._get_schema_map(), NoIndexSchema._get_config()) + NoIndexSchema.Config = NoIndexSchema._get_config() + NoIndexSchema._validate_multiindex(df) # 5. Unique NoIndexSchema.Config = IndexUniqueConfig - NoIndexSchema._validate_multiindex(df, NoIndexSchema._get_schema_map(), NoIndexSchema._get_config()) + NoIndexSchema.Config = NoIndexSchema._get_config() + NoIndexSchema._validate_multiindex(df) def test_no_index_schema__failing(): df = pd.DataFrame(dict(a=[1, 2], b=[1.0, 2.0]), index=pd.Index([1, 2], name="index")) with pytest.raises(SchemaValidationError): - NoIndexSchema._validate_multiindex(df, NoIndexSchema._get_schema_map(), NoIndexSchema._get_config()) + NoIndexSchema._validate_multiindex(df) def test_index_schema__passing(): df = pd.DataFrame(dict(a=[1, 2, 3], b=[1.0, 2.0, 3.0]), index=pd.Index([1, 2, 3], name="index")) # 1. passes - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema._validate_multiindex(df) # 2. Strict IndexSchema.Config = IndexStrictConfig - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema.Config = IndexSchema._get_config() + IndexSchema._validate_multiindex(df) # 3. Ordered IndexSchema.Config = IndexOrderedConfig - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema.Config = IndexSchema._get_config() + IndexSchema._validate_multiindex(df) # 4. Sorted IndexSchema.Config = IndexSortedConfig - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema.Config = IndexSchema._get_config() + IndexSchema._validate_multiindex(df) # 5. Unique IndexSchema.Config = IndexUniqueConfig - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema.Config = IndexSchema._get_config() + IndexSchema._validate_multiindex(df) def test_index_schema__failing_sorting_unique(): @@ -94,26 +109,32 @@ class IndexSchema(DataFrameModel): a: int = Field() b: float = Field() + IndexSchema.schema_map = IndexSchema._get_schema_map() + # 1. passes - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema._validate_multiindex(df) # 2. Strict IndexSchema.Config = IndexStrictConfig - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema.Config = IndexSchema._get_config() + IndexSchema._validate_multiindex(df) # 3. Ordered IndexSchema.Config = IndexOrderedConfig - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema.Config = IndexSchema._get_config() + IndexSchema._validate_multiindex(df) # 4. Sorted fails with pytest.raises(SchemaValidationError): IndexSchema.Config = IndexSortedConfig - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema.Config = IndexSchema._get_config() + IndexSchema._validate_multiindex(df) # 5. Unique, failing with pytest.raises(SchemaValidationError): IndexSchema.Config = IndexUniqueConfig - IndexSchema._validate_multiindex(df, IndexSchema._get_schema_map(), IndexSchema._get_config()) + IndexSchema.Config = IndexSchema._get_config() + IndexSchema._validate_multiindex(df) def test_multiindex_schema__passing(): @@ -122,25 +143,30 @@ class MultiIndexSchema(DataFrameModel): ix1: Index[int] = Field() a: int = Field() + MultiIndexSchema.schema_map = MultiIndexSchema._get_schema_map() + df = pd.DataFrame( dict(a=[1, 2, 3, 4], b=[1.0, 2.0, 3.0, 4.0]), index=pd.MultiIndex.from_tuples([(1, 1), (1, 2), (2, 1), (2, 2)], names=["ix0", "ix1"]), ) # 1. passes - MultiIndexSchema._validate_multiindex(df, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema._validate_multiindex(df) # 2. ordered MultiIndexSchema.Config = IndexOrderedConfig - MultiIndexSchema._validate_multiindex(df, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema.Config = MultiIndexSchema._get_config() + MultiIndexSchema._validate_multiindex(df) # 3. sorted MultiIndexSchema.Config = IndexSortedConfig - MultiIndexSchema._validate_multiindex(df, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema.Config = MultiIndexSchema._get_config() + MultiIndexSchema._validate_multiindex(df) # 4. unique MultiIndexSchema.Config = IndexUniqueConfig - MultiIndexSchema._validate_multiindex(df, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema.Config = MultiIndexSchema._get_config() + MultiIndexSchema._validate_multiindex(df) def test_multiindex_schema__failing(): @@ -149,37 +175,43 @@ class MultiIndexSchema(DataFrameModel): ix1: Index[int] = Field() a: int = Field() + MultiIndexSchema.schema_map = MultiIndexSchema._get_schema_map() + df = pd.DataFrame( dict(a=[1, 2, 3, 4], b=[1.0, 2.0, 3.0, 4.0]), index=pd.MultiIndex.from_tuples([(1, 1), (1, 1), (2, 2), (2, 1)], names=["ix0", "ix1"]), ) # 1. passes - MultiIndexSchema._validate_multiindex(df, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema._validate_multiindex(df) # 3. sorted with pytest.raises(SchemaValidationError): MultiIndexSchema.Config = IndexSortedConfig - MultiIndexSchema._validate_multiindex(df, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema.Config = MultiIndexSchema._get_config() + MultiIndexSchema._validate_multiindex(df) # 4. unique with pytest.raises(SchemaValidationError): MultiIndexSchema.Config = IndexUniqueConfig - MultiIndexSchema._validate_multiindex(df, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema.Config = MultiIndexSchema._get_config() + MultiIndexSchema._validate_multiindex(df) # 2. ordered, passing MultiIndexSchema.Config = IndexOrderedConfig - MultiIndexSchema._validate_multiindex(df, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema.Config = MultiIndexSchema._get_config() + MultiIndexSchema._validate_multiindex(df) # 2. ordered, failing with pytest.raises(SchemaValidationError): - MultiIndexSchema.Config = IndexOrderedConfig df.index.names = ["ix1", "ix0"] - MultiIndexSchema._validate_multiindex(df, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema.Config = IndexOrderedConfig + MultiIndexSchema.Config = MultiIndexSchema._get_config() + MultiIndexSchema._validate_multiindex(df) df2 = df.reset_index() with pytest.raises(MissingIndexError): - MultiIndexSchema._validate_multiindex(df2, MultiIndexSchema._get_schema_map(), MultiIndexSchema._get_config()) + MultiIndexSchema._validate_multiindex(df2) def test_multiindex_check_index_name__success(): diff --git a/tests/test_validate_series_coerce.py b/tests/test_validate_series_coerce.py index 8be93d4..cc55a1e 100644 --- a/tests/test_validate_series_coerce.py +++ b/tests/test_validate_series_coerce.py @@ -1,8 +1,9 @@ +import numpy as np import pandas as pd import pytest from pandabear.exceptions import CoersionError, SchemaValidationError -from pandabear.model import TYPE_DTYPE_MAP, DataFrameModel +from pandabear.model import DataFrameModel from pandabear.model_components import Field, Index @@ -19,14 +20,16 @@ class MySchema(DataFrameModel): df = pd.DataFrame(dict(a=["1"], b=[1], c=[2])) print(df.dtypes) - expected_message = "Expected `a` with dtype but found object" + expected_message = "Expected `a` with dtype but found dtype `object`" with pytest.raises(SchemaValidationError, match=expected_message): dfval = MySchema.validate(df) # 2. will coerce dtypes MySchema.Config = CoerceConfig dfval = MySchema.validate(df) - assert dfval.dtypes.tolist() == [int, float, TYPE_DTYPE_MAP[str]] + assert dfval.dtypes.a == int + assert dfval.dtypes.b == float + assert dfval.dtypes.c == np.dtype("O") assert dfval.a.tolist() == [1] assert dfval.b.tolist() == [1.0] assert dfval.c.tolist() == ["2"]