From efef7b8c3635cf7d235c11841a2fe49c8ae148e0 Mon Sep 17 00:00:00 2001 From: Brian Walsh Date: Thu, 17 Nov 2022 16:55:14 -0800 Subject: [PATCH] Optimizes memory use --- drs_downloader/cli.py | 16 +++++++++++ drs_downloader/clients/mock.py | 2 +- drs_downloader/clients/terra.py | 29 +++++++++++++------ drs_downloader/manager.py | 49 ++++++++++++++++++++++++--------- 4 files changed, 73 insertions(+), 23 deletions(-) diff --git a/drs_downloader/cli.py b/drs_downloader/cli.py index d07c32c..9d38164 100644 --- a/drs_downloader/cli.py +++ b/drs_downloader/cli.py @@ -58,6 +58,13 @@ def mock(silent: bool, destination_dir: str): logger.info((drs_object.name, 'OK', drs_object.size, len(drs_object.file_parts))) logger.info(('done', 'statistics.max_files_open', drs_client.statistics.max_files_open)) + for drs_object in drs_objects: + at_least_one_error = False + if len(drs_object.errors) > 0: + logger.error((drs_object.name, 'ERROR', drs_object.size, len(drs_object.file_parts), drs_object.errors)) + at_least_one_error = True + if at_least_one_error: + exit(99) @cli.command() @click.option("--silent", "-s", is_flag=True, show_default=True, default=False, help="Display nothing.") @@ -119,6 +126,15 @@ def _extract_tsv_info(manifest_path_: Path, drs_header: str = 'pfb:ga4gh_drs_uri logger.info((drs_object.name, 'OK', drs_object.size, len(drs_object.file_parts))) logger.info(('done', 'statistics.max_files_open', drs_client.statistics.max_files_open)) + for drs_object in drs_objects: + at_least_one_error = False + if len(drs_object.errors) > 0: + logger.error((drs_object.name, 'ERROR', drs_object.size, len(drs_object.file_parts), drs_object.errors)) + at_least_one_error = True + if at_least_one_error: + exit(99) + + if __name__ == "__main__": cli() diff --git a/drs_downloader/clients/mock.py b/drs_downloader/clients/mock.py index 93ff1a1..2be7f6d 100644 --- a/drs_downloader/clients/mock.py +++ b/drs_downloader/clients/mock.py @@ -59,7 +59,7 @@ async def download_part(self, drs_object: DrsObject, start: int, size: int, dest # calculate actual part size from range see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Range - length_ = size - start + length_ = size - start + 1 # logger.info((drs_object.name, start, length_)) with open(f'/tmp/testing/{drs_object.name}.golden', 'rb') as f: f.seek(start) diff --git a/drs_downloader/clients/terra.py b/drs_downloader/clients/terra.py index 8c1d964..31e9ff4 100644 --- a/drs_downloader/clients/terra.py +++ b/drs_downloader/clients/terra.py @@ -4,9 +4,13 @@ import aiofiles import aiohttp +import logging +from drs_downloader import MB from drs_downloader.models import DrsClient, DrsObject, AccessMethod, Checksum +logger = logging.getLogger(__name__) + class TerraDrsClient(DrsClient): """ @@ -33,16 +37,23 @@ def _get_auth_token() -> str: return token async def download_part(self, drs_object: DrsObject, start: int, size: int, destination_path: Path) -> Path: + try: + headers = {'Range': f'bytes={start}-{size}'} - headers = {'Range': f'bytes={start}-{size}'} - (fd, name,) = tempfile.mkstemp(prefix=f'{drs_object.name}.{start}.{size}.', suffix='.part', - dir=str(destination_path)) - async with aiohttp.ClientSession(headers=headers) as session: - async with session.get(drs_object.access_methods[0].access_url) as request: - file = await aiofiles.open(name, 'wb') - self.statistics.set_max_files_open() - await file.write(await request.content.read()) - return Path(name) + file_name = destination_path / f'{drs_object.name}.{start}.{size}.part' + async with aiohttp.ClientSession(headers=headers) as session: + async with session.get(drs_object.access_methods[0].access_url) as request: + file = await aiofiles.open(file_name, 'wb') + self.statistics.set_max_files_open() + async for data in request.content.iter_any(): # uses less memory + await file.write(data) + # await file.write(await request.content.read()) # original + await file.close() + return Path(file_name) + except Exception as e: + logger.error(f"terra.download_part {str(e)}") + drs_object.errors.append(str(e)) + return None async def sign_url(self, drs_object: DrsObject) -> DrsObject: """No-op. terra returns a signed url in `get_object` """ diff --git a/drs_downloader/manager.py b/drs_downloader/manager.py index df22083..bea7645 100644 --- a/drs_downloader/manager.py +++ b/drs_downloader/manager.py @@ -1,6 +1,7 @@ import asyncio import hashlib import logging +import math import shutil from abc import ABC, abstractmethod from pathlib import Path @@ -10,7 +11,7 @@ import tqdm.asyncio from drs_downloader import DEFAULT_MAX_SIMULTANEOUS_OBJECT_RETRIEVERS, DEFAULT_MAX_SIMULTANEOUS_PART_HANDLERS, \ - DEFAULT_MAX_SIMULTANEOUS_DOWNLOADERS, DEFAULT_PART_SIZE + DEFAULT_MAX_SIMULTANEOUS_DOWNLOADERS, DEFAULT_PART_SIZE, MB from drs_downloader.models import DrsClient, DrsObject @@ -126,8 +127,8 @@ def _parts_generator(size: int, start: int = 0, part_size: int = None) -> Iterat """ while size - start > part_size: yield start, start + part_size - # start += part_size + 1 - start += part_size + start += part_size + 1 + # start += part_size yield start, size async def _run_download_parts(self, drs_object: DrsObject, destination_path: Path) -> DrsObject: @@ -139,24 +140,41 @@ async def _run_download_parts(self, drs_object: DrsObject, destination_path: Pat Returns: list of paths to files for each part, in order. """ - tasks = [] + # create a list of parts + parts = [] for start, size in self._parts_generator(size=drs_object.size, part_size=self.part_size): - task = asyncio.create_task(self._drs_client.download_part(drs_object=drs_object, start=start, size=size, - destination_path=destination_path)) - tasks.append(task) + parts.append((start, size, )) - if len(tasks) > 1000: - logger.warning(f'tasks > 1000 {drs_object.name} has over 1000 parts, consider optimization.') + if len(parts) > 1000: + logger.error(f'tasks > 1000 {drs_object.name} has over 1000 parts, consider optimization. ({len(parts)})') paths = [] - for chunk_tasks in DrsAsyncManager._chunker(tasks, self.max_simultaneous_part_handlers): + # TODO - tqdm ugly here? + for chunk_parts in \ + tqdm.tqdm(DrsAsyncManager._chunker(parts, self.max_simultaneous_part_handlers), + total=math.ceil(len(parts)/self.max_simultaneous_part_handlers), + desc=" * batch", + leave=False, + disable=self.disable): + chunk_tasks = [] + for start, size in chunk_parts: + task = asyncio.create_task(self._drs_client.download_part(drs_object=drs_object, start=start, size=size, + destination_path=destination_path)) + chunk_tasks.append(task) + chunk_paths = [ await f for f in tqdm.tqdm(asyncio.as_completed(chunk_tasks), total=len(chunk_tasks), leave=False, desc=f" * {drs_object.name}", disable=self.disable) ] + # something bad happened + if None in chunk_paths: + logger.error(f"{drs_object.name} had missing part.") + return drs_object + paths.extend(chunk_paths) + drs_object.file_parts = paths # re-assemble and test the file parts @@ -165,7 +183,9 @@ async def _run_download_parts(self, drs_object: DrsObject, destination_path: Pat assert checksum_type in hashlib.algorithms_available, f"Checksum {checksum_type} not supported." md5_hash = hashlib.new(checksum_type) with open(destination_path.joinpath(drs_object.name), 'wb') as wfd: - for f in sorted(drs_object.file_parts): + # sort the items of the list in place - Numerically based on start i.e. "xxxxxx.start.end.part" + drs_object.file_parts.sort(key=lambda x: int(str(x).split('.')[-3])) + for f in drs_object.file_parts: fd = open(f, 'rb') wrapped_fd = Wrapped(fd, md5_hash) # efficient way to write @@ -219,8 +239,7 @@ async def _run_download(self, drs_objects: List[DrsObject], destination_path: Pa drs_objects_with_file_parts = [ await f - for f in tqdm.tqdm(asyncio.as_completed(tasks), total=len(tasks), leave=leave, desc=" * batch", - disable=self.disable) + for f in asyncio.as_completed(tasks) ] return drs_objects_with_file_parts @@ -357,4 +376,8 @@ def optimize_workload(self, drs_objects: List[DrsObject]) -> List[DrsObject]: """ # TODO - now that we have the objects to download, we have an opportunity to shape the downloads # TODO - e.g. smallest files first? tweak MAX_* to optimize per workload + # for example, open it up for 1 big file. + if len(drs_objects) == 1: + self.max_simultaneous_part_handlers = 50 + self.part_size = 64 * MB return drs_objects