Skip to content

Commit

Permalink
Add TensorflowHub embeddings to MLTransform (#30289)
Browse files Browse the repository at this point in the history
* Add tensorflow hub embeddings

* Add test suite to gradle

* Fix TF tests

* Update sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py

* Address comments

* Add tensorflow_text to pydocs

* Update sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py

* Revert to ignore

* Fix pydocs

* Fix pydocs

* Another attempt to fix pydocs, lint

* Remove type ignore
  • Loading branch information
AnandInguva authored Feb 21, 2024
1 parent 12eee95 commit bd7ed76
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 1 deletion.
134 changes: 134 additions & 0 deletions sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable
from typing import List
from typing import Optional

import apache_beam as beam
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import
from apache_beam.ml.inference import utils
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _TextEmbeddingHandler

__all__ = ['TensorflowHubTextEmbeddings']


# TODO: https://github.com/apache/beam/issues/30288
# Replace with TFModelHandlerTensor when load_model() supports TFHUB models.
class _TensorflowHubModelHandler(TFModelHandlerTensor):
"""
Note: Intended for internal use only. No backwards compatibility guarantees.
"""
def __init__(self, preprocessing_url: Optional[str], *args, **kwargs):
self.preprocessing_url = preprocessing_url
super().__init__(*args, **kwargs)

def load_model(self):
# unable to load the models with tf.keras.models.load_model so
# using hub.KerasLayer instead
model = hub.KerasLayer(self._model_uri, **self._load_model_args)
return model

def _convert_prediction_result_to_list(
self, predictions: Iterable[PredictionResult]):
result = []
for prediction in predictions:
inference = prediction.inference.numpy().tolist()
result.append(inference)
return result

def run_inference(self, batch, model, inference_args, model_id=None):
if not inference_args:
inference_args = {}
if not self.preprocessing_url:
predictions = default_tensor_inference_fn(
model=model,
batch=batch,
inference_args=inference_args,
model_id=model_id)
return self._convert_prediction_result_to_list(predictions)

vectorized_batch = tf.stack(batch, axis=0)
preprocessor_fn = hub.KerasLayer(self.preprocessing_url)
vectorized_batch = preprocessor_fn(vectorized_batch)
predictions = model(vectorized_batch)
# https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model # pylint: disable=line-too-long
# pooled_output -> represents the text as a whole. This is an embeddings
# of the whole text. The shape is [batch_size, embedding_dimension]
# sequence_output -> represents the text as a sequence of tokens. This is
# an embeddings of each token in the text. The shape is
# [batch_size, max_sequence_length, embedding_dimension]
# pooled output is the embeedings as per the documentation. so let's use
# that.
embeddings = predictions['pooled_output']
predictions = utils._convert_to_result(batch, embeddings, model_id)
return self._convert_prediction_result_to_list(predictions)


class TensorflowHubTextEmbeddings(EmbeddingsManager):
def __init__(
self,
columns: List[str],
hub_url: str,
preprocessing_url: Optional[str] = None,
**kwargs):
"""
Embedding config for tensorflow hub models. This config can be used with
MLTransform to embed text data. Models are loaded using the RunInference
PTransform with the help of a ModelHandler.
Args:
columns: The columns containing the text to be embedded.
hub_url: The url of the tensorflow hub model.
preprocessing_url: The url of the preprocessing model. This is optional.
If provided, the preprocessing model will be used to preprocess the
text before feeding it to the main model.
min_batch_size: The minimum batch size to be used for inference.
max_batch_size: The maximum batch size to be used for inference.
large_model: Whether to share the model across processes.
"""
super().__init__(columns=columns, **kwargs)
self.model_uri = hub_url
self.preprocessing_url = preprocessing_url

def get_model_handler(self) -> ModelHandler:
# override the default inference function
return _TensorflowHubModelHandler(
model_uri=self.model_uri,
preprocessing_url=self.preprocessing_url,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
large_model=self.large_model,
)

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
"""
Returns a RunInference object that is used to run inference on the text
input using _TextEmbeddingHandler.
"""
return (
RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args,
))
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest
import uuid

import apache_beam as beam
from apache_beam.ml.transforms.base import MLTransform

hub_url = 'https://tfhub.dev/google/nnlm-en-dim128/2'
test_query_column = 'test_query'
test_query = 'This is a test query'

# pylint: disable=ungrouped-imports
try:
from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings
except ImportError:
TensorflowHubTextEmbeddings = None # type: ignore

# pylint: disable=ungrouped-imports
try:
import tensorflow_transform as tft
from apache_beam.ml.transforms.tft import ScaleTo01
except ImportError:
tft = None


@unittest.skipIf(
TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.')
class TFHubEmbeddingsTest(unittest.TestCase):
def setUp(self) -> None:
self.artifact_location = tempfile.mkdtemp()

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def test_tfhub_text_embeddings(self):
embedding_config = TensorflowHubTextEmbeddings(
hub_url=hub_url, columns=[test_query_column])
with beam.Pipeline() as pipeline:
transformed_pcoll = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: test_query
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))

def assert_element(element):
assert len(element[test_query_column]) == 128

_ = (transformed_pcoll | beam.Map(assert_element))

@unittest.skipIf(tft is None, 'Tensorflow Transform is not installed.')
def test_embeddings_with_scale_to_0_1(self):
embedding_config = TensorflowHubTextEmbeddings(
hub_url=hub_url,
columns=[test_query_column],
)
with beam.Pipeline() as pipeline:
transformed_pcoll = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: test_query
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config).with_transform(
ScaleTo01(columns=[test_query_column])))

def assert_element(element):
assert max(element[test_query_column]) == 1

_ = (
transformed_pcoll | beam.Map(lambda x: x.as_dict())
| beam.Map(assert_element))

def pipeline_with_configurable_artifact_location(
self,
pipeline,
embedding_config=None,
read_artifact_location=None,
write_artifact_location=None):
if write_artifact_location:
return (
pipeline
| MLTransform(write_artifact_location=write_artifact_location).
with_transform(embedding_config))
elif read_artifact_location:
return (
pipeline
| MLTransform(read_artifact_location=read_artifact_location))
else:
raise NotImplementedError

def test_embeddings_with_read_artifact_location(self):
with beam.Pipeline() as p:
embedding_config = TensorflowHubTextEmbeddings(
hub_url=hub_url, columns=[test_query_column])

with beam.Pipeline() as p:
data = (
p
| "CreateData" >> beam.Create([{
test_query_column: test_query
}]))
_ = self.pipeline_with_configurable_artifact_location(
pipeline=data,
embedding_config=embedding_config,
write_artifact_location=self.artifact_location)

with beam.Pipeline() as p:
data = (
p
| "CreateData" >> beam.Create([{
test_query_column: test_query
}, {
test_query_column: test_query
}]))
result_pcoll = self.pipeline_with_configurable_artifact_location(
pipeline=data, read_artifact_location=self.artifact_location)

def assert_element(element):
# 0.29836970567703247
assert round(element, 2) == 0.3

_ = (
result_pcoll
| beam.Map(lambda x: max(x[test_query_column]))
| beam.Map(assert_element))

def test_with_int_data_types(self):
embedding_config = TensorflowHubTextEmbeddings(
hub_url=hub_url, columns=[test_query_column])
with self.assertRaises(TypeError):
with beam.Pipeline() as pipeline:
_ = (
pipeline
| "CreateData" >> beam.Create([{
test_query_column: 1
}])
| "MLTransform" >> MLTransform(
write_artifact_location=self.artifact_location).with_transform(
embedding_config))


@unittest.skipIf(
TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.')
class TFHubEmbeddingsGCSArtifactLocationTest(TFHubEmbeddingsTest):
def setUp(self):
self.artifact_location = os.path.join(
'gs://temp-storage-for-perf-tests/tfhub', uuid.uuid4().hex)

def tearDown(self):
pass


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion sdks/python/scripts/generate_pydoc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ autodoc_member_order = 'bysource'
autodoc_mock_imports = ["tensorrt", "cuda", "torch",
"onnxruntime", "onnx", "tensorflow", "tensorflow_hub",
"tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost", "datatable", "transformers",
"sentence_transformers", "redis",
"sentence_transformers", "redis", "tensorflow_text",
]
# Allow a special section for documenting DataFrame API
Expand Down
10 changes: 10 additions & 0 deletions sdks/python/test-suites/tox/py38/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ toxTask "testPy38embeddingsMLTransform", "py38-embeddings", "${posargs}"
test.dependsOn "testPy38embeddingsMLTransform"
preCommitPyCoverage.dependsOn "testPy38embeddingsMLTransform"

// Part of MLTransform embeddings test suite but requires tensorflow hub, which we need to test on
// mutliple versions so keeping this suite separate.
toxTask "testPy38TensorflowHubEmbeddings-014", "py38-TFHubEmbeddings-014", "${posargs}"
test.dependsOn "testPy38TensorflowHubEmbeddings-014"
preCommitPyCoverage.dependsOn "testPy38TensorflowHubEmbeddings-014"

toxTask "testPy38TensorflowHubEmbeddings-015", "py38-TFHubEmbeddings-015", "${posargs}"
test.dependsOn "testPy38TensorflowHubEmbeddings-015"
preCommitPyCoverage.dependsOn "testPy38TensorflowHubEmbeddings-015"

toxTask "whitespacelint", "whitespacelint", "${posargs}"

task archiveFilesToLint(type: Zip) {
Expand Down
14 changes: 14 additions & 0 deletions sdks/python/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,17 @@ commands =
/bin/sh -c "pip freeze | grep -E google-cloud-aiplatform"
# Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories.
/bin/sh -c 'pytest apache_beam/ml/transforms/embeddings -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret'


[testenv:py{38,39,310,311}-TFHubEmbeddings-{014,015}]
deps =
014: tensorflow-hub>=0.14.0,<0.15.0
015: tensorflow-hub>=0.15.0,<0.16.0
tensorflow-text # required to register ops for text embedding models.

extras = test,gcp
commands =
# Log aiplatform and its dependencies version for debugging
/bin/sh -c "pip freeze | grep -E tensorflow"
# Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories.
bash {toxinidir}/scripts/run_pytest.sh {envname} 'apache_beam/ml/transforms/embeddings'

0 comments on commit bd7ed76

Please sign in to comment.