From 3d7e4595c4705a9bf983e050bd730fe9eb78fb8a Mon Sep 17 00:00:00 2001 From: Gustavo Hidalgo Date: Fri, 21 Jun 2024 12:45:15 -0400 Subject: [PATCH 1/7] initial implementation --- adlfs/spec.py | 45 ++++++++++++++++++++++++++++++++++++ adlfs/tests/test_pickling.py | 25 ++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 adlfs/tests/test_pickling.py diff --git a/adlfs/spec.py b/adlfs/spec.py index d6c52af4..bd2b1d70 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -259,6 +259,7 @@ def __init__( read_timeout: Optional[int] = None, **kwargs, ): + self.kwargs = kwargs.copy() super_kwargs = { k: kwargs.pop(k) for k in ["use_listings_cache", "listings_expiry_time", "max_paths"] @@ -1843,6 +1844,34 @@ def _open( **kwargs, ) + def __getnewargs__(self): + """Return the arguments that would be passed to __init__, useful for pickling""" + # Tuple elements and __init__ parameters must be identical! + return ( + self.account_name, + self.account_key, + self.connection_string, + self.credential, + self.sas_token, + None, + _SOCKET_TIMEOUT_DEFAULT, + self.blocksize, + self.client_id, + self.client_secret, + self.tenant_id, + self.anon, + self.location_mode, + None, + self.asynchronous, + self.default_fill_cache, + self.default_cache_type, + self.version_aware, + self.assume_container_exists, + self.max_concurrency, + self._timeout_kwargs.get("timeout", None), + self._timeout_kwargs.get("connection_timeout", None), + self._timeout_kwargs.get("read_timeout", None), + self.kwargs) class AzureBlobFile(AbstractBufferedFile): """File-like operations on Azure Blobs""" @@ -1969,6 +1998,8 @@ def __init__( self.path, version_id=self.version_id, refresh=True ) self.size = self.details["size"] + self.cache_type = cache_type + self.cache_options = cache_options self.cache = caches[cache_type]( blocksize=self.blocksize, fetcher=self._fetch_range, @@ -2179,3 +2210,17 @@ def __del__(self): self.close() except TypeError: pass + + def __getnewargs__(self): + return ( + self.fs, + self.path, + self.mode, + self.block_size, + self.autocommit, + self.cache_type, + self.cache_options, + self.metadata, + self.version_id, + self.kwargs, + ) diff --git a/adlfs/tests/test_pickling.py b/adlfs/tests/test_pickling.py new file mode 100644 index 00000000..0ed92e07 --- /dev/null +++ b/adlfs/tests/test_pickling.py @@ -0,0 +1,25 @@ +import pickle +from adlfs import AzureBlobFileSystem + +URL = "http://127.0.0.1:10000" +ACCOUNT_NAME = "devstoreaccount1" +KEY = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" # NOQA +CONN_STR = f"DefaultEndpointsProtocol=http;AccountName={ACCOUNT_NAME};AccountKey={KEY};BlobEndpoint={URL}/{ACCOUNT_NAME};" # NOQA + +def test_fs_pickling(storage): + fs = AzureBlobFileSystem( + account_name=storage.account_name, connection_string=CONN_STR + ) + fs2 : AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) + assert fs.ls("") == ["data"] + assert fs2.ls("") == ["data"] + +def test_blob_pickling(storage): + fs = AzureBlobFileSystem( + account_name=storage.account_name, connection_string=CONN_STR + ) + fs2 : AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) + blob = fs2.open("data/root/a/file.txt") + assert blob.read() == b"0123456789" + blob2 = pickle.loads(pickle.dumps(blob)) + assert blob2.read() == b"0123456789" \ No newline at end of file From 371e8334f15752a86a03d464f9f511be980412e0 Mon Sep 17 00:00:00 2001 From: Gustavo Hidalgo Date: Fri, 21 Jun 2024 13:13:05 -0400 Subject: [PATCH 2/7] Successfully pickle blob --- adlfs/spec.py | 63 +++++++++++++++++++----------------- adlfs/tests/test_pickling.py | 1 + 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/adlfs/spec.py b/adlfs/spec.py index bd2b1d70..d36edab9 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -1944,22 +1944,8 @@ def __init__( else None ) - try: - # Need to confirm there is an event loop running in - # the thread. If not, create the fsspec loop - # and set it. This is to handle issues with - # Async Credentials from the Azure SDK - loop = get_running_loop() - - except RuntimeError: - loop = get_loop() - asyncio.set_event_loop(loop) - - self.loop = self.fs.loop or get_loop() - self.container_client = ( - fs.service_client.get_container_client(self.container_name) - or self.connect_client() - ) + self.loop = self._get_loop() + self.container_client = self._get_container_client() self.blocksize = ( self.DEFAULT_BLOCK_SIZE if block_size in ["default", None] else block_size ) @@ -2021,6 +2007,26 @@ def __init__( self.forced = False self.location = None + def _get_loop(self): + try: + # Need to confirm there is an event loop running in + # the thread. If not, create the fsspec loop + # and set it. This is to handle issues with + # Async Credentials from the Azure SDK + loop = get_running_loop() + + except RuntimeError: + loop = get_loop() + asyncio.set_event_loop(loop) + + return self.fs.loop or get_loop() + + def _get_container_client(self): + return ( + self.fs.service_client.get_container_client(self.container_name) + or self.connect_client() + ) + def close(self): """Close file and azure client.""" asyncio.run_coroutine_threadsafe(close_container_client(self), loop=self.loop) @@ -2211,16 +2217,15 @@ def __del__(self): except TypeError: pass - def __getnewargs__(self): - return ( - self.fs, - self.path, - self.mode, - self.block_size, - self.autocommit, - self.cache_type, - self.cache_options, - self.metadata, - self.version_id, - self.kwargs, - ) + def __getstate__(self): + # loop and container client, can be reconstructed after pickling + # Anyway they don't allow us to pickly because they are weak refs + state = self.__dict__.copy() + del state['container_client'] + del state['loop'] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.loop = self._get_loop() + self.container_client = self._get_container_client() diff --git a/adlfs/tests/test_pickling.py b/adlfs/tests/test_pickling.py index 0ed92e07..270ae326 100644 --- a/adlfs/tests/test_pickling.py +++ b/adlfs/tests/test_pickling.py @@ -22,4 +22,5 @@ def test_blob_pickling(storage): blob = fs2.open("data/root/a/file.txt") assert blob.read() == b"0123456789" blob2 = pickle.loads(pickle.dumps(blob)) + blob2.seek(0) assert blob2.read() == b"0123456789" \ No newline at end of file From d8da35179afaae8dc8f1ab06dfbbf17f0f9e72fd Mon Sep 17 00:00:00 2001 From: Gustavo Hidalgo Date: Fri, 21 Jun 2024 13:16:03 -0400 Subject: [PATCH 3/7] Remove empty line, update changelog --- CHANGELOG.md | 1 + adlfs/tests/test_pickling.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9f4083a..90ad0e8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ Unreleased ---------- +- `AzureBlobFileSystem` and `AzureBlobFile` support pickling. - Handle mixed casing for `hdi_isfolder` metadata when determining whether a blob should be treated as a folder. - `_put_file`: `overwrite` now defaults to `True`. diff --git a/adlfs/tests/test_pickling.py b/adlfs/tests/test_pickling.py index 270ae326..5ebf751c 100644 --- a/adlfs/tests/test_pickling.py +++ b/adlfs/tests/test_pickling.py @@ -13,7 +13,7 @@ def test_fs_pickling(storage): fs2 : AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) assert fs.ls("") == ["data"] assert fs2.ls("") == ["data"] - + def test_blob_pickling(storage): fs = AzureBlobFileSystem( account_name=storage.account_name, connection_string=CONN_STR From 39c73d02aabf7d38f2c724c5364a7a24118ace4f Mon Sep 17 00:00:00 2001 From: Gustavo Hidalgo Date: Fri, 21 Jun 2024 16:36:55 -0400 Subject: [PATCH 4/7] remove unnecessary getnewargs --- adlfs/spec.py | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/adlfs/spec.py b/adlfs/spec.py index d36edab9..21f6c013 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -1841,38 +1841,9 @@ def _open( cache_type=cache_type, metadata=metadata, version_id=version_id, - **kwargs, + **kwargs, ) - def __getnewargs__(self): - """Return the arguments that would be passed to __init__, useful for pickling""" - # Tuple elements and __init__ parameters must be identical! - return ( - self.account_name, - self.account_key, - self.connection_string, - self.credential, - self.sas_token, - None, - _SOCKET_TIMEOUT_DEFAULT, - self.blocksize, - self.client_id, - self.client_secret, - self.tenant_id, - self.anon, - self.location_mode, - None, - self.asynchronous, - self.default_fill_cache, - self.default_cache_type, - self.version_aware, - self.assume_container_exists, - self.max_concurrency, - self._timeout_kwargs.get("timeout", None), - self._timeout_kwargs.get("connection_timeout", None), - self._timeout_kwargs.get("read_timeout", None), - self.kwargs) - class AzureBlobFile(AbstractBufferedFile): """File-like operations on Azure Blobs""" From 826b6d6988a7d8be7acdbe13b45d811fe38b2d2c Mon Sep 17 00:00:00 2001 From: Gustavo Hidalgo Date: Fri, 21 Jun 2024 16:37:08 -0400 Subject: [PATCH 5/7] fix up tests --- adlfs/tests/conftest.py | 3 ++- adlfs/tests/test_pickling.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/adlfs/tests/conftest.py b/adlfs/tests/conftest.py index c54f3865..cc3a7dc5 100644 --- a/adlfs/tests/conftest.py +++ b/adlfs/tests/conftest.py @@ -37,7 +37,8 @@ def storage(host): conn_str = f"DefaultEndpointsProtocol=http;AccountName={ACCOUNT_NAME};AccountKey={KEY};BlobEndpoint={URL}/{ACCOUNT_NAME};" # NOQA bbs = BlobServiceClient.from_connection_string(conn_str=conn_str) - bbs.create_container("data") + if "data" not in [c["name"] for c in bbs.list_containers()]: + bbs.create_container("data") container_client = bbs.get_container_client(container="data") bbs.insert_time = datetime.datetime.utcnow().replace( microsecond=0, tzinfo=datetime.timezone.utc diff --git a/adlfs/tests/test_pickling.py b/adlfs/tests/test_pickling.py index 5ebf751c..5bcb41da 100644 --- a/adlfs/tests/test_pickling.py +++ b/adlfs/tests/test_pickling.py @@ -1,5 +1,7 @@ import pickle -from adlfs import AzureBlobFileSystem +import pytest +from adlfs import AzureBlobFileSystem, AzureBlobFile +import asyncio URL = "http://127.0.0.1:10000" ACCOUNT_NAME = "devstoreaccount1" @@ -8,11 +10,14 @@ def test_fs_pickling(storage): fs = AzureBlobFileSystem( - account_name=storage.account_name, connection_string=CONN_STR + account_name=storage.account_name, + connection_string=CONN_STR, + kwarg1= "some_value", ) fs2 : AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) - assert fs.ls("") == ["data"] - assert fs2.ls("") == ["data"] + assert "data" in fs.ls("") + assert "data" in fs2.ls("") + assert fs2.kwargs["kwarg1"] == "some_value" def test_blob_pickling(storage): fs = AzureBlobFileSystem( From c32cae7bf5f0f0c8bc70d922f801ef06fe7a6004 Mon Sep 17 00:00:00 2001 From: Gustavo Hidalgo Date: Sat, 29 Jun 2024 12:10:40 -0400 Subject: [PATCH 6/7] Remove unncessary comment --- adlfs/spec.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/adlfs/spec.py b/adlfs/spec.py index 21f6c013..db17e81c 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -2189,8 +2189,6 @@ def __del__(self): pass def __getstate__(self): - # loop and container client, can be reconstructed after pickling - # Anyway they don't allow us to pickly because they are weak refs state = self.__dict__.copy() del state['container_client'] del state['loop'] From b6ec194a1467515407db405b71658ffebd4d6074 Mon Sep 17 00:00:00 2001 From: Gustavo Hidalgo Date: Sat, 29 Jun 2024 12:16:16 -0400 Subject: [PATCH 7/7] formatting fixes --- adlfs/spec.py | 13 +++++++------ adlfs/tests/test_pickling.py | 15 ++++++++------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/adlfs/spec.py b/adlfs/spec.py index db17e81c..513882de 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -1841,9 +1841,10 @@ def _open( cache_type=cache_type, metadata=metadata, version_id=version_id, - **kwargs, + **kwargs, ) + class AzureBlobFile(AbstractBufferedFile): """File-like operations on Azure Blobs""" @@ -1915,7 +1916,7 @@ def __init__( else None ) - self.loop = self._get_loop() + self.loop = self._get_loop() self.container_client = self._get_container_client() self.blocksize = ( self.DEFAULT_BLOCK_SIZE if block_size in ["default", None] else block_size @@ -2187,13 +2188,13 @@ def __del__(self): self.close() except TypeError: pass - + def __getstate__(self): state = self.__dict__.copy() - del state['container_client'] - del state['loop'] + del state["container_client"] + del state["loop"] return state - + def __setstate__(self, state): self.__dict__.update(state) self.loop = self._get_loop() diff --git a/adlfs/tests/test_pickling.py b/adlfs/tests/test_pickling.py index 5bcb41da..4e46d14f 100644 --- a/adlfs/tests/test_pickling.py +++ b/adlfs/tests/test_pickling.py @@ -1,31 +1,32 @@ import pickle -import pytest -from adlfs import AzureBlobFileSystem, AzureBlobFile -import asyncio + +from adlfs import AzureBlobFileSystem URL = "http://127.0.0.1:10000" ACCOUNT_NAME = "devstoreaccount1" KEY = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" # NOQA CONN_STR = f"DefaultEndpointsProtocol=http;AccountName={ACCOUNT_NAME};AccountKey={KEY};BlobEndpoint={URL}/{ACCOUNT_NAME};" # NOQA + def test_fs_pickling(storage): fs = AzureBlobFileSystem( account_name=storage.account_name, connection_string=CONN_STR, - kwarg1= "some_value", + kwarg1="some_value", ) - fs2 : AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) + fs2: AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) assert "data" in fs.ls("") assert "data" in fs2.ls("") assert fs2.kwargs["kwarg1"] == "some_value" + def test_blob_pickling(storage): fs = AzureBlobFileSystem( account_name=storage.account_name, connection_string=CONN_STR ) - fs2 : AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) + fs2: AzureBlobFileSystem = pickle.loads(pickle.dumps(fs)) blob = fs2.open("data/root/a/file.txt") assert blob.read() == b"0123456789" blob2 = pickle.loads(pickle.dumps(blob)) blob2.seek(0) - assert blob2.read() == b"0123456789" \ No newline at end of file + assert blob2.read() == b"0123456789"