diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py index bcfa965c0469..c5423e167026 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py @@ -48,7 +48,9 @@ from typing import overload import grpc +from sortedcontainers import SortedSet +from apache_beam import coders from apache_beam.io import filesystems from apache_beam.io.filesystems import CompressionTypes from apache_beam.portability import common_urns @@ -959,7 +961,8 @@ class StateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer, 'multimap_keys_values_side_input', 'iterable_side_input', 'bag_user_state', - 'multimap_user_state' + 'multimap_user_state', + 'ordered_list_user_state' ]) class CopyOnWriteState(object): @@ -1021,6 +1024,8 @@ def __init__(self): self._checkpoint = None # type: Optional[StateServicer.StateType] self._use_continuation_tokens = False self._continuations = {} # type: Dict[bytes, Tuple[bytes, ...]] + self._ordered_list_keys = collections.defaultdict( + SortedSet) # type: DefaultDict[bytes, SortedSet] def checkpoint(self): # type: () -> None @@ -1050,6 +1055,14 @@ def process_instruction_id(self, unused_instruction_id): # type: (Any) -> Iterator yield + def _get_one_interval_key(self, state_key, start): + # type: (beam_fn_api_pb2.StateKey, int) -> bytes + state_key_copy = beam_fn_api_pb2.StateKey() + state_key_copy.CopyFrom(state_key) + state_key_copy.ordered_list_user_state.range.start = start + state_key_copy.ordered_list_user_state.range.end = start + 1 + return self._to_key(state_key_copy) + def get_raw(self, state_key, # type: beam_fn_api_pb2.StateKey continuation_token=None # type: Optional[bytes] @@ -1061,7 +1074,30 @@ def get_raw(self, 'Unknown state type: ' + state_key.WhichOneof('type')) with self._lock: - full_state = self._state[self._to_key(state_key)] + if not continuation_token: + # Compute full_state only when no continuation token is provided. + # If there is continuation token, full_state is already in + # continuation cache. No need to recompute. + full_state = [] # type: List[bytes] + if state_key.WhichOneof('type') == 'ordered_list_user_state': + maybe_start = state_key.ordered_list_user_state.range.start + maybe_end = state_key.ordered_list_user_state.range.end + persistent_state_key = beam_fn_api_pb2.StateKey() + persistent_state_key.CopyFrom(state_key) + persistent_state_key.ordered_list_user_state.ClearField("range") + + available_keys = self._ordered_list_keys[self._to_key( + persistent_state_key)] + + for i in available_keys.irange(maybe_start, + maybe_end, + inclusive=(True, False)): + entries = self._state[self._get_one_interval_key( + persistent_state_key, i)] + full_state.extend(entries) + else: + full_state.extend(self._state[self._to_key(state_key)]) + if self._use_continuation_tokens: # The token is "nonce:index". if not continuation_token: @@ -1087,14 +1123,40 @@ def append_raw( ): # type: (...) -> _Future with self._lock: - self._state[self._to_key(state_key)].append(data) + if state_key.WhichOneof('type') == 'ordered_list_user_state': + coder = coders.TupleCoder([ + coders.VarIntCoder(), + coders.coders.LengthPrefixCoder(coders.BytesCoder()) + ]).get_impl() + + for key, value in coder.decode_all(data): + self._state[self._get_one_interval_key(state_key, key)].append( + coder.encode((key, value))) + self._ordered_list_keys[self._to_key(state_key)].add(key) + else: + self._state[self._to_key(state_key)].append(data) return _Future.done() def clear(self, state_key): # type: (beam_fn_api_pb2.StateKey) -> _Future with self._lock: try: - del self._state[self._to_key(state_key)] + if state_key.WhichOneof('type') == 'ordered_list_user_state': + start = state_key.ordered_list_user_state.range.start + end = state_key.ordered_list_user_state.range.end + persistent_state_key = beam_fn_api_pb2.StateKey() + persistent_state_key.CopyFrom(state_key) + persistent_state_key.ordered_list_user_state.ClearField("range") + available_keys = self._ordered_list_keys[self._to_key( + persistent_state_key)] + + for i in list(available_keys.irange(start, + end, + inclusive=(True, False))): + del self._state[self._get_one_interval_key(persistent_state_key, i)] + available_keys.remove(i) + else: + del self._state[self._to_key(state_key)] except KeyError: # This may happen with the caching layer across bundles. Caching may # skip this storage layer for a blocking_get(key) request. Without diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index fdb13a03bb94..0f1700f52486 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -19,16 +19,21 @@ # pytype: skip-file +from __future__ import annotations + import base64 import bisect import collections import copy +import heapq +import itertools import json import logging import random import threading from dataclasses import dataclass from dataclasses import field +from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -50,6 +55,8 @@ from google.protobuf import duration_pb2 from google.protobuf import timestamp_pb2 +from sortedcontainers import SortedDict +from sortedcontainers import SortedList import apache_beam as beam from apache_beam import coders @@ -104,7 +111,8 @@ FnApiUserRuntimeStateTypes = Union['ReadModifyWriteRuntimeState', 'CombiningValueRuntimeState', 'SynchronousSetRuntimeState', - 'SynchronousBagRuntimeState'] + 'SynchronousBagRuntimeState', + 'SynchronousOrderedListRuntimeState'] DATA_INPUT_URN = 'beam:runner:source:v1' DATA_OUTPUT_URN = 'beam:runner:sink:v1' @@ -704,6 +712,180 @@ def commit(self): to_await.get() +class RangeSet: + """For Internal Use only. A simple range set for ranges of [x,y).""" + def __init__(self) -> None: + # The start points and end points are stored separately in order. + self._sorted_starts = SortedList() + self._sorted_ends = SortedList() + + def add(self, start: int, end: int) -> None: + if start >= end: + return + + # ranges[:min_idx] and ranges[max_idx:] is unaffected by this insertion + # the first range whose end point >= the start of the new range + min_idx = self._sorted_ends.bisect_left(start) + # the first range whose start point > the end point of the new range + max_idx = self._sorted_starts.bisect_right(end) + + if min_idx >= len(self._sorted_starts) or max_idx <= 0: + # the new range is beyond any current ranges + new_start = start + new_end = end + else: + # the new range overlaps with ranges[min_idx:max_idx] + new_start = min(start, self._sorted_starts[min_idx]) + new_end = max(end, self._sorted_ends[max_idx - 1]) + + del self._sorted_starts[min_idx:max_idx] + del self._sorted_ends[min_idx:max_idx] + + self._sorted_starts.add(new_start) + self._sorted_ends.add(new_end) + + def __contains__(self, key: int) -> bool: + idx = self._sorted_starts.bisect_left(key) + return (idx < len(self._sorted_starts) and self._sorted_starts[idx] == key + ) or (idx > 0 and self._sorted_ends[idx - 1] > key) + + def __len__(self) -> int: + assert len(self._sorted_starts) == len(self._sorted_ends) + return len(self._sorted_starts) + + def __iter__(self) -> Iterator[Tuple[int, int]]: + return zip(self._sorted_starts, self._sorted_ends) + + def __str__(self) -> str: + return str(list(zip(self._sorted_starts, self._sorted_ends))) + + +class SynchronousOrderedListRuntimeState(userstate.OrderedListRuntimeState): + RANGE_MIN = -(1 << 63) + RANGE_MAX = (1 << 63) - 1 + TIMESTAMP_RANGE_MIN = timestamp.Timestamp(micros=RANGE_MIN) + TIMESTAMP_RANGE_MAX = timestamp.Timestamp(micros=RANGE_MAX) + + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: + self._state_handler = state_handler + self._state_key = state_key + self._elem_coder = beam.coders.TupleCoder( + [coders.VarIntCoder(), coders.coders.LengthPrefixCoder(value_coder)]) + self._cleared = False + self._pending_adds = SortedDict() + self._pending_removes = RangeSet() + + def add(self, elem: Tuple[timestamp.Timestamp, Any]) -> None: + assert len(elem) == 2 + key_ts, value = elem + key = key_ts.micros + + if key >= self.RANGE_MAX or key < self.RANGE_MIN: + raise ValueError("key value %d is out of range" % key) + self._pending_adds.setdefault(key, []).append(value) + + def read(self) -> Iterable[Tuple[timestamp.Timestamp, Any]]: + return self.read_range(self.TIMESTAMP_RANGE_MIN, self.TIMESTAMP_RANGE_MAX) + + def read_range( + self, + min_timestamp: timestamp.Timestamp, + limit_timestamp: timestamp.Timestamp + ) -> Iterable[Tuple[timestamp.Timestamp, Any]]: + # convert timestamp to int, as sort keys are stored as int internally. + min_key = min_timestamp.micros + limit_key = limit_timestamp.micros + + keys_to_add = self._pending_adds.irange( + min_key, limit_key, inclusive=(True, False)) + + # use list interpretation here to construct the actual list + # of iterators of the selected range. + local_items = chain.from_iterable([ + itertools.islice( + zip(itertools.cycle([ + k, + ]), self._pending_adds[k]), + len(self._pending_adds[k])) for k in keys_to_add + ]) + + if not self._cleared: + range_query_state_key = beam_fn_api_pb2.StateKey() + range_query_state_key.CopyFrom(self._state_key) + range_query_state_key.ordered_list_user_state.range.start = min_key + range_query_state_key.ordered_list_user_state.range.end = limit_key + + # make a deep copy here because there could be other operations occur in + # the middle of an iteration and change pending_removes + pending_removes_snapshot = copy.deepcopy(self._pending_removes) + persistent_items = filter( + lambda kv: kv[0] not in pending_removes_snapshot, + _StateBackedIterable( + self._state_handler, range_query_state_key, self._elem_coder)) + + return map( + lambda x: (timestamp.Timestamp(micros=x[0]), x[1]), + heapq.merge(persistent_items, local_items)) + + return map(lambda x: (timestamp.Timestamp(micros=x[0]), x[1]), local_items) + + def clear(self) -> None: + self._cleared = True + self._pending_adds = SortedDict() + self._pending_removes = RangeSet() + self._pending_removes.add(self.RANGE_MIN, self.RANGE_MAX) + + def clear_range( + self, + min_timestamp: timestamp.Timestamp, + limit_timestamp: timestamp.Timestamp) -> None: + min_key = min_timestamp.micros + limit_key = limit_timestamp.micros + + # materialize the keys to remove before the actual removal + keys_to_remove = list( + self._pending_adds.irange(min_key, limit_key, inclusive=(True, False))) + for k in keys_to_remove: + del self._pending_adds[k] + + if not self._cleared: + self._pending_removes.add(min_key, limit_key) + + def commit(self) -> None: + futures = [] + if self._pending_removes: + for start, end in self._pending_removes: + range_query_state_key = beam_fn_api_pb2.StateKey() + range_query_state_key.CopyFrom(self._state_key) + range_query_state_key.ordered_list_user_state.range.start = start + range_query_state_key.ordered_list_user_state.range.end = end + futures.append(self._state_handler.clear(range_query_state_key)) + + self._pending_removes = RangeSet() + + if self._pending_adds: + items_to_add = [] + for k in self._pending_adds: + items_to_add.extend(zip(itertools.cycle([ + k, + ]), self._pending_adds[k])) + futures.append( + self._state_handler.extend( + self._state_key, self._elem_coder.get_impl(), items_to_add)) + self._pending_adds = SortedDict() + + if len(futures): + # To commit, we need to wait on every state request futures to complete. + for to_await in futures: + to_await.get() + + self._cleared = False + + class OutputTimer(userstate.BaseTimer): def __init__(self, key, @@ -850,6 +1032,17 @@ def _create_state(self, # State keys are expected in nested encoding format key=self._key_coder.encode_nested(key))), value_coder=state_spec.coder) + elif isinstance(state_spec, userstate.OrderedListStateSpec): + return SynchronousOrderedListRuntimeState( + self._state_handler, + state_key=beam_fn_api_pb2.StateKey( + ordered_list_user_state=beam_fn_api_pb2.StateKey. + OrderedListUserState( + transform_id=self._transform_id, + user_state_id=state_spec.name, + window=self._window_coder.encode(window), + key=self._key_coder.encode_nested(key))), + value_coder=state_spec.coder) else: raise NotImplementedError(state_spec) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py index dafb4dbd4bf0..0eb4dd9485fd 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py @@ -18,24 +18,31 @@ """Unit tests for bundle processing.""" # pytype: skip-file +import random import unittest import apache_beam as beam +from apache_beam.coders import StrUtf8Coder from apache_beam.coders.coders import FastPrimitivesCoder from apache_beam.portability import common_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners import common +from apache_beam.runners.portability.fn_api_runner.worker_handlers import StateServicer from apache_beam.runners.worker import bundle_processor from apache_beam.runners.worker import operations from apache_beam.runners.worker.bundle_processor import BeamTransformFactory from apache_beam.runners.worker.bundle_processor import BundleProcessor from apache_beam.runners.worker.bundle_processor import DataInputOperation from apache_beam.runners.worker.bundle_processor import FnApiUserStateContext +from apache_beam.runners.worker.bundle_processor import SynchronousOrderedListRuntimeState from apache_beam.runners.worker.bundle_processor import TimerInfo from apache_beam.runners.worker.data_plane import SizeBasedBufferingClosableOutputStream from apache_beam.runners.worker.data_sampler import DataSampler +from apache_beam.runners.worker.sdk_worker import GlobalCachingStateHandler +from apache_beam.runners.worker.statecache import StateCache from apache_beam.transforms import userstate from apache_beam.transforms.window import GlobalWindow +from apache_beam.utils import timestamp from apache_beam.utils.windowed_value import WindowedValue @@ -422,5 +429,312 @@ def test_user_modified_sdks_need_to_be_installed_in_runtime_env(self): "beam:version:sdk_base:apache/beam_python3.5_sdk:2.1.0-custom")) +class OrderedListStateTest(unittest.TestCase): + class NoStateCache(StateCache): + def __init__(self): + super().__init__(max_weight=0) + + @staticmethod + def _create_state(window=b"my_window", key=b"my_key", coder=StrUtf8Coder()): + state_handler = GlobalCachingStateHandler( + OrderedListStateTest.NoStateCache(), StateServicer()) + state_key = beam_fn_api_pb2.StateKey( + ordered_list_user_state=beam_fn_api_pb2.StateKey.OrderedListUserState( + window=window, key=key)) + return SynchronousOrderedListRuntimeState(state_handler, state_key, coder) + + def setUp(self): + self.state = self._create_state() + + def test_read_range(self): + T0 = timestamp.Timestamp.of(0) + T1 = timestamp.Timestamp.of(1) + T2 = timestamp.Timestamp.of(2) + T3 = timestamp.Timestamp.of(3) + T4 = timestamp.Timestamp.of(4) + T5 = timestamp.Timestamp.of(5) + T9 = timestamp.Timestamp.of(9) + A1, B1, A4 = [(T1, "a1"), (T1, "b1"), (T4, "a4")] + self.assertEqual([], list(self.state.read_range(T0, T5))) + + self.state.add(A1) + self.assertEqual([A1], list(self.state.read_range(T0, T5))) + + self.state.add(B1) + self.assertEqual([A1, B1], list(self.state.read_range(T0, T5))) + + self.state.add(A4) + self.assertEqual([A1, B1, A4], list(self.state.read_range(T0, T5))) + + self.assertEqual([], list(self.state.read_range(T0, T1))) + self.assertEqual([], list(self.state.read_range(T5, T9))) + self.assertEqual([A1, B1], list(self.state.read_range(T1, T2))) + self.assertEqual([], list(self.state.read_range(T2, T3))) + self.assertEqual([], list(self.state.read_range(T2, T4))) + self.assertEqual([A4], list(self.state.read_range(T4, T5))) + + def test_read(self): + T1 = timestamp.Timestamp.of(1) + T4 = timestamp.Timestamp.of(4) + A1, B1, A4 = [(T1, "a1"), (T1, "b1"), (T4, "a4")] + self.assertEqual([], list(self.state.read())) + + self.state.add(A1) + self.assertEqual([A1], list(self.state.read())) + + self.state.add(A1) + self.assertEqual([A1, A1], list(self.state.read())) + + self.state.add(B1) + self.assertEqual([A1, A1, B1], list(self.state.read())) + + self.state.add(A4) + self.assertEqual([A1, A1, B1, A4], list(self.state.read())) + + def test_clear_range(self): + T0 = timestamp.Timestamp.of(0) + T1 = timestamp.Timestamp.of(1) + T2 = timestamp.Timestamp.of(2) + T3 = timestamp.Timestamp.of(3) + T4 = timestamp.Timestamp.of(4) + T5 = timestamp.Timestamp.of(5) + A1, B1, A4, A5 = [(T1, "a1"), (T1, "b1"), (T4, "a4"), (T5, "a5")] + self.state.clear_range(T0, T1) + self.assertEqual([], list(self.state.read())) + + self.state.add(A1) + self.state.add(B1) + self.state.add(A4) + self.state.add(A5) + self.assertEqual([A1, B1, A4, A5], list(self.state.read())) + + self.state.clear_range(T0, T1) + self.assertEqual([A1, B1, A4, A5], list(self.state.read())) + + self.state.clear_range(T1, T2) + self.assertEqual([A4, A5], list(self.state.read())) + + # no side effect on clearing the same range twice + self.state.clear_range(T1, T2) + self.assertEqual([A4, A5], list(self.state.read())) + + self.state.clear_range(T3, T4) + self.assertEqual([A4, A5], list(self.state.read())) + + self.state.clear_range(T3, T5) + self.assertEqual([A5], list(self.state.read())) + + def test_add_and_clear_range_after_commit(self): + T1 = timestamp.Timestamp.of(1) + T4 = timestamp.Timestamp.of(4) + T5 = timestamp.Timestamp.of(5) + T6 = timestamp.Timestamp.of(6) + A1, B1, C1, A4, A5, A6 = [(T1, "a1"), (T1, "b1"), (T1, "c1"), + (T4, "a4"), (T5, "a5"), (T6, "a6")] + self.state.add(A1) + self.state.add(B1) + self.state.add(A4) + self.state.add(A5) + self.state.clear_range(T4, T5) + self.assertEqual([A1, B1, A5], list(self.state.read())) + + self.state.commit() + self.assertEqual(len(self.state._pending_adds), 0) + self.assertEqual(len(self.state._pending_removes), 0) + self.assertEqual([A1, B1, A5], list(self.state.read())) + + self.state.add(C1) + self.state.add(A6) + self.assertEqual([A1, B1, C1, A5, A6], list(self.state.read())) + + self.state.clear_range(T5, T6) + self.assertEqual([A1, B1, C1, A6], list(self.state.read())) + + self.state.commit() + self.assertEqual(len(self.state._pending_adds), 0) + self.assertEqual(len(self.state._pending_removes), 0) + self.assertEqual([A1, B1, C1, A6], list(self.state.read())) + + def test_clear(self): + T1 = timestamp.Timestamp.of(1) + T4 = timestamp.Timestamp.of(4) + T5 = timestamp.Timestamp.of(5) + T9 = timestamp.Timestamp.of(9) + A1, B1, C1, A4, A5, B5 = [(T1, "a1"), (T1, "b1"), (T1, "c1"), + (T4, "a4"), (T5, "a5"), (T5, "b5")] + self.state.add(A1) + self.state.add(B1) + self.state.add(A4) + self.state.add(A5) + self.state.clear_range(T4, T5) + self.assertEqual([A1, B1, A5], list(self.state.read())) + self.state.commit() + + self.state.add(C1) + self.state.clear_range(T5, T9) + self.assertEqual([A1, B1, C1], list(self.state.read())) + self.state.clear() + self.assertEqual(len(self.state._pending_adds), 0) + self.assertEqual(len(self.state._pending_removes), 1) + + self.state.add(B5) + self.assertEqual([B5], list(self.state.read())) + self.state.commit() + + self.assertEqual(len(self.state._pending_adds), 0) + self.assertEqual(len(self.state._pending_removes), 0) + + self.assertEqual([B5], list(self.state.read())) + + def test_multiple_iterators(self): + T1 = timestamp.Timestamp.of(1) + T3 = timestamp.Timestamp.of(3) + T9 = timestamp.Timestamp.of(9) + A1, B1, A3, B3 = [(T1, "a1"), (T1, "b1"), (T3, "a3"), (T3, "b3")] + self.state.add(A1) + self.state.add(A3) + self.state.commit() + + iter_before_b1 = iter(self.state.read()) + self.assertEqual(A1, next(iter_before_b1)) + + self.state.add(B1) + self.assertEqual(A3, next(iter_before_b1)) + self.assertRaises(StopIteration, lambda: next(iter_before_b1)) + + self.state.add(B3) + iter_before_clear_range = iter(self.state.read()) + self.assertEqual(A1, next(iter_before_clear_range)) + self.state.clear_range(T3, T9) + self.assertEqual(B1, next(iter_before_clear_range)) + self.assertEqual(A3, next(iter_before_clear_range)) + self.assertEqual(B3, next(iter_before_clear_range)) + self.assertRaises(StopIteration, lambda: next(iter_before_clear_range)) + self.assertEqual([A1, B1], list(self.state.read())) + + iter_before_clear = iter(self.state.read()) + self.assertEqual(A1, next(iter_before_clear)) + self.state.clear() + self.assertEqual(B1, next(iter_before_clear)) + self.assertRaises(StopIteration, lambda: next(iter_before_clear)) + + self.assertEqual([], list(self.state.read())) + + def fuzz_test_helper(self, seed=0, lower=0, upper=20): + class NaiveState: + def __init__(self): + self._data = [[] for i in range((upper - lower + 1))] + self._logs = [] + + def add(self, elem): + k, v = elem + k = k.micros + self._data[k - lower].append(v) + self._logs.append("add(%d, %s)" % (k, v)) + + def clear_range(self, lo, hi): + lo = lo.micros + hi = hi.micros + for i in range(lo, hi): + self._data[i - lower] = [] + self._logs.append("clear_range(%d, %d)" % (lo, hi)) + + def clear(self): + for i in range(len(self._data)): + self._data[i] = [] + self._logs.append("clear()") + + def read(self): + self._logs.append("read()") + for i in range(len(self._data)): + for v in self._data[i]: + yield (timestamp.Timestamp(micros=(i + lower)), v) + + random.seed(seed) + + state = self._create_state() + bench_state = NaiveState() + + steps = random.randint(20, 50) + for i in range(steps): + op = random.randint(1, 100) + if 1 <= op < 70: + num = random.randint(lower, upper) + state.add((timestamp.Timestamp(micros=num), "a%d" % num)) + bench_state.add((timestamp.Timestamp(micros=num), "a%d" % num)) + elif 70 <= op < 95: + num1 = random.randint(lower, upper) + num2 = random.randint(lower, upper) + min_time = timestamp.Timestamp(micros=min(num1, num2)) + max_time = timestamp.Timestamp(micros=max(num1, num2)) + state.clear_range(min_time, max_time) + bench_state.clear_range(min_time, max_time) + elif op >= 95: + state.clear() + bench_state.clear() + + op = random.randint(1, 10) + if 1 <= op <= 9: + pass + else: + state.commit() + + a = list(bench_state.read()) + b = list(state.read()) + self.assertEqual( + a, + b, + "Mismatch occurred on seed=%d, step=%d, logs=%s" % + (seed, i, ';'.join(bench_state._logs))) + + def test_fuzz(self): + for _ in range(1000): + seed = random.randint(0, 0xffffffffffffffff) + try: + self.fuzz_test_helper(seed=seed) + except Exception as e: + raise RuntimeError("Exception occurred on seed=%d: %s" % (seed, e)) + + def test_min_max(self): + T_MIN = timestamp.Timestamp(micros=(-(1 << 63))) + T_MAX_MINUS_ONE = timestamp.Timestamp(micros=((1 << 63) - 2)) + T_MAX = timestamp.Timestamp(micros=((1 << 63) - 1)) + T0 = timestamp.Timestamp(micros=0) + INT64_MIN, INT64_MAX_MINUS_ONE, INT64_MAX = [(T_MIN, "min"), + (T_MAX_MINUS_ONE, "max"), + (T_MAX, "err")] + self.state.add(INT64_MIN) + self.state.add(INT64_MAX_MINUS_ONE) + self.assertRaises(ValueError, lambda: self.state.add(INT64_MAX)) + + self.assertEqual([INT64_MIN, INT64_MAX_MINUS_ONE], list(self.state.read())) + self.assertEqual([INT64_MIN], list(self.state.read_range(T_MIN, T0))) + self.assertEqual([INT64_MAX_MINUS_ONE], + list(self.state.read_range(T0, T_MAX))) + + def test_continuation_token(self): + T1 = timestamp.Timestamp.of(1) + T2 = timestamp.Timestamp.of(2) + T7 = timestamp.Timestamp.of(7) + T8 = timestamp.Timestamp.of(8) + A1, A2, A7, B7, A8 = [(T1, "a1"), (T2, "a2"), (T7, "a7"), + (T7, "b7"), (T8, "a8")] + self.state._state_handler._underlying._use_continuation_tokens = True + self.assertEqual([], list(self.state.read_range(T1, T8))) + + self.state.add(A1) + self.state.add(A2) + self.state.add(A7) + self.state.add(B7) + self.state.add(A8) + + self.assertEqual([A2, A7, B7], list(self.state.read_range(T2, T8))) + + self.state.commit() + self.assertEqual([A2, A7, B7], list(self.state.read_range(T2, T8))) + + self.assertEqual([A1, A2, A7, B7, A8], list(self.state.read())) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py index ada0b755bd6c..cad733538111 100644 --- a/sdks/python/apache_beam/transforms/userstate.py +++ b/sdks/python/apache_beam/transforms/userstate.py @@ -150,6 +150,17 @@ def to_runner_api( urn=common_urns.user_state.BAG.urn)) +class OrderedListStateSpec(StateSpec): + """Specification for a user DoFn ordered list state cell.""" + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: + return beam_runner_api_pb2.StateSpec( + ordered_list_spec=beam_runner_api_pb2.OrderedListStateSpec( + element_coder_id=context.coders.get_id(self.coder)), + protocol=beam_runner_api_pb2.FunctionSpec( + urn=common_urns.user_state.ORDERED_LIST.urn)) + + # TODO(BEAM-9562): Update Timer to have of() and clear() APIs. Timer = NamedTuple( 'Timer', @@ -372,6 +383,24 @@ class CombiningValueRuntimeState(AccumulatingRuntimeState): """Combining value state interface object passed to user code.""" +class OrderedListRuntimeState(AccumulatingRuntimeState): + """Ordered list state interface object passed to user code.""" + def read(self) -> Iterable[Tuple[Timestamp, Any]]: + raise NotImplementedError(type(self)) + + def add(self, value: Tuple[Timestamp, Any]) -> None: + raise NotImplementedError(type(self)) + + def read_range( + self, min_time_stamp: Timestamp, + limit_time_stamp: Timestamp) -> Iterable[Tuple[Timestamp, Any]]: + raise NotImplementedError(type(self)) + + def clear_range( + self, min_time_stamp: Timestamp, limit_time_stamp: Timestamp) -> None: + raise NotImplementedError(type(self)) + + class UserStateContext(object): """Wrapper allowing user state and timers to be accessed by a DoFnInvoker.""" def get_timer( diff --git a/sdks/python/setup.py b/sdks/python/setup.py index c3189e18d2c8..6eb74e9099c1 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -388,6 +388,7 @@ def get_portability_package_data(): 'redis>=5.0.0,<6', 'regex>=2020.6.8', 'requests>=2.24.0,<3.0.0', + 'sortedcontainers>=2.4.0', 'typing-extensions>=3.7.0', 'zstandard>=0.18.0,<1', # Dynamic dependencies must be specified in a separate list, otherwise