diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index 8b571d9a685e..f1ff0e4dfe57 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -172,8 +172,8 @@ def __init__( num_oov_buckets: Any lookup of an out-of-vocabulary token will return a bucket ID based on its hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the `default_value`. - vocab_filename: The file name for the vocabulary file. If not provided, - the default name would be `compute_and_apply_vocab' + vocab_filename: The file name for the vocabulary file. The vocab file + will be suffixed with the column name. NOTE in order to make your pipelines resilient to implementation details please set `vocab_filename` when you are using the vocab_filename on a downstream component. @@ -183,8 +183,7 @@ def __init__( self._top_k = top_k self._frequency_threshold = frequency_threshold self._num_oov_buckets = num_oov_buckets - self._vocab_filename = vocab_filename if vocab_filename else ( - 'compute_and_apply_vocab') + self._vocab_filename = vocab_filename self._name = name self.split_string_by_delimiter = split_string_by_delimiter @@ -196,6 +195,9 @@ def apply_transform( data = self._split_string_with_delimiter( data, self.split_string_by_delimiter) + vocab_filename = self._vocab_filename + if vocab_filename: + vocab_filename = vocab_filename + f'_{output_column_name}' return { output_column_name: tft.compute_and_apply_vocabulary( x=data, @@ -203,7 +205,7 @@ def apply_transform( top_k=self._top_k, frequency_threshold=self._frequency_threshold, num_oov_buckets=self._num_oov_buckets, - vocab_filename=self._vocab_filename, + vocab_filename=vocab_filename, name=self._name) } @@ -535,7 +537,7 @@ def __init__( ngram_range: Tuple[int, int] = (1, 1), ngrams_separator: Optional[str] = None, compute_word_count: bool = False, - key_vocab_filename: str = 'key_vocab_mapping', + key_vocab_filename: Optional[str] = None, name: Optional[str] = None, ): """ @@ -558,7 +560,9 @@ def __init__( compute_word_count: A boolean that specifies whether to compute the unique word count over the entire dataset. Defaults to False. key_vocab_filename: The file name for the key vocabulary file when - compute_word_count is True. + compute_word_count is True. If empty, a file name + will be chosen based on the current scope. If provided, the vocab + file will be suffixed with the column name. name: A name for the operation (optional). Note that original order of the input may not be preserved. @@ -585,10 +589,14 @@ def apply_transform(self, data: tf.SparseTensor, output_col_name: str): data, self.split_string_by_delimiter) output = tft.bag_of_words( data, self.ngram_range, self.ngrams_separator, self.name) - # word counts are written to the key_vocab_filename - self.compute_word_count_fn(data, self.key_vocab_filename) + # word counts are written to the file only if compute_word_count is True + key_vocab_filename = self.key_vocab_filename + if key_vocab_filename: + key_vocab_filename = key_vocab_filename + f'_{output_col_name}' + self.compute_word_count_fn(data, key_vocab_filename) return {output_col_name: output} -def count_unqiue_words(data: tf.SparseTensor, output_vocab_name: str) -> None: +def count_unqiue_words( + data: tf.SparseTensor, output_vocab_name: Optional[str]) -> None: tft.count_per_key(data, key_vocabulary_filename=output_vocab_name) diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py b/sdks/python/apache_beam/ml/transforms/tft_test.py index 9f15db45bd28..558b4ede2ec6 100644 --- a/sdks/python/apache_beam/ml/transforms/tft_test.py +++ b/sdks/python/apache_beam/ml/transforms/tft_test.py @@ -357,6 +357,85 @@ def test_string_split_with_multiple_delimiters(self): ] assert_that(result, equal_to(expected_result, equals_fn=np.array_equal)) + def test_multiple_columns_with_default_vocab_name(self): + data = [{ + 'x': ['I', 'like', 'pie'], 'y': ['Apach', 'Beam', 'is', 'awesome'] + }, + { + 'x': ['yum', 'yum', 'pie'], + 'y': ['Beam', 'is', 'a', 'unified', 'model'] + }] + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.ComputeAndApplyVocabulary(columns=['x', 'y']))) + + expected_data_x = [np.array([3, 2, 1]), np.array([0, 0, 1])] + + expected_data_y = [np.array([6, 1, 0, 4]), np.array([1, 0, 5, 2, 3])] + + actual_data_x = (result | beam.Map(lambda x: x.x)) + actual_data_y = (result | beam.Map(lambda x: x.y)) + + assert_that( + actual_data_x, + equal_to(expected_data_x, equals_fn=np.array_equal), + label='x') + assert_that( + actual_data_y, + equal_to(expected_data_y, equals_fn=np.array_equal), + label='y') + files = os.listdir(self.artifact_location) + files.remove(base._ATTRIBUTE_FILE_NAME) + assert len(files) == 1 + tft_vocab_assets = os.listdir( + os.path.join( + self.artifact_location, files[0], 'transform_fn', 'assets')) + assert len(tft_vocab_assets) == 2 + + def test_multiple_columns_with_vocab_name(self): + data = [{ + 'x': ['I', 'like', 'pie'], 'y': ['Apach', 'Beam', 'is', 'awesome'] + }, + { + 'x': ['yum', 'yum', 'pie'], + 'y': ['Beam', 'is', 'a', 'unified', 'model'] + }] + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + tft.ComputeAndApplyVocabulary( + columns=['x', 'y'], vocab_filename='my_vocab'))) + + expected_data_x = [np.array([3, 2, 1]), np.array([0, 0, 1])] + + expected_data_y = [np.array([6, 1, 0, 4]), np.array([1, 0, 5, 2, 3])] + + actual_data_x = (result | beam.Map(lambda x: x.x)) + actual_data_y = (result | beam.Map(lambda x: x.y)) + + assert_that( + actual_data_x, + equal_to(expected_data_x, equals_fn=np.array_equal), + label='x') + assert_that( + actual_data_y, + equal_to(expected_data_y, equals_fn=np.array_equal), + label='y') + files = os.listdir(self.artifact_location) + files.remove(base._ATTRIBUTE_FILE_NAME) + assert len(files) == 1 + tft_vocab_assets = os.listdir( + os.path.join( + self.artifact_location, files[0], 'transform_fn', 'assets')) + assert len(tft_vocab_assets) == 2 + class TFIDIFTest(unittest.TestCase): def setUp(self) -> None: @@ -717,7 +796,7 @@ def validate_count_per_key(key_vocab_filename): self.artifact_location, files[0], 'transform_fn/assets', - key_vocab_filename) + key_vocab_filename + '_x') with open(key_vocab_location, 'r') as f: key_vocab_list = [line.strip() for line in f] return key_vocab_list