Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[yaml] Add yaml_provider.py Unit Tests #27804

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
from apache_beam.portability.api import schema_pb2
from apache_beam.transforms import external
from apache_beam.transforms import window
from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform
from apache_beam.transforms.fully_qualified_named_transform import \
FullyQualifiedNamedTransform
from apache_beam.typehints import schemas
from apache_beam.typehints import trivial_inference
from apache_beam.utils import python_callable
Expand Down Expand Up @@ -119,7 +120,7 @@ def as_provider_list(name, lst):

class ExternalProvider(Provider):
"""A Provider implemented via the cross language transform service."""
_provider_types: Dict[str, Callable[..., Provider]] = {}
provider_types: Dict[str, Callable[..., Provider]] = {}

def __init__(self, urns, service):
self._urns = urns
Expand Down Expand Up @@ -166,17 +167,18 @@ def provider_from_spec(cls, spec):
urns = spec['transforms']
type = spec['type']
config = SafeLineLoader.strip_metadata(spec.get('config', {}))
extra_params = set(SafeLineLoader.strip_metadata(spec).keys()) - set(
['transforms', 'type', 'config'])
extra_params = set(SafeLineLoader.strip_metadata(spec).keys()) - {
'transforms', 'type', 'config'
}
if extra_params:
raise ValueError(
f'Unexpected parameters in provider of type {type} '
f'at line {SafeLineLoader.get_line(spec)}: {extra_params}')
if config.get('version', None) == 'BEAM_VERSION':
config['version'] = beam_version
if type in cls._provider_types:
if type in cls.provider_types:
try:
return cls._provider_types[type](urns, **config)
return cls.provider_types[type](urns, **config)
except Exception as exn:
raise ValueError(
f'Unable to instantiate provider of type {type} '
Expand All @@ -189,7 +191,7 @@ def provider_from_spec(cls, spec):
@classmethod
def register_provider_type(cls, type_name):
def apply(constructor):
cls._provider_types[type_name] = constructor
cls.provider_types[type_name] = constructor

return apply

Expand Down Expand Up @@ -217,6 +219,7 @@ def maven_jar(
urns,
lambda: subprocess_server.JavaJarServer.path_to_maven_jar(
artifact_id=artifact_id,
group_id=group_id,
version=version,
repository=repository,
classifier=classifier,
Expand All @@ -234,8 +237,10 @@ def beam_jar(
return ExternalJavaProvider(
urns,
lambda: subprocess_server.JavaJarServer.path_to_beam_jar(
gradle_target=gradle_target, version=version, artifact_id=artifact_id)
)
gradle_target=gradle_target,
appendix=appendix,
version=version,
artifact_id=artifact_id))


@ExternalProvider.register_provider_type('docker')
Expand Down Expand Up @@ -282,7 +287,7 @@ def cache_artifacts(self):
@ExternalProvider.register_provider_type('python')
def python(urns, packages=()):
if packages:
return ExternalPythonProvider(urns, packages)
return ExternalProvider.provider_types['pythonPackage'](urns, packages)
else:
return InlineProvider({
name:
Expand Down Expand Up @@ -372,7 +377,14 @@ def provided_transforms(self):
return self._transform_factories.keys()

def create_transform(self, type, args, yaml_create_transform):
return self._transform_factories[type](**args)
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should revert catching these here (see below).

return self._transform_factories[type](**args)
except KeyError:
raise KeyError(f'Invalid transform specified: "{type}".')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be other key errors that we catch here. Also, this should never be called unless type is already in self._transform_factories. If we want to raise this error here, do an explicit check for the key in the dictionary, rather than catch all key errors that may occur in the call.

except TypeError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already attempt a better error message at a higher level. But this will catch all type errors, not just invalid transform arguments.

raise TypeError(
f'Invalid transform arguments specified for "{type}": '
f'{list(args.keys())}.')

def to_json(self):
return {'type': "InlineProvider"}
Expand Down Expand Up @@ -436,7 +448,7 @@ def parse_type(spec):
return schema_pb2.FieldType(
iterable_type=schema_pb2.RowType(schema=parse_schema(spec[0])))
else:
raise ValueError("Unknown schema type: {spec}")
raise ValueError(f"Unknown schema type: {spec}")

def parse_schema(spec):
return schema_pb2.Schema(
Expand All @@ -446,7 +458,9 @@ def parse_schema(spec):
],
id=str(uuid.uuid4()))

named_tuple = schemas.named_tuple_from_schema(parse_schema(args))
if 'schema' not in args.keys():
raise ValueError("WithSchema transform missing required 'schema' tag.")
named_tuple = schemas.named_tuple_from_schema(parse_schema(args['schema']))
Comment on lines +461 to +463
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertwb Should I remove this now that we have the config tag?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema needs to be in the config now.

names = list(args.keys())

def extract_field(x, name):
Expand Down Expand Up @@ -575,10 +589,10 @@ def _create_venv_from_scratch(cls, base_python, packages):
def _create_venv_from_clone(cls, base_python, packages):
venv = cls._path(base_python, packages)
if not os.path.exists(venv):
clonable_venv = cls._create_venv_to_clone(base_python)
clonable_python = os.path.join(clonable_venv, 'bin', 'python')
cloneable_venv = cls._create_venv_to_clone(base_python)
cloneable_python = os.path.join(cloneable_venv, 'bin', 'python')
subprocess.run(
[clonable_python, '-m', 'clonevirtualenv', clonable_venv, venv],
[cloneable_python, '-m', 'clonevirtualenv', cloneable_venv, venv],
check=True)
venv_binary = os.path.join(venv, 'bin', 'python')
subprocess.run([venv_binary, '-m', 'pip', 'install'] + packages,
Expand Down Expand Up @@ -628,9 +642,12 @@ def __init__(self, transforms, mappings, underlying_provider):
underlying_provider)
self._transforms = transforms
self._underlying_provider = underlying_provider
for transform in transforms.keys():
if transform not in mappings:
raise ValueError(f'Missing transform {transform} in mappings.')
missing = [
transform for transform in transforms.keys()
if transform not in mappings
]
if missing:
raise ValueError(f'Missing transforms {missing} in mappings.')
self._mappings = mappings

def available(self) -> bool:
Expand Down
Loading