diff --git a/src/tufup/client.py b/src/tufup/client.py index d73cf25..b1974c7 100644 --- a/src/tufup/client.py +++ b/src/tufup/client.py @@ -153,7 +153,11 @@ def check_for_updates( If `patch` is `False`, a full update is enforced. """ - included = {None: '', '': '', 'a': 'abrc', 'b': 'brc', 'rc': 'rc'} + # invalid pre-release specifiers are ignored, with a warning + pre_map = dict(a='abrc', b='brc', rc='rc') + prereleases = pre_map.get(pre, '') + if pre and not prereleases: + logger.warning(f'ignoring invalid pre-release specifier: "{pre}"') # refresh top-level metadata (root -> timestamp -> snapshot -> targets) try: self.refresh() @@ -177,7 +181,7 @@ def check_for_updates( item for item in all_new_targets.items() if item[0].is_archive - and (not item[0].version.pre or item[0].version.pre[0] in included[pre]) + and (not item[0].version.pre or item[0].version.pre[0] in prereleases) ) new_archive_meta = None if new_archives: diff --git a/tests/test_client.py b/tests/test_client.py index 39922fd..a2e8c15 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -150,16 +150,18 @@ def test_download_and_apply_update(self): def test_check_for_updates(self): # expectations (based on targets in tests/data/repository): - # - pre=None: only full releases are included, so finds 2.0 patch + # - pre=None, '', or 'invalid': only full releases are included, finds 2.0 patch # - pre='a': finds all, but total patch size exceeds archive size # - pre='b': there is no 'b' release, so this finds same as 'rc' # - pre='rc': finds 2.0 and 3.0rc0, total patch size smaller than archive client = self.get_refreshed_client() with patch.object(client, 'refresh', Mock()): - for pre, expected in [(None, 1), ('a', 1), ('b', 2), ('rc', 2)]: + for pre, expected in [ + (None, 1), ('', 1), ('a', 1), ('b', 2), ('rc', 2), ('invalid', 1) + ]: with self.subTest(msg=pre): target_meta = client.check_for_updates(pre=pre) - self.assertTrue(target_meta) + self.assertTrue(expected and target_meta) self.assertEqual(expected, len(client.new_targets)) if pre == 'a': self.assertTrue(