From ea3faa479c526c07e3828f2646fed9e4623bac79 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Wed, 13 Dec 2023 10:41:21 -0500 Subject: [PATCH] Refactor pipeline --- .../testing/benchmarks/mltransform/criteo.py | 28 ++++++++----------- .../benchmarks/mltransform/criteo_test.py | 3 +- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/sdks/python/apache_beam/testing/benchmarks/mltransform/criteo.py b/sdks/python/apache_beam/testing/benchmarks/mltransform/criteo.py index 50920e9372c9..7f7c4772672d 100644 --- a/sdks/python/apache_beam/testing/benchmarks/mltransform/criteo.py +++ b/sdks/python/apache_beam/testing/benchmarks/mltransform/criteo.py @@ -77,7 +77,7 @@ def parse_known_args(argv): return parser.parse_known_args(argv) -def run(argv=None, ): +def run(argv=None): known_args, pipeline_args = parse_known_args(argv) options = PipelineOptions(flags=pipeline_args) data_path = known_args.input @@ -99,21 +99,17 @@ def run(argv=None, ): for i in range(len(x))}) | beam.Map(convert_str_to_int)) - # processed_lines | beam.Map(logging.info) - - artifact_location = known_args.artifact_location - if not artifact_location: - import tempfile - artifact_location = tempfile.mkdtemp(prefix='criteo-mltransform-') - ml_transform = MLTransform(write_artifact_location=artifact_location) - ml_transform.with_transform( - ComputeAndApplyVocabulary(columns=CATEGORICAL_FEATURE_KEYS)) - ml_transform.with_transform( - Bucketize(columns=NUMERIC_FEATURE_KEYS, num_buckets=_NUM_BUCKETS)) - - transformed_lines = (processed_lines | 'MLTransform' >> ml_transform) - - # _ = transformed_lines | beam.Map(logging.info) + transformed_lines = ( + processed_lines + | "MLTransform" >> + MLTransform(write_artifact_location=known_args.artifact_location). + with_transform( + ComputeAndApplyVocabulary( + columns=CATEGORICAL_FEATURE_KEYS, frequency_threshold=5) + ).with_transform( + Bucketize(columns=NUMERIC_FEATURE_KEYS, num_buckets=_NUM_BUCKETS))) + + transformed_lines | beam.Map(logging.info) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/testing/benchmarks/mltransform/criteo_test.py b/sdks/python/apache_beam/testing/benchmarks/mltransform/criteo_test.py index cf2a1ef27acc..32fca0e51e9f 100644 --- a/sdks/python/apache_beam/testing/benchmarks/mltransform/criteo_test.py +++ b/sdks/python/apache_beam/testing/benchmarks/mltransform/criteo_test.py @@ -47,9 +47,10 @@ def test_process_criteo_10GB_dataset(self): # beam pipeline options extra_opts['input'] = os.path.join( _INPUT_GCS_BUCKET_ROOT, constants.INPUT_CRITEO_10GB) + logging.info("#################") extra_opts['artifact_location'] = os.path.join( _OUTPUT_GCS_BUCKET_ROOT, uuid.uuid4().hex) - + logging.info(extra_opts['artifact_location']) extra_opts['frequency_threshold'] = 0 extra_opts['job_name'] = (