Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit concurrency #329

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add max_concurrency option to limit concurrency
This adds a new keyword to AzureBlobFileSystem to limit the number of
concurrent connectiouns. See pangeo-forge/pangeo-forge-recipes#227 (comment)
for some motivation. In that situation, we had a single FileSystem
instance that was generating many concurrent write requests through
`.pipe`. So many, that we were seeing memory issues from creating all
the BlobClient connections simultaneously.

This adds an asyncio.Semaphore instance to the AzureBlobFilesSytem that
controls the number of concurrent BlobClient connections. The default of
None is backwards-compatible (no limit)
  • Loading branch information
Tom Augspurger committed Nov 9, 2021
commit 6cc062d3b7f88899fd200a1ee077e46cdcedf541
58 changes: 34 additions & 24 deletions adlfs/spec.py
Original file line number Diff line number Diff line change
@@ -4,12 +4,14 @@
from __future__ import absolute_import, division, print_function

import asyncio
import contextlib
from glob import has_magic
import io
import logging
import os
import warnings
import weakref
from typing import Optional

from azure.core.exceptions import (
ClientAuthenticationError,
@@ -354,6 +356,8 @@ class AzureBlobFileSystem(AsyncFileSystem):
default_cache_type: string ('bytes')
If given, the default cache_type value used for "open()". Set to none if no caching
is desired. Docs in fsspec
max_concurrency : int, optional
The maximum number of BlobClient connections to make for this filesystem instance.

Pass on to fsspec:

@@ -412,6 +416,7 @@ def __init__(
asynchronous: bool = False,
default_fill_cache: bool = True,
default_cache_type: str = "bytes",
max_concurrency: Optional[int] = None,
**kwargs,
):
super_kwargs = {
@@ -440,6 +445,13 @@ def __init__(
self.blocksize = blocksize
self.default_fill_cache = default_fill_cache
self.default_cache_type = default_cache_type
self.max_concurrency = max_concurrency

if self.max_concurrency is None:
self._blob_client_semaphore = contextlib.nullcontext()
else:
self._blob_client_semaphore = asyncio.Semaphore(max_concurrency)

if (
self.credential is None
and self.account_key is None
@@ -452,6 +464,7 @@ def __init__(
) = self._get_credential_from_service_principal()
else:
self.sync_credential = None

self.do_connect()
weakref.finalize(self, sync, self.loop, close_service_client, self)

@@ -491,6 +504,15 @@ def _strip_protocol(cls, path: str):
logger.debug(f"_strip_protocol({path}) = {ops}")
return ops["path"]

@contextlib.asynccontextmanager
async def _get_blob_client(self, container_name, path):
"""
Get a blob client, respecting `self.max_concurrency` if set.
"""
async with self._blob_client_semaphore:
async with self.service_client.get_blob_client(container_name, path) as bc:
yield bc

def _get_credential_from_service_principal(self):
"""
Create a Credential for authentication. This can include a TokenCredential
@@ -1332,9 +1354,7 @@ async def _isfile(self, path):
return False
else:
try:
async with self.service_client.get_blob_client(
container_name, path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
props = await bc.get_blob_properties()
if props["metadata"]["is_directory"] == "false":
return True
@@ -1393,7 +1413,7 @@ async def _exists(self, path):
# Empty paths exist by definition
return True

async with self.service_client.get_blob_client(container_name, path) as bc:
async with self._get_blob_client(container_name, path) as bc:
if await bc.exists():
return True

@@ -1411,9 +1431,7 @@ async def _exists(self, path):
async def _pipe_file(self, path, value, overwrite=True, **kwargs):
"""Set the bytes of given file"""
container_name, path = self.split_path(path)
async with self.service_client.get_blob_client(
container=container_name, blob=path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
result = await bc.upload_blob(
data=value, overwrite=overwrite, metadata={"is_directory": "false"}
)
@@ -1430,9 +1448,7 @@ async def _cat_file(self, path, start=None, end=None, **kwargs):
else:
length = None
container_name, path = self.split_path(path)
async with self.service_client.get_blob_client(
container=container_name, blob=path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
try:
stream = await bc.download_blob(offset=start, length=length)
except ResourceNotFoundError as e:
@@ -1494,7 +1510,7 @@ async def _url(self, path, expires=3600, **kwargs):
expiry=datetime.utcnow() + timedelta(seconds=expires),
)

async with self.service_client.get_blob_client(container_name, blob) as bc:
async with self._get_blob_client(container_name, blob) as bc:
url = f"{bc.url}?{sas_token}"
return url

@@ -1569,9 +1585,7 @@ async def _put_file(
else:
try:
with open(lpath, "rb") as f1:
async with self.service_client.get_blob_client(
container_name, path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
await bc.upload_blob(
f1,
overwrite=overwrite,
@@ -1596,14 +1610,10 @@ async def _cp_file(self, path1, path2, **kwargs):
container1, path1 = self.split_path(path1, delimiter="/")
container2, path2 = self.split_path(path2, delimiter="/")

cc1 = self.service_client.get_container_client(container1)
blobclient1 = cc1.get_blob_client(blob=path1)
if container1 == container2:
blobclient2 = cc1.get_blob_client(blob=path2)
else:
cc2 = self.service_client.get_container_client(container2)
blobclient2 = cc2.get_blob_client(blob=path2)
await blobclient2.start_copy_from_url(blobclient1.url)
# TODO: this could cause a deadlock. Can we protect the user?
async with self._get_blob_client(container1, path1) as blobclient1:
async with self._get_blob_client(container2, path1) as blobclient2:
await blobclient2.start_copy_from_url(blobclient1.url)
self.invalidate_cache(container1)
self.invalidate_cache(container2)

@@ -1623,7 +1633,7 @@ async def _get_file(
""" Copy single file remote to local """
container_name, path = self.split_path(rpath, delimiter=delimiter)
try:
async with self.service_client.get_blob_client(
async with self._get_blob_client(
container_name, path.rstrip(delimiter)
) as bc:
with open(lpath, "wb") as my_blob:
@@ -1645,7 +1655,7 @@ def getxattr(self, path, attr):
async def _setxattrs(self, rpath, **kwargs):
container_name, path = self.split_path(rpath)
try:
async with self.service_client.get_blob_client(container_name, path) as bc:
async with self._get_blob_client(container_name, path) as bc:
await bc.set_blob_metadata(metadata=kwargs)
self.invalidate_cache(self._parent(rpath))
except Exception as e:
16 changes: 16 additions & 0 deletions adlfs/tests/test_spec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import os
import tempfile
from unittest import mock
import datetime
import dask.dataframe as dd
from fsspec.implementations.local import LocalFileSystem
@@ -1348,3 +1350,17 @@ def test_find_with_prefix(storage):
assert test_1s == [test_bucket_name + "/prefixes/test_1"] + [
test_bucket_name + f"/prefixes/test_{cursor}" for cursor in range(10, 20)
]


def test_max_concurrency(storage):
fs = AzureBlobFileSystem(
account_name=storage.account_name, connection_string=CONN_STR, max_concurrency=2
)

assert isinstance(fs._blob_client_semaphore, asyncio.Semaphore)

fs._blob_client_semaphore = mock.MagicMock(fs._blob_client_semaphore)
path = {f"/data/{i}": b"value" for i in range(10)}
fs.pipe(path)

assert fs._blob_client_semaphore.__aenter__.call_count == 10