diff --git a/CHANGELOG.md b/CHANGELOG.md index c9f4083a..90ad0e8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ Unreleased ---------- +- `AzureBlobFileSystem` and `AzureBlobFile` support pickling. - Handle mixed casing for `hdi_isfolder` metadata when determining whether a blob should be treated as a folder. - `_put_file`: `overwrite` now defaults to `True`. diff --git a/adlfs/spec.py b/adlfs/spec.py index d6c52af4..513882de 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -259,6 +259,7 @@ def __init__( read_timeout: Optional[int] = None, **kwargs, ): + self.kwargs = kwargs.copy() super_kwargs = { k: kwargs.pop(k) for k in ["use_listings_cache", "listings_expiry_time", "max_paths"] @@ -1915,22 +1916,8 @@ def __init__( else None ) - try: - # Need to confirm there is an event loop running in - # the thread. If not, create the fsspec loop - # and set it. This is to handle issues with - # Async Credentials from the Azure SDK - loop = get_running_loop() - - except RuntimeError: - loop = get_loop() - asyncio.set_event_loop(loop) - - self.loop = self.fs.loop or get_loop() - self.container_client = ( - fs.service_client.get_container_client(self.container_name) - or self.connect_client() - ) + self.loop = self._get_loop() + self.container_client = self._get_container_client() self.blocksize = ( self.DEFAULT_BLOCK_SIZE if block_size in ["default", None] else block_size ) @@ -1969,6 +1956,8 @@ def __init__( self.path, version_id=self.version_id, refresh=True ) self.size = self.details["size"] + self.cache_type = cache_type + self.cache_options = cache_options self.cache = caches[cache_type]( blocksize=self.blocksize, fetcher=self._fetch_range, @@ -1990,6 +1979,26 @@ def __init__( self.forced = False self.location = None + def _get_loop(self): + try: + # Need to confirm there is an event loop running in + # the thread. If not, create the fsspec loop + # and set it. This is to handle issues with + # Async Credentials from the Azure SDK + loop = get_running_loop() + + except RuntimeError: + loop = get_loop() + asyncio.set_event_loop(loop) + + return self.fs.loop or get_loop() + + def _get_container_client(self): + return ( + self.fs.service_client.get_container_client(self.container_name) + or self.connect_client() + ) + def close(self): """Close file and azure client.""" asyncio.run_coroutine_threadsafe(close_container_client(self), loop=self.loop) @@ -2179,3 +2188,14 @@ def __del__(self): self.close() except TypeError: pass + + def __getstate__(self): + state = self.__dict__.copy() + del state["container_client"] + del state["loop"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.loop = self._get_loop() + self.container_client = self._get_container_client() diff --git a/adlfs/tests/conftest.py b/adlfs/tests/conftest.py index c54f3865..cc3a7dc5 100644 --- a/adlfs/tests/conftest.py +++ b/adlfs/tests/conftest.py @@ -37,7 +37,8 @@ def storage(host): conn_str = f"DefaultEndpointsProtocol=http;AccountName={ACCOUNT_NAME};AccountKey={KEY};BlobEndpoint={URL}/{ACCOUNT_NAME};" # NOQA bbs = BlobServiceClient.from_connection_string(conn_str=conn_str) - bbs.create_container("data") + if "data" not in [c["name"] for c in bbs.list_containers()]: + bbs.create_container("data") container_client = bbs.get_container_client(container="data") bbs.insert_time = datetime.datetime.utcnow().replace( microsecond=0, tzinfo=datetime.timezone.utc diff --git a/adlfs/tests/test_pickling.py b/adlfs/tests/test_pickling.py new file mode 100644 index 00000000..4e46d14f --- /dev/null +++ b/adlfs/tests/test_pickling.py @@ -0,0 +1,32 @@ +import pickle + +from adlfs import AzureBlobFileSystem + +URL = "http://127.0.0.1:10000" +ACCOUNT_NAME = "devstoreaccount1" +KEY = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" # NOQA +CONN_STR = f"DefaultEndpointsProtocol=http;AccountName={ACCOUNT_NAME};AccountKey={KEY};BlobEndpoint={URL}/{ACCOUNT_NAME};" # NOQA + + +def test_fs_pickling(storage): + fs = AzureBlobFileSystem( + account_name=storage.account_name, + connection_string=CONN_STR, + kwarg1="some_value", + ) + fs2: AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) + assert "data" in fs.ls("") + assert "data" in fs2.ls("") + assert fs2.kwargs["kwarg1"] == "some_value" + + +def test_blob_pickling(storage): + fs = AzureBlobFileSystem( + account_name=storage.account_name, connection_string=CONN_STR + ) + fs2: AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) + blob = fs2.open("data/root/a/file.txt") + assert blob.read() == b"0123456789" + blob2 = pickle.loads(pickle.dumps(blob)) + blob2.seek(0) + assert blob2.read() == b"0123456789"