Skip to content

Commit

Permalink
Add ZstdSerializer, change Key to include "domain"
Browse files Browse the repository at this point in the history
  • Loading branch information
bisho committed Mar 5, 2024
1 parent db56483 commit 7e96f8e
Show file tree
Hide file tree
Showing 11 changed files with 702 additions and 107 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ application-import-names = meta_memcache,tests
import-order-style = cryptography
per-file-ignores =
__init__.py:F401
tests/*:S101,S403
tests/*:S101,S403
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,26 @@ will be gone or present, since they are stored in the same server). Note this is
also risky, if you place all keys of a user in the same server, and the server
goes down, the user life will be miserable.

### Unicode keys:
Unicode keys are supported, the keys will be hashed according to Meta commands
### Custom domains:
You can add a domain to keys. This domain can be used for custom per-domain
metrics like hit ratios or to control serialization of the values.
```python:
Key("key:1:2", domain="example")
```
For example the ZstdSerializer allows to configure different dictionaries by
domain, so you can compress more efficiently data of different domains.

### Unicode/binary keys:
Both unicode and binary keys are supported, the keys will be hashed/encoded according to Meta commands
[binary encoded keys](https://github.com/memcached/memcached/wiki/MetaCommands#binary-encoded-keys)
specification.

To use this, mark the key as unicode:
Using binary keys can have benefits, saving space in memory. While over the wire the key
is transmited b64 encoded, the memcache server will use the byte representation, so it will
not have the 1/4 overhead of b64 encoding.

```python:
Key("🍺", unicode=True)
Key("🍺")
```

### Large keys:
Expand Down
234 changes: 184 additions & 50 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ packages = [{include = "meta_memcache", from="src"}]
python = "^3.8"
uhashring = "^2.1"
marisa-trie = "^1.0.0"
meta-memcache-socket = "0.1.1"
meta-memcache-socket = "0.1.3"
zstandard = "^0.22.0"

[tool.poetry.group.extras.dependencies]
prometheus-client = "^0.17.1"
Expand All @@ -27,7 +28,7 @@ testpaths = [

[tool.isort]
profile = "black"
known_third_party = ["uhashring", "pytest", "pytest_mock", "marisa-trie"]
known_third_party = ["uhashring", "pytest", "pytest_mock", "marisa-trie", "zstandard"]

[tool.coverage.paths]
source = ["src", "*/site-packages"]
Expand Down
3 changes: 2 additions & 1 deletion src/meta_memcache/base/base_serializer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, NamedTuple

from meta_memcache.protocol import Blob
from meta_memcache.protocol import Blob, Key


class EncodedValue(NamedTuple):
Expand All @@ -13,6 +13,7 @@ class BaseSerializer(ABC):
@abstractmethod
def serialize(
self,
key: Key,
value: Any,
) -> EncodedValue: ...

Expand Down
7 changes: 4 additions & 3 deletions src/meta_memcache/executors/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def _build_cmd(

def _prepare_serialized_value_and_flags(
self,
key: Key,
value: ValueContainer,
flags: Optional[RequestFlags],
) -> Tuple[Optional[bytes], RequestFlags]:
encoded_value = self._serializer.serialize(value.value)
encoded_value = self._serializer.serialize(key, value.value)
flags = flags if flags is not None else RequestFlags()
flags.client_flag = encoded_value.encoding_id
return encoded_value.data, flags
Expand Down Expand Up @@ -106,7 +107,7 @@ def exec_on_pool(
cmd_value, flags = (
(None, flags)
if value is None
else self._prepare_serialized_value_and_flags(value, flags)
else self._prepare_serialized_value_and_flags(key, value, flags)
)
try:
conn = pool.pop_connection()
Expand Down Expand Up @@ -159,7 +160,7 @@ def exec_multi_on_pool( # noqa: C901
cmd_value, flags = (
(None, flags)
if value is None
else self._prepare_serialized_value_and_flags(value, flags)
else self._prepare_serialized_value_and_flags(key, value, flags)
)

self._conn_send_cmd(
Expand Down
11 changes: 7 additions & 4 deletions src/meta_memcache/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@

@dataclass
class Key:
__slots__ = ("key", "routing_key", "is_unicode")
__slots__ = ("key", "routing_key", "domain", "disable_compression")
key: str
routing_key: Optional[str]
is_unicode: bool
domain: Optional[str]
disable_compression: bool

def __init__(
self,
key: str,
routing_key: Optional[str] = None,
is_unicode: bool = False,
domain: Optional[str] = None,
disabled_compression: bool = False,
) -> None:
self.key = key
self.routing_key = routing_key
self.is_unicode = is_unicode
self.domain = domain
self.disable_compression = disabled_compression

def __hash__(self) -> int:
return hash((self.key, self.routing_key))
Expand Down
178 changes: 176 additions & 2 deletions src/meta_memcache/serializer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pickle # noqa: S403
import zlib
from typing import Any
from typing import Any, Dict, List, NamedTuple, Optional, Tuple

from meta_memcache.base.base_serializer import BaseSerializer, EncodedValue
from meta_memcache.protocol import Blob
from meta_memcache.protocol import Blob, Key
import zstandard as zstd


class MixedSerializer(BaseSerializer):
Expand All @@ -13,13 +14,15 @@ class MixedSerializer(BaseSerializer):
LONG = 4
ZLIB_COMPRESSED = 8
BINARY = 16

COMPRESSION_THRESHOLD = 128

def __init__(self, pickle_protocol: int = 0) -> None:
self._pickle_protocol = pickle_protocol

def serialize(
self,
key: Key,
value: Any,
) -> EncodedValue:
if isinstance(value, bytes):
Expand Down Expand Up @@ -53,3 +56,174 @@ def unserialize(self, data: Blob, encoding_id: int) -> Any:
return bytes(data)
else:
return pickle.loads(data) # noqa: S301


class DictionaryMapping(NamedTuple):
dictionary: bytes
active_domains: List[str]


class ZstdSerializer(BaseSerializer):
STR = 0
PICKLE = 1
INT = 2
LONG = 4
ZLIB_COMPRESSED = 8
BINARY = 16
ZSTD_COMPRESSED = 32

ZSTD_MAGIC = b"(\xb5/\xfd"
DEFAULT_PICKLE_PROTOCOL = 5
DEFAULT_COMPRESSION_LEVEL = 9
DEFAULT_COMPRESSION_THRESHOLD = 128
DEFAULT_DICT_COMPRESSION_THRESHOLD = 64

_pickle_protocol: int
_compression_level: int
_default_compression_threshold: int
_dict_compression_threshold: int
_zstd_compressors: Dict[int, zstd.ZstdCompressor]
_zstd_decompressors: Dict[int, zstd.ZstdDecompressor]
_domain_to_dict_id: Dict[str, int]
_default_zstd_compressor: Optional[zstd.ZstdCompressor]

def __init__(
self,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
compression_level: int = DEFAULT_COMPRESSION_LEVEL,
compression_threshold: int = DEFAULT_COMPRESSION_THRESHOLD,
dict_compression_threshold: int = DEFAULT_DICT_COMPRESSION_THRESHOLD,
dictionary_mappings: Optional[List[DictionaryMapping]] = None,
default_dictionary: Optional[bytes] = None,
default_zstd: bool = True,
) -> None:
self._pickle_protocol = pickle_protocol
self._compression_level = compression_level
self._default_compression_threshold = (
compression_threshold
if not default_dictionary
else dict_compression_threshold
)
self._dict_compression_threshold = dict_compression_threshold
self._zstd_compressors = {}
self._zstd_decompressors = {}
self._domain_to_dict_id = {}

compression_params = zstd.ZstdCompressionParameters.from_level(
compression_level,
format=zstd.FORMAT_ZSTD1_MAGICLESS,
write_content_size=True,
write_checksum=False,
write_dict_id=True,
)

if dictionary_mappings:
for dictionary_mapping in dictionary_mappings:
dict_id, zstd_dict = self._build_dict(dictionary_mapping.dictionary)
self._add_dict_decompressor(dict_id, zstd_dict)
if dictionary_mapping.active_domains:
# The dictionary is active for some domains
self._add_dict_compressor(dict_id, zstd_dict, compression_params)
for domain in dictionary_mapping.active_domains:
self._domain_to_dict_id[domain] = dict_id

if default_dictionary:
dict_id, zstd_dict = self._build_dict(default_dictionary)
self._add_dict_decompressor(dict_id, zstd_dict)

self._default_zstd_compressor = self._add_dict_compressor(
dict_id, zstd_dict, compression_params
)
elif default_zstd:
self._default_zstd_compressor = zstd.ZstdCompressor(
compression_params=compression_params
)
else:
self._default_zstd_compressor = None

self._zstd_decompressors[0] = zstd.ZstdDecompressor()

def _build_dict(self, dictionary: bytes) -> Tuple[int, zstd.ZstdCompressionDict]:
zstd_dict = zstd.ZstdCompressionDict(dictionary)
dict_id = zstd_dict.dict_id()
return dict_id, zstd_dict

def _add_dict_decompressor(
self, dict_id: int, zstd_dict: zstd.ZstdCompressionDict
) -> zstd.ZstdDecompressor:
self._zstd_decompressors[dict_id] = zstd.ZstdDecompressor(dict_data=zstd_dict)
return self._zstd_decompressors[dict_id]

def _add_dict_compressor(
self,
dict_id: int,
zstd_dict: zstd.ZstdCompressionDict,
compression_params: zstd.ZstdCompressionParameters,
) -> zstd.ZstdCompressor:
self._zstd_compressors[dict_id] = zstd.ZstdCompressor(
dict_data=zstd_dict, compression_params=compression_params
)
return self._zstd_compressors[dict_id]

def _compress(self, key: Key, data: bytes) -> Tuple[bytes, int]:
if key.domain and (dict_id := self._domain_to_dict_id.get(key.domain)):
return self._zstd_compressors[dict_id].compress(data), self.ZSTD_COMPRESSED
elif self._default_zstd_compressor:
return self._default_zstd_compressor.compress(data), self.ZSTD_COMPRESSED
else:
return zlib.compress(data), self.ZLIB_COMPRESSED

def _decompress(self, data: bytes) -> bytes:
data = self.ZSTD_MAGIC + data
dict_id = zstd.get_frame_parameters(data).dict_id
if decompressor := self._zstd_decompressors.get(dict_id):
return decompressor.decompress(data)
raise ValueError(f"Unknown dictionary id: {dict_id}")

def _should_compress(self, key: Key, data: bytes) -> bool:
data_len = len(data)
if data_len >= self._default_compression_threshold:
return True
elif data_len >= self._dict_compression_threshold:
return bool(key.domain and self._domain_to_dict_id.get(key.domain))
return False

def serialize(
self,
key: Key,
value: Any,
) -> EncodedValue:
if isinstance(value, bytes):
data = value
encoding_id = self.BINARY
elif isinstance(value, int) and not isinstance(value, bool):
data = str(value).encode("ascii")
encoding_id = self.INT
elif isinstance(value, str):
data = str(value).encode()
encoding_id = self.STR
else:
data = pickle.dumps(value, protocol=self._pickle_protocol)
encoding_id = self.PICKLE

if not key.disable_compression and self._should_compress(key, data):
data, compression_flag = self._compress(key, data)
encoding_id |= compression_flag
return EncodedValue(data=data, encoding_id=encoding_id)

def unserialize(self, data: Blob, encoding_id: int) -> Any:
if encoding_id & self.ZLIB_COMPRESSED:
data = zlib.decompress(data)
encoding_id ^= self.ZLIB_COMPRESSED
elif encoding_id & self.ZSTD_COMPRESSED:
data = self._decompress(data)
encoding_id ^= self.ZSTD_COMPRESSED

if encoding_id == self.STR:
return bytes(data).decode()
elif encoding_id in (self.INT, self.LONG):
return int(data)
elif encoding_id == self.BINARY:
return bytes(data)
else:
return pickle.loads(data) # noqa: S301
14 changes: 7 additions & 7 deletions tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def test_get_cmd(memcache_socket: MemcacheSocket, cache_client: CacheClient) ->
memcache_socket.sendall.reset_mock()

cache_client.get(
key=Key("úníçod⍷", is_unicode=True),
key=Key("úníçod⍷"),
touch_ttl=300,
recache_policy=RecachePolicy(),
)
Expand Down Expand Up @@ -614,7 +614,7 @@ def test_get_miss(memcache_socket: MemcacheSocket, cache_client: CacheClient) ->
def test_get_value(memcache_socket: MemcacheSocket, cache_client: CacheClient) -> None:
expected_cas_token = 123
expected_value = Foo("hello world")
encoded_value = MixedSerializer().serialize(expected_value)
encoded_value = MixedSerializer().serialize(Key("foo"), expected_value)
memcache_socket.get_response.return_value = Value(
size=len(encoded_value.data),
value=None,
Expand Down Expand Up @@ -659,7 +659,7 @@ def test_value_wrong_type(
) -> None:
expected_cas_token = 123
expected_value = Foo("hello world")
encoded_value = MixedSerializer().serialize(expected_value)
encoded_value = MixedSerializer().serialize(Key("foo"), expected_value)
memcache_socket.get_response.return_value = Value(
size=len(encoded_value.data),
value=None,
Expand Down Expand Up @@ -722,7 +722,7 @@ def test_recache_win_returns_miss(
) -> None:
expected_cas_token = 123
expected_value = Foo("hello world")
encoded_value = MixedSerializer().serialize(expected_value)
encoded_value = MixedSerializer().serialize(Key("foo"), expected_value)
memcache_socket.get_response.return_value = Value(
size=len(encoded_value.data),
value=None,
Expand All @@ -745,7 +745,7 @@ def test_recache_lost_returns_stale_value(
) -> None:
expected_cas_token = 123
expected_value = Foo("hello world")
encoded_value = MixedSerializer().serialize(expected_value)
encoded_value = MixedSerializer().serialize(Key("foo"), expected_value)
memcache_socket.get_response.return_value = Value(
size=len(encoded_value.data),
value=None,
Expand All @@ -768,7 +768,7 @@ def test_get_or_lease_hit(
) -> None:
expected_cas_token = 123
expected_value = Foo("hello world")
encoded_value = MixedSerializer().serialize(expected_value)
encoded_value = MixedSerializer().serialize(Key("foo"), expected_value)
memcache_socket.get_response.return_value = Value(
size=len(encoded_value.data),
value=None,
Expand Down Expand Up @@ -822,7 +822,7 @@ def test_get_or_lease_miss_lost_then_data(
) -> None:
expected_cas_token = 123
expected_value = Foo("hello world")
encoded_value = MixedSerializer().serialize(expected_value)
encoded_value = MixedSerializer().serialize(Key("foo"), expected_value)
memcache_socket.get_response.side_effect = [
Value(
size=0,
Expand Down
Loading

0 comments on commit 7e96f8e

Please sign in to comment.