Skip to content

Commit

Permalink
[#32004] Ensure all pcollection coders are length prefixed if necessa…
Browse files Browse the repository at this point in the history
…ry. (#32012)

* [#32004] Ensure input collection is wrapped. Send precise PCollections.

* error out if there's an issue rewriting coders.

* Unwrap length prefix coders in element hasher.

---------

Co-authored-by: lostluck <[email protected]>
  • Loading branch information
lostluck and lostluck authored Jul 30, 2024
1 parent 88a0102 commit e7847c9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
7 changes: 7 additions & 0 deletions sdks/go/pkg/beam/core/runtime/exec/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ type elementHasher interface {
func makeElementHasher(c *coder.Coder, wc *coder.WindowCoder) elementHasher {
hasher := &maphash.Hash{}
we := MakeWindowEncoder(wc)

// Unwrap length prefix coders.
// A length prefix changes the hash itself, but shouldn't affect
// that identical elements have the same hash, so skip them here.
if c.Kind == coder.LP {
c = c.Components[0]
}
switch c.Kind {
case coder.Bytes:
return &bytesHasher{hash: hasher, we: we}
Expand Down
33 changes: 26 additions & 7 deletions sdks/go/pkg/beam/runners/prism/internal/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng

coders := map[string]*pipepb.Coder{}
transforms := map[string]*pipepb.PTransform{}
pcollections := map[string]*pipepb.PCollection{}

clonePColToBundle := func(pid string) *pipepb.PCollection {
col := proto.Clone(comps.GetPcollections()[pid]).(*pipepb.PCollection)
pcollections[pid] = col
return col
}

for _, tid := range stg.transforms {
t := comps.GetTransforms()[tid]
Expand Down Expand Up @@ -408,7 +415,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
sink2Col := map[string]string{}
col2Coders := map[string]engine.PColInfo{}
for _, o := range stg.outputs {
col := comps.GetPcollections()[o.Global]
col := clonePColToBundle(o.Global)
wOutCid, err := makeWindowedValueCoder(o.Global, comps, coders)
if err != nil {
return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for output %+v, pcol %q %v:\n%w %v", stg.ID, o, o.Global, prototext.Format(col), err, stg.transforms)
Expand All @@ -435,7 +442,8 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng

var prepareSides []func(b *worker.B, watermark mtime.Time)
for _, si := range stg.sideInputs {
col := comps.GetPcollections()[si.Global]
col := clonePColToBundle(si.Global)

oCID := col.GetCoderId()
nCID, err := lpUnknownCoders(oCID, coders, comps.GetCoders())
if err != nil {
Expand All @@ -444,7 +452,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
if oCID != nCID {
// Add a synthetic PCollection set with the new coder.
newGlobal := si.Global + "_prismside"
comps.GetPcollections()[newGlobal] = &pipepb.PCollection{
pcollections[newGlobal] = &pipepb.PCollection{
DisplayData: col.GetDisplayData(),
UniqueName: col.GetUniqueName(),
CoderId: nCID,
Expand All @@ -467,7 +475,13 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
// coders used by side inputs to the coders map for the bundle, so
// needs to be run for every ID.

col := comps.GetPcollections()[stg.primaryInput]
col := clonePColToBundle(stg.primaryInput)
if newCID, err := lpUnknownCoders(col.GetCoderId(), coders, comps.GetCoders()); err == nil && col.GetCoderId() != newCID {
col.CoderId = newCID
} else if err != nil {
return fmt.Errorf("buildDescriptor: couldn't rewrite coder %q for primary input pcollection %q: %w", col.GetCoderId(), stg.primaryInput, err)
}

wInCid, err := makeWindowedValueCoder(stg.primaryInput, comps, coders)
if err != nil {
return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for primary input, pcol %q %v:\n%w\n%v", stg.ID, stg.primaryInput, prototext.Format(col), err, stg.transforms)
Expand All @@ -491,9 +505,14 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
stg.inputTransformID = stg.ID + "_source"
transforms[stg.inputTransformID] = sourceTransform(stg.inputTransformID, portFor(wInCid, wk), stg.primaryInput)

// Add coders for internal collections.
// Update coders for internal collections, and add those collections to the bundle descriptor.
for _, pid := range stg.internalCols {
lpUnknownCoders(comps.GetPcollections()[pid].GetCoderId(), coders, comps.GetCoders())
col := clonePColToBundle(pid)
if newCID, err := lpUnknownCoders(col.GetCoderId(), coders, comps.GetCoders()); err == nil && col.GetCoderId() != newCID {
col.CoderId = newCID
} else if err != nil {
return fmt.Errorf("buildDescriptor: coder couldn't rewrite coder %q for internal pcollection %q: %w", col.GetCoderId(), pid, err)
}
}
// Add coders for all windowing strategies.
// TODO: filter PCollections, filter windowing strategies by Pcollections instead.
Expand All @@ -514,7 +533,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
Id: stg.ID,
Transforms: transforms,
WindowingStrategies: comps.GetWindowingStrategies(),
Pcollections: comps.GetPcollections(),
Pcollections: pcollections,
Coders: coders,
StateApiServiceDescriptor: &pipepb.ApiServiceDescriptor{
Url: wk.Endpoint(),
Expand Down

0 comments on commit e7847c9

Please sign in to comment.