Skip to content

Commit

Permalink
Early fail if coder isn't "known", and plumb from protos.
Browse files Browse the repository at this point in the history
  • Loading branch information
lostluck committed Dec 11, 2024
1 parent 08392c2 commit 1d96a20
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 75 deletions.
38 changes: 2 additions & 36 deletions sdks/go/pkg/beam/runners/prism/internal/engine/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,47 +49,13 @@ type TentativeData struct {
// stateTypeLen is a map from LinkID to valueLen function for parsing data.
// Only used by OrderedListState, since Prism must manipulate these datavalues,
// which isn't expected, or a requirement of other state values.
stateTypeLen map[LinkID]valueLen
stateTypeLen map[LinkID]func([]byte) int
// state is a map from transformID + UserStateID, to window, to userKey, to datavalues.
state map[LinkID]map[typex.Window]map[string]StateData
// timers is a map from the Timer transform+family to the encoded timer.
timers map[TimerKey][][]byte
}

// valueLen is a function that extracts the length of a value in bytes from a
// byte buffer. The expectation is that it will provide an exact count of bytes
// that are required to consume the value from the buffer.
type valueLen func([]byte) int

var (
// varIntLen returns the number of bytes representing the varint.
varIntLen valueLen = func(b []byte) int {
_, n := protowire.ConsumeVarint(b)
return int(n)
}
// lenPrefiedLen returns the total length of the length prefixed
// value and the bytes of the length prefix.
// This applies to arbitrary custom coders, bytes, and string values.
lenPrefiedLen = func(b []byte) int {
l, n := protowire.ConsumeVarint(b)
return int(l) + n
}
// oneByteLen returns the number of bytes in a boolean, which is 1.
oneByteLen = func(_ []byte) int {
return 1
}
// fourByteLen returns the number of bytes in a single float, float32, or int32
// which is 4.
fourByteLen = func(_ []byte) int {
return 4
}
// eightByteLen returns the number of bytes in a double, float64, or int64
// which is 8.
eightByteLen = func(_ []byte) int {
return 8
}
)

// WriteData adds data to a given global collectionID.
func (d *TentativeData) WriteData(colID string, data []byte) {
if d.Raw == nil {
Expand Down Expand Up @@ -287,7 +253,7 @@ func (d *TentativeData) AppendOrderedListState(stateID LinkID, wKey, uKey []byte
// We need to parse out all values individually for later sorting.
for i := 0; i < len(data); {
// Get the length of the VarInt for the timestamp.
tn := varIntLen(data[i:])
_, tn := protowire.ConsumeVarint(data[i:])

// Get the length of the encoded value.
vn := typeLen(data[i+tn:])
Expand Down
58 changes: 23 additions & 35 deletions sdks/go/pkg/beam/runners/prism/internal/engine/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ func TestOrderedListState(t *testing.T) {

t.Run("bool", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]valueLen{
linkID: oneByteLen,
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(_ []byte) int {
return 1
},
},
}

Expand Down Expand Up @@ -91,35 +93,12 @@ func TestOrderedListState(t *testing.T) {
t.Errorf("OrderedList booleans, after clear\n%v", d)
}
})
t.Run("int32", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]valueLen{
linkID: fourByteLen,
},
}

d.AppendOrderedListState(linkID, wKey, uKey, cc(time5, 0, 0, 0, 1))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time4, 0, 0, 1, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time3, 0, 1, 0, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time2, 1, 0, 0, 0))
d.AppendOrderedListState(linkID, wKey, uKey, cc(time1, 1, 0, 0, 1))

got := d.GetOrderedListState(linkID, wKey, uKey, 0, 60)
want := [][]byte{
cc(time1, 1, 0, 0, 1),
cc(time2, 1, 0, 0, 0),
cc(time3, 0, 1, 0, 0),
cc(time4, 0, 0, 1, 0),
cc(time5, 0, 0, 0, 1),
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("OrderedList int32 \n%v", d)
}
})
t.Run("float64", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]valueLen{
linkID: eightByteLen,
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(_ []byte) int {
return 8
},
},
}

Expand Down Expand Up @@ -157,8 +136,11 @@ func TestOrderedListState(t *testing.T) {

t.Run("varint", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]valueLen{
linkID: varIntLen,
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(b []byte) int {
_, n := protowire.ConsumeVarint(b)
return int(n)
},
},
}

Expand All @@ -182,8 +164,11 @@ func TestOrderedListState(t *testing.T) {
})
t.Run("lp", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]valueLen{
linkID: lenPrefiedLen,
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(b []byte) int {
l, n := protowire.ConsumeVarint(b)
return int(l) + n
},
},
}

Expand All @@ -207,8 +192,11 @@ func TestOrderedListState(t *testing.T) {
})
t.Run("lp_onecall", func(t *testing.T) {
d := TentativeData{
stateTypeLen: map[LinkID]valueLen{
linkID: lenPrefiedLen,
stateTypeLen: map[LinkID]func([]byte) int{
linkID: func(b []byte) int {
l, n := protowire.ConsumeVarint(b)
return int(l) + n
},
},
}
d.AppendOrderedListState(linkID, wKey, uKey, bytes.Join([][]byte{
Expand Down
11 changes: 8 additions & 3 deletions sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,10 @@ func (em *ElementManager) StageAggregates(ID string) {

// StageStateful marks the given stage as stateful, which means elements are
// processed by key.
func (em *ElementManager) StageStateful(ID string) {
em.stages[ID].stateful = true
func (em *ElementManager) StageStateful(ID string, stateTypeLen map[LinkID]func([]byte) int) {
ss := em.stages[ID]
ss.stateful = true
ss.stateTypeLen = stateTypeLen
}

// StageProcessingTimeTimers indicates which timers are processingTime domain timers.
Expand Down Expand Up @@ -635,7 +637,9 @@ func (em *ElementManager) StateForBundle(rb RunBundle) TentativeData {
ss := em.stages[rb.StageID]
ss.mu.Lock()
defer ss.mu.Unlock()
var ret TentativeData
ret := TentativeData{
stateTypeLen: ss.stateTypeLen,
}
keys := ss.inprogressKeysByBundle[rb.BundleID]
// TODO(lostluck): Also track windows per bundle, to reduce copying.
if len(ss.state) > 0 {
Expand Down Expand Up @@ -1083,6 +1087,7 @@ type stageState struct {
inprogressKeys set[string] // all keys that are assigned to bundles.
inprogressKeysByBundle map[string]set[string] // bundle to key assignments.
state map[LinkID]map[typex.Window]map[string]StateData // state data for this stage, from {tid, stateID} -> window -> userKey
stateTypeLen map[LinkID]func([]byte) int // map from state to a function that will produce the total length of a single value in bytes.

// Accounting for handling watermark holds for timers.
// We track the count of timers with the same hold, and clear it from
Expand Down
2 changes: 1 addition & 1 deletion sdks/go/pkg/beam/runners/prism/internal/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic
sort.Strings(outputs)
em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs)
if stage.stateful {
em.StageStateful(stage.ID)
em.StageStateful(stage.ID, stage.stateTypeLen)
}
if len(stage.processingTimeTimers) > 0 {
em.StageProcessingTimeTimers(stage.ID, stage.processingTimeTimers)
Expand Down
38 changes: 38 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -66,11 +67,16 @@ type stage struct {
envID string
finalize bool
stateful bool

// hasTimers indicates the transform+timerfamily pairs that need to be waited on for
// the stage to be considered complete.
hasTimers []struct{ Transform, TimerFamily string }
processingTimeTimers map[string]bool

// stateTypeLen maps state values to encoded lengths for the type.
// Only used for OrderedListState which must manipulate individual state datavalues.
stateTypeLen map[engine.LinkID]func([]byte) int

exe transformExecuter
inputTransformID string
inputInfo engine.PColInfo
Expand Down Expand Up @@ -436,6 +442,38 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
rewriteCoder(&s.SetSpec.ElementCoderId)
case *pipepb.StateSpec_OrderedListSpec:
rewriteCoder(&s.OrderedListSpec.ElementCoderId)
// Add the length determination helper for OrderedList state values.
if stg.stateTypeLen == nil {
stg.stateTypeLen = map[engine.LinkID]func([]byte) int{}
}
linkID := engine.LinkID{
Transform: tid,
Local: stateID,
}
var fn func([]byte) int
switch v := coders[s.OrderedListSpec.GetElementCoderId()]; v.GetSpec().GetUrn() {
case urns.CoderBool:
fn = func(_ []byte) int {
return 1
}
case urns.CoderDouble:
fn = func(_ []byte) int {
return 8
}
case urns.CoderVarInt:
fn = func(b []byte) int {
_, n := protowire.ConsumeVarint(b)
return int(n)
}
case urns.CoderLengthPrefix, urns.CoderBytes, urns.CoderStringUTF8:
fn = func(b []byte) int {
l, n := protowire.ConsumeVarint(b)
return int(l) + n
}
default:
rewriteErr = fmt.Errorf("unknown coder used for ordered list state after re-write id: %v coder: %v, for state %v for transform %v in stage %v", s.OrderedListSpec.GetElementCoderId(), v, stateID, tid, stg.ID)
}
stg.stateTypeLen[linkID] = fn
case *pipepb.StateSpec_CombiningSpec:
rewriteCoder(&s.CombiningSpec.AccumulatorCoderId)
case *pipepb.StateSpec_MapSpec:
Expand Down

0 comments on commit 1d96a20

Please sign in to comment.