Skip to content

Commit

Permalink
Add criteo benchmark for MLTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Jan 30, 2024
1 parent b9fd39c commit 881a6cf
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _publish_metrics(pipeline, metric_value, metrics_table, metric_name):
)])


@unittest.skip('Uncomment this line to run this test.')
@pytest.mark.uses_tft
class LargeMovieReviewDatasetProcessTest(unittest.TestCase):
def test_process_large_movie_review_dataset(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _publish_metrics(pipeline, metric_value, metrics_table, metric_name):
)])


@unittest.skip('Remve this line to run this test')
@pytest.mark.uses_tft
class CloudMLTFTBenchmarkTest(unittest.TestCase):
def test_cloudml_benchmark_criteo_small(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: skip-file

import logging
import argparse
import numpy as np

import apache_beam as beam
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary
from apache_beam.ml.transforms.tft import Bucketize
from apache_beam.options.pipeline_options import PipelineOptions

NUM_NUMERIC_FEATURES = 13
# Number of buckets for integer columns.
_NUM_BUCKETS = 10
csv_delimiter = '\t'

NUMERIC_FEATURE_KEYS = ["int_feature_%d" % x for x in range(1, 14)]
CATEGORICAL_FEATURE_KEYS = ["categorical_feature_%d" % x for x in range(14, 40)]
LABEL_KEY = "clicked"


class FillMissing(beam.DoFn):
"""Fills missing elements with zero string value."""
def process(self, element):
elem_list = element.split(csv_delimiter)
out_list = []
for val in elem_list:
new_val = "0" if not val else val
out_list.append(new_val)
yield (csv_delimiter).join(out_list)


class NegsToZeroLog(beam.DoFn):
"""For int features, sets negative values to zero and takes log(x+1)."""
def process(self, element):
elem_list = element.split(csv_delimiter)
out_list = []
for i, val in enumerate(elem_list):
if 0 < i <= NUM_NUMERIC_FEATURES:
val = "0" if int(val) < 0 else val
val = str(np.log(int(val) + 1))
out_list.append(val)
yield (csv_delimiter).join(out_list)


def convert_str_to_int(element):
for key, value in element.items():
if key in NUMERIC_FEATURE_KEYS:
element[key] = float(value)
return element


def parse_known_args(argv):
parser = argparse.ArgumentParser()
parser.add_argument(
'--input',
default='/usr/local/google/home/anandinguva/Downloads/train.txt')
parser.add_argument(
"--artifact_location", help="Artifact location to store artifacts.")
return parser.parse_known_args(argv)


def run(argv=None, ):
known_args, pipeline_args = parse_known_args(argv)
options = PipelineOptions(flags=pipeline_args)
data_path = known_args.input
ordered_columns = [
LABEL_KEY
] + NUMERIC_FEATURE_KEYS + CATEGORICAL_FEATURE_KEYS
with beam.Pipeline(options=options) as pipeline:
processed_lines = (
pipeline
# Read in TSV data.
| beam.io.ReadFromText(data_path, coder=beam.coders.StrUtf8Coder())
# Fill in missing elements with the defaults (zeros).
| "FillMissing" >> beam.ParDo(FillMissing())
# For numerical features, set negatives to zero. Then take log(x+1).
| "NegsToZeroLog" >> beam.ParDo(NegsToZeroLog())
| beam.Map(lambda x: str(x).split(csv_delimiter))
# Creates 50 GB data.
| beam.Map(lambda x: {ordered_columns[i]: x[i]
for i in range(len(x))})
| beam.Map(convert_str_to_int))

# processed_lines | beam.Map(logging.info)

artifact_location = known_args.artifact_location
if not artifact_location:
import tempfile
artifact_location = tempfile.mkdtemp(prefix='criteo-mltransform-')
ml_transform = MLTransform(write_artifact_location=artifact_location)
ml_transform.with_transform(
ComputeAndApplyVocabulary(columns=CATEGORICAL_FEATURE_KEYS))
ml_transform.with_transform(
Bucketize(columns=NUMERIC_FEATURE_KEYS, num_buckets=_NUM_BUCKETS))

transformed_lines = (processed_lines | 'MLTransform' >> ml_transform)

# _ = transformed_lines | beam.Map(logging.info)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pylint: skip-file

import logging
import os
# import time
import unittest
import uuid

import pytest

try:
import apache_beam.testing.benchmarks.cloudml.cloudml_benchmark_constants_lib as constants
# from apache_beam.examples.ml_transform import vocab_tfidf_processing
from apache_beam.testing.load_tests.load_test_metrics_utils import InfluxDBMetricsPublisherOptions
from apache_beam.testing.load_tests.load_test_metrics_utils import MetricsReader
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.examples.ml_transform import criteo
except ImportError: # pylint: disable=bare-except
raise unittest.SkipTest('tensorflow_transform is not installed.')

_INPUT_GCS_BUCKET_ROOT = 'gs://apache-beam-ml/datasets/cloudml/criteo'
_OUTPUT_GCS_BUCKET_ROOT = 'gs://temp-storage-for-end-to-end-tests/tft/'
_DISK_SIZE = 150


@pytest.mark.uses_tft
class CriteoTest(unittest.TestCase):
def test_process_criteo_10GB_dataset(self):
test_pipeline = TestPipeline(is_integration_test=True)
extra_opts = {}

# beam pipeline options
extra_opts['input'] = os.path.join(
_INPUT_GCS_BUCKET_ROOT, constants.INPUT_CRITEO_10GB)
extra_opts['artifact_location'] = os.path.join(
_OUTPUT_GCS_BUCKET_ROOT, uuid.uuid4().hex)

extra_opts['frequency_threshold'] = 0

# dataflow pipeliens options
extra_opts['disk_size_gb'] = _DISK_SIZE
extra_opts['machine_type'] = 'e2-highmem-2'
extra_opts['job_name'] = (
'mltransform-criteo-dataset-{}-10'.format(uuid.uuid4().hex))
# start_time = time.time()
criteo.run(
test_pipeline.get_full_options_as_args(
**extra_opts, save_main_session=False))
# end_time = time.time()


if __name__ == '__main__':
unittest.main()

0 comments on commit 881a6cf

Please sign in to comment.