Skip to content

Commit

Permalink
Refactor response flag parsing
Browse files Browse the repository at this point in the history
## Motivation / Description
We spent a lot of time tokenizing and then parsing the response flags.
This does the flag parsing in one single loop, speeding things up.
Also this rewrites response flags not to be Flags/IntFlags as they
can't be used for requests, and having them as part of the Value
dataclass speeds things up as well as simplifies the usage of the
library.

## Changes introduced
- Avoid header tokenization phase
- Move flags to Success/Value class as attributes
- Move flag parsing to those classes.
  • Loading branch information
bisho committed Nov 6, 2023
1 parent ded6e9b commit 9027621
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 214 deletions.
16 changes: 7 additions & 9 deletions src/meta_memcache/commands/high_level_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,24 +289,23 @@ def get_or_lease_cas(

if isinstance(result, Value):
# It is a hit.
cas_token = result.int_flags.get(IntFlag.RETURNED_CAS_TOKEN)
if Flag.WIN in result.flags:
if result.win:
# Win flag present, meaning we got the lease to
# recache/cache the item. We need to mimic a miss.
return None, cas_token
if result.size == 0 and Flag.LOST in result.flags:
return None, result.cas_token
if result.size == 0 and result.win is False:
# The value is empty, this is a miss lease,
# and we lost, so we must keep retrying and
# wait for the winner to populate the value.
if i < lease_policy.miss_retries:
continue
else:
# We run out of retries, behave as a miss
return None, cas_token
return None, result.cas_token
else:
# There is data, either the is no lease or
# we lost and should use the stale value.
return result.value, cas_token
return result.value, result.cas_token
else:
# With MISS_LEASE_TTL we should always get a value
# because on miss a lease empty value is generated
Expand Down Expand Up @@ -381,8 +380,7 @@ def get_cas(
if result is None:
return None, None
else:
cas_token = result.int_flags.get(IntFlag.RETURNED_CAS_TOKEN)
return result.value, cas_token
return result.value, result.cas_token

def _get(
self: HighLevelCommandMixinWithMetaCommands,
Expand Down Expand Up @@ -418,7 +416,7 @@ def _process_get_result(
) -> Optional[Value]:
if isinstance(result, Value):
# It is a hit
if Flag.WIN in result.flags:
if result.win:
# Win flag present, meaning we got the lease to
# recache the item. We need to mimic a miss, so
# we set the value to None.
Expand Down
106 changes: 28 additions & 78 deletions src/meta_memcache/connection/memcache_socket.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import logging
import socket
from typing import Iterable, List, Union
from typing import Union

from meta_memcache.errors import MemcacheError
from meta_memcache.protocol import (
ENDL,
ENDL_LEN,
NOOP,
SPACE,
Conflict,
Miss,
NotStored,
ServerVersion,
Success,
Value,
flag_values,
get_store_success_response_header,
int_flags_values,
token_flags_values,
)

_log: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -119,73 +115,34 @@ def _reset_buffer(self) -> None:
Reset buffer moving remaining bytes in it (if any)
"""
remaining_data = self._read - self._pos
if remaining_data > 0:
if self._pos <= self._reset_buffer_size:
# Avoid moving memory if buffer still has
# spare capacity for new responses. If the
# whole buffer us used, we will just reset
# the pointers and save a lot of memory
# data copies
return
self._buf_view[0:remaining_data] = self._buf_view[self._pos : self._read]
self._buf_view[0:remaining_data] = self._buf_view[self._pos : self._read]
self._pos = 0
self._read = remaining_data

def _recv_header(self) -> memoryview:
endl_pos = self._buf.find(ENDL, self._pos, self._read)
while endl_pos < 0 and self._read < self._buffer_size:
def _get_single_header(self) -> memoryview:
# Reset buffer for new data
if self._read == self._pos:
self._read = 0
self._pos = 0
elif self._pos > self._reset_buffer_size:
self._reset_buffer()

endl_pos = -1
while True:
if self._read - self._pos > ENDL_LEN:
endl_pos = self._buf.find(ENDL, self._pos, self._read)
if endl_pos >= 0:
break
# Missing data, but still space in buffer, so read more
if self._recv_info_buffer() <= 0:
break
endl_pos = self._buf.find(ENDL, self._pos, self._read)

if endl_pos < 0:
raise MemcacheError("Bad response. Socket might have closed unexpectedly")

header = self._buf_view[self._pos : endl_pos]
pos = self._pos
self._pos = endl_pos + ENDL_LEN
return header

def _add_flags(self, success: Success, chunks: Iterable[memoryview]) -> None:
"""
Each flag starts with one byte for the flag, and and optional int/byte
value depending on the flag.
"""
for chunk in chunks:
flag = chunk[0]
if len(chunk) == 1:
# Flag without value
if f := flag_values.get(flag):
success.flags.add(f)
else:
_log.warning(f"Unrecognized flag {bytes(chunk)!r}")
else:
# Value flag
if int_flag := int_flags_values.get(flag):
success.int_flags[int_flag] = int(chunk[1:])
elif token_flag := token_flags_values.get(flag):
success.token_flags[token_flag] = bytes(chunk[1:])
else:
_log.warning(f"Unrecognized flag {bytes(chunk)!r}")

def _tokenize_header(self, header: memoryview) -> List[memoryview]:
"""
Slice header by spaces into memoryview chunks
"""
chunks = []
prev, i = 0, -1
for i, v in enumerate(header):
if v == SPACE:
if i > prev:
chunks.append(header[prev:i])
prev = i + 1
if prev <= i:
chunks.append(header[prev:])
return chunks

def _get_single_header(self) -> List[memoryview]:
self._reset_buffer()
return self._tokenize_header(self._recv_header())
return self._buf_view[pos:endl_pos]

def sendall(self, data: bytes, with_noop: bool = False) -> None:
if with_noop:
Expand All @@ -195,11 +152,11 @@ def sendall(self, data: bytes, with_noop: bool = False) -> None:

def _read_until_noop_header(self) -> None:
while self._noop_expected > 0:
response_code, *_chunks = self._get_single_header()
if response_code == b"MN":
header = self._get_single_header()
if header[0:2] == b"MN":
self._noop_expected -= 1

def _get_header(self) -> List[memoryview]:
def _get_header(self) -> memoryview:
try:
if self._noop_expected > 0:
self._read_until_noop_header()
Expand All @@ -212,37 +169,30 @@ def _get_header(self) -> List[memoryview]:
def get_response(
self,
) -> Union[Value, Success, NotStored, Conflict, Miss]:
header = self._get_header()
header = self._get_header().tobytes()
response_code = header[0:2]
result: Union[Value, Success, NotStored, Conflict, Miss]
try:
response_code, *chunks = header
if response_code == b"VA":
# Value response, parse size and flags
value_size = int(chunks.pop(0))
result = Value(value_size)
self._add_flags(result, chunks)
# Value response
result = Value.from_header(header)
elif response_code == self._store_success_response_header:
# Stored or no value, return Success
result = Success()
self._add_flags(result, chunks)
result = Success.from_header(header)
elif response_code == b"NS":
# Value response, parse size and flags
result = NOT_STORED
assert len(chunks) == 0 # noqa: S101
elif response_code == b"EX":
# Already exists, not changed, CAS conflict
result = CONFLICT
assert len(chunks) == 0 # noqa: S101
elif response_code == b"EN" or response_code == b"NF":
# Not Found, Miss.
result = MISS
assert len(chunks) == 0 # noqa: S101
else:
raise MemcacheError(f"Unknown response: {bytes(response_code)!r}")
except Exception as e:
response = b" ".join(header).decode()
_log.warning(f"Error parsing response header in {self}: {response}")
raise MemcacheError(f"Error parsing response header {response}") from e
_log.warning(f"Error parsing response header in {self}: {header!r}")
raise MemcacheError(f"Error parsing response header {header!r}") from e

return result

Expand Down
4 changes: 2 additions & 2 deletions src/meta_memcache/executors/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,12 @@ def _conn_recv_response(
Read response on a connection
"""
if flags and Flag.NOREPLY in flags:
return Success(flags=set([Flag.NOREPLY]))
return Success()
result = conn.get_response()
if isinstance(result, Value):
data = conn.get_value(result.size)
if result.size > 0:
encoding_id = result.int_flags.get(IntFlag.CLIENT_FLAG, 0)
encoding_id = result.client_flag or 0
try:
result.value = self._serializer.unserialize(data, encoding_id)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion src/meta_memcache/extras/migrating_cache_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_migration_mode(self) -> MigrationMode:
return current_mode

def _get_value_ttl(self, value: Value) -> int:
ttl = value.int_flags.get(IntFlag.TTL, self._default_read_backfill_ttl)
ttl = value.ttl if value.ttl is not None else self._default_read_backfill_ttl
if ttl < 0:
# TTL for items marked to store forvered is returned as -1
ttl = 0
Expand Down
6 changes: 3 additions & 3 deletions src/meta_memcache/extras/probabilistic_hot_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from meta_memcache.extras.client_wrapper import ClientWrapper
from meta_memcache.interfaces.cache_api import CacheApi
from meta_memcache.metrics.base import BaseMetricsCollector, MetricDefinition
from meta_memcache.protocol import IntFlag, Key, Value
from meta_memcache.protocol import Key, Value


@dataclass
Expand Down Expand Up @@ -124,8 +124,8 @@ def _store_in_hot_cache_if_necessary(
allowed: bool,
) -> None:
if not is_hot:
hit_after_write = value.int_flags.get(IntFlag.HIT_AFTER_WRITE, 0)
last_read_age = value.int_flags.get(IntFlag.LAST_READ_AGE, 9999)
hit_after_write = value.fetched or 0
last_read_age = value.last_access if value.last_access is not None else 9999
if (
hit_after_write > 0
and last_read_age <= self._max_last_access_age_seconds
Expand Down
Loading

0 comments on commit 9027621

Please sign in to comment.