Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb committed Sep 15, 2023
1 parent 893b0f9 commit 7ff7f85
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import itertools
from typing import Any
from typing import Callable
from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Mapping
Expand Down Expand Up @@ -187,9 +188,9 @@ def expand(pcoll, error_handling=None, **kwargs):
class _Explode(beam.PTransform):
def __init__(
self,
fields: Union[str, Iterable[str]],
fields: Union[str, Collection[str]],
cross_product: Optional[bool] = None,
error_handling: Mapping[str, Any] = None):
error_handling: Optional[Mapping[str, Any]] = None):
if isinstance(fields, str):
fields = [fields]
if cross_product is None:
Expand Down Expand Up @@ -254,11 +255,12 @@ def with_exception_handling(self, **kwargs):

@beam.ptransform.ptransform_fn
@maybe_with_exception_handling_transform_fn
def _PyJsFilter(pcoll, keep: Union[str, Dict[str, str]], language: str = None):
def _PyJsFilter(
pcoll, keep: Union[str, Dict[str, str]], language: Optional[str] = None):

input_schema = dict(named_fields_from_element_type(pcoll.element_type))
if isinstance(keep, str) and keep in input_schema:
keep_fn = lambda row: getatr(row, keep)
keep_fn = lambda row: getattr(row, keep)
else:
keep_fn = _as_callable(list(input_schema.keys()), keep, "keep", language)
return pcoll | beam.Filter(keep_fn)
Expand All @@ -272,9 +274,9 @@ def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'):
try:
input_schema = dict(named_fields_from_element_type(pcoll.element_type))
except ValueError as exn:
if 'drop':
if drop:
raise ValueError("Can only drop fields on a schema'd input.") from exn
if 'append':
if append:
raise ValueError("Can only append fields on a schema'd input.") from exn
elif any(is_expr(x) for x in fields.values()):
raise ValueError("Can only use expressions on a schema'd input.") from exn
Expand Down

0 comments on commit 7ff7f85

Please sign in to comment.