diff --git a/adlfs/spec.py b/adlfs/spec.py index af0a60ef..4ff92e7c 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function import asyncio +import contextlib import io import logging import os @@ -12,6 +13,7 @@ import weakref from datetime import datetime, timedelta from glob import has_magic +from typing import Optional from azure.core.exceptions import ( HttpResponseError, @@ -31,6 +33,7 @@ from fsspec.utils import infer_storage_options, tokenize from .utils import ( + _nullcontext, close_container_client, close_credential, close_service_client, @@ -349,6 +352,12 @@ class AzureBlobFileSystem(AsyncFileSystem): default_cache_type: string ('bytes') If given, the default cache_type value used for "open()". Set to none if no caching is desired. Docs in fsspec + max_concurrency : int, optional + The maximum number of BlobClient connections that can exist simultaneously for this + filesystem instance. By default, there is no limit. Setting this might be helpful if + you have a very large number of small, independent blob operations to perform. By + default a single BlobClient is created per blob, which might cause high memory usage + and clogging the asyncio event loop as many instances are created and quickly destroyed. Pass on to fsspec: @@ -410,6 +419,7 @@ def __init__( asynchronous: bool = False, default_fill_cache: bool = True, default_cache_type: str = "bytes", + max_concurrency: Optional[int] = None, **kwargs, ): super_kwargs = { @@ -438,6 +448,13 @@ def __init__( self.blocksize = blocksize self.default_fill_cache = default_fill_cache self.default_cache_type = default_cache_type + self.max_concurrency = max_concurrency + + if self.max_concurrency is None: + self._blob_client_semaphore = _nullcontext() + else: + self._blob_client_semaphore = asyncio.Semaphore(max_concurrency) + if ( self.credential is None and self.account_key is None @@ -519,6 +536,15 @@ def _get_kwargs_from_urls(urlpath): out["account_name"] = account_name return out + @contextlib.asynccontextmanager + async def _get_blob_client(self, container_name, path): + """ + Get a blob client, respecting `self.max_concurrency` if set. + """ + async with self._blob_client_semaphore: + async with self.service_client.get_blob_client(container_name, path) as bc: + yield bc + def _get_credential_from_service_principal(self): """ Create a Credential for authentication. This can include a TokenCredential @@ -1366,9 +1392,7 @@ async def _isfile(self, path): return False else: try: - async with self.service_client.get_blob_client( - container_name, path - ) as bc: + async with self._get_blob_client(container_name, path) as bc: props = await bc.get_blob_properties() if props["metadata"]["is_directory"] == "false": return True @@ -1427,7 +1451,7 @@ async def _exists(self, path): # Empty paths exist by definition return True - async with self.service_client.get_blob_client(container_name, path) as bc: + async with self._get_blob_client(container_name, path) as bc: if await bc.exists(): return True @@ -1445,9 +1469,7 @@ async def _exists(self, path): async def _pipe_file(self, path, value, overwrite=True, **kwargs): """Set the bytes of given file""" container_name, path = self.split_path(path) - async with self.service_client.get_blob_client( - container=container_name, blob=path - ) as bc: + async with self._get_blob_client(container_name, path) as bc: result = await bc.upload_blob( data=value, overwrite=overwrite, metadata={"is_directory": "false"} ) @@ -1464,9 +1486,7 @@ async def _cat_file(self, path, start=None, end=None, **kwargs): else: length = None container_name, path = self.split_path(path) - async with self.service_client.get_blob_client( - container=container_name, blob=path - ) as bc: + async with self._get_blob_client(container_name, path) as bc: try: stream = await bc.download_blob(offset=start, length=length) except ResourceNotFoundError as e: @@ -1528,7 +1548,7 @@ async def _url(self, path, expires=3600, **kwargs): expiry=datetime.utcnow() + timedelta(seconds=expires), ) - async with self.service_client.get_blob_client(container_name, blob) as bc: + async with self._get_blob_client(container_name, blob) as bc: url = f"{bc.url}?{sas_token}" return url @@ -1603,9 +1623,7 @@ async def _put_file( else: try: with open(lpath, "rb") as f1: - async with self.service_client.get_blob_client( - container_name, path - ) as bc: + async with self._get_blob_client(container_name, path) as bc: await bc.upload_blob( f1, overwrite=overwrite, @@ -1659,7 +1677,7 @@ async def _get_file( return container_name, path = self.split_path(rpath, delimiter=delimiter) try: - async with self.service_client.get_blob_client( + async with self._get_blob_client( container_name, path.rstrip(delimiter) ) as bc: with open(lpath, "wb") as my_blob: @@ -1681,7 +1699,7 @@ def getxattr(self, path, attr): async def _setxattrs(self, rpath, **kwargs): container_name, path = self.split_path(rpath) try: - async with self.service_client.get_blob_client(container_name, path) as bc: + async with self._get_blob_client(container_name, path) as bc: await bc.set_blob_metadata(metadata=kwargs) self.invalidate_cache(self._parent(rpath)) except Exception as e: diff --git a/adlfs/tests/test_spec.py b/adlfs/tests/test_spec.py index 39848843..ab19c619 100644 --- a/adlfs/tests/test_spec.py +++ b/adlfs/tests/test_spec.py @@ -1,6 +1,8 @@ +import asyncio import datetime import os import tempfile +from unittest import mock import dask.dataframe as dd import numpy as np @@ -1424,3 +1426,17 @@ def test_find_with_prefix(storage): assert test_1s == [test_bucket_name + "/prefixes/test_1"] + [ test_bucket_name + f"/prefixes/test_{cursor}" for cursor in range(10, 20) ] + + +def test_max_concurrency(storage): + fs = AzureBlobFileSystem( + account_name=storage.account_name, connection_string=CONN_STR, max_concurrency=2 + ) + + assert isinstance(fs._blob_client_semaphore, asyncio.Semaphore) + + fs._blob_client_semaphore = mock.MagicMock(fs._blob_client_semaphore) + path = {f"/data/{i}": b"value" for i in range(10)} + fs.pipe(path) + + assert fs._blob_client_semaphore.__aenter__.call_count == 10 diff --git a/adlfs/utils.py b/adlfs/utils.py index 15df3e59..f7f4d05c 100644 --- a/adlfs/utils.py +++ b/adlfs/utils.py @@ -1,3 +1,7 @@ +import contextlib +import sys + + async def filter_blobs(blobs, target_path, delimiter="/"): """ Filters out blobs that do not come from target_path @@ -45,9 +49,25 @@ async def close_container_client(file_obj): await file_obj.container_client.close() +if sys.version_info < (3, 10): + # Python 3.10 added support for async to nullcontext + @contextlib.asynccontextmanager + async def _nullcontext(*args): + yield + +else: + _nullcontext = contextlib.nullcontext + + async def close_credential(file_obj): """ Implements asynchronous closure of credentials for AzureBlobFile objects """ - await file_obj.credential.close() + try: + if file_obj.credential is not None: + await file_obj.credential.close() + else: + pass + except AttributeError: + pass