Skip to content

Commit

Permalink
Optimizes memory use
Browse files Browse the repository at this point in the history
  • Loading branch information
bwalsh committed Nov 18, 2022
1 parent 2f2f371 commit efef7b8
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 23 deletions.
16 changes: 16 additions & 0 deletions drs_downloader/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def mock(silent: bool, destination_dir: str):
logger.info((drs_object.name, 'OK', drs_object.size, len(drs_object.file_parts)))
logger.info(('done', 'statistics.max_files_open', drs_client.statistics.max_files_open))

for drs_object in drs_objects:
at_least_one_error = False
if len(drs_object.errors) > 0:
logger.error((drs_object.name, 'ERROR', drs_object.size, len(drs_object.file_parts), drs_object.errors))
at_least_one_error = True
if at_least_one_error:
exit(99)

@cli.command()
@click.option("--silent", "-s", is_flag=True, show_default=True, default=False, help="Display nothing.")
Expand Down Expand Up @@ -119,6 +126,15 @@ def _extract_tsv_info(manifest_path_: Path, drs_header: str = 'pfb:ga4gh_drs_uri
logger.info((drs_object.name, 'OK', drs_object.size, len(drs_object.file_parts)))
logger.info(('done', 'statistics.max_files_open', drs_client.statistics.max_files_open))

for drs_object in drs_objects:
at_least_one_error = False
if len(drs_object.errors) > 0:
logger.error((drs_object.name, 'ERROR', drs_object.size, len(drs_object.file_parts), drs_object.errors))
at_least_one_error = True
if at_least_one_error:
exit(99)



if __name__ == "__main__":
cli()
2 changes: 1 addition & 1 deletion drs_downloader/clients/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def download_part(self, drs_object: DrsObject, start: int, size: int, dest

# calculate actual part size from range see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Range

length_ = size - start
length_ = size - start + 1
# logger.info((drs_object.name, start, length_))
with open(f'/tmp/testing/{drs_object.name}.golden', 'rb') as f:
f.seek(start)
Expand Down
29 changes: 20 additions & 9 deletions drs_downloader/clients/terra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

import aiofiles
import aiohttp
import logging

from drs_downloader import MB
from drs_downloader.models import DrsClient, DrsObject, AccessMethod, Checksum

logger = logging.getLogger(__name__)


class TerraDrsClient(DrsClient):
"""
Expand All @@ -33,16 +37,23 @@ def _get_auth_token() -> str:
return token

async def download_part(self, drs_object: DrsObject, start: int, size: int, destination_path: Path) -> Path:
try:
headers = {'Range': f'bytes={start}-{size}'}

headers = {'Range': f'bytes={start}-{size}'}
(fd, name,) = tempfile.mkstemp(prefix=f'{drs_object.name}.{start}.{size}.', suffix='.part',
dir=str(destination_path))
async with aiohttp.ClientSession(headers=headers) as session:
async with session.get(drs_object.access_methods[0].access_url) as request:
file = await aiofiles.open(name, 'wb')
self.statistics.set_max_files_open()
await file.write(await request.content.read())
return Path(name)
file_name = destination_path / f'{drs_object.name}.{start}.{size}.part'
async with aiohttp.ClientSession(headers=headers) as session:
async with session.get(drs_object.access_methods[0].access_url) as request:
file = await aiofiles.open(file_name, 'wb')
self.statistics.set_max_files_open()
async for data in request.content.iter_any(): # uses less memory
await file.write(data)
# await file.write(await request.content.read()) # original
await file.close()
return Path(file_name)
except Exception as e:
logger.error(f"terra.download_part {str(e)}")
drs_object.errors.append(str(e))
return None

async def sign_url(self, drs_object: DrsObject) -> DrsObject:
"""No-op. terra returns a signed url in `get_object` """
Expand Down
49 changes: 36 additions & 13 deletions drs_downloader/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import hashlib
import logging
import math
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
Expand All @@ -10,7 +11,7 @@
import tqdm.asyncio

from drs_downloader import DEFAULT_MAX_SIMULTANEOUS_OBJECT_RETRIEVERS, DEFAULT_MAX_SIMULTANEOUS_PART_HANDLERS, \
DEFAULT_MAX_SIMULTANEOUS_DOWNLOADERS, DEFAULT_PART_SIZE
DEFAULT_MAX_SIMULTANEOUS_DOWNLOADERS, DEFAULT_PART_SIZE, MB

from drs_downloader.models import DrsClient, DrsObject

Expand Down Expand Up @@ -126,8 +127,8 @@ def _parts_generator(size: int, start: int = 0, part_size: int = None) -> Iterat
"""
while size - start > part_size:
yield start, start + part_size
# start += part_size + 1
start += part_size
start += part_size + 1
# start += part_size
yield start, size

async def _run_download_parts(self, drs_object: DrsObject, destination_path: Path) -> DrsObject:
Expand All @@ -139,24 +140,41 @@ async def _run_download_parts(self, drs_object: DrsObject, destination_path: Pat
Returns:
list of paths to files for each part, in order.
"""
tasks = []
# create a list of parts
parts = []
for start, size in self._parts_generator(size=drs_object.size, part_size=self.part_size):
task = asyncio.create_task(self._drs_client.download_part(drs_object=drs_object, start=start, size=size,
destination_path=destination_path))
tasks.append(task)
parts.append((start, size, ))

if len(tasks) > 1000:
logger.warning(f'tasks > 1000 {drs_object.name} has over 1000 parts, consider optimization.')
if len(parts) > 1000:
logger.error(f'tasks > 1000 {drs_object.name} has over 1000 parts, consider optimization. ({len(parts)})')

paths = []
for chunk_tasks in DrsAsyncManager._chunker(tasks, self.max_simultaneous_part_handlers):
# TODO - tqdm ugly here?
for chunk_parts in \
tqdm.tqdm(DrsAsyncManager._chunker(parts, self.max_simultaneous_part_handlers),
total=math.ceil(len(parts)/self.max_simultaneous_part_handlers),
desc=" * batch",
leave=False,
disable=self.disable):
chunk_tasks = []
for start, size in chunk_parts:
task = asyncio.create_task(self._drs_client.download_part(drs_object=drs_object, start=start, size=size,
destination_path=destination_path))
chunk_tasks.append(task)

chunk_paths = [
await f
for f in
tqdm.tqdm(asyncio.as_completed(chunk_tasks), total=len(chunk_tasks), leave=False,
desc=f" * {drs_object.name}", disable=self.disable)
]
# something bad happened
if None in chunk_paths:
logger.error(f"{drs_object.name} had missing part.")
return drs_object

paths.extend(chunk_paths)

drs_object.file_parts = paths

# re-assemble and test the file parts
Expand All @@ -165,7 +183,9 @@ async def _run_download_parts(self, drs_object: DrsObject, destination_path: Pat
assert checksum_type in hashlib.algorithms_available, f"Checksum {checksum_type} not supported."
md5_hash = hashlib.new(checksum_type)
with open(destination_path.joinpath(drs_object.name), 'wb') as wfd:
for f in sorted(drs_object.file_parts):
# sort the items of the list in place - Numerically based on start i.e. "xxxxxx.start.end.part"
drs_object.file_parts.sort(key=lambda x: int(str(x).split('.')[-3]))
for f in drs_object.file_parts:
fd = open(f, 'rb')
wrapped_fd = Wrapped(fd, md5_hash)
# efficient way to write
Expand Down Expand Up @@ -219,8 +239,7 @@ async def _run_download(self, drs_objects: List[DrsObject], destination_path: Pa

drs_objects_with_file_parts = [
await f
for f in tqdm.tqdm(asyncio.as_completed(tasks), total=len(tasks), leave=leave, desc=" * batch",
disable=self.disable)
for f in asyncio.as_completed(tasks)
]

return drs_objects_with_file_parts
Expand Down Expand Up @@ -357,4 +376,8 @@ def optimize_workload(self, drs_objects: List[DrsObject]) -> List[DrsObject]:
"""
# TODO - now that we have the objects to download, we have an opportunity to shape the downloads
# TODO - e.g. smallest files first? tweak MAX_* to optimize per workload
# for example, open it up for 1 big file.
if len(drs_objects) == 1:
self.max_simultaneous_part_handlers = 50
self.part_size = 64 * MB
return drs_objects

0 comments on commit efef7b8

Please sign in to comment.