Skip to content

Commit

Permalink
Faster default coder for unknown windows. (apache#33382)
Browse files Browse the repository at this point in the history
This will get used in a windowed reshuffle, among other places.
  • Loading branch information
robertwb authored Dec 17, 2024
1 parent d34409e commit f7a7bdd
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 1 deletion.
12 changes: 12 additions & 0 deletions sdks/python/apache_beam/coders/coder_impl.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ cdef libc.stdint.int64_t MIN_TIMESTAMP_micros
cdef libc.stdint.int64_t MAX_TIMESTAMP_micros


cdef class _OrderedUnionCoderImpl(StreamCoderImpl):
cdef tuple _types
cdef tuple _coder_impls
cdef CoderImpl _fallback_coder_impl

@cython.locals(ix=int, c=CoderImpl)
cpdef encode_to_stream(self, value, OutputStream stream, bint nested)

@cython.locals(ix=int, c=CoderImpl)
cpdef decode_from_stream(self, InputStream stream, bint nested)


cdef class WindowedValueCoderImpl(StreamCoderImpl):
"""A coder for windowed values."""
cdef CoderImpl _value_coder
Expand Down
31 changes: 31 additions & 0 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,37 @@ def estimate_size(self, value, nested=False):
return size


class _OrderedUnionCoderImpl(StreamCoderImpl):
def __init__(self, coder_impl_types, fallback_coder_impl):
assert len(coder_impl_types) < 128
self._types, self._coder_impls = zip(*coder_impl_types)
self._fallback_coder_impl = fallback_coder_impl

def encode_to_stream(self, value, out, nested):
value_t = type(value)
for (ix, t) in enumerate(self._types):
if value_t is t:
out.write_byte(ix)
c = self._coder_impls[ix] # for typing
c.encode_to_stream(value, out, nested)
break
else:
if self._fallback_coder_impl is None:
raise ValueError("No fallback.")
out.write_byte(0xFF)
self._fallback_coder_impl.encode_to_stream(value, out, nested)

def decode_from_stream(self, in_stream, nested):
ix = in_stream.read_byte()
if ix == 0xFF:
if self._fallback_coder_impl is None:
raise ValueError("No fallback.")
return self._fallback_coder_impl.decode_from_stream(in_stream, nested)
else:
c = self._coder_impls[ix] # for typing
return c.decode_from_stream(in_stream, nested)


class WindowedValueCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees.
Expand Down
38 changes: 37 additions & 1 deletion sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,12 +1350,48 @@ def __hash__(self):
common_urns.coders.INTERVAL_WINDOW.urn, IntervalWindowCoder)


class _OrderedUnionCoder(FastCoder):
def __init__(
self, *coder_types: Tuple[type, Coder], fallback_coder: Optional[Coder]):
self._coder_types = coder_types
self._fallback_coder = fallback_coder

def _create_impl(self):
return coder_impl._OrderedUnionCoderImpl(
[(t, c.get_impl()) for t, c in self._coder_types],
fallback_coder_impl=self._fallback_coder.get_impl()
if self._fallback_coder else None)

def is_deterministic(self) -> bool:
return (
all(c.is_deterministic for _, c in self._coder_types) and (
self._fallback_coder is None or
self._fallback_coder.is_deterministic()))

def to_type_hint(self):
return Any

def __eq__(self, other):
return (
type(self) == type(other) and
self._coder_types == other._coder_types and
self._fallback_coder == other._fallback_coder)

def __hash__(self):
return hash((type(self), tuple(self._coder_types), self._fallback_coder))


class WindowedValueCoder(FastCoder):
"""Coder for windowed values."""
def __init__(self, wrapped_value_coder, window_coder=None):
# type: (Coder, Optional[Coder]) -> None
if not window_coder:
window_coder = PickleCoder()
# Avoid circular imports.
from apache_beam.transforms import window
window_coder = _OrderedUnionCoder(
(window.GlobalWindow, GlobalWindowCoder()),
(window.IntervalWindow, IntervalWindowCoder()),
fallback_coder=PickleCoder())
self.wrapped_value_coder = wrapped_value_coder
self.timestamp_coder = TimestampCoder()
self.window_coder = window_coder
Expand Down
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/coders/coders_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,14 @@ def test_decimal_coder(self):
test_encodings[idx],
base64.b64encode(test_coder.encode(value)).decode().rstrip("="))

def test_OrderedUnionCoder(self):
test_coder = coders._OrderedUnionCoder((str, coders.StrUtf8Coder()),
(int, coders.VarIntCoder()),
fallback_coder=coders.FloatCoder())
self.check_coder(test_coder, 's')
self.check_coder(test_coder, 123)
self.check_coder(test_coder, 1.5)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit f7a7bdd

Please sign in to comment.