diff --git a/simplekv/__init__.py b/simplekv/__init__.py index ca9ff381..780b90d7 100644 --- a/simplekv/__init__.py +++ b/simplekv/__init__.py @@ -335,6 +335,23 @@ def _put_filename(self, key, filename): with open(filename, 'rb') as source: return self._put_file(key, source) + def close(self): + """Specific store implementations might require teardown methods. + (Dangling ports, unclosed files). This allows calling close also + for stores, which do not require this. + """ + return + + def __enter__(self): + """Support for with clause for automatic calling of close. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Support for with clause for automatic calling of close. + """ + self.close() + class UrlMixin(object): """Supports getting a download URL for keys.""" diff --git a/simplekv/decorator.py b/simplekv/decorator.py index e6f3e0f2..83d52da6 100644 --- a/simplekv/decorator.py +++ b/simplekv/decorator.py @@ -25,6 +25,15 @@ def __contains__(self, *args, **kwargs): def __iter__(self, *args, **kwargs): return self._dstore.__iter__(*args, **kwargs) + def close(self): + self._dstore.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + class KeyTransformingDecorator(StoreDecorator): # currently undocumented (== not advertised as a feature) diff --git a/simplekv/net/_azurestore_new.py b/simplekv/net/_azurestore_new.py index 9548bd90..67643f9c 100644 --- a/simplekv/net/_azurestore_new.py +++ b/simplekv/net/_azurestore_new.py @@ -67,6 +67,8 @@ def __init__( self.max_block_size = max_block_size self.max_single_put_size = max_single_put_size self.checksum = checksum + self._service_client = None + self._container_client = None # Using @lazy_property will (re-)create block_blob_service instance needed. # Together with the __getstate__ implementation below, this allows @@ -83,16 +85,28 @@ def blob_container_client(self): if self.max_block_size: kwargs["max_block_size"] = self.max_block_size - service_client = BlobServiceClient.from_connection_string( + self._service_client = BlobServiceClient.from_connection_string( self.conn_string, **kwargs ) - container_client = service_client.get_container_client(self.container) + self._container_client = self._service_client.get_container_client(self.container) if self.create_if_missing: with map_azure_exceptions(error_codes_pass=("ContainerAlreadyExists")): - container_client.create_container( + self._container_client.create_container( public_access="container" if self.public else None ) - return container_client + return self._container_client + + def close(self): + """ + Method to close container_client and service_client ports, if opened. + """ + if self._container_client: + self._container_client.close() + self._container_client = None + if self._service_client: + self._service_client.close() + self._service_client = None + def _delete(self, key): with map_azure_exceptions(key, error_codes_pass=("BlobNotFound",)): diff --git a/tests/basic_store.py b/tests/basic_store.py index 66211f56..dbc85914 100644 --- a/tests/basic_store.py +++ b/tests/basic_store.py @@ -128,7 +128,8 @@ def test_put_opened_file(self, store, key, value): tmp.write(value) tmp.flush() - store.put_file(key, open(tmp.name, 'rb')) + with open(tmp.name, 'rb') as infile: + store.put_file(key, infile) assert store.get(key) == value @@ -137,8 +138,8 @@ def test_get_into_file(self, store, key, value, tmp_path): out_filename = os.path.join(str(tmp_path), 'output') store.get_file(key, out_filename) - - assert open(out_filename, 'rb').read() == value + with open(out_filename, 'rb') as infile: + assert infile.read() == value def test_get_into_stream(self, store, key, value): store.put(key, value) diff --git a/tests/test_azure_store.py b/tests/test_azure_store.py index d64d19f1..60eb6bd0 100644 --- a/tests/test_azure_store.py +++ b/tests/test_azure_store.py @@ -50,6 +50,7 @@ def _delete_container(conn_string, container): # ignore the ContainerNotFound error: if ex.error_code != 'ContainerNotFound': raise + s.close() except ImportError: # for azure-storage-blob<12 from azure.storage.blob import BlockBlobService @@ -62,8 +63,9 @@ class TestAzureStorage(BasicStore, OpenSeekTellStore): def store(self): container = str(uuid()) conn_string = get_azure_conn_string() - yield AzureBlockBlobStore(conn_string=conn_string, container=container, - public=False) + with AzureBlockBlobStore(conn_string=conn_string, container=container, + public=False) as store: + yield store _delete_container(conn_string, container) @@ -79,10 +81,41 @@ def store(self): class ExtendedKeysStore(ExtendedKeyspaceMixin, AzureBlockBlobStore): pass - yield ExtendedKeysStore(conn_string=conn_string, - container=container, public=False) + with ExtendedKeysStore(conn_string=conn_string, + container=container, public=False) as store: + yield store _delete_container(conn_string, container) +@pytest.mark.filterwarnings("error") +def test_azure_dangling_port_enter_exit(): + container = str(uuid()) + conn_string = get_azure_conn_string() + with AzureBlockBlobStore(conn_string=conn_string, container=container) as store: + container_client = store.blob_container_client + +@pytest.mark.filterwarnings("error") +def test_azure_dangling_port_explicit_close(): + container = str(uuid()) + conn_string = get_azure_conn_string() + store = AzureBlockBlobStore(conn_string=conn_string, container=container) + container_client = store.blob_container_client + store.close() + +@pytest.mark.filterwarnings("error") +def test_azure_dangling_port_explicit_close_multi(): + container = str(uuid()) + conn_string = get_azure_conn_string() + store = AzureBlockBlobStore(conn_string=conn_string, container=container) + container_client = store.blob_container_client + # We check that multiclose and reuse do not do weird things + store.close() + store.close() + container_client = store.blob_container_client + container_client = store.blob_container_client + container_client = store.blob_container_client + store.close() + store.close() + store.close() def test_azure_setgetstate(): container = str(uuid()) @@ -91,9 +124,11 @@ def test_azure_setgetstate(): store.put(u'key1', b'value1') buf = pickle.dumps(store, protocol=2) + store.close() store = pickle.loads(buf) assert store.get(u'key1') == b'value1' + store.close() _delete_container(conn_string, container) @@ -114,6 +149,7 @@ def test_azure_store_attributes(): assert abbs2.create_if_missing is True assert abbs2.max_connections == 42 assert abbs2.checksum is True + abbs.close() def test_azure_special_args(): @@ -134,6 +170,7 @@ def test_azure_special_args(): cfg = abbs.blob_container_client._config assert cfg.max_single_put_size == MSP assert cfg.max_block_size == MBS + abbs.close() class TestAzureExceptionHandling(object): @@ -146,6 +183,7 @@ def test_missing_container(self): with pytest.raises(IOError) as exc: store.keys() assert u"The specified container does not exist." in str(exc.value) + store.close() def test_wrong_endpoint(self): container = str(uuid()) @@ -167,6 +205,7 @@ def test_wrong_endpoint(self): with pytest.raises(IOError) as exc: store.put(u"key", b"data") assert u"connect" in str(exc.value) + store.close() def test_wrong_credentials(self): container = str(uuid()) @@ -188,6 +227,7 @@ def test_wrong_credentials(self): with pytest.raises(IOError) as exc: store.put(u"key", b"data") assert u"Incorrect padding" in str(exc.value) + store.close() class TestChecksum(object): @@ -200,12 +240,13 @@ def store(self): container = str(uuid()) conn_string = get_azure_conn_string() - yield AzureBlockBlobStore( + with AzureBlockBlobStore( conn_string=conn_string, container=container, public=False, checksum=True, - ) + ) as store: + yield store _delete_container(conn_string, container) def _checksum(self, store): @@ -229,7 +270,8 @@ def test_checksum_put(self, store): def test_checksum_put_file(self, store, tmpdir): file_ = tmpdir.join('my_file') file_.write(self.CONTENT) - store.put_file(self.KEY, file_.open('rb')) + with file_.open('rb') as infile: + store.put_file(self.KEY, infile) assert self._checksum(store) == self.EXPECTED_CHECKSUM assert store.get(self.KEY) == self.CONTENT diff --git a/tests/test_keyvaluestore_base.py b/tests/test_keyvaluestore_base.py new file mode 100644 index 00000000..2f213665 --- /dev/null +++ b/tests/test_keyvaluestore_base.py @@ -0,0 +1,8 @@ +from unittest import mock +from simplekv import KeyValueStore + +def test_keyvaluestore_enter_exit(): + with mock.patch("simplekv.KeyValueStore.close") as closefunc: + with KeyValueStore() as kv: + pass + closefunc.assert_called_once()