diff --git a/src/create_labeled_dataset.py b/src/create_labeled_dataset.py index f9e47f42..317fb593 100644 --- a/src/create_labeled_dataset.py +++ b/src/create_labeled_dataset.py @@ -18,6 +18,7 @@ merged with the original Tensorflow Examples in order to create a labeled training and test set. """ +import multiprocessing import random from absl import app @@ -61,6 +62,11 @@ True, 'If true, starts multiple processes to run task.', ) +flags.DEFINE_integer( + 'max_processes', + multiprocessing.cpu_count(), + 'If using multiprocessing, the maximum number of processes to use.', +) def main(unused_argv): @@ -89,7 +95,9 @@ def main(unused_argv): FLAGS.train_output_path, FLAGS.test_output_path, FLAGS.connecting_distance_meters, - FLAGS.use_multiprocessing) + FLAGS.use_multiprocessing, + None, + FLAGS.max_processes) if __name__ == '__main__': diff --git a/src/create_labeling_examples.py b/src/create_labeling_examples.py index 46c2d175..2a30ba18 100644 --- a/src/create_labeling_examples.py +++ b/src/create_labeling_examples.py @@ -30,6 +30,7 @@ """ # pylint: enable=line-too-long +import multiprocessing import sys from absl import app @@ -59,6 +60,11 @@ True, 'If true, starts multiple processes to run task.', ) +flags.DEFINE_integer( + 'max_processes', + multiprocessing.cpu_count(), + 'If using multiprocessing, the maximum number of processes to use.', +) flags.DEFINE_float( 'buffered_sampling_radius', 70.0, @@ -85,6 +91,7 @@ def main(unused_argv): FLAGS.output_dir, FLAGS.use_multiprocessing, None, + FLAGS.max_processes, FLAGS.buffered_sampling_radius, ) diff --git a/src/skai/labeling.py b/src/skai/labeling.py index 9ac1c8bb..c3f27619 100644 --- a/src/skai/labeling.py +++ b/src/skai/labeling.py @@ -213,6 +213,7 @@ def create_labeling_images( output_dir: str, use_multiprocessing: bool, multiprocessing_context: Any, + max_processes: int, buffered_sampling_radius: float, ) -> Tuple[int, Optional[str]]: """Creates PNGs used for labeling from TFRecords. @@ -233,6 +234,7 @@ def create_labeling_images( images. multiprocessing_context: Context to spawn processes with when using multiprocessing. + max_processes: Maximum number of processes. buffered_sampling_radius: The minimum distance between two examples for the two examples to be in the labeling task. @@ -305,7 +307,8 @@ def create_labeling_images( def accumulate(images: list[tuple[int, str, str, str]]) -> None: all_images.extend(images) - num_workers = min(multiprocessing.cpu_count(), len(example_files), 10) + num_workers = min( + multiprocessing.cpu_count(), len(example_files), max_processes) if multiprocessing_context: pool = multiprocessing_context.Pool(num_workers) else: @@ -360,7 +363,7 @@ def _tfrecord_iterator(path: str) -> tf.train.Example: Yields: Examples from the TFRecord file. """ - ds = tf.data.TFRecordDataset([path]) + ds = tf.data.TFRecordDataset([path]).prefetch(tf.data.AUTOTUNE) if tf.executing_eagerly(): for record in ds: example = tf.train.Example() @@ -628,9 +631,7 @@ def _merge_single_example_file_and_labels( List of TF examples merged with labels for a single example_file. """ labeled_examples = [] - for record in tf.data.TFRecordDataset([example_file]): - example = Example() - example.ParseFromString(record.numpy()) + for example in _tfrecord_iterator(example_file): if 'example_id' in example.features.feature: example_id = ( example.features.feature['example_id'].bytes_list.value[0].decode() @@ -683,6 +684,8 @@ def _merge_examples_and_labels( test_output_path: str, connecting_distance_meters: float, use_multiprocessing: bool, + multiprocessing_context: Any, + max_processes: int, ) -> None: """Merges examples with labels into train and test TFRecords. @@ -696,6 +699,9 @@ def _merge_examples_and_labels( connecting_distance_meters: Maximum distance for two points to be connected. use_multiprocessing: If true, create multiple processes to create labeled examples. + multiprocessing_context: Context to spawn processes with when using + multiprocessing. + max_processes: Maximum number of processes. """ example_files = tf.io.gfile.glob(examples_pattern) @@ -708,15 +714,21 @@ def _merge_examples_and_labels( ) if use_multiprocessing: - num_workers = min(multiprocessing.cpu_count(), len(example_files)) - with multiprocessing.Pool(num_workers) as pool_executor: - logging.info('Using multiprocessing with %d processes.', num_workers) - results = pool_executor.map( - functools.partial( - _merge_single_example_file_and_labels, labels=labels - ), - example_files, - ) + num_workers = min( + multiprocessing.cpu_count(), len(example_files), max_processes + ) + if multiprocessing_context: + pool = multiprocessing_context.Pool(num_workers) + else: + pool = multiprocessing.Pool(num_workers) + + logging.info('Using multiprocessing with %d processes.', num_workers) + results = pool.map( + functools.partial( + _merge_single_example_file_and_labels, labels=labels + ), + example_files, + ) else: logging.info('Not using multiprocessing.') results = [ @@ -775,7 +787,9 @@ def create_labeled_examples( train_output_path: str, test_output_path: str, connecting_distance_meters: float, - use_multiprocessing: bool) -> None: + use_multiprocessing: bool, + multiprocessing_context: Any, + max_processes: int) -> None: """Creates a labeled dataset by merging cloud labels and unlabeled examples. Args: @@ -789,6 +803,9 @@ def create_labeled_examples( connecting_distance_meters: Maximum distance for two points to be connected. use_multiprocessing: If true, create multiple processes to create labeled examples. + multiprocessing_context: Context to spawn processes with when using + multiprocessing. + max_processes: Maximum number of processes. """ string_to_numeric_map = {} for string_to_numeric_label in string_to_numeric_labels: @@ -829,4 +846,6 @@ def create_labeled_examples( test_output_path, connecting_distance_meters, use_multiprocessing, + multiprocessing_context, + max_processes )