diff --git a/fsspec/implementations/cached.py b/fsspec/implementations/cached.py index d16e22707..b679cce51 100644 --- a/fsspec/implementations/cached.py +++ b/fsspec/implementations/cached.py @@ -5,6 +5,7 @@ import os import tempfile import time +import weakref from shutil import rmtree from typing import TYPE_CHECKING, Any, Callable, ClassVar @@ -111,7 +112,9 @@ def __init__( "Both filesystems (fs) and target_protocol may not be both given." ) if cache_storage == "TMP": - storage = [tempfile.mkdtemp()] + tempdir = tempfile.mkdtemp() + storage = [tempdir] + weakref.finalize(self, self._remove_tempdir, tempdir) else: if isinstance(cache_storage, str): storage = [cache_storage] @@ -152,6 +155,13 @@ def _strip_protocol(path): self._strip_protocol: Callable = _strip_protocol + @staticmethod + def _remove_tempdir(tempdir): + try: + rmtree(tempdir) + except Exception: + pass + def _mkcache(self): os.makedirs(self.storage[-1], exist_ok=True) diff --git a/fsspec/implementations/tests/test_cached.py b/fsspec/implementations/tests/test_cached.py index 3307495b1..ee69045c3 100644 --- a/fsspec/implementations/tests/test_cached.py +++ b/fsspec/implementations/tests/test_cached.py @@ -1115,3 +1115,40 @@ def test_getitems_errors(tmpdir): assert m.getitems(["afile", "bfile"], on_error="omit") == {"afile": b"test"} with pytest.raises(FileNotFoundError): m.getitems(["afile", "bfile"]) + + +@pytest.mark.parametrize("temp_cache", [False, True]) +def test_cache_dir_auto_deleted(temp_cache, tmpdir): + import gc + + source = os.path.join(tmpdir, "source") + afile = os.path.join(source, "afile") + os.mkdir(source) + open(afile, "w").write("test") + + fs = fsspec.filesystem( + "filecache", + target_protocol="file", + cache_storage="TMP" if temp_cache else os.path.join(tmpdir, "cache"), + skip_instance_cache=True, # Important to avoid fs itself being cached + ) + + cache_dir = fs.storage[-1] + + # Force cache to be created + with fs.open(afile, "rb") as f: + assert f.read(5) == b"test" + + # Confirm cache exists + local = fsspec.filesystem("file") + assert local.exists(cache_dir) + + # Delete file system + del fs + gc.collect() + + # Ensure cache has been deleted, if it is temporary + if temp_cache: + assert not local.exists(cache_dir) + else: + assert local.exists(cache_dir)