Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support max_concurrency in upload_blob and download_blob operations #420

Merged
merged 2 commits into from
Aug 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions adlfs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from azure.storage.blob._shared.base_client import create_configuration
from azure.storage.blob.aio import BlobServiceClient as AIOBlobServiceClient
from azure.storage.blob.aio._list_blobs_helper import BlobPrefix
from fsspec.asyn import AsyncFileSystem, get_loop, sync, sync_wrapper
from fsspec.asyn import AsyncFileSystem, _get_batch_size, get_loop, sync, sync_wrapper
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
from fsspec.spec import AbstractBufferedFile
from fsspec.utils import infer_storage_options

Expand Down Expand Up @@ -165,6 +165,10 @@ class AzureBlobFileSystem(AsyncFileSystem):
False throws if retrieving container properties fails, which might happen if your
authentication is only valid at the storage container level, and not the
storage account level.
max_concurrency:
The number of concurrent connections to use when uploading or downloading a blob.
If None it will be inferred from fsspec.asyn._get_batch_size().

Pass on to fsspec:

skip_instance_cache: to control reuse of instances
Expand Down Expand Up @@ -227,6 +231,7 @@ def __init__(
default_cache_type: str = "bytes",
version_aware: bool = False,
assume_container_exists: Optional[bool] = None,
max_concurrency: Optional[int] = None,
**kwargs,
):
super_kwargs = {
Expand Down Expand Up @@ -292,6 +297,12 @@ def __init__(
if self.credential is not None:
weakref.finalize(self, sync, self.loop, close_credential, self)

if max_concurrency is None:
batch_size = _get_batch_size()
if batch_size > 0:
max_concurrency = batch_size
self.max_concurrency = max_concurrency

@classmethod
def _strip_protocol(cls, path: str):
"""
Expand Down Expand Up @@ -1426,7 +1437,9 @@ async def _dir_exists(self, container, path):
except ResourceNotFoundError:
return False

async def _pipe_file(self, path, value, overwrite=True, **kwargs):
async def _pipe_file(
self, path, value, overwrite=True, max_concurrency=None, **kwargs
):
"""Set the bytes of given file"""
container_name, path, _ = self.split_path(path)
async with self.service_client.get_blob_client(
Expand All @@ -1436,14 +1449,23 @@ async def _pipe_file(self, path, value, overwrite=True, **kwargs):
data=value,
overwrite=overwrite,
metadata={"is_directory": "false"},
max_concurrency=max_concurrency or self.max_concurrency,
**kwargs,
)
self.invalidate_cache(self._parent(path))
return result

pipe_file = sync_wrapper(_pipe_file)

async def _cat_file(self, path, start=None, end=None, **kwargs):
async def _pipe(self, *args, batch_size=None, max_concurrency=None, **kwargs):
max_concurrency = max_concurrency or 1
return await super()._pipe(
*args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs
)

async def _cat_file(
self, path, start=None, end=None, max_concurrency=None, **kwargs
):
path = self._strip_protocol(path)
if end is not None:
start = start or 0 # download_blob requires start if length is provided.
Expand All @@ -1456,7 +1478,10 @@ async def _cat_file(self, path, start=None, end=None, **kwargs):
) as bc:
try:
stream = await bc.download_blob(
offset=start, length=length, version_id=version_id
offset=start,
length=length,
version_id=version_id,
max_concurrency=max_concurrency or self.max_concurrency,
)
except ResourceNotFoundError as e:
raise FileNotFoundError from e
Expand Down Expand Up @@ -1497,6 +1522,12 @@ def cat(self, path, recursive=False, on_error="raise", **kwargs):
else:
return self.cat_file(paths[0])

async def _cat_ranges(self, *args, batch_size=None, max_concurrency=None, **kwargs):
max_concurrency = max_concurrency or 1
return await super()._cat_ranges(
*args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs
)

def url(self, path, expires=3600, **kwargs):
return sync(self.loop, self._url, path, expires, **kwargs)

Expand Down Expand Up @@ -1593,7 +1624,14 @@ async def _expand_path(
return list(sorted(out))

async def _put_file(
self, lpath, rpath, delimiter="/", overwrite=False, callback=None, **kwargws
self,
lpath,
rpath,
delimiter="/",
overwrite=False,
callback=None,
max_concurrency=None,
**kwargws,
):
"""
Copy single file to remote
Expand Down Expand Up @@ -1621,6 +1659,7 @@ async def _put_file(
raw_response_hook=make_callback(
"upload_stream_current", callback
),
max_concurrency=max_concurrency or self.max_concurrency,
)
self.invalidate_cache()
except ResourceExistsError:
Expand All @@ -1633,6 +1672,12 @@ async def _put_file(

put_file = sync_wrapper(_put_file)

async def _put(self, *args, batch_size=None, max_concurrency=None, **kwargs):
max_concurrency = max_concurrency or 1
return await super()._put(
*args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs
)

async def _cp_file(self, path1, path2, **kwargs):
"""Copy the file at path1 to path2"""
container1, path1, version_id = self.split_path(path1, delimiter="/")
Expand Down Expand Up @@ -1668,7 +1713,14 @@ def download(self, rpath, lpath, recursive=False, **kwargs):
return self.get(rpath, lpath, recursive=recursive, **kwargs)

async def _get_file(
self, rpath, lpath, recursive=False, delimiter="/", callback=None, **kwargs
self,
rpath,
lpath,
recursive=False,
delimiter="/",
callback=None,
max_concurrency=None,
**kwargs,
):
"""Copy single file remote to local"""
if os.path.isdir(lpath):
Expand All @@ -1683,6 +1735,7 @@ async def _get_file(
"download_stream_current", callback
),
version_id=version_id,
max_concurrency=max_concurrency or self.max_concurrency,
)
with open(lpath, "wb") as my_blob:
await stream.readinto(my_blob)
Expand All @@ -1691,6 +1744,12 @@ async def _get_file(

get_file = sync_wrapper(_get_file)

async def _get(self, *args, batch_size=None, max_concurrency=None, **kwargs):
max_concurrency = max_concurrency or 1
return await super()._get(
*args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs
)

def getxattr(self, path, attr):
meta = self.info(path).get("metadata", {})
return meta[attr]
Expand Down
13 changes: 10 additions & 3 deletions adlfs/tests/test_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,9 @@ def test_info_missing(storage, path):

def test_time_info(storage):
fs = AzureBlobFileSystem(
account_name=storage.account_name, connection_string=CONN_STR
account_name=storage.account_name,
connection_string=CONN_STR,
max_concurrency=1,
Copy link
Contributor Author

@pmrowla pmrowla Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concurrency cannot be used for this test, otherwise storage.insert_time will be the timestamp of the first completed chunk and creation_time will be the timestamp of the finished operation (after all chunks are uploaded).

)

creation_time = fs.created("data/root/d/file_with_metadata.txt")
Expand Down Expand Up @@ -1486,7 +1488,10 @@ async def test_cat_file_versioned(storage, mocker):

await fs._cat_file(f"data/root/a/file.txt?versionid={DEFAULT_VERSION_ID}")
download_blob.assert_called_once_with(
offset=None, length=None, version_id=DEFAULT_VERSION_ID
offset=None,
length=None,
version_id=DEFAULT_VERSION_ID,
max_concurrency=fs.max_concurrency,
)

download_blob.reset_mock()
Expand Down Expand Up @@ -1744,7 +1749,9 @@ async def test_get_file_versioned(storage, mocker, tmp_path):
f"data/root/a/file.txt?versionid={DEFAULT_VERSION_ID}", tmp_path / "file.txt"
)
download_blob.assert_called_once_with(
raw_response_hook=mocker.ANY, version_id=DEFAULT_VERSION_ID
raw_response_hook=mocker.ANY,
version_id=DEFAULT_VERSION_ID,
max_concurrency=fs.max_concurrency,
)
download_blob.reset_mock()
download_blob.side_effect = ResourceNotFoundError
Expand Down