From b873be8bdb94757c34e27e59e2640bc540a49791 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Wed, 21 Aug 2024 17:10:13 +0100 Subject: [PATCH] Add timeout to runinference (#32237) * [WIP] Add timeout to runinference * More implementation including plumbing through params * Add tests + fix problems exposed by tests * Add garbage collection * typing + feature scoping * Python version * Update CHANGES * Feedback * lint --- CHANGES.md | 21 +- sdks/python/apache_beam/ml/inference/base.py | 222 ++++++++++++++++-- .../apache_beam/ml/inference/base_test.py | 216 +++++++++++++++++ 3 files changed, 426 insertions(+), 33 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 7c813d9e5e3f..357e25ef46eb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -57,12 +57,10 @@ ## Highlights -* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). -* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). +* Added support for setting a configureable timeout when loading a model and performing inference in the [RunInference](https://beam.apache.org/documentation/ml/inference-overview/) transform using [with_exception_handling](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference.with_exception_handling) ([#32137](https://github.com/apache/beam/issues/32137)) ## I/Os -* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Improvements to the performance of BigqueryIO when using withPropagateSuccessfulStorageApiWrites(true) method (Java) ([#31840](https://github.com/apache/beam/pull/31840)). * [Managed Iceberg] Added support for writing to partitioned tables ([#32102](https://github.com/apache/beam/pull/32102)) * Update ClickHouseIO to use the latest version of the ClickHouse JDBC driver ([#32228](https://github.com/apache/beam/issues/32228)). @@ -70,21 +68,13 @@ ## New Features / Improvements -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * BigQuery endpoint can be overridden via PipelineOptions, this enables BigQuery emulators (Java) ([#28149](https://github.com/apache/beam/issues/28149)). * Go SDK Minimum Go Version updated to 1.21 ([#32092](https://github.com/apache/beam/pull/32092)). * [BigQueryIO] Added support for withFormatRecordOnFailureFunction() for STORAGE_WRITE_API and STORAGE_API_AT_LEAST_ONCE methods (Java) ([#31354](https://github.com/apache/beam/issues/31354)). * Updated Go protobuf package to new version (Go) ([#21515](https://github.com/apache/beam/issues/21515)). +* Added support for setting a configureable timeout when loading a model and performing inference in the [RunInference](https://beam.apache.org/documentation/ml/inference-overview/) transform using [with_exception_handling](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference.with_exception_handling) ([#32137](https://github.com/apache/beam/issues/32137)) * Adds OrderedListState support for Java SDK via FnApi. -## Breaking Changes - -* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). - -## Deprecations - -* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). - ## Bugfixes * Fixed incorrect service account impersonation flow for Python pipelines using BigQuery IOs ([#32030](https://github.com/apache/beam/issues/32030)). @@ -92,13 +82,6 @@ * (Python) Upgraded google-cloud-storage to version 2.18.2 to fix a data corruption issue ([#32135](https://github.com/apache/beam/pull/32135)). * (Go) Fix corruption on State API writes. ([#32245](https://github.com/apache/beam/issues/32245)). -## Security Fixes -* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). - -## Known Issues - -* ([#X](https://github.com/apache/beam/issues/X)). - # [2.58.1] - 2024-08-15 ## New Features / Improvements diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 29a568def07b..b075f22f15be 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -38,6 +38,8 @@ from collections import defaultdict from copy import deepcopy from dataclasses import dataclass +from datetime import datetime +from datetime import timedelta from typing import Any from typing import Callable from typing import Dict @@ -48,6 +50,7 @@ from typing import NamedTuple from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import TypeVar from typing import Union @@ -328,6 +331,13 @@ def override_metrics(self, metrics_namespace: str = '') -> bool: metrics.""" return False + def should_garbage_collect_on_timeout(self) -> bool: + """Whether the model should be garbage collected if model loading or + inference timeout, or if it should be left for future calls. Usually should + not be overriden unless the model handler implements other mechanisms for + garbage collection.""" + return self.share_model_across_processes() + class _ModelManager: """ @@ -829,6 +839,9 @@ def override_metrics(self, metrics_namespace: str = '') -> bool: return True + def should_garbage_collect_on_timeout(self) -> bool: + return self._single_model and self.share_model_across_processes() + class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Union[ExampleT, Tuple[KeyT, @@ -1168,6 +1181,8 @@ def __init__( self._metrics_namespace = metrics_namespace self._model_metadata_pcoll = model_metadata_pcoll self._with_exception_handling = False + self._exception_handling_timeout = None + self._timeout = None self._watch_model_pattern = watch_model_pattern self._kwargs = kwargs # Generate a random tag to use for shared.py and multi_process_shared.py to @@ -1230,7 +1245,8 @@ def _apply_fns( fn).with_exception_handling( exc_class=self._exc_class, use_subprocess=self._use_subprocess, - threshold=self._threshold)) + threshold=self._threshold, + timeout = self._timeout)) bad_preprocessed.append(bad) else: pcoll = pcoll | f"{step_prefix}-{idx}" >> beam.Map(fn) @@ -1269,12 +1285,17 @@ def expand( # batching DoFn APIs. | beam.BatchElements(**self._model_handler.batch_elements_kwargs())) + # Skip loading in setup if we are dependent on side inputs or we want to + # enforce a timeout since neither of these are available in a helpful way + # in setup. + load_model_at_runtime = ( + self._model_metadata_pcoll is not None or self._timeout is not None) run_inference_pardo = beam.ParDo( _RunInferenceDoFn( self._model_handler, self._clock, self._metrics_namespace, - self._model_metadata_pcoll is not None, + load_model_at_runtime, self._model_tag), self._inference_args, beam.pvalue.AsSingleton( @@ -1283,13 +1304,35 @@ def expand( **resource_hints) if self._with_exception_handling: + # On timeouts, report back to the central model metadata + # that the model is invalid + model_tag = self._model_tag + share_across_processes = self._model_handler.share_model_across_processes( + ) + timeout = self._timeout + + def failure_callback(exception: Exception, element: Any): + if type(exception) is not TimeoutError: + return + model_metadata = load_model_status(model_tag, share_across_processes) + model_metadata.try_mark_current_model_invalid(timeout) + logging.warning( + "Inference failed with a timeout, marking the current " + + "model for garbage collection") + + callback = None + if (self._timeout is not None and + self._model_handler.should_garbage_collect_on_timeout()): + callback = failure_callback results, bad_inference = ( batched_elements_pcoll | 'BeamML_RunInference' >> run_inference_pardo.with_exception_handling( exc_class=self._exc_class, use_subprocess=self._use_subprocess, - threshold=self._threshold)) + threshold=self._threshold, + timeout = self._timeout, + on_failure_callback=callback)) else: results = ( batched_elements_pcoll @@ -1305,7 +1348,12 @@ def expand( return results def with_exception_handling( - self, *, exc_class=Exception, use_subprocess=False, threshold=1): + self, + *, + exc_class=Exception, + use_subprocess=False, + threshold=1, + timeout: Optional[int] = None): """Automatically provides a dead letter output for skipping bad records. This can allow a pipeline to continue successfully rather than fail or continuously throw errors on retry when bad elements are encountered. @@ -1359,11 +1407,22 @@ def with_exception_handling( threshold: An upper bound on the ratio of records that can be bad before aborting the entire pipeline. Optional, defaults to 1.0 (meaning up to 100% of records can be bad and the pipeline will still succeed). + timeout: The maximum amount of time in seconds given to load a model, run + inference on a batch of elements and perform and pre/postprocessing + operations. Since the timeout applies in multiple places, it should + be equal to the maximum possible timeout for any of these operations. + Note in particular that model load and inference on a single batch + count to the same timeout value. When an inference fails, all related + resources, including the model, will be deleted and reloaded. As a + result, it is recommended to leave significant buffer and set the + timeout to at least `2 * (time to load model + time to run + inference on a batch of data)`. """ self._with_exception_handling = True self._exc_class = exc_class self._use_subprocess = use_subprocess self._threshold = threshold + self._timeout = timeout return self @@ -1445,6 +1504,111 @@ def next_model_index(self, num_models): return self._cur_index +class _ModelStatus(): + """A class holding any metadata about a model required by RunInference. + + Currently, this only includes whether or not the model is valid. Uses the + model tag to map models to metadata. + """ + def __init__(self, share_model_across_processes: bool): + self._active_tags = set() + self._invalid_tags = set() + self._tag_mapping = {} + self._model_first_seen = {} + self._pending_hard_delete = [] + self._share_model_across_process = share_model_across_processes + + def try_mark_current_model_invalid(self, min_model_life_seconds): + """Mark the current model invalid. + + Since we don't have sufficient information to say which model is being + marked invalid, but there may be multiple active models, we will mark all + models currently in use as inactive so that they all get reloaded. To + avoid thrashing, however, we will only mark models as invalid if they've + been active at least min_model_life_seconds seconds. + """ + cutoff_time = datetime.now() - timedelta(seconds=min_model_life_seconds) + for tag in list(self._active_tags): + if cutoff_time >= self._model_first_seen[tag]: + self._invalid_tags.add(tag) + # Delete old models after a grace period of 2 * the model life. + # This already happens automatically for shared.Shared models, so + # cleanup is only necessary for multi_process_shared models. + if self._share_model_across_process: + self._pending_hard_delete.append(( + tag, + datetime.now() + 2 * timedelta(seconds=min_model_life_seconds))) + self._active_tags.remove(tag) + + def get_valid_tag(self, tag: str) -> str: + """Takes in a proposed valid tag and returns a valid one. + + Will always return a valid tag. If the passed in tag is valid, this + function will simply return it, otherwise it will deterministically + generate a new tag to use instead. The new tag will be the original tag + with an incrementing suffix (e.g. `my_tag_reload_1`, `my_tag_reload_2`) + for each reload + """ + if tag not in self._invalid_tags: + if tag not in self._model_first_seen: + self._model_first_seen[tag] = datetime.now() + self._active_tags.add(tag) + return tag + if (tag in self._tag_mapping and + self._tag_mapping[tag] not in self._invalid_tags): + return self._tag_mapping[tag] + i = 1 + new_tag = f'{tag}_reload_{i}' + while new_tag in self._invalid_tags: + i += 1 + new_tag = f'{tag}_reload_{i}' + self._tag_mapping[tag] = new_tag + self._model_first_seen[new_tag] = datetime.now() + self._active_tags.add(new_tag) + return new_tag + + def is_valid_tag(self, tag: str) -> bool: + return tag == self.get_valid_tag(tag) + + def get_tags_for_garbage_collection(self) -> List[str]: + # Since this function may be in multi_process_shared space, delegate model + # deletion to the calling process which is not to avoid having a + # multi_process_shared reference in multi_process_shared space, which + # can create issues with python's multiprocessing module. + # We will rely on the calling process to report back deleted models so that + # we can confirm deletion. + to_delete = [] + cur_time = datetime.now() + for i in range(len(self._pending_hard_delete)): + delete_time = self._pending_hard_delete[i][1] + tag = self._pending_hard_delete[i][0] + if delete_time <= cur_time: + to_delete.append(tag) + else: + # early return once we hit a model which was added later since models + # are added in order. + return to_delete + + return to_delete + + def mark_tags_deleted(self, deleted_tags: Set[str]): + while len(self._pending_hard_delete) > 0: + tag = self._pending_hard_delete[0][0] + if tag in deleted_tags: + self._pending_hard_delete.pop(0) + else: + return + + +def load_model_status( + model_tag: str, share_across_processes: bool) -> _ModelStatus: + tag = f'{model_tag}_model_status' + if share_across_processes: + return multi_process_shared.MultiProcessShared( + lambda: _ModelStatus(True), tag=tag, always_proxy=True).acquire() + return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag) + + class _SharedModelWrapper(): """A router class to map incoming calls to the correct model. @@ -1477,7 +1641,7 @@ def __init__( model_handler: ModelHandler[ExampleT, PredictionT, Any], clock, metrics_namespace, - enable_side_input_loading: bool = False, + load_model_at_runtime: bool = False, model_tag: str = "RunInference"): """A DoFn implementation generic to frameworks. @@ -1485,8 +1649,10 @@ def __init__( model_handler: An implementation of ModelHandler. clock: A clock implementing time_ns. *Used for unit testing.* metrics_namespace: Namespace of the transform to collect metrics. - enable_side_input_loading: Bool to indicate if model updates - with side inputs. + load_model_at_runtime: Bool to indicate if model loading should be + deferred to runtime - for example if we are depending on side + inputs to get the model path or we want to enforce a timeout on + model loading. model_tag: Tag to use to disambiguate models in multi-model settings. """ self._model_handler = model_handler @@ -1494,9 +1660,12 @@ def __init__( self._clock = clock self._model = None self._metrics_namespace = metrics_namespace - self._enable_side_input_loading = enable_side_input_loading + self._load_model_at_runtime = load_model_at_runtime self._side_input_path = None + # _model_tag is the original tag passed in. + # _cur_tag is the tag of the actually loaded model self._model_tag = model_tag + self._cur_tag = model_tag def _load_model( self, @@ -1529,16 +1698,19 @@ def load(): model_tag = self._model_tag if isinstance(side_input_model_path, str) and side_input_model_path != '': model_tag = side_input_model_path + # Ensure the tag we're loading is valid, if not replace it with a valid tag + self._cur_tag = self._model_metadata.get_valid_tag(model_tag) if self._model_handler.share_model_across_processes(): models = [] - for i in range(self._model_handler.model_copies()): + for copy_tag in _get_tags_for_copies(self._cur_tag, + self._model_handler.model_copies()): models.append( multi_process_shared.MultiProcessShared( - load, tag=f'{model_tag}{i}', always_proxy=True).acquire()) - model_wrapper = _SharedModelWrapper(models, model_tag) + load, tag=copy_tag, always_proxy=True).acquire()) + model_wrapper = _SharedModelWrapper(models, self._cur_tag) else: - model = self._shared_model_handle.acquire(load, tag=model_tag) - model_wrapper = _SharedModelWrapper([model], model_tag) + model = self._shared_model_handle.acquire(load, tag=self._cur_tag) + model_wrapper = _SharedModelWrapper([model], self._cur_tag) # since shared_model_handle is shared across threads, the model path # might not get updated in the model handler # because we directly get cached weak ref model from shared cache, instead @@ -1568,7 +1740,9 @@ def get_metrics_collector(self, prefix: str = ''): def setup(self): self._metrics_collector = self.get_metrics_collector() self._model_handler.set_environment_vars() - if not self._enable_side_input_loading: + self._model_metadata = load_model_status( + self._model_tag, self._model_handler.share_model_across_processes()) + if not self._load_model_at_runtime: self._model = self._load_model() def update_model( @@ -1621,7 +1795,20 @@ def process( side input is empty or the model has not been updated, the method simply runs inference on the batch of data. """ + # Do garbage collection of old models before starting inference. + tags_to_gc = self._model_metadata.get_tags_for_garbage_collection() + if len(tags_to_gc) > 0: + for unprefixed_tag in tags_to_gc: + for tag in _get_tags_for_copies(unprefixed_tag, + self._model_handler.model_copies()): + multi_process_shared.MultiProcessShared(lambda: None, + tag).unsafe_hard_delete() + self._model_metadata.mark_tags_deleted(tags_to_gc) + if not si_model_metadata: + if (not self._model_metadata.is_valid_tag(self._cur_tag) or + self._model is None): + self.update_model(side_input_model_path=None) return self._run_inference(batch, inference_args) if isinstance(si_model_metadata, beam.pvalue.EmptySideInput): @@ -1668,3 +1855,10 @@ def _get_current_process_memory_in_bytes(): 'Resource module is not available for current platform, ' 'memory usage cannot be fetched.') return 0 + + +def _get_tags_for_copies(base_tag, num_copies): + tags = [] + for i in range(num_copies): + tags.append(f'{base_tag}{i}') + return tags diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index b392ebd3085a..359a372bc5b0 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -19,6 +19,8 @@ import math import os import pickle +import sys +import tempfile import time import unittest from typing import Any @@ -63,6 +65,22 @@ def increment_state(self, amount: int): self._state += amount +class FakeSlowModel: + def __init__(self, sleep_on_load_seconds=0, file_path_write_on_del=None): + + self._file_path_write_on_del = file_path_write_on_del + time.sleep(sleep_on_load_seconds) + + def predict(self, example: int) -> int: + time.sleep(example) + return example + + def __del__(self): + if self._file_path_write_on_del is not None: + with open(self._file_path_write_on_del, 'a') as myfile: + myfile.write('Deleted FakeSlowModel') + + class FakeIncrementingModel: def __init__(self): self._state = 0 @@ -72,6 +90,34 @@ def predict(self, example: int) -> int: return self._state +class FakeSlowModelHandler(base.ModelHandler[int, int, FakeModel]): + def __init__( + self, + sleep_on_load: int, + multi_process_shared=False, + file_path_write_on_del=None): + self._sleep_on_load = sleep_on_load + self._multi_process_shared = multi_process_shared + self._file_path_write_on_del = file_path_write_on_del + + def load_model(self): + return FakeSlowModel(self._sleep_on_load, self._file_path_write_on_del) + + def run_inference( + self, + batch: Sequence[int], + model: FakeModel, + inference_args=None) -> Iterable[int]: + for example in batch: + yield model.predict(example) + + def share_model_across_processes(self) -> bool: + return self._multi_process_shared + + def batch_elements_kwargs(self): + return {'min_batch_size': 1, 'max_batch_size': 1} + + class FakeModelHandler(base.ModelHandler[int, int, FakeModel]): def __init__( self, @@ -783,6 +829,89 @@ def test_run_inference_impl_dlq(self): assert_that( bad_without_error, equal_to(expected_bad), label='assert:failures') + def test_run_inference_timeout_on_load_dlq(self): + with TestPipeline() as pipeline: + examples = [1, 2] + expected_good = [] + expected_bad = [1, 2] + pcoll = pipeline | 'start' >> beam.Create(examples) + main, other = pcoll | base.RunInference( + FakeSlowModelHandler(10)).with_exception_handling(timeout=2) + assert_that(main, equal_to(expected_good), label='assert:inferences') + + # bad.failed_inferences will be in form [batch[elements], error]. + # Just pull out bad element. + bad_without_error = other.failed_inferences | beam.Map(lambda x: x[0][0]) + assert_that( + bad_without_error, equal_to(expected_bad), label='assert:failures') + + def test_run_inference_timeout_on_inference_dlq(self): + with TestPipeline() as pipeline: + examples = [10, 11] + expected_good = [] + expected_bad = [10, 11] + pcoll = pipeline | 'start' >> beam.Create(examples) + main, other = pcoll | base.RunInference( + FakeSlowModelHandler(0)).with_exception_handling(timeout=5) + assert_that(main, equal_to(expected_good), label='assert:inferences') + + # bad.failed_inferences will be in form [batch[elements], error]. + # Just pull out bad element. + bad_without_error = other.failed_inferences | beam.Map(lambda x: x[0][0]) + assert_that( + bad_without_error, equal_to(expected_bad), label='assert:failures') + + def test_run_inference_timeout_not_hit(self): + with TestPipeline() as pipeline: + examples = [1, 2] + expected_good = [1, 2] + expected_bad = [] + pcoll = pipeline | 'start' >> beam.Create(examples) + main, other = pcoll | base.RunInference( + FakeSlowModelHandler(3)).with_exception_handling(timeout=500) + assert_that(main, equal_to(expected_good), label='assert:inferences') + + # bad.failed_inferences will be in form [batch[elements], error]. + # Just pull out bad element. + bad_without_error = other.failed_inferences | beam.Map(lambda x: x[0][0]) + assert_that( + bad_without_error, equal_to(expected_bad), label='assert:failures') + + @unittest.skipIf( + sys.version_info < (3, 11), + "This test relies on the __del__ lifecycle method, but __del__ does " + + "not get invoked in the same way on older versions of Python, " + + "breaking this test. See " + + "github.com/python/cpython/issues/87950#issuecomment-1807570983 " + + "for example.") + def test_run_inference_timeout_does_garbage_collection(self): + with tempfile.TemporaryDirectory() as tmp_dirname: + tmp_path = os.path.join(tmp_dirname, 'tmp_filename') + expected_file_contents = 'Deleted FakeSlowModel' + with TestPipeline() as pipeline: + # Start with bad example which gets timed out. + # Then provide plenty of time for GC to happen. + examples = [20] + [1] * 15 + expected_good = [1] * 15 + expected_bad = [20] + pcoll = pipeline | 'start' >> beam.Create(examples) + main, other = pcoll | base.RunInference( + FakeSlowModelHandler( + 0, True, tmp_path)).with_exception_handling(timeout=5) + + assert_that(main, equal_to(expected_good), label='assert:inferences') + + # # bad.failed_inferences will be in form [batch[elements], error]. + # # Just pull out bad element. + bad_without_error = other.failed_inferences | beam.Map( + lambda x: x[0][0]) + assert_that( + bad_without_error, equal_to(expected_bad), label='assert:failures') + + with open(tmp_path) as f: + s = f.read() + self.assertEqual(s, expected_file_contents) + def test_run_inference_impl_inference_args(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] @@ -1653,6 +1782,93 @@ def test_run_inference_watch_file_pattern_keyword_arg_side_input_label(self): FakeModelHandler(), watch_model_pattern='fake/path/*') assert side_input_str in str(result_pcoll.producer.side_inputs[0]) + def test_model_status_provides_valid_tags(self): + ms = base._ModelStatus(False) + + self.assertTrue(ms.is_valid_tag('tag1')) + self.assertTrue(ms.is_valid_tag('tag2')) + + self.assertEqual('tag1', ms.get_valid_tag('tag1')) + self.assertEqual('tag2', ms.get_valid_tag('tag2')) + + self.assertTrue(ms.is_valid_tag('tag1')) + self.assertTrue(ms.is_valid_tag('tag2')) + + ms.try_mark_current_model_invalid(0) + + self.assertFalse(ms.is_valid_tag('tag1')) + self.assertFalse(ms.is_valid_tag('tag2')) + + self.assertEqual('tag1_reload_1', ms.get_valid_tag('tag1')) + self.assertEqual('tag2_reload_1', ms.get_valid_tag('tag2')) + + self.assertTrue(ms.is_valid_tag('tag1_reload_1')) + self.assertTrue(ms.is_valid_tag('tag2_reload_1')) + + ms.try_mark_current_model_invalid(0) + + self.assertFalse(ms.is_valid_tag('tag1')) + self.assertFalse(ms.is_valid_tag('tag2')) + self.assertFalse(ms.is_valid_tag('tag1_reload_1')) + self.assertFalse(ms.is_valid_tag('tag2_reload_1')) + + self.assertEqual('tag1_reload_2', ms.get_valid_tag('tag1')) + self.assertEqual('tag2_reload_2', ms.get_valid_tag('tag2')) + + self.assertTrue(ms.is_valid_tag('tag1_reload_2')) + self.assertTrue(ms.is_valid_tag('tag2_reload_2')) + + def test_model_status_provides_valid_garbage_collection(self): + ms = base._ModelStatus(True) + + self.assertTrue(ms.is_valid_tag('mspvgc_tag1')) + self.assertTrue(ms.is_valid_tag('mspvgc_tag2')) + + ms.try_mark_current_model_invalid(1) + + tags = ms.get_tags_for_garbage_collection() + + self.assertEqual(0, len(tags)) + + time.sleep(4) + + ms.try_mark_current_model_invalid(1) + + time.sleep(4) + + self.assertTrue(ms.is_valid_tag('mspvgc_tag3')) + + tags = ms.get_tags_for_garbage_collection() + + self.assertEqual(2, len(tags)) + self.assertTrue('mspvgc_tag1' in tags) + self.assertTrue('mspvgc_tag2' in tags) + + ms.try_mark_current_model_invalid(0) + + tags = ms.get_tags_for_garbage_collection() + + self.assertEqual(3, len(tags)) + self.assertTrue('mspvgc_tag1' in tags) + self.assertTrue('mspvgc_tag2' in tags) + self.assertTrue('mspvgc_tag3' in tags) + + ms.mark_tags_deleted(set(['mspvgc_tag1', 'mspvgc_tag2'])) + tags = ms.get_tags_for_garbage_collection() + + self.assertEqual(['mspvgc_tag3'], tags) + + ms.mark_tags_deleted(set(['mspvgc_tag1'])) + tags = ms.get_tags_for_garbage_collection() + + self.assertEqual(1, len(tags)) + self.assertEqual('mspvgc_tag3', tags[0]) + + ms.mark_tags_deleted(set(['mspvgc_tag3'])) + tags = ms.get_tags_for_garbage_collection() + + self.assertEqual(0, len(tags)) + if __name__ == '__main__': unittest.main()