-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TensorflowHub embeddings to MLTransform (#30289)
* Add tensorflow hub embeddings * Add test suite to gradle * Fix TF tests * Update sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py * Address comments * Add tensorflow_text to pydocs * Update sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py * Revert to ignore * Fix pydocs * Fix pydocs * Another attempt to fix pydocs, lint * Remove type ignore
- Loading branch information
1 parent
12eee95
commit bd7ed76
Showing
5 changed files
with
335 additions
and
1 deletion.
There are no files selected for viewing
134 changes: 134 additions & 0 deletions
134
sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Iterable | ||
from typing import List | ||
from typing import Optional | ||
|
||
import apache_beam as beam | ||
import tensorflow as tf | ||
import tensorflow_hub as hub | ||
import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import | ||
from apache_beam.ml.inference import utils | ||
from apache_beam.ml.inference.base import ModelHandler | ||
from apache_beam.ml.inference.base import PredictionResult | ||
from apache_beam.ml.inference.base import RunInference | ||
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor | ||
from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn | ||
from apache_beam.ml.transforms.base import EmbeddingsManager | ||
from apache_beam.ml.transforms.base import _TextEmbeddingHandler | ||
|
||
__all__ = ['TensorflowHubTextEmbeddings'] | ||
|
||
|
||
# TODO: https://github.com/apache/beam/issues/30288 | ||
# Replace with TFModelHandlerTensor when load_model() supports TFHUB models. | ||
class _TensorflowHubModelHandler(TFModelHandlerTensor): | ||
""" | ||
Note: Intended for internal use only. No backwards compatibility guarantees. | ||
""" | ||
def __init__(self, preprocessing_url: Optional[str], *args, **kwargs): | ||
self.preprocessing_url = preprocessing_url | ||
super().__init__(*args, **kwargs) | ||
|
||
def load_model(self): | ||
# unable to load the models with tf.keras.models.load_model so | ||
# using hub.KerasLayer instead | ||
model = hub.KerasLayer(self._model_uri, **self._load_model_args) | ||
return model | ||
|
||
def _convert_prediction_result_to_list( | ||
self, predictions: Iterable[PredictionResult]): | ||
result = [] | ||
for prediction in predictions: | ||
inference = prediction.inference.numpy().tolist() | ||
result.append(inference) | ||
return result | ||
|
||
def run_inference(self, batch, model, inference_args, model_id=None): | ||
if not inference_args: | ||
inference_args = {} | ||
if not self.preprocessing_url: | ||
predictions = default_tensor_inference_fn( | ||
model=model, | ||
batch=batch, | ||
inference_args=inference_args, | ||
model_id=model_id) | ||
return self._convert_prediction_result_to_list(predictions) | ||
|
||
vectorized_batch = tf.stack(batch, axis=0) | ||
preprocessor_fn = hub.KerasLayer(self.preprocessing_url) | ||
vectorized_batch = preprocessor_fn(vectorized_batch) | ||
predictions = model(vectorized_batch) | ||
# https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model # pylint: disable=line-too-long | ||
# pooled_output -> represents the text as a whole. This is an embeddings | ||
# of the whole text. The shape is [batch_size, embedding_dimension] | ||
# sequence_output -> represents the text as a sequence of tokens. This is | ||
# an embeddings of each token in the text. The shape is | ||
# [batch_size, max_sequence_length, embedding_dimension] | ||
# pooled output is the embeedings as per the documentation. so let's use | ||
# that. | ||
embeddings = predictions['pooled_output'] | ||
predictions = utils._convert_to_result(batch, embeddings, model_id) | ||
return self._convert_prediction_result_to_list(predictions) | ||
|
||
|
||
class TensorflowHubTextEmbeddings(EmbeddingsManager): | ||
def __init__( | ||
self, | ||
columns: List[str], | ||
hub_url: str, | ||
preprocessing_url: Optional[str] = None, | ||
**kwargs): | ||
""" | ||
Embedding config for tensorflow hub models. This config can be used with | ||
MLTransform to embed text data. Models are loaded using the RunInference | ||
PTransform with the help of a ModelHandler. | ||
Args: | ||
columns: The columns containing the text to be embedded. | ||
hub_url: The url of the tensorflow hub model. | ||
preprocessing_url: The url of the preprocessing model. This is optional. | ||
If provided, the preprocessing model will be used to preprocess the | ||
text before feeding it to the main model. | ||
min_batch_size: The minimum batch size to be used for inference. | ||
max_batch_size: The maximum batch size to be used for inference. | ||
large_model: Whether to share the model across processes. | ||
""" | ||
super().__init__(columns=columns, **kwargs) | ||
self.model_uri = hub_url | ||
self.preprocessing_url = preprocessing_url | ||
|
||
def get_model_handler(self) -> ModelHandler: | ||
# override the default inference function | ||
return _TensorflowHubModelHandler( | ||
model_uri=self.model_uri, | ||
preprocessing_url=self.preprocessing_url, | ||
min_batch_size=self.min_batch_size, | ||
max_batch_size=self.max_batch_size, | ||
large_model=self.large_model, | ||
) | ||
|
||
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: | ||
""" | ||
Returns a RunInference object that is used to run inference on the text | ||
input using _TextEmbeddingHandler. | ||
""" | ||
return ( | ||
RunInference( | ||
model_handler=_TextEmbeddingHandler(self), | ||
inference_args=self.inference_args, | ||
)) |
176 changes: 176 additions & 0 deletions
176
sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import shutil | ||
import tempfile | ||
import unittest | ||
import uuid | ||
|
||
import apache_beam as beam | ||
from apache_beam.ml.transforms.base import MLTransform | ||
|
||
hub_url = 'https://tfhub.dev/google/nnlm-en-dim128/2' | ||
test_query_column = 'test_query' | ||
test_query = 'This is a test query' | ||
|
||
# pylint: disable=ungrouped-imports | ||
try: | ||
from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings | ||
except ImportError: | ||
TensorflowHubTextEmbeddings = None # type: ignore | ||
|
||
# pylint: disable=ungrouped-imports | ||
try: | ||
import tensorflow_transform as tft | ||
from apache_beam.ml.transforms.tft import ScaleTo01 | ||
except ImportError: | ||
tft = None | ||
|
||
|
||
@unittest.skipIf( | ||
TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') | ||
class TFHubEmbeddingsTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.artifact_location = tempfile.mkdtemp() | ||
|
||
def tearDown(self) -> None: | ||
shutil.rmtree(self.artifact_location) | ||
|
||
def test_tfhub_text_embeddings(self): | ||
embedding_config = TensorflowHubTextEmbeddings( | ||
hub_url=hub_url, columns=[test_query_column]) | ||
with beam.Pipeline() as pipeline: | ||
transformed_pcoll = ( | ||
pipeline | ||
| "CreateData" >> beam.Create([{ | ||
test_query_column: test_query | ||
}]) | ||
| "MLTransform" >> MLTransform( | ||
write_artifact_location=self.artifact_location).with_transform( | ||
embedding_config)) | ||
|
||
def assert_element(element): | ||
assert len(element[test_query_column]) == 128 | ||
|
||
_ = (transformed_pcoll | beam.Map(assert_element)) | ||
|
||
@unittest.skipIf(tft is None, 'Tensorflow Transform is not installed.') | ||
def test_embeddings_with_scale_to_0_1(self): | ||
embedding_config = TensorflowHubTextEmbeddings( | ||
hub_url=hub_url, | ||
columns=[test_query_column], | ||
) | ||
with beam.Pipeline() as pipeline: | ||
transformed_pcoll = ( | ||
pipeline | ||
| "CreateData" >> beam.Create([{ | ||
test_query_column: test_query | ||
}]) | ||
| "MLTransform" >> MLTransform( | ||
write_artifact_location=self.artifact_location).with_transform( | ||
embedding_config).with_transform( | ||
ScaleTo01(columns=[test_query_column]))) | ||
|
||
def assert_element(element): | ||
assert max(element[test_query_column]) == 1 | ||
|
||
_ = ( | ||
transformed_pcoll | beam.Map(lambda x: x.as_dict()) | ||
| beam.Map(assert_element)) | ||
|
||
def pipeline_with_configurable_artifact_location( | ||
self, | ||
pipeline, | ||
embedding_config=None, | ||
read_artifact_location=None, | ||
write_artifact_location=None): | ||
if write_artifact_location: | ||
return ( | ||
pipeline | ||
| MLTransform(write_artifact_location=write_artifact_location). | ||
with_transform(embedding_config)) | ||
elif read_artifact_location: | ||
return ( | ||
pipeline | ||
| MLTransform(read_artifact_location=read_artifact_location)) | ||
else: | ||
raise NotImplementedError | ||
|
||
def test_embeddings_with_read_artifact_location(self): | ||
with beam.Pipeline() as p: | ||
embedding_config = TensorflowHubTextEmbeddings( | ||
hub_url=hub_url, columns=[test_query_column]) | ||
|
||
with beam.Pipeline() as p: | ||
data = ( | ||
p | ||
| "CreateData" >> beam.Create([{ | ||
test_query_column: test_query | ||
}])) | ||
_ = self.pipeline_with_configurable_artifact_location( | ||
pipeline=data, | ||
embedding_config=embedding_config, | ||
write_artifact_location=self.artifact_location) | ||
|
||
with beam.Pipeline() as p: | ||
data = ( | ||
p | ||
| "CreateData" >> beam.Create([{ | ||
test_query_column: test_query | ||
}, { | ||
test_query_column: test_query | ||
}])) | ||
result_pcoll = self.pipeline_with_configurable_artifact_location( | ||
pipeline=data, read_artifact_location=self.artifact_location) | ||
|
||
def assert_element(element): | ||
# 0.29836970567703247 | ||
assert round(element, 2) == 0.3 | ||
|
||
_ = ( | ||
result_pcoll | ||
| beam.Map(lambda x: max(x[test_query_column])) | ||
| beam.Map(assert_element)) | ||
|
||
def test_with_int_data_types(self): | ||
embedding_config = TensorflowHubTextEmbeddings( | ||
hub_url=hub_url, columns=[test_query_column]) | ||
with self.assertRaises(TypeError): | ||
with beam.Pipeline() as pipeline: | ||
_ = ( | ||
pipeline | ||
| "CreateData" >> beam.Create([{ | ||
test_query_column: 1 | ||
}]) | ||
| "MLTransform" >> MLTransform( | ||
write_artifact_location=self.artifact_location).with_transform( | ||
embedding_config)) | ||
|
||
|
||
@unittest.skipIf( | ||
TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') | ||
class TFHubEmbeddingsGCSArtifactLocationTest(TFHubEmbeddingsTest): | ||
def setUp(self): | ||
self.artifact_location = os.path.join( | ||
'gs://temp-storage-for-perf-tests/tfhub', uuid.uuid4().hex) | ||
|
||
def tearDown(self): | ||
pass | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters