From 16852518861514a56ad1b57113e3815cfcc8533a Mon Sep 17 00:00:00 2001
From: Anand Inguva <34158215+AnandInguva@users.noreply.github.com>
Date: Fri, 4 Nov 2022 08:39:34 -0400
Subject: [PATCH] TFX image classification example (#23456)

* TFX image classification example

* TFX image classification with serving model

* Add TF Model Wrapper

* Clean up code

* Refactoring

* Add a test for tfx_bsl RunInference

* fixup lint

* Refactoring

* Fixup lint

* Add brief summary for the example

* Apply suggestions from code review

Co-authored-by: Andy Ye <andyye333@gmail.com>

* Refactoring code and add comments

* Update help description

* reorder imports

* Reorder imports again

* Add docstring

* Refactoring

* Add pillow to tfx pipeline requirements

* Move inferencePostCommitIT to Python 3.9 suite

* Uncomment other postcommit suites

Co-authored-by: Andy Ye <andyye333@gmail.com>
---
 build.gradle.kts                              |   4 +-
 .../examples/inference/tfx_bsl/__init__.py    |  16 ++
 .../tfx_bsl/build_tensorflow_model.py         | 147 +++++++++++++
 .../inference/tfx_bsl/requirements.txt        |  19 ++
 .../tensorflow_image_classification.py        | 194 ++++++++++++++++++
 .../tfx_bsl/tfx_bsl_inference_it_test.py      |  99 +++++++++
 sdks/python/pytest.ini                        |   1 +
 sdks/python/test-suites/direct/common.gradle  |  31 ++-
 8 files changed, 509 insertions(+), 2 deletions(-)
 create mode 100644 sdks/python/apache_beam/examples/inference/tfx_bsl/__init__.py
 create mode 100644 sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py
 create mode 100644 sdks/python/apache_beam/examples/inference/tfx_bsl/requirements.txt
 create mode 100644 sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py
 create mode 100644 sdks/python/apache_beam/examples/inference/tfx_bsl/tfx_bsl_inference_it_test.py

diff --git a/build.gradle.kts b/build.gradle.kts
index 38d2971303b..2556db7bc9a 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -362,6 +362,9 @@ tasks.register("python39PostCommit") {
   dependsOn(":sdks:python:test-suites:direct:py39:postCommitIT")
   dependsOn(":sdks:python:test-suites:direct:py39:hdfsIntegrationTest")
   dependsOn(":sdks:python:test-suites:portable:py39:postCommitPy39")
+  // TODO (https://github.com/apache/beam/issues/23966)
+  // Move this to Python 3.10 test suite once tfx-bsl has python 3.10 wheel.
+  dependsOn(":sdks:python:test-suites:direct:py39:inferencePostCommitIT")
 }
 
 tasks.register("python310PostCommit") {
@@ -369,7 +372,6 @@ tasks.register("python310PostCommit") {
   dependsOn(":sdks:python:test-suites:direct:py310:postCommitIT")
   dependsOn(":sdks:python:test-suites:direct:py310:hdfsIntegrationTest")
   dependsOn(":sdks:python:test-suites:portable:py310:postCommitPy310")
-  dependsOn(":sdks:python:test-suites:direct:py310:inferencePostCommitIT")
 }
 
 task("python37SickbayPostCommit") {
diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/__init__.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/__init__.py
new file mode 100644
index 00000000000..cce3acad34a
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py
new file mode 100644
index 00000000000..9230f84955e
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py
@@ -0,0 +1,147 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Intended only for internal testing.
+
+from typing import Dict
+from typing import Optional
+
+import tensorflow as tf
+
+
+class TFModelWrapperWithSignature(tf.keras.Model):
+  """
+  Helper class used to wrap a based tf.keras.Model object with a serving
+  signature that can passed to the tfx_bsl RunInference transform.
+
+  A TF model saved using this helper class expects inputs as
+    images serialized to tf.string using tf.io.parse_tensor
+    and then passing serialized images to the RunInference transform
+    in the tf.train.Example. More about tf.train.Example at
+    https://www.tensorflow.org/api_docs/python/tf/train/Example
+
+  Usage:
+  Step 1:
+  # Save the base TF model with modified signature .
+  signature_model = TFModelWrapperWithSignature(
+      model=model,
+      preprocess_input=preprocess_input,
+      input_dtype=input_dtype,
+      feature_description=feature_description,
+      **kwargs
+      )
+  tf.saved_model.save(signature_model, path)
+
+  Step 2:
+  # Load the saved_model in the beam pipeline to create ModelHandler.
+  saved_model_spec = model_spec_pb2.SavedModelSpec(
+      model_path=known_args.model_path)
+  inferece_spec_type = model_spec_pb2.InferenceSpecType(
+      saved_model_spec=saved_model_spec)
+  model_handler = CreateModelHandler(inferece_spec_type)
+  """
+  def __init__(
+      self,
+      model,
+      preprocess_input=None,
+      input_dtype=None,
+      feature_description=None,
+      **kwargs):
+    """
+    model: model: Base tensorflow model used for TFX-BSL RunInference transform.
+    preprocess_input: Preprocess method to be included as part of the
+      model's serving signature.
+    input_dtype: tf dtype of the inputs passed to the model.
+      For eg: tf.int32, tf.uint8.
+    feature_description: Feature spec to parse inputs from tf.train.Example
+      using tf.parse_example(). For more details, please take a look at
+      https://www.tensorflow.org/api_docs/python/tf/io/parse_example
+    If there are extra arguments(for eg: training=False) that should be
+    passed to the base tf model during inference, please pass them in kwargs.
+    """
+    super().__init__()
+    self.model = model
+    self.preprocess_input = preprocess_input
+    self.input_dtype = input_dtype
+    self.feature_description = feature_description
+    if not feature_description:
+      self.feature_description = {'image': tf.io.FixedLenFeature((), tf.string)}
+    self._kwargs = kwargs
+
+  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
+  def call(self, serialized_examples):
+    features = tf.io.parse_example(
+        serialized_examples, features=self.feature_description)
+
+    # Initialize a TensorArray to store the deserialized values.
+    # For more details, please look at
+    # https://github.com/tensorflow/tensorflow/issues/39323#issuecomment-627586602
+    num_batches = len(features['image'])
+    deserialized_vectors = tf.TensorArray(
+        self.input_dtype, size=num_batches, dynamic_size=True)
+    # Vectorized version of tf.io.parse_tensor is not available.
+    # Use for loop to vectorize the tensor. For more details, refer
+    # https://github.com/tensorflow/tensorflow/issues/43706
+    for i in range(num_batches):
+      deserialized_value = tf.io.parse_tensor(
+          features['image'][i], out_type=self.input_dtype)
+      # In Graph mode, return value must get assigned in order to
+      # update the array. More details at
+      # http://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873
+      deserialized_vectors = deserialized_vectors.write(i, deserialized_value)
+    deserialized_tensor = deserialized_vectors.stack()
+    if self.preprocess_input:
+      deserialized_tensor = self.preprocess_input(deserialized_tensor)
+    return self.model(deserialized_tensor, **self._kwargs)
+
+
+def save_tf_model_with_signature(
+    path_to_save_model,
+    model=None,
+    preprocess_input=None,
+    input_dtype=tf.float32,
+    feature_description: Optional[Dict] = None,
+    **kwargs,
+):
+  """
+  Helper function used to save the Tensorflow Model with a serving signature.
+  This is intended only for internal testing.
+
+  Args:
+   path_to_save_model: Path to save the model with modified signature.
+  model: model: Base tensorflow model used for TFX-BSL RunInference transform.
+  preprocess_input: Preprocess method to be included as part of the
+    model's serving signature.
+  input_dtype: tf dtype of the inputs passed to the model.
+    For eg: tf.int32, tf.uint8.
+  feature_description: Feature spec to parse inputs from tf.train.Example using
+    tf.parse_example(). For more details, please take a look at
+    https://www.tensorflow.org/api_docs/python/tf/io/parse_example
+
+  If there are extra arguments(for eg: training=False) that should be passed to
+  the base tf model during inference, please pass them in kwargs.
+  """
+  if not model:
+    model = tf.keras.applications.MobileNetV2(weights='imagenet')
+    preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
+  signature_model = TFModelWrapperWithSignature(
+      model=model,
+      preprocess_input=preprocess_input,
+      input_dtype=input_dtype,
+      feature_description=feature_description,
+      **kwargs)
+  tf.saved_model.save(signature_model, path_to_save_model)
diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/requirements.txt b/sdks/python/apache_beam/examples/inference/tfx_bsl/requirements.txt
new file mode 100644
index 00000000000..ea00cb6e312
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/requirements.txt
@@ -0,0 +1,19 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.s
+#
+
+tfx_bsl
+pillow
diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py
new file mode 100644
index 00000000000..09a70caa4ed
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py
@@ -0,0 +1,194 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A sample pipeline illustrating how to use Apache Beam RunInference
+with TFX_BSL CreateModelHandler API. For more details, please look at
+https://www.tensorflow.org/tfx/tfx_bsl/api_docs/python/tfx_bsl/public/beam/run_inference/CreateModelHandler.
+
+Note: A Tensorflow Model needs to be updated with a @tf.function
+      signature in order to accept bytes as inputs, and should have logic
+      to decode bytes to Tensors that is acceptable by the TensorFlow model.
+      Please take a look at TFModelWrapperWithSignature class in
+      build_tensorflow_model.py on how to modify TF Model's signature
+      and the logic to decode the image tensor.
+"""
+
+import argparse
+import io
+import logging
+import os
+from typing import Iterable
+from typing import Iterator
+from typing import Optional
+from typing import Tuple
+
+import apache_beam as beam
+import tensorflow as tf
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+from PIL import Image
+from tfx_bsl.public.beam.run_inference import CreateModelHandler
+from tfx_bsl.public.beam.run_inference import prediction_log_pb2
+from tfx_bsl.public.proto import model_spec_pb2
+
+_IMG_SIZE = (224, 224)
+
+
+def filter_empty_lines(text: str) -> Iterator[str]:
+  if len(text.strip()) > 0:
+    yield text
+
+
+def read_and_process_image(
+    image_file_name: str,
+    path_to_dir: Optional[str] = None) -> Tuple[str, tf.Tensor]:
+  if path_to_dir is not None:
+    image_file_name = os.path.join(path_to_dir, image_file_name)
+  with FileSystems().open(image_file_name, 'r') as file:
+    data = Image.open(io.BytesIO(file.read())).convert('RGB')
+  # Note: Converts the image dtype from uint8 to float32
+  # https://www.tensorflow.org/api_docs/python/tf/image/resize
+  image = tf.keras.preprocessing.image.img_to_array(data)
+  image = tf.image.resize(image, _IMG_SIZE)
+  return image_file_name, image
+
+
+def convert_image_to_example_proto(tensor: tf.Tensor) -> tf.train.Example:
+  """
+  This method performs the following:
+  1. Accepts the tensor as input
+  2. Serializes the tensor into bytes and pass it through
+        tf.train.Feature
+  3. Pass the serialized tensor feature using tf.train.Example
+      Proto to the RunInference transform.
+
+  Args:
+    tensor: A TF tensor.
+  Returns:
+    example_proto: A tf.train.Example containing serialized tensor.
+  """
+  serialized_non_scalar = tf.io.serialize_tensor(tensor)
+  feature_of_bytes = tf.train.Feature(
+      bytes_list=tf.train.BytesList(value=[serialized_non_scalar.numpy()]))
+  features_for_example = {'image': feature_of_bytes}
+  example_proto = tf.train.Example(
+      features=tf.train.Features(feature=features_for_example))
+  return example_proto
+
+
+class ProcessInferenceToString(beam.DoFn):
+  def process(
+      self, element: Tuple[str,
+                           prediction_log_pb2.PredictionLog]) -> Iterable[str]:
+    """
+    Args:
+      element: Tuple of str, and PredictionLog. Inference can be parsed
+        from prediction_log
+    returns:
+      str of filename and inference.
+    """
+    filename, predict_log = element[0], element[1].predict_log
+    output_value = predict_log.response.outputs
+    output_tensor = (
+        tf.io.decode_raw(
+            output_value['output_0'].tensor_content, out_type=tf.float32))
+    max_index_output_tensor = tf.math.argmax(output_tensor, axis=0)
+    yield filename + ',' + str(tf.get_static_value(max_index_output_tensor))
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input',
+      dest='input',
+      required=True,
+      help='Path to the text file containing image names.')
+  parser.add_argument(
+      '--output',
+      dest='output',
+      required=True,
+      help='Path to save output predictions text file.')
+  parser.add_argument(
+      '--model_path',
+      dest='model_path',
+      required=True,
+      help="Path to the model.")
+  parser.add_argument(
+      '--images_dir',
+      default=None,
+      help='Path to the directory where images are stored.'
+      'Not required if image names in the input file have absolute path.')
+  return parser.parse_known_args(argv)
+
+
+def run(argv=None, save_main_session=True, pipeline=None) -> PipelineResult:
+  """
+  Args:
+    argv: Command line arguments defined for this example.
+    save_main_session: Used for internal testing.
+    test_pipeline: Used for internal testing.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  saved_model_spec = model_spec_pb2.SavedModelSpec(
+      model_path=known_args.model_path)
+  inferece_spec_type = model_spec_pb2.InferenceSpecType(
+      saved_model_spec=saved_model_spec)
+  model_handler = CreateModelHandler(inferece_spec_type)
+  # create a KeyedModelHandler to accommodate image names as keys.
+  keyed_model_handler = KeyedModelHandler(model_handler)
+
+  if not pipeline:
+    pipeline = beam.Pipeline(options=pipeline_options)
+
+  filename_value_pair = (
+      pipeline
+      | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
+      | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
+      | 'ProcessImageData' >> beam.Map(
+          lambda image_name: read_and_process_image(
+              image_file_name=image_name, path_to_dir=known_args.images_dir)))
+
+  predictions = (
+      filename_value_pair
+      | 'ConvertToExampleProto' >>
+      beam.Map(lambda x: (x[0], convert_image_to_example_proto(x[1])))
+      | 'TFXRunInference' >> RunInference(keyed_model_handler)
+      | 'PostProcess' >> beam.ParDo(ProcessInferenceToString()))
+  _ = (
+      predictions
+      | "WriteOutputToGCS" >> beam.io.WriteToText(
+          known_args.output,
+          shard_name_template='',
+          append_trailing_newlines=True))
+
+  result = pipeline.run()
+  result.wait_until_finish()
+  return result
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  run()
diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/tfx_bsl_inference_it_test.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/tfx_bsl_inference_it_test.py
new file mode 100644
index 00000000000..d72794df4f7
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/tfx_bsl_inference_it_test.py
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+import uuid
+
+import pytest
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=ungrouped-imports
+try:
+  import tfx_bsl
+  import tensorflow as tf
+  from apache_beam.examples.inference.tfx_bsl import tensorflow_image_classification
+  from apache_beam.examples.inference.tfx_bsl.build_tensorflow_model import save_tf_model_with_signature
+except ImportError as e:
+  tfx_bsl = None
+# pylint: disable=line-too-long
+_EXPECTED_OUTPUTS = {
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005001.JPEG': '681',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005002.JPEG': '333',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005003.JPEG': '711',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005004.JPEG': '286',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005005.JPEG': '445',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005006.JPEG': '288',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005007.JPEG': '880',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005008.JPEG': '534',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005009.JPEG': '888',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005010.JPEG': '996',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005011.JPEG': '327',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005012.JPEG': '573'
+}
+
+
+def process_outputs(filepath):
+  with FileSystems().open(filepath) as f:
+    lines = f.readlines()
+  lines = [l.decode('utf-8').strip('\n') for l in lines]
+  return lines
+
+
+@unittest.skipIf(
+    tfx_bsl is None, 'Missing dependencies. '
+    'Test depends on tfx_bsl')
+class TFXRunInferenceTests(unittest.TestCase):
+  @pytest.mark.uses_tensorflow
+  @pytest.mark.it_postcommit
+  def test_tfx_run_inference_mobilenetv2(self):
+    test_pipeline = TestPipeline(is_integration_test=True)
+    # Save the Tensorflow model with TF Signature
+    model = tf.keras.applications.MobileNetV2(weights='imagenet')
+    preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
+    path_to_save_model = 'gs://apache-beam-ml/models/tensorflow_models/mobilenet_v2'
+    save_tf_model_with_signature(
+        path_to_save_model, model, preprocess_input, training=False)
+    # Use the saved model to run the test
+    file_of_image_names = (
+        'gs://apache-beam-ml/testing/inputs/it_mobilenetv2_imagenet_validation_inputs.txt'
+    )
+    output_file_dir = 'gs://apache-beam-ml/testing/predictions'
+    output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
+    extra_opts = {
+        'input': file_of_image_names,
+        'output': output_file,
+        'model_path': path_to_save_model,
+    }
+
+    tensorflow_image_classification.run(
+        test_pipeline.get_full_options_as_args(**extra_opts),
+        save_main_session=False)
+
+    self.assertEqual(FileSystems().exists(output_file), True)
+    predictions = process_outputs(filepath=output_file)
+
+    for prediction in predictions:
+      filename, prediction = prediction.split(',')
+      self.assertEqual(_EXPECTED_OUTPUTS[filename], prediction)
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.DEBUG)
+  unittest.main()
diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini
index 0a9a274f1f1..e2ffe74ac27 100644
--- a/sdks/python/pytest.ini
+++ b/sdks/python/pytest.ini
@@ -49,6 +49,7 @@ markers =
     uses_pyarrow: tests that utilize pyarrow in some way
     uses_pytorch: tests that utilize pytorch in some way
     uses_sklearn: tests that utilize scikit-learn in some way
+    uses_tensorflow: tests that utilize tensorflow in some way
 
 # Default timeout intended for unit tests.
 # If certain tests need a different value, please see the docs on how to
diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle
index 7cd53c0a078..cd23645db15 100644
--- a/sdks/python/test-suites/direct/common.gradle
+++ b/sdks/python/test-suites/direct/common.gradle
@@ -215,6 +215,7 @@ task torchInferenceTest {
     }
 
 }
+
 // Scikit-learn RunInference IT tests
 task sklearnInferenceTest {
   dependsOn 'installGcpTest'
@@ -235,10 +236,38 @@ task sklearnInferenceTest {
     }
 }
 
+// TensorFlow Extended (TFX) RunInference IT tests
+task tfxInferenceTest {
+  dependsOn 'installGcpTest'
+  dependsOn ':sdks:python:sdist'
+  def requirementsFile = "${rootDir}/sdks/python/apache_beam/examples/inference/tfx_bsl/requirements.txt"
+  doFirst {
+      exec {
+        executable 'sh'
+        args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile"
+      }
+    }
+  doLast {
+      def testOpts = basicTestOpts
+      def argMap = [
+          "test_opts": testOpts,
+          "suite": "postCommitIT-direct-py${pythonVersionSuffix}",
+          "collect": "uses_tensorflow and it_postcommit" ,
+          "runner": "TestDirectRunner"
+      ]
+      def cmdArgs = mapToArgString(argMap)
+      exec {
+        executable 'sh'
+        args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs"
+      }
+    }
+}
+
 // Add all the RunInference framework IT tests to this gradle task that runs on Direct Runner Post commit suite.
 project.tasks.register("inferencePostCommitIT") {
   dependsOn = [
   'torchInferenceTest',
-  'sklearnInferenceTest'
+  'sklearnInferenceTest',
+  'tfxInferenceTest'
   ]
 }
-- 
GitLab