diff --git a/adlfs/spec.py b/adlfs/spec.py index 610b77cb..eb160e05 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -297,7 +297,11 @@ def __init__( if self.credential is not None: weakref.finalize(self, sync, self.loop, close_credential, self) - self.max_concurrency = max_concurrency or (_get_batch_size() // 4) + if max_concurrency is None: + batch_size = _get_batch_size() + if batch_size > 0: + max_concurrency = batch_size + self.max_concurrency = max_concurrency @classmethod def _strip_protocol(cls, path: str): @@ -1453,6 +1457,12 @@ async def _pipe_file( pipe_file = sync_wrapper(_pipe_file) + async def _pipe(self, *args, batch_size=None, max_concurrency=None, **kwargs): + max_concurrency = max_concurrency or 1 + return await super()._pipe( + *args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs + ) + async def _cat_file( self, path, start=None, end=None, max_concurrency=None, **kwargs ): @@ -1512,6 +1522,12 @@ def cat(self, path, recursive=False, on_error="raise", **kwargs): else: return self.cat_file(paths[0]) + async def _cat_ranges(self, *args, batch_size=None, max_concurrency=None, **kwargs): + max_concurrency = max_concurrency or 1 + return await super()._cat_ranges( + *args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs + ) + def url(self, path, expires=3600, **kwargs): return sync(self.loop, self._url, path, expires, **kwargs) @@ -1656,6 +1672,12 @@ async def _put_file( put_file = sync_wrapper(_put_file) + async def _put(self, *args, batch_size=None, max_concurrency=None, **kwargs): + max_concurrency = max_concurrency or 1 + return await super()._put( + *args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs + ) + async def _cp_file(self, path1, path2, **kwargs): """Copy the file at path1 to path2""" container1, path1, version_id = self.split_path(path1, delimiter="/") @@ -1722,6 +1744,12 @@ async def _get_file( get_file = sync_wrapper(_get_file) + async def _get(self, *args, batch_size=None, max_concurrency=None, **kwargs): + max_concurrency = max_concurrency or 1 + return await super()._get( + *args, batch_size=batch_size, max_concurrency=max_concurrency, **kwargs + ) + def getxattr(self, path, attr): meta = self.info(path).get("metadata", {}) return meta[attr]