Skip to content

Commit

Permalink
Fail-safe mechanism for patch updates (#101)
Browse files Browse the repository at this point in the history
* fall back on full update if existing patch is found with .failed suffix

* flag target with .failed suffix if update fails

* simplify patch failure test

* ruff formatting...

* remove unused test stub
  • Loading branch information
dennisvang authored Feb 6, 2024
1 parent ef3fbad commit 7733ea1
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 38 deletions.
63 changes: 43 additions & 20 deletions src/tufup/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import bsdiff4
from copy import deepcopy
import logging
import pathlib
import shutil
Expand All @@ -7,6 +7,7 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
from urllib import parse

import bsdiff4
import requests
from requests.auth import AuthBase
from tuf.api.exceptions import DownloadError, UnsignedMetadataError
Expand All @@ -18,6 +19,7 @@
logger = logging.getLogger(__name__)

DEFAULT_EXTRACT_DIR = pathlib.Path(tempfile.gettempdir()) / 'tufup'
SUFFIX_FAILED = '.failed'


class Client(tuf.ngclient.Updater):
Expand Down Expand Up @@ -194,17 +196,27 @@ def check_for_updates(
total_patch_size = sum(
target_file.length for target_file in new_patches.values()
)
# abort patch update if any of the new patches have failed on a previous run
abort_patch = False
for patch_info in new_patches.values():
patch_info_mod = deepcopy(patch_info) # modify a copy, just to be sure
patch_info_mod.path += SUFFIX_FAILED
if self.find_cached_target(targetinfo=patch_info_mod):
logger.debug(f'aborting patch due to {patch_info_mod.path}')
abort_patch = True
# use file size to decide if we want to do a patch update or a
# full update (if there are no patches, or if the current archive
# is not available, we must do a full update)
self.new_targets = new_patches
no_patches = total_patch_size == 0
patches_too_big = total_patch_size > self.new_archive_info.length
current_archive_not_found = not self.current_archive_local_path.exists()
if not patch or no_patches or patches_too_big or current_archive_not_found:
patch_too_big = total_patch_size > self.new_archive_info.length
no_archive = not self.current_archive_local_path.exists()
if not patch or no_patches or patch_too_big or no_archive or abort_patch:
# fall back on full update
self.new_targets = {new_archive_meta: self.new_archive_info}
logger.debug('full update available')
else:
# continue with patch update
logger.debug('patch update(s) available')
else:
self.new_targets = {}
Expand Down Expand Up @@ -241,22 +253,33 @@ def _apply_updates(
"""
# patch current archive (if we have patches) or use new full archive
archive_bytes = None
for target, file_path in sorted(self.downloaded_target_files.items()):
if target.is_archive:
# just ensure the full archive file is available
assert len(self.downloaded_target_files) == 1
assert self.new_archive_local_path.exists()
elif target.is_patch:
# create new archive by patching current archive (patches
# must be sorted by increasing version)
if archive_bytes is None:
archive_bytes = self.current_archive_local_path.read_bytes()
archive_bytes = bsdiff4.patch(archive_bytes, file_path.read_bytes())
if archive_bytes:
# verify the patched archive length and hash
self.new_archive_info.verify_length_and_hashes(data=archive_bytes)
# write the patched new archive
self.new_archive_local_path.write_bytes(archive_bytes)
file_path = None
target = None
try:
for target, file_path in sorted(self.downloaded_target_files.items()):
if target.is_archive:
# just ensure the full archive file is available
assert len(self.downloaded_target_files) == 1, 'too many targets'
assert self.new_archive_local_path.exists(), 'new archive missing'
elif target.is_patch:
# create new archive by patching current archive (patches
# must be sorted by increasing version)
if archive_bytes is None:
archive_bytes = self.current_archive_local_path.read_bytes()
archive_bytes = bsdiff4.patch(archive_bytes, file_path.read_bytes())
if archive_bytes:
# verify the patched archive length and hash
self.new_archive_info.verify_length_and_hashes(data=archive_bytes)
# write the patched new archive
self.new_archive_local_path.write_bytes(archive_bytes)
except Exception as e:
if target and file_path and file_path.exists():
renamed_path = file_path.replace(
file_path.with_suffix(file_path.suffix + SUFFIX_FAILED)
)
logger.debug(f'update failed: target renamed to {renamed_path}')
logger.error(f'update aborted: {e}')
return
# extract archive to temporary directory
if self.extract_dir is None:
self.extract_dir = DEFAULT_EXTRACT_DIR
Expand Down
2 changes: 1 addition & 1 deletion src/tufup/utils/platform_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _install_update_win(
batch_template_extra_kwargs: Optional[dict] = None,
log_file_name: Optional[str] = None,
robocopy_options_override: Optional[List[str]] = None,
process_creation_flags = None,
process_creation_flags=None,
):
"""
Create a batch script that moves files from src to dst, then run the
Expand Down
70 changes: 53 additions & 17 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import packaging.version
from requests.auth import HTTPBasicAuth
import tuf.api.exceptions
from tuf.api.exceptions import LengthOrHashMismatchError
from tuf.ngclient import TargetFile

from tests import TempDirTestCase, TEST_REPO_DIR
from tufup.client import AuthRequestsFetcher, Client
from tufup.client import AuthRequestsFetcher, Client, SUFFIX_FAILED
from tufup.common import TargetMeta

ROOT_FILENAME = 'root.json'
Expand Down Expand Up @@ -182,6 +183,27 @@ def test_check_for_updates_current_archive_missing(self):
target_meta = next(iter(client.new_targets.keys()))
self.assertTrue(target_meta.is_archive)

def test_check_for_updates_failed_patch(self):
client = self.get_refreshed_client()
# first verify that we would normally get a patch update
with patch.object(client, 'refresh', Mock()):
client.check_for_updates()
target_meta = next(iter(client.new_targets.keys()))
self.assertTrue(target_meta.is_patch)
# copy the patch into the client cache to simulate existing failed patch
client_cache_dir = pathlib.Path(client.target_dir)
shutil.copy(
src=TEST_REPO_DIR / 'targets' / target_meta.filename,
dst=client_cache_dir / (target_meta.filename + SUFFIX_FAILED),
)
# test: should fall back on full archive update
with patch.object(client, 'refresh', Mock()):
with self.assertLogs(level='DEBUG') as logs:
client.check_for_updates()
target_meta = next(iter(client.new_targets.keys()))
self.assertTrue(target_meta.is_archive)
self.assertIn('aborting', ''.join(logs.output))

def test__download_updates(self):
client = Client(**self.client_kwargs)
client.new_targets = {Mock(): Mock()}
Expand All @@ -200,12 +222,14 @@ def test__download_updates(self):

def test__apply_updates(self):
client = self.get_refreshed_client()
# directly use target files from test repo as downloaded files
client.downloaded_target_files = {
target_meta: TEST_REPO_DIR / 'targets' / str(target_meta)
for target_meta in client.trusted_target_metas
if target_meta.is_patch and str(target_meta.version) in ['2.0', '3.0rc0']
}
# copy files from test data to temporary client cache
client.downloaded_target_files = dict()
for target_meta in client.trusted_target_metas:
if target_meta.is_patch and str(target_meta.version) in ['2.0', '3.0rc0']:
src_path = TEST_REPO_DIR / 'targets' / target_meta.filename
dst_path = pathlib.Path(client.target_dir, target_meta.filename)
shutil.copy(src=src_path, dst=dst_path)
client.downloaded_target_files[target_meta] = dst_path
# specify new archive (normally done in _check_updates)
archives = [
tp
Expand All @@ -216,16 +240,28 @@ def test__apply_updates(self):
client.new_archive_local_path = pathlib.Path(
client.target_dir, client.new_archive_info.path
)
# test confirmation
mock_install = Mock()
with patch('builtins.input', Mock(return_value='y')):
client._apply_updates(install=mock_install, skip_confirmation=False)
self.assertTrue(any(client.extract_dir.iterdir()))
self.assertTrue(mock_install.called)
# test skip confirmation
mock_install = Mock()
client._apply_updates(install=mock_install, skip_confirmation=True)
mock_install.assert_called()
# tests
with self.subTest(msg='with confirmation'):
mock_install = Mock()
with patch('builtins.input', Mock(return_value='y')):
client._apply_updates(install=mock_install, skip_confirmation=False)
self.assertTrue(any(client.extract_dir.iterdir()))
self.assertTrue(mock_install.called)
with self.subTest(msg='skip confirmation'):
mock_install = Mock()
client._apply_updates(install=mock_install, skip_confirmation=True)
mock_install.assert_called()
with self.subTest(msg='patch failure due to mismatch'):
mock_install = Mock()
with patch.object(
client.new_archive_info,
'verify_length_and_hashes',
Mock(side_effect=LengthOrHashMismatchError()),
):
client._apply_updates(install=mock_install, skip_confirmation=True)
mock_install.assert_not_called()
target_paths = pathlib.Path(client.target_dir).iterdir()
self.assertTrue(any(path.suffix == SUFFIX_FAILED for path in target_paths))

def test_version_comparison(self):
# verify assumed version hierarchy
Expand Down

0 comments on commit 7733ea1

Please sign in to comment.