diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index 6b88600a97be..d1839e9de0f3 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -504,7 +504,8 @@ def __init__( given their final names. By default, the temporary directory will be within the temp_location of your pipeline. sink (callable, FileSink): The sink to use to write into a file. It should - implement the methods of a ``FileSink``. If none is provided, a + implement the methods of a ``FileSink``. Pass a class signature or an + instance of FileSink to this parameter. If none is provided, a ``TextSink`` is used. shards (int): The number of shards per destination and trigger firing. max_writers_per_bundle (int): The number of writers that can be open @@ -525,8 +526,11 @@ def __init__( @staticmethod def _get_sink_fn(input_sink): # type: (...) -> Callable[[Any], FileSink] - if isinstance(input_sink, FileSink): - return lambda x: input_sink + if isinstance(input_sink, type) and issubclass(input_sink, FileSink): + return lambda x: input_sink() + elif isinstance(input_sink, FileSink): + kls = input_sink.__class__ + return lambda x: kls() elif callable(input_sink): return input_sink else: @@ -791,7 +795,6 @@ def process( def _get_or_create_writer_and_sink(self, destination, window): """Returns a tuple of writer, sink.""" writer_key = (destination, window) - if writer_key in self._writers_and_sinks: return self._writers_and_sinks.get(writer_key) elif len(self._writers_and_sinks) >= self.max_num_writers_per_bundle: @@ -807,7 +810,6 @@ def _get_or_create_writer_and_sink(self, destination, window): create_metadata_fn=sink.create_metadata) sink.open(writer) - self._writers_and_sinks[writer_key] = (writer, sink) self._file_names[writer_key] = full_file_name return self._writers_and_sinks[writer_key] diff --git a/sdks/python/apache_beam/io/fileio_test.py b/sdks/python/apache_beam/io/fileio_test.py index f21fb8d17962..ab4dba2366c8 100644 --- a/sdks/python/apache_beam/io/fileio_test.py +++ b/sdks/python/apache_beam/io/fileio_test.py @@ -459,6 +459,42 @@ def test_write_to_single_file_batch(self): assert_that(result, equal_to([row for row in self.SIMPLE_COLLECTION])) + def test_write_to_dynamic_destination(self): + + sink_params = [ + fileio.TextSink, # pass a type signature + fileio.TextSink() # pass a FileSink object + ] + + for sink in sink_params: + dir = self._new_tempdir() + + with TestPipeline() as p: + _ = ( + p + | "Create" >> beam.Create(range(100)) + | beam.Map(lambda x: str(x)) + | fileio.WriteToFiles( + path=dir, + destination=lambda n: "odd" if int(n) % 2 else "even", + sink=sink, + file_naming=fileio.destination_prefix_naming("test"))) + + with TestPipeline() as p: + result = ( + p + | fileio.MatchFiles(FileSystems.join(dir, '*')) + | fileio.ReadMatches() + | beam.Map( + lambda f: ( + os.path.basename(f.metadata.path).split('-')[0], + sorted(map(int, f.read_utf8().strip().split('\n')))))) + + assert_that( + result, + equal_to([('odd', list(range(1, 100, 2))), + ('even', list(range(0, 100, 2)))])) + def test_write_to_different_file_types_some_spilling(self): dir = self._new_tempdir()