-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[yaml] Add yaml_provider.py Unit Tests #27804
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,8 @@ | |
from apache_beam.portability.api import schema_pb2 | ||
from apache_beam.transforms import external | ||
from apache_beam.transforms import window | ||
from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform | ||
from apache_beam.transforms.fully_qualified_named_transform import \ | ||
FullyQualifiedNamedTransform | ||
from apache_beam.typehints import schemas | ||
from apache_beam.typehints import trivial_inference | ||
from apache_beam.utils import python_callable | ||
|
@@ -119,7 +120,7 @@ def as_provider_list(name, lst): | |
|
||
class ExternalProvider(Provider): | ||
"""A Provider implemented via the cross language transform service.""" | ||
_provider_types: Dict[str, Callable[..., Provider]] = {} | ||
provider_types: Dict[str, Callable[..., Provider]] = {} | ||
|
||
def __init__(self, urns, service): | ||
self._urns = urns | ||
|
@@ -166,17 +167,18 @@ def provider_from_spec(cls, spec): | |
urns = spec['transforms'] | ||
type = spec['type'] | ||
config = SafeLineLoader.strip_metadata(spec.get('config', {})) | ||
extra_params = set(SafeLineLoader.strip_metadata(spec).keys()) - set( | ||
['transforms', 'type', 'config']) | ||
extra_params = set(SafeLineLoader.strip_metadata(spec).keys()) - { | ||
'transforms', 'type', 'config' | ||
} | ||
if extra_params: | ||
raise ValueError( | ||
f'Unexpected parameters in provider of type {type} ' | ||
f'at line {SafeLineLoader.get_line(spec)}: {extra_params}') | ||
if config.get('version', None) == 'BEAM_VERSION': | ||
config['version'] = beam_version | ||
if type in cls._provider_types: | ||
if type in cls.provider_types: | ||
try: | ||
return cls._provider_types[type](urns, **config) | ||
return cls.provider_types[type](urns, **config) | ||
except Exception as exn: | ||
raise ValueError( | ||
f'Unable to instantiate provider of type {type} ' | ||
|
@@ -189,7 +191,7 @@ def provider_from_spec(cls, spec): | |
@classmethod | ||
def register_provider_type(cls, type_name): | ||
def apply(constructor): | ||
cls._provider_types[type_name] = constructor | ||
cls.provider_types[type_name] = constructor | ||
|
||
return apply | ||
|
||
|
@@ -217,6 +219,7 @@ def maven_jar( | |
urns, | ||
lambda: subprocess_server.JavaJarServer.path_to_maven_jar( | ||
artifact_id=artifact_id, | ||
group_id=group_id, | ||
version=version, | ||
repository=repository, | ||
classifier=classifier, | ||
|
@@ -234,8 +237,10 @@ def beam_jar( | |
return ExternalJavaProvider( | ||
urns, | ||
lambda: subprocess_server.JavaJarServer.path_to_beam_jar( | ||
gradle_target=gradle_target, version=version, artifact_id=artifact_id) | ||
) | ||
gradle_target=gradle_target, | ||
appendix=appendix, | ||
version=version, | ||
artifact_id=artifact_id)) | ||
|
||
|
||
@ExternalProvider.register_provider_type('docker') | ||
|
@@ -282,7 +287,7 @@ def cache_artifacts(self): | |
@ExternalProvider.register_provider_type('python') | ||
def python(urns, packages=()): | ||
if packages: | ||
return ExternalPythonProvider(urns, packages) | ||
return ExternalProvider.provider_types['pythonPackage'](urns, packages) | ||
else: | ||
return InlineProvider({ | ||
name: | ||
|
@@ -372,7 +377,14 @@ def provided_transforms(self): | |
return self._transform_factories.keys() | ||
|
||
def create_transform(self, type, args, yaml_create_transform): | ||
return self._transform_factories[type](**args) | ||
try: | ||
return self._transform_factories[type](**args) | ||
except KeyError: | ||
raise KeyError(f'Invalid transform specified: "{type}".') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There may be other key errors that we catch here. Also, this should never be called unless type is already in self._transform_factories. If we want to raise this error here, do an explicit check for the key in the dictionary, rather than catch all key errors that may occur in the call. |
||
except TypeError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already attempt a better error message at a higher level. But this will catch all type errors, not just invalid transform arguments. |
||
raise TypeError( | ||
f'Invalid transform arguments specified for "{type}": ' | ||
f'{list(args.keys())}.') | ||
|
||
def to_json(self): | ||
return {'type': "InlineProvider"} | ||
|
@@ -436,7 +448,7 @@ def parse_type(spec): | |
return schema_pb2.FieldType( | ||
iterable_type=schema_pb2.RowType(schema=parse_schema(spec[0]))) | ||
else: | ||
raise ValueError("Unknown schema type: {spec}") | ||
raise ValueError(f"Unknown schema type: {spec}") | ||
|
||
def parse_schema(spec): | ||
return schema_pb2.Schema( | ||
|
@@ -446,7 +458,9 @@ def parse_schema(spec): | |
], | ||
id=str(uuid.uuid4())) | ||
|
||
named_tuple = schemas.named_tuple_from_schema(parse_schema(args)) | ||
if 'schema' not in args.keys(): | ||
raise ValueError("WithSchema transform missing required 'schema' tag.") | ||
named_tuple = schemas.named_tuple_from_schema(parse_schema(args['schema'])) | ||
Comment on lines
+461
to
+463
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertwb Should I remove this now that we have the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The schema needs to be in the config now. |
||
names = list(args.keys()) | ||
|
||
def extract_field(x, name): | ||
|
@@ -575,10 +589,10 @@ def _create_venv_from_scratch(cls, base_python, packages): | |
def _create_venv_from_clone(cls, base_python, packages): | ||
venv = cls._path(base_python, packages) | ||
if not os.path.exists(venv): | ||
clonable_venv = cls._create_venv_to_clone(base_python) | ||
clonable_python = os.path.join(clonable_venv, 'bin', 'python') | ||
cloneable_venv = cls._create_venv_to_clone(base_python) | ||
cloneable_python = os.path.join(cloneable_venv, 'bin', 'python') | ||
subprocess.run( | ||
[clonable_python, '-m', 'clonevirtualenv', clonable_venv, venv], | ||
[cloneable_python, '-m', 'clonevirtualenv', cloneable_venv, venv], | ||
check=True) | ||
venv_binary = os.path.join(venv, 'bin', 'python') | ||
subprocess.run([venv_binary, '-m', 'pip', 'install'] + packages, | ||
|
@@ -628,9 +642,12 @@ def __init__(self, transforms, mappings, underlying_provider): | |
underlying_provider) | ||
self._transforms = transforms | ||
self._underlying_provider = underlying_provider | ||
for transform in transforms.keys(): | ||
if transform not in mappings: | ||
raise ValueError(f'Missing transform {transform} in mappings.') | ||
missing = [ | ||
transform for transform in transforms.keys() | ||
if transform not in mappings | ||
] | ||
if missing: | ||
raise ValueError(f'Missing transforms {missing} in mappings.') | ||
self._mappings = mappings | ||
|
||
def available(self) -> bool: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should revert catching these here (see below).