From dbe72830b11c2603e062e32295dfb2e3efedbaa9 Mon Sep 17 00:00:00 2001 From: mls3odp Date: Wed, 3 Jul 2024 11:55:16 -0700 Subject: [PATCH] Add a test for getting state with MultimapSideInput StateKey (#31757) --- .../prism/internal/worker/worker_test.go | 81 ++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go index e5b03214ae0..469e0e2f3d8 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go @@ -18,13 +18,14 @@ package worker import ( "bytes" "context" - "github.com/google/go-cmp/cmp" "net" "sort" "sync" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" @@ -386,3 +387,81 @@ func TestWorker_State_MultimapKeysSideInput(t *testing.T) { }) } } + +func TestWorker_State_MultimapSideInput(t *testing.T) { + for _, tt := range []struct { + name string + w typex.Window + }{ + { + name: "global window", + w: window.GlobalWindow{}, + }, + { + name: "interval window", + w: window.IntervalWindow{ + Start: 1000, + End: 2000, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + var encW []byte + if !tt.w.Equals(window.GlobalWindow{}) { + buf := bytes.Buffer{} + if err := exec.MakeWindowEncoder(coder.NewIntervalWindow()).EncodeSingle(tt.w, &buf); err != nil { + t.Fatalf("error encoding window: %v, err: %v", tt.w, err) + } + encW = buf.Bytes() + } + wk, stateStream, done := serveTestWorkerStateStream(t) + defer done() + instID := wk.NextInst() + wk.activeInstructions[instID] = &B{ + MultiMapSideInputData: map[SideInputKey]map[typex.Window]map[string][][]byte{ + SideInputKey{ + TransformID: "transformID", + Local: "i1", + }: { + tt.w: map[string][][]byte{"a": {{5}}, "b": {{12}}}, + }, + }, + } + var testKey = []string{"a", "b", "x"} + expectedResult := map[string][]int{ + "a": {5}, + "b": {12}, + } + for _, key := range testKey { + stateStream.Send(&fnpb.StateRequest{ + Id: "first", + InstructionId: instID, + Request: &fnpb.StateRequest_Get{ + Get: &fnpb.StateGetRequest{}, + }, + StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_MultimapSideInput_{ + MultimapSideInput: &fnpb.StateKey_MultimapSideInput{ + TransformId: "transformID", + SideInputId: "i1", + Window: encW, + Key: []byte(key), + }, + }}, + }) + + resp, err := stateStream.Recv() + if err != nil { + t.Fatal("Couldn't receive state response:", err) + } + + var got []int + for _, b := range resp.GetGet().GetData() { + got = append(got, int(b)) + } + if !cmp.Equal(got, expectedResult[key]) { + t.Errorf("For test key: %v, didn't receive expected state response data: got %v, want %v", key, got, expectedResult[key]) + } + } + }) + } +}