From 22aafedf9bf6d87c8e3816728d89b15481ef74d1 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 3 Aug 2023 15:00:26 +0900 Subject: [PATCH 1/2] spec: expose max_concurrency param in blob_upload/blob_download --- adlfs/spec.py | 43 ++++++++++++++++++++++++++++++++++------ adlfs/tests/test_spec.py | 13 +++++++++--- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/adlfs/spec.py b/adlfs/spec.py index 3352ca3c..610b77cb 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -26,7 +26,7 @@ from azure.storage.blob._shared.base_client import create_configuration from azure.storage.blob.aio import BlobServiceClient as AIOBlobServiceClient from azure.storage.blob.aio._list_blobs_helper import BlobPrefix -from fsspec.asyn import AsyncFileSystem, get_loop, sync, sync_wrapper +from fsspec.asyn import AsyncFileSystem, _get_batch_size, get_loop, sync, sync_wrapper from fsspec.spec import AbstractBufferedFile from fsspec.utils import infer_storage_options @@ -165,6 +165,10 @@ class AzureBlobFileSystem(AsyncFileSystem): False throws if retrieving container properties fails, which might happen if your authentication is only valid at the storage container level, and not the storage account level. + max_concurrency: + The number of concurrent connections to use when uploading or downloading a blob. + If None it will be inferred from fsspec.asyn._get_batch_size(). + Pass on to fsspec: skip_instance_cache: to control reuse of instances @@ -227,6 +231,7 @@ def __init__( default_cache_type: str = "bytes", version_aware: bool = False, assume_container_exists: Optional[bool] = None, + max_concurrency: Optional[int] = None, **kwargs, ): super_kwargs = { @@ -292,6 +297,8 @@ def __init__( if self.credential is not None: weakref.finalize(self, sync, self.loop, close_credential, self) + self.max_concurrency = max_concurrency or (_get_batch_size() // 4) + @classmethod def _strip_protocol(cls, path: str): """ @@ -1426,7 +1433,9 @@ async def _dir_exists(self, container, path): except ResourceNotFoundError: return False - async def _pipe_file(self, path, value, overwrite=True, **kwargs): + async def _pipe_file( + self, path, value, overwrite=True, max_concurrency=None, **kwargs + ): """Set the bytes of given file""" container_name, path, _ = self.split_path(path) async with self.service_client.get_blob_client( @@ -1436,6 +1445,7 @@ async def _pipe_file(self, path, value, overwrite=True, **kwargs): data=value, overwrite=overwrite, metadata={"is_directory": "false"}, + max_concurrency=max_concurrency or self.max_concurrency, **kwargs, ) self.invalidate_cache(self._parent(path)) @@ -1443,7 +1453,9 @@ async def _pipe_file(self, path, value, overwrite=True, **kwargs): pipe_file = sync_wrapper(_pipe_file) - async def _cat_file(self, path, start=None, end=None, **kwargs): + async def _cat_file( + self, path, start=None, end=None, max_concurrency=None, **kwargs + ): path = self._strip_protocol(path) if end is not None: start = start or 0 # download_blob requires start if length is provided. @@ -1456,7 +1468,10 @@ async def _cat_file(self, path, start=None, end=None, **kwargs): ) as bc: try: stream = await bc.download_blob( - offset=start, length=length, version_id=version_id + offset=start, + length=length, + version_id=version_id, + max_concurrency=max_concurrency or self.max_concurrency, ) except ResourceNotFoundError as e: raise FileNotFoundError from e @@ -1593,7 +1608,14 @@ async def _expand_path( return list(sorted(out)) async def _put_file( - self, lpath, rpath, delimiter="/", overwrite=False, callback=None, **kwargws + self, + lpath, + rpath, + delimiter="/", + overwrite=False, + callback=None, + max_concurrency=None, + **kwargws, ): """ Copy single file to remote @@ -1621,6 +1643,7 @@ async def _put_file( raw_response_hook=make_callback( "upload_stream_current", callback ), + max_concurrency=max_concurrency or self.max_concurrency, ) self.invalidate_cache() except ResourceExistsError: @@ -1668,7 +1691,14 @@ def download(self, rpath, lpath, recursive=False, **kwargs): return self.get(rpath, lpath, recursive=recursive, **kwargs) async def _get_file( - self, rpath, lpath, recursive=False, delimiter="/", callback=None, **kwargs + self, + rpath, + lpath, + recursive=False, + delimiter="/", + callback=None, + max_concurrency=None, + **kwargs, ): """Copy single file remote to local""" if os.path.isdir(lpath): @@ -1683,6 +1713,7 @@ async def _get_file( "download_stream_current", callback ), version_id=version_id, + max_concurrency=max_concurrency or self.max_concurrency, ) with open(lpath, "wb") as my_blob: await stream.readinto(my_blob) diff --git a/adlfs/tests/test_spec.py b/adlfs/tests/test_spec.py index c226e835..86023ca5 100644 --- a/adlfs/tests/test_spec.py +++ b/adlfs/tests/test_spec.py @@ -392,7 +392,9 @@ def test_info_missing(storage, path): def test_time_info(storage): fs = AzureBlobFileSystem( - account_name=storage.account_name, connection_string=CONN_STR + account_name=storage.account_name, + connection_string=CONN_STR, + max_concurrency=1, ) creation_time = fs.created("data/root/d/file_with_metadata.txt") @@ -1486,7 +1488,10 @@ async def test_cat_file_versioned(storage, mocker): await fs._cat_file(f"data/root/a/file.txt?versionid={DEFAULT_VERSION_ID}") download_blob.assert_called_once_with( - offset=None, length=None, version_id=DEFAULT_VERSION_ID + offset=None, + length=None, + version_id=DEFAULT_VERSION_ID, + max_concurrency=fs.max_concurrency, ) download_blob.reset_mock() @@ -1744,7 +1749,9 @@ async def test_get_file_versioned(storage, mocker, tmp_path): f"data/root/a/file.txt?versionid={DEFAULT_VERSION_ID}", tmp_path / "file.txt" ) download_blob.assert_called_once_with( - raw_response_hook=mocker.ANY, version_id=DEFAULT_VERSION_ID + raw_response_hook=mocker.ANY, + version_id=DEFAULT_VERSION_ID, + max_concurrency=fs.max_concurrency, ) download_blob.reset_mock() download_blob.side_effect = ResourceNotFoundError From a7eeb893f5f8365743966336710dbda1820bb588 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 3 Aug 2023 15:10:18 +0900 Subject: [PATCH 2/2] prefer batch_size in batched async methods --- adlfs/spec.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/adlfs/spec.py b/adlfs/spec.py index 610b77cb..eb160e05 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -297,7 +297,11 @@ def __init__( if self.credential is not None: weakref.finalize(self, sync, self.loop, close_credential, self) - self.max_concurrency = max_concurrency or (_get_batch_size() // 4) + if max_concurrency is None: + batch_size = _get_batch_size() + if batch_size > 0: + max_concurrency = batch_size + self.max_concurrency = max_concurrency @classmethod def _strip_protocol(cls, path: str): @@ -1453,6 +1457,12 @@ async def _pipe_file( pipe_file = sync_wrapper(_pipe_file) + async def _pipe(self, *args, batch_size=None, max_concurrency=None, **kwargs): + max_concurrency = max_concurrency or 1 + return await super()._pipe( + *args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs + ) + async def _cat_file( self, path, start=None, end=None, max_concurrency=None, **kwargs ): @@ -1512,6 +1522,12 @@ def cat(self, path, recursive=False, on_error="raise", **kwargs): else: return self.cat_file(paths[0]) + async def _cat_ranges(self, *args, batch_size=None, max_concurrency=None, **kwargs): + max_concurrency = max_concurrency or 1 + return await super()._cat_ranges( + *args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs + ) + def url(self, path, expires=3600, **kwargs): return sync(self.loop, self._url, path, expires, **kwargs) @@ -1656,6 +1672,12 @@ async def _put_file( put_file = sync_wrapper(_put_file) + async def _put(self, *args, batch_size=None, max_concurrency=None, **kwargs): + max_concurrency = max_concurrency or 1 + return await super()._put( + *args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs + ) + async def _cp_file(self, path1, path2, **kwargs): """Copy the file at path1 to path2""" container1, path1, version_id = self.split_path(path1, delimiter="/") @@ -1722,6 +1744,12 @@ async def _get_file( get_file = sync_wrapper(_get_file) + async def _get(self, *args, batch_size=None, max_concurrency=None, **kwargs): + max_concurrency = max_concurrency or 1 + return await super()._get( + *args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs + ) + def getxattr(self, path, attr): meta = self.info(path).get("metadata", {}) return meta[attr]