From 8e1e12453baa8ffa564922496994446c9e41003c Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Tue, 17 Dec 2024 10:45:33 -0800 Subject: [PATCH] [#32929] Add OrderedListState support to Prism. (#33350) --- CHANGES.md | 1 + runners/prism/java/build.gradle | 4 - .../runners/prism/internal/engine/data.go | 97 ++++++++ .../prism/internal/engine/data_test.go | 222 ++++++++++++++++++ .../prism/internal/engine/elementmanager.go | 11 +- .../beam/runners/prism/internal/execute.go | 2 +- .../prism/internal/jobservices/management.go | 3 +- .../pkg/beam/runners/prism/internal/stage.go | 37 +++ .../beam/runners/prism/internal/urns/urns.go | 5 +- .../runners/prism/internal/worker/worker.go | 14 ++ 10 files changed, 385 insertions(+), 11 deletions(-) create mode 100644 sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go diff --git a/CHANGES.md b/CHANGES.md index 7a8ed493c216..deaa8bfcd471 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -74,6 +74,7 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Support OnWindowExpiration in Prism ([#32211](https://github.com/apache/beam/issues/32211)). * This enables initial Java GroupIntoBatches support. +* Support OrderedListState in Prism ([#32929](https://github.com/apache/beam/issues/32929)). ## Breaking Changes diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle index ce71151099bd..cd2e90fde67c 100644 --- a/runners/prism/java/build.gradle +++ b/runners/prism/java/build.gradle @@ -233,10 +233,6 @@ def createPrismValidatesRunnerTask = { name, environmentType -> excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' - // Not yet implemented in Prism - // https://github.com/apache/beam/issues/32929 - excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState' - // Not supported in Portable Java SDK yet. // https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go index 7b8689f95112..380b6e2f31d1 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -17,13 +17,17 @@ package engine import ( "bytes" + "cmp" "fmt" "log/slog" + "slices" + "sort" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "google.golang.org/protobuf/encoding/protowire" ) // StateData is a "union" between Bag state and MultiMap state to increase common code. @@ -42,6 +46,10 @@ type TimerKey struct { type TentativeData struct { Raw map[string][][]byte + // stateTypeLen is a map from LinkID to valueLen function for parsing data. + // Only used by OrderedListState, since Prism must manipulate these datavalues, + // which isn't expected, or a requirement of other state values. + stateTypeLen map[LinkID]func([]byte) int // state is a map from transformID + UserStateID, to window, to userKey, to datavalues. state map[LinkID]map[typex.Window]map[string]StateData // timers is a map from the Timer transform+family to the encoded timer. @@ -220,3 +228,92 @@ func (d *TentativeData) ClearMultimapKeysState(stateID LinkID, wKey, uKey []byte kmap[string(uKey)] = StateData{} slog.Debug("State() MultimapKeys.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey)) } + +// AppendOrderedListState appends the incoming timestamped data to the existing tentative data bundle. +// Assumes the data is TimestampedValue encoded, which has a BigEndian int64 suffixed to the data. +// This means we may always use the last 8 bytes to determine the value sorting. +// +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) AppendOrderedListState(stateID LinkID, wKey, uKey []byte, data []byte) { + kmap := d.appendState(stateID, wKey) + typeLen := d.stateTypeLen[stateID] + var datums [][]byte + + // We need to parse out all values individually for later sorting. + // + // OrderedListState is encoded as KVs with varint encoded millis followed by the value. + // This is not the standard TimestampValueCoder encoding, which + // uses a big-endian long as a suffix to the value. This is important since + // values may be concatenated, and we'll need to split them out out. + // + // The TentativeData.stateTypeLen is populated with a function to extract + // the length of a the next value, so we can skip through elements individually. + for i := 0; i < len(data); { + // Get the length of the VarInt for the timestamp. + _, tn := protowire.ConsumeVarint(data[i:]) + + // Get the length of the encoded value. + vn := typeLen(data[i+tn:]) + prev := i + i += tn + vn + datums = append(datums, data[prev:i]) + } + + s := StateData{Bag: append(kmap[string(uKey)].Bag, datums...)} + sort.SliceStable(s.Bag, func(i, j int) bool { + vi := s.Bag[i] + vj := s.Bag[j] + return compareTimestampSuffixes(vi, vj) + }) + kmap[string(uKey)] = s + slog.Debug("State() OrderedList.Append", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Any("NewData", s)) +} + +func compareTimestampSuffixes(vi, vj []byte) bool { + ims, _ := protowire.ConsumeVarint(vi) + jms, _ := protowire.ConsumeVarint(vj) + return (int64(ims)) < (int64(jms)) +} + +// GetOrderedListState available state from the tentative bundle data. +// The stateID has the Transform and Local fields populated, for the Transform and UserStateID respectively. +func (d *TentativeData) GetOrderedListState(stateID LinkID, wKey, uKey []byte, start, end int64) [][]byte { + winMap := d.state[stateID] + w := d.toWindow(wKey) + data := winMap[w][string(uKey)] + + lo, hi := findRange(data.Bag, start, end) + slog.Debug("State() OrderedList.Get", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", slog.Int64("start", start), slog.Int64("end", end)), slog.Group("outrange", slog.Int("lo", lo), slog.Int("hi", hi)), slog.Any("Data", data.Bag[lo:hi])) + return data.Bag[lo:hi] +} + +func cmpSuffix(vs [][]byte, target int64) func(i int) int { + return func(i int) int { + v := vs[i] + ims, _ := protowire.ConsumeVarint(v) + tvsbi := cmp.Compare(target, int64(ims)) + slog.Debug("cmpSuffix", "target", target, "bi", ims, "tvsbi", tvsbi) + return tvsbi + } +} + +func findRange(bag [][]byte, start, end int64) (int, int) { + lo, _ := sort.Find(len(bag), cmpSuffix(bag, start)) + hi, _ := sort.Find(len(bag), cmpSuffix(bag, end)) + return lo, hi +} + +func (d *TentativeData) ClearOrderedListState(stateID LinkID, wKey, uKey []byte, start, end int64) { + winMap := d.state[stateID] + w := d.toWindow(wKey) + kMap := winMap[w] + data := kMap[string(uKey)] + + lo, hi := findRange(data.Bag, start, end) + slog.Debug("State() OrderedList.Clear", slog.Any("StateID", stateID), slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Group("range", slog.Int64("start", start), slog.Int64("end", end)), "lo", lo, "hi", hi, slog.Any("PreClearData", data.Bag)) + + cleared := slices.Delete(data.Bag, lo, hi) + // Zero the current entry to clear. + // Delete makes it difficult to delete the persisted stage state for the key. + kMap[string(uKey)] = StateData{Bag: cleared} +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go new file mode 100644 index 000000000000..1d0497104182 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package engine + +import ( + "bytes" + "encoding/binary" + "math" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/encoding/protowire" +) + +func TestCompareTimestampSuffixes(t *testing.T) { + t.Run("simple", func(t *testing.T) { + loI := int64(math.MinInt64) + hiI := int64(math.MaxInt64) + + loB := binary.BigEndian.AppendUint64(nil, uint64(loI)) + hiB := binary.BigEndian.AppendUint64(nil, uint64(hiI)) + + if compareTimestampSuffixes(loB, hiB) != (loI < hiI) { + t.Errorf("lo vs Hi%v < %v: bytes %v vs %v, %v %v", loI, hiI, loB, hiB, loI < hiI, compareTimestampSuffixes(loB, hiB)) + } + }) +} + +func TestOrderedListState(t *testing.T) { + time1 := protowire.AppendVarint(nil, 11) + time2 := protowire.AppendVarint(nil, 22) + time3 := protowire.AppendVarint(nil, 33) + time4 := protowire.AppendVarint(nil, 44) + time5 := protowire.AppendVarint(nil, 55) + + wKey := []byte{} // global window. + uKey := []byte("\u0007userkey") + linkID := LinkID{ + Transform: "dofn", + Local: "localStateName", + } + cc := func(a []byte, b ...byte) []byte { + return bytes.Join([][]byte{a, b}, []byte{}) + } + + t.Run("bool", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(_ []byte) int { + return 1 + }, + }, + } + + d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 1)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 1)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 1)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, 1), + cc(time2, 0), + cc(time3, 1), + cc(time4, 0), + cc(time5, 1), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList booleans \n%v", d) + } + + d.ClearOrderedListState(linkID, wKey, uKey, 12, 54) + got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want = [][]byte{ + cc(time1, 1), + cc(time5, 1), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList booleans, after clear\n%v", d) + } + }) + t.Run("float64", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(_ []byte) int { + return 8 + }, + }, + } + + d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 0, 0, 0, 0, 0, 0, 0, 1)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 0, 0, 0, 0, 0, 0, 1, 0)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 0, 0, 0, 0, 0, 1, 0, 0)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 0, 0, 0, 0, 1, 0, 0, 0)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0, 0, 0, 1, 0, 0, 0, 0)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, 0, 0, 0, 0, 0, 0, 1, 0), + cc(time2, 0, 0, 0, 0, 1, 0, 0, 0), + cc(time3, 0, 0, 0, 0, 0, 1, 0, 0), + cc(time4, 0, 0, 0, 1, 0, 0, 0, 0), + cc(time5, 0, 0, 0, 0, 0, 0, 0, 1), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList float64s \n%v", d) + } + + d.ClearOrderedListState(linkID, wKey, uKey, 11, 12) + d.ClearOrderedListState(linkID, wKey, uKey, 33, 34) + d.ClearOrderedListState(linkID, wKey, uKey, 55, 56) + + got = d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want = [][]byte{ + cc(time2, 0, 0, 0, 0, 1, 0, 0, 0), + cc(time4, 0, 0, 0, 1, 0, 0, 0, 0), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList float64s, after clear \n%v", d) + } + }) + + t.Run("varint", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(b []byte) int { + _, n := protowire.ConsumeVarint(b) + return int(n) + }, + }, + } + + d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, protowire.AppendVarint(nil, 56)...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, protowire.AppendVarint(nil, 20067)...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, protowire.AppendVarint(nil, 7777777)...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, protowire.AppendVarint(nil, 424242)...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, protowire.AppendVarint(nil, 0)...)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, protowire.AppendVarint(nil, 424242)...), + cc(time2, protowire.AppendVarint(nil, 56)...), + cc(time3, protowire.AppendVarint(nil, 7777777)...), + cc(time4, protowire.AppendVarint(nil, 20067)...), + cc(time5, protowire.AppendVarint(nil, 0)...), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList int32 \n%v", d) + } + }) + t.Run("lp", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(b []byte) int { + l, n := protowire.ConsumeVarint(b) + return int(l) + n + }, + }, + } + + d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, []byte("\u0003one")...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, []byte("\u0003two")...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, []byte("\u0005three")...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, []byte("\u0004four")...)) + d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, []byte("\u0003one")...), + cc(time2, []byte("\u0003two")...), + cc(time3, []byte("\u0005three")...), + cc(time4, []byte("\u0004four")...), + cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList int32 \n%v", d) + } + }) + t.Run("lp_onecall", func(t *testing.T) { + d := TentativeData{ + stateTypeLen: map[LinkID]func([]byte) int{ + linkID: func(b []byte) int { + l, n := protowire.ConsumeVarint(b) + return int(l) + n + }, + }, + } + d.AppendOrderedListState(linkID, wKey, uKey, bytes.Join([][]byte{ + time5, []byte("\u0019FourHundredAndEleventyTwo"), + time3, []byte("\u0005three"), + time2, []byte("\u0003two"), + time1, []byte("\u0003one"), + time4, []byte("\u0004four"), + }, nil)) + + got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60) + want := [][]byte{ + cc(time1, []byte("\u0003one")...), + cc(time2, []byte("\u0003two")...), + cc(time3, []byte("\u0005three")...), + cc(time4, []byte("\u0004four")...), + cc(time5, []byte("\u0019FourHundredAndEleventyTwo")...), + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("OrderedList int32 \n%v", d) + } + }) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index 00e18c669afa..7180bb456f1a 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -269,8 +269,10 @@ func (em *ElementManager) StageAggregates(ID string) { // StageStateful marks the given stage as stateful, which means elements are // processed by key. -func (em *ElementManager) StageStateful(ID string) { - em.stages[ID].stateful = true +func (em *ElementManager) StageStateful(ID string, stateTypeLen map[LinkID]func([]byte) int) { + ss := em.stages[ID] + ss.stateful = true + ss.stateTypeLen = stateTypeLen } // StageOnWindowExpiration marks the given stage as stateful, which means elements are @@ -669,7 +671,9 @@ func (em *ElementManager) StateForBundle(rb RunBundle) TentativeData { ss := em.stages[rb.StageID] ss.mu.Lock() defer ss.mu.Unlock() - var ret TentativeData + ret := TentativeData{ + stateTypeLen: ss.stateTypeLen, + } keys := ss.inprogressKeysByBundle[rb.BundleID] // TODO(lostluck): Also track windows per bundle, to reduce copying. if len(ss.state) > 0 { @@ -1136,6 +1140,7 @@ type stageState struct { inprogressKeys set[string] // all keys that are assigned to bundles. inprogressKeysByBundle map[string]set[string] // bundle to key assignments. state map[LinkID]map[typex.Window]map[string]StateData // state data for this stage, from {tid, stateID} -> window -> userKey + stateTypeLen map[LinkID]func([]byte) int // map from state to a function that will produce the total length of a single value in bytes. // Accounting for handling watermark holds for timers. // We track the count of timers with the same hold, and clear it from diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index fde62f00c7c1..8b56c30eb61b 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -316,7 +316,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic sort.Strings(outputs) em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs) if stage.stateful { - em.StageStateful(stage.ID) + em.StageStateful(stage.ID, stage.stateTypeLen) } if stage.onWindowExpiration.TimerFamily != "" { slog.Debug("OnWindowExpiration", slog.String("stage", stage.ID), slog.Any("values", stage.onWindowExpiration)) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index 894a6e1427a2..af559a92ab46 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -174,7 +174,8 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * // Validate all the state features for _, spec := range pardo.GetStateSpecs() { isStateful = true - check("StateSpec.Protocol.Urn", spec.GetProtocol().GetUrn(), urns.UserStateBag, urns.UserStateMultiMap) + check("StateSpec.Protocol.Urn", spec.GetProtocol().GetUrn(), + urns.UserStateBag, urns.UserStateMultiMap, urns.UserStateOrderedList) } // Validate all the timer features for _, spec := range pardo.GetTimerFamilySpecs() { diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index 9dd6cbdafec8..e1e942a06f0c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -35,6 +35,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" ) @@ -73,6 +74,10 @@ type stage struct { hasTimers []engine.StaticTimerID processingTimeTimers map[string]bool + // stateTypeLen maps state values to encoded lengths for the type. + // Only used for OrderedListState which must manipulate individual state datavalues. + stateTypeLen map[engine.LinkID]func([]byte) int + exe transformExecuter inputTransformID string inputInfo engine.PColInfo @@ -438,6 +443,38 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng rewriteCoder(&s.SetSpec.ElementCoderId) case *pipepb.StateSpec_OrderedListSpec: rewriteCoder(&s.OrderedListSpec.ElementCoderId) + // Add the length determination helper for OrderedList state values. + if stg.stateTypeLen == nil { + stg.stateTypeLen = map[engine.LinkID]func([]byte) int{} + } + linkID := engine.LinkID{ + Transform: tid, + Local: stateID, + } + var fn func([]byte) int + switch v := coders[s.OrderedListSpec.GetElementCoderId()]; v.GetSpec().GetUrn() { + case urns.CoderBool: + fn = func(_ []byte) int { + return 1 + } + case urns.CoderDouble: + fn = func(_ []byte) int { + return 8 + } + case urns.CoderVarInt: + fn = func(b []byte) int { + _, n := protowire.ConsumeVarint(b) + return int(n) + } + case urns.CoderLengthPrefix, urns.CoderBytes, urns.CoderStringUTF8: + fn = func(b []byte) int { + l, n := protowire.ConsumeVarint(b) + return int(l) + n + } + default: + rewriteErr = fmt.Errorf("unknown coder used for ordered list state after re-write id: %v coder: %v, for state %v for transform %v in stage %v", s.OrderedListSpec.GetElementCoderId(), v, stateID, tid, stg.ID) + } + stg.stateTypeLen[linkID] = fn case *pipepb.StateSpec_CombiningSpec: rewriteCoder(&s.CombiningSpec.AccumulatorCoderId) case *pipepb.StateSpec_MapSpec: diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go index 5312fd799c89..12e62ef84a81 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go +++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go @@ -95,8 +95,9 @@ var ( SideInputMultiMap = siUrn(pipepb.StandardSideInputTypes_MULTIMAP) // UserState kinds - UserStateBag = usUrn(pipepb.StandardUserStateTypes_BAG) - UserStateMultiMap = usUrn(pipepb.StandardUserStateTypes_MULTIMAP) + UserStateBag = usUrn(pipepb.StandardUserStateTypes_BAG) + UserStateMultiMap = usUrn(pipepb.StandardUserStateTypes_MULTIMAP) + UserStateOrderedList = usUrn(pipepb.StandardUserStateTypes_ORDERED_LIST) // WindowsFns WindowFnGlobal = quickUrn(pipepb.GlobalWindowsPayload_PROPERTIES) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index c2c988aa097f..9d9058975b26 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -554,6 +554,11 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case *fnpb.StateKey_MultimapKeysUserState_: mmkey := key.GetMultimapKeysUserState() data = b.OutputData.GetMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey()) + case *fnpb.StateKey_OrderedListUserState_: + olkey := key.GetOrderedListUserState() + data = b.OutputData.GetOrderedListState( + engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()}, + olkey.GetWindow(), olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd()) default: panic(fmt.Sprintf("unsupported StateKey Get type: %T: %v", key.GetType(), prototext.Format(key))) } @@ -578,6 +583,11 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case *fnpb.StateKey_MultimapUserState_: mmkey := key.GetMultimapUserState() b.OutputData.AppendMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey(), req.GetAppend().GetData()) + case *fnpb.StateKey_OrderedListUserState_: + olkey := key.GetOrderedListUserState() + b.OutputData.AppendOrderedListState( + engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()}, + olkey.GetWindow(), olkey.GetKey(), req.GetAppend().GetData()) default: panic(fmt.Sprintf("unsupported StateKey Append type: %T: %v", key.GetType(), prototext.Format(key))) } @@ -601,6 +611,10 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case *fnpb.StateKey_MultimapKeysUserState_: mmkey := key.GetMultimapUserState() b.OutputData.ClearMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey()) + case *fnpb.StateKey_OrderedListUserState_: + olkey := key.GetOrderedListUserState() + b.OutputData.ClearOrderedListState(engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()}, + olkey.GetWindow(), olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd()) default: panic(fmt.Sprintf("unsupported StateKey Clear type: %T: %v", key.GetType(), prototext.Format(key))) }