diff --git a/adlfs/spec.py b/adlfs/spec.py index 3352ca3c..610b77cb 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -26,7 +26,7 @@ from azure.storage.blob._shared.base_client import create_configuration from azure.storage.blob.aio import BlobServiceClient as AIOBlobServiceClient from azure.storage.blob.aio._list_blobs_helper import BlobPrefix -from fsspec.asyn import AsyncFileSystem, get_loop, sync, sync_wrapper +from fsspec.asyn import AsyncFileSystem, _get_batch_size, get_loop, sync, sync_wrapper from fsspec.spec import AbstractBufferedFile from fsspec.utils import infer_storage_options @@ -165,6 +165,10 @@ class AzureBlobFileSystem(AsyncFileSystem): False throws if retrieving container properties fails, which might happen if your authentication is only valid at the storage container level, and not the storage account level. + max_concurrency: + The number of concurrent connections to use when uploading or downloading a blob. + If None it will be inferred from fsspec.asyn._get_batch_size(). + Pass on to fsspec: skip_instance_cache: to control reuse of instances @@ -227,6 +231,7 @@ def __init__( default_cache_type: str = "bytes", version_aware: bool = False, assume_container_exists: Optional[bool] = None, + max_concurrency: Optional[int] = None, **kwargs, ): super_kwargs = { @@ -292,6 +297,8 @@ def __init__( if self.credential is not None: weakref.finalize(self, sync, self.loop, close_credential, self) + self.max_concurrency = max_concurrency or (_get_batch_size() // 4) + @classmethod def _strip_protocol(cls, path: str): """ @@ -1426,7 +1433,9 @@ async def _dir_exists(self, container, path): except ResourceNotFoundError: return False - async def _pipe_file(self, path, value, overwrite=True, **kwargs): + async def _pipe_file( + self, path, value, overwrite=True, max_concurrency=None, **kwargs + ): """Set the bytes of given file""" container_name, path, _ = self.split_path(path) async with self.service_client.get_blob_client( @@ -1436,6 +1445,7 @@ async def _pipe_file(self, path, value, overwrite=True, **kwargs): data=value, overwrite=overwrite, metadata={"is_directory": "false"}, + max_concurrency=max_concurrency or self.max_concurrency, **kwargs, ) self.invalidate_cache(self._parent(path)) @@ -1443,7 +1453,9 @@ async def _pipe_file(self, path, value, overwrite=True, **kwargs): pipe_file = sync_wrapper(_pipe_file) - async def _cat_file(self, path, start=None, end=None, **kwargs): + async def _cat_file( + self, path, start=None, end=None, max_concurrency=None, **kwargs + ): path = self._strip_protocol(path) if end is not None: start = start or 0 # download_blob requires start if length is provided. @@ -1456,7 +1468,10 @@ async def _cat_file(self, path, start=None, end=None, **kwargs): ) as bc: try: stream = await bc.download_blob( - offset=start, length=length, version_id=version_id + offset=start, + length=length, + version_id=version_id, + max_concurrency=max_concurrency or self.max_concurrency, ) except ResourceNotFoundError as e: raise FileNotFoundError from e @@ -1593,7 +1608,14 @@ async def _expand_path( return list(sorted(out)) async def _put_file( - self, lpath, rpath, delimiter="/", overwrite=False, callback=None, **kwargws + self, + lpath, + rpath, + delimiter="/", + overwrite=False, + callback=None, + max_concurrency=None, + **kwargws, ): """ Copy single file to remote @@ -1621,6 +1643,7 @@ async def _put_file( raw_response_hook=make_callback( "upload_stream_current", callback ), + max_concurrency=max_concurrency or self.max_concurrency, ) self.invalidate_cache() except ResourceExistsError: @@ -1668,7 +1691,14 @@ def download(self, rpath, lpath, recursive=False, **kwargs): return self.get(rpath, lpath, recursive=recursive, **kwargs) async def _get_file( - self, rpath, lpath, recursive=False, delimiter="/", callback=None, **kwargs + self, + rpath, + lpath, + recursive=False, + delimiter="/", + callback=None, + max_concurrency=None, + **kwargs, ): """Copy single file remote to local""" if os.path.isdir(lpath): @@ -1683,6 +1713,7 @@ async def _get_file( "download_stream_current", callback ), version_id=version_id, + max_concurrency=max_concurrency or self.max_concurrency, ) with open(lpath, "wb") as my_blob: await stream.readinto(my_blob) diff --git a/adlfs/tests/test_spec.py b/adlfs/tests/test_spec.py index c226e835..3e615ee1 100644 --- a/adlfs/tests/test_spec.py +++ b/adlfs/tests/test_spec.py @@ -1486,7 +1486,10 @@ async def test_cat_file_versioned(storage, mocker): await fs._cat_file(f"data/root/a/file.txt?versionid={DEFAULT_VERSION_ID}") download_blob.assert_called_once_with( - offset=None, length=None, version_id=DEFAULT_VERSION_ID + offset=None, + length=None, + version_id=DEFAULT_VERSION_ID, + max_concurrency=fs.max_concurrency, ) download_blob.reset_mock() @@ -1744,7 +1747,9 @@ async def test_get_file_versioned(storage, mocker, tmp_path): f"data/root/a/file.txt?versionid={DEFAULT_VERSION_ID}", tmp_path / "file.txt" ) download_blob.assert_called_once_with( - raw_response_hook=mocker.ANY, version_id=DEFAULT_VERSION_ID + raw_response_hook=mocker.ANY, + version_id=DEFAULT_VERSION_ID, + max_concurrency=fs.max_concurrency, ) download_blob.reset_mock() download_blob.side_effect = ResourceNotFoundError