Skip to content

Commit

Permalink
[YAML] Transform schema introspection. (apache#28478)
Browse files Browse the repository at this point in the history
This allows one to enumerate the set of provided transforms together with their config schemas.

Future work would be to pull out documentation as well. It would be valuable, if possible, to trace through forwarding of *args and **kwargs as well.
  • Loading branch information
robertwb authored Sep 21, 2023
1 parent af877ff commit b5b69b1
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from typing import ByteString
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import NamedTuple
Expand Down Expand Up @@ -308,6 +309,11 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType:
return schema_pb2.FieldType(
map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type))

elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str):
element_type = self.typing_to_runner_api(_get_args(type_)[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))

try:
logical_type = LogicalType.from_typing(type_)
except ValueError:
Expand Down
92 changes: 91 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import collections
import hashlib
import inspect
import json
import os
import subprocess
Expand All @@ -45,8 +46,12 @@
from apache_beam.transforms import external
from apache_beam.transforms import window
from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import schemas
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.schemas import named_tuple_to_schema
from apache_beam.typehints.schemas import typing_from_runner_api
from apache_beam.typehints.schemas import typing_to_runner_api
from apache_beam.utils import python_callable
from apache_beam.utils import subprocess_server
from apache_beam.version import __version__ as beam_version
Expand All @@ -65,6 +70,9 @@ def provided_transforms(self) -> Iterable[str]:
"""Returns a list of transform type names this provider can handle."""
raise NotImplementedError(type(self))

def config_schema(self, type):
return None

def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool:
"""Returns whether this transform requires inputs.
Expand Down Expand Up @@ -140,6 +148,8 @@ def provided_transforms(self):
return self._urns.keys()

def schema_transforms(self):
if callable(self._service):
self._service = self._service()
if self._schema_transforms is None:
try:
self._schema_transforms = {
Expand All @@ -152,6 +162,11 @@ def schema_transforms(self):
self._schema_transforms = {}
return self._schema_transforms

def config_schema(self, type):
if self._urns[type] in self.schema_transforms():
return named_tuple_to_schema(
self.schema_transforms()[self._urns[type]].configuration_schema)

def requires_inputs(self, typ, args):
if self._urns[type] in self.schema_transforms():
return bool(self.schema_transforms()[self._urns[type]].inputs)
Expand Down Expand Up @@ -392,6 +407,31 @@ def cache_artifacts(self):
def provided_transforms(self):
return self._transform_factories.keys()

def config_schema(self, typ):
factory = self._transform_factories[typ]
if isinstance(factory, type) and issubclass(factory, beam.PTransform):
# https://bugs.python.org/issue40897
params = dict(inspect.signature(factory.__init__).parameters)
del params['self']
else:
params = inspect.signature(factory).parameters

def type_of(p):
t = p.annotation
if t == p.empty:
return Any
else:
return t

names_and_types = [
(name, typing_to_runner_api(type_of(p))) for name, p in params.items()
]
return schema_pb2.Schema(
fields=[
schema_pb2.Field(name=name, type=type) for name,
type in names_and_types
])

def create_transform(self, type, args, yaml_create_transform):
return self._transform_factories[type](**args)

Expand Down Expand Up @@ -490,7 +530,10 @@ def extract_field(x, name):

# Or should this be posargs, args?
# pylint: disable=dangerous-default-value
def fully_qualified_named_transform(constructor, args=(), kwargs={}):
def fully_qualified_named_transform(
constructor: str,
args: Iterable[Any] = (),
kwargs: Mapping[str, Any] = {}):
with FullyQualifiedNamedTransform.with_filter('*'):
return constructor >> FullyQualifiedNamedTransform(
constructor, args, kwargs)
Expand Down Expand Up @@ -662,6 +705,19 @@ def available(self) -> bool:
def provided_transforms(self) -> Iterable[str]:
return self._transforms.keys()

def config_schema(self, type):
underlying_schema = self._underlying_provider.config_schema(
self._transforms[type])
if underlying_schema is None:
return None
underlying_schema_types = {f.name: f.type for f in underlying_schema.fields}
return schema_pb2.Schema(
fields=[
schema_pb2.Field(name=src, type=underlying_schema_types[dest])
for src,
dest in self._mappings[type].items()
])

def requires_inputs(self, typ, args):
return self._underlying_provider.requires_inputs(typ, args)

Expand Down Expand Up @@ -723,8 +779,42 @@ def standard_providers():
with open(os.path.join(os.path.dirname(__file__),
'standard_providers.yaml')) as fin:
standard_providers = yaml.load(fin, Loader=SafeLoader)

return merge_providers(
create_builtin_provider(),
create_mapping_providers(),
io_providers(),
parse_providers(standard_providers))


def list_providers():
def pretty_type(field_type):
if field_type.WhichOneof('type_info') == 'row_type':
return pretty_schema(field_type.row_type.schema)
else:
t = typing_from_runner_api(field_type)
optional_base = native_type_compatibility.extract_optional_type(t)
if optional_base:
t = optional_base
suffix = '?'
else:
suffix = ''
s = str(t)
if s.startswith('<class '):
s = t.__name__
return s + suffix

def pretty_schema(s):
if s is None:
return '[no schema]'
return 'Row(%s)' % ', '.join(
f'{f.name}={pretty_type(f.type)}' for f in s.fields)

for t, providers in sorted(standard_providers().items()):
print(t)
for p in providers:
print('\t', type(p).__name__, pretty_schema(p.config_schema(t)))


if __name__ == '__main__':
list_providers()

0 comments on commit b5b69b1

Please sign in to comment.