Skip to content

Commit

Permalink
Add multiprocessing context to create_labeled_examples. Add max numbe…
Browse files Browse the repository at this point in the history
…r of processes parameter.

PiperOrigin-RevId: 609015514
  • Loading branch information
jzxu authored and copybara-github committed Feb 21, 2024
1 parent 40943cf commit eee1b70
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 16 deletions.
10 changes: 9 additions & 1 deletion src/create_labeled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
merged with the original Tensorflow Examples in order to create a labeled
training and test set.
"""
import multiprocessing
import random

from absl import app
Expand Down Expand Up @@ -61,6 +62,11 @@
True,
'If true, starts multiple processes to run task.',
)
flags.DEFINE_integer(
'max_processes',
multiprocessing.cpu_count(),
'If using multiprocessing, the maximum number of processes to use.',
)


def main(unused_argv):
Expand Down Expand Up @@ -89,7 +95,9 @@ def main(unused_argv):
FLAGS.train_output_path,
FLAGS.test_output_path,
FLAGS.connecting_distance_meters,
FLAGS.use_multiprocessing)
FLAGS.use_multiprocessing,
None,
FLAGS.max_processes)


if __name__ == '__main__':
Expand Down
7 changes: 7 additions & 0 deletions src/create_labeling_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"""
# pylint: enable=line-too-long

import multiprocessing
import sys

from absl import app
Expand Down Expand Up @@ -59,6 +60,11 @@
True,
'If true, starts multiple processes to run task.',
)
flags.DEFINE_integer(
'max_processes',
multiprocessing.cpu_count(),
'If using multiprocessing, the maximum number of processes to use.',
)
flags.DEFINE_float(
'buffered_sampling_radius',
70.0,
Expand All @@ -85,6 +91,7 @@ def main(unused_argv):
FLAGS.output_dir,
FLAGS.use_multiprocessing,
None,
FLAGS.max_processes,
FLAGS.buffered_sampling_radius,
)

Expand Down
49 changes: 34 additions & 15 deletions src/skai/labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def create_labeling_images(
output_dir: str,
use_multiprocessing: bool,
multiprocessing_context: Any,
max_processes: int,
buffered_sampling_radius: float,
) -> Tuple[int, Optional[str]]:
"""Creates PNGs used for labeling from TFRecords.
Expand All @@ -233,6 +234,7 @@ def create_labeling_images(
images.
multiprocessing_context: Context to spawn processes with when using
multiprocessing.
max_processes: Maximum number of processes.
buffered_sampling_radius: The minimum distance between two examples for the
two examples to be in the labeling task.
Expand Down Expand Up @@ -305,7 +307,8 @@ def create_labeling_images(
def accumulate(images: list[tuple[int, str, str, str]]) -> None:
all_images.extend(images)

num_workers = min(multiprocessing.cpu_count(), len(example_files), 10)
num_workers = min(
multiprocessing.cpu_count(), len(example_files), max_processes)
if multiprocessing_context:
pool = multiprocessing_context.Pool(num_workers)
else:
Expand Down Expand Up @@ -360,7 +363,7 @@ def _tfrecord_iterator(path: str) -> tf.train.Example:
Yields:
Examples from the TFRecord file.
"""
ds = tf.data.TFRecordDataset([path])
ds = tf.data.TFRecordDataset([path]).prefetch(tf.data.AUTOTUNE)
if tf.executing_eagerly():
for record in ds:
example = tf.train.Example()
Expand Down Expand Up @@ -628,9 +631,7 @@ def _merge_single_example_file_and_labels(
List of TF examples merged with labels for a single example_file.
"""
labeled_examples = []
for record in tf.data.TFRecordDataset([example_file]):
example = Example()
example.ParseFromString(record.numpy())
for example in _tfrecord_iterator(example_file):
if 'example_id' in example.features.feature:
example_id = (
example.features.feature['example_id'].bytes_list.value[0].decode()
Expand Down Expand Up @@ -683,6 +684,8 @@ def _merge_examples_and_labels(
test_output_path: str,
connecting_distance_meters: float,
use_multiprocessing: bool,
multiprocessing_context: Any,
max_processes: int,
) -> None:
"""Merges examples with labels into train and test TFRecords.
Expand All @@ -696,6 +699,9 @@ def _merge_examples_and_labels(
connecting_distance_meters: Maximum distance for two points to be connected.
use_multiprocessing: If true, create multiple processes to create labeled
examples.
multiprocessing_context: Context to spawn processes with when using
multiprocessing.
max_processes: Maximum number of processes.
"""
example_files = tf.io.gfile.glob(examples_pattern)

Expand All @@ -708,15 +714,21 @@ def _merge_examples_and_labels(
)

if use_multiprocessing:
num_workers = min(multiprocessing.cpu_count(), len(example_files))
with multiprocessing.Pool(num_workers) as pool_executor:
logging.info('Using multiprocessing with %d processes.', num_workers)
results = pool_executor.map(
functools.partial(
_merge_single_example_file_and_labels, labels=labels
),
example_files,
)
num_workers = min(
multiprocessing.cpu_count(), len(example_files), max_processes
)
if multiprocessing_context:
pool = multiprocessing_context.Pool(num_workers)
else:
pool = multiprocessing.Pool(num_workers)

logging.info('Using multiprocessing with %d processes.', num_workers)
results = pool.map(
functools.partial(
_merge_single_example_file_and_labels, labels=labels
),
example_files,
)
else:
logging.info('Not using multiprocessing.')
results = [
Expand Down Expand Up @@ -775,7 +787,9 @@ def create_labeled_examples(
train_output_path: str,
test_output_path: str,
connecting_distance_meters: float,
use_multiprocessing: bool) -> None:
use_multiprocessing: bool,
multiprocessing_context: Any,
max_processes: int) -> None:
"""Creates a labeled dataset by merging cloud labels and unlabeled examples.
Args:
Expand All @@ -789,6 +803,9 @@ def create_labeled_examples(
connecting_distance_meters: Maximum distance for two points to be connected.
use_multiprocessing: If true, create multiple processes to create labeled
examples.
multiprocessing_context: Context to spawn processes with when using
multiprocessing.
max_processes: Maximum number of processes.
"""
string_to_numeric_map = {}
for string_to_numeric_label in string_to_numeric_labels:
Expand Down Expand Up @@ -829,4 +846,6 @@ def create_labeled_examples(
test_output_path,
connecting_distance_meters,
use_multiprocessing,
multiprocessing_context,
max_processes
)

0 comments on commit eee1b70

Please sign in to comment.