Skip to content

Commit

Permalink
Use child dir for branch taken
Browse files Browse the repository at this point in the history
We reuse the same data directory for the branch node and the branch taken (child). This change adds an additional set of directories to decouple these. It makes sure to copy the child output up to the branch node output like we do in subworkflow

Example
```
from dataclasses import dataclass
from flytekit import conditional, task, workflow

@DataClass
class MyData:
    foo: str

@workflow
def root_wf(data: MyData) -> str:
    return sub_wf(data=data)

@workflow
def sub_wf(data: MyData) -> str:
    check = always_true()
    return conditional("decision").if_(check.is_true()).then(conditional_wf(data=data)).else_().fail("not done")

@task
def always_true() -> bool:
    return True

@workflow
def conditional_wf(data: MyData) -> str:
    return done(data)

@task
def done(data: MyData) -> str:
    return f"done ({data.foo})"
```

- [x] Add unittests
- [x] Run locally using sandbox and verify behavior with working example. Inspect paths

Signed-off-by: Andrew Dye <[email protected]>
  • Loading branch information
andrewwdye committed Dec 20, 2024
1 parent 27b2f3a commit 97e17ab
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 24 deletions.
31 changes: 22 additions & 9 deletions flytepropeller/pkg/controller/nodes/branch/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package branch
import (
"context"
"fmt"
"strconv"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
Expand All @@ -15,6 +16,7 @@ import (
stdErrors "github.com/flyteorg/flyte/flytestdlib/errors"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/flyteorg/flyte/flytestdlib/storage"
)

type metrics struct {
Expand Down Expand Up @@ -74,8 +76,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha
childNodeStatus.SetParentNodeID(&i)

logger.Debugf(ctx, "Recursively executing branchNode's chosen path")
nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID())
return b.recurseDownstream(ctx, nCtx, nodeStatus, finalNode)
return b.recurseDownstream(ctx, nCtx, finalNode)

Check warning on line 79 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L79

Added line #L79 was not covered by tests
}

// If the branchNodestatus was already evaluated i.e, Node is in Running status
Expand All @@ -99,8 +100,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha
}

// Recurse downstream
nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID())
return b.recurseDownstream(ctx, nCtx, nodeStatus, branchTakenNode)
return b.recurseDownstream(ctx, nCtx, branchTakenNode)

Check warning on line 103 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L103

Added line #L103 was not covered by tests
}

func (b *branchHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) {
Expand All @@ -123,7 +123,7 @@ func (b *branchHandler) getExecutionContextForDownstream(nCtx interfaces.NodeExe
return executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), nil
}

func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) {
func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) {
// TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time
// There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc
// The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node.
Expand All @@ -134,8 +134,16 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.N
}

childNodeStatus := nl.GetNodeExecutionStatus(ctx, branchTakenNode.GetID())
childNodeStatus.SetDataDir(nodeStatus.GetDataDir())
childNodeStatus.SetOutputDir(nodeStatus.GetOutputDir())
childDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), branchTakenNode.GetID())
if err != nil {
return handler.UnknownTransition, err
}

Check warning on line 140 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L139-L140

Added lines #L139 - L140 were not covered by tests
childOutputDir, err := nCtx.DataStore().ConstructReference(ctx, childDataDir, strconv.Itoa(int(childNodeStatus.GetAttempts())))
if err != nil {
return handler.UnknownTransition, err
}

Check warning on line 144 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L143-L144

Added lines #L143 - L144 were not covered by tests
childNodeStatus.SetDataDir(childDataDir)
childNodeStatus.SetOutputDir(childOutputDir)
upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID())
if err != nil {
return handler.UnknownTransition, err
Expand All @@ -151,9 +159,14 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.N
}

if downstreamStatus.IsComplete() {
// For branch node we set the output node to be the same as the child nodes output
childOutputsPath := v1alpha1.GetOutputsFile(childOutputDir)
outputsPath := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir())
if err := nCtx.DataStore().CopyRaw(ctx, childOutputsPath, outputsPath, storage.Options{}); err != nil {
errMsg := fmt.Sprintf("Failed to copy child node outputs from [%v] to [%v]", childOutputsPath, outputsPath)
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.OutputsNotFoundError, errMsg, nil)), nil
}

Check warning on line 167 in flytepropeller/pkg/controller/nodes/branch/handler.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/branch/handler.go#L165-L167

Added lines #L165 - L167 were not covered by tests
phase := handler.PhaseInfoSuccess(&handler.ExecutionInfo{
OutputInfo: &handler.OutputInfo{OutputURI: v1alpha1.GetOutputsFile(childNodeStatus.GetOutputDir())},
OutputInfo: &handler.OutputInfo{OutputURI: outputsPath},
})

return handler.DoTransition(handler.TransitionTypeEphemeral, phase), nil
Expand Down
31 changes: 16 additions & 15 deletions flytepropeller/pkg/controller/nodes/branch/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package branch

import (
"bytes"
"context"
"fmt"
"testing"
Expand Down Expand Up @@ -110,6 +111,7 @@ func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.Nod

ns := &mocks2.ExecutableNodeStatus{}
ns.OnGetDataDir().Return(storage.DataReference("data-dir"))
ns.OnGetOutputDir().Return(storage.DataReference("output-dir"))
ns.OnGetPhase().Return(v1alpha1.NodePhaseNotYetStarted)

ir := &mocks3.InputReader{}
Expand Down Expand Up @@ -162,25 +164,24 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
name string
ns interfaces.NodeStatus
err error
nodeStatus *mocks2.ExecutableNodeStatus
branchTakenNode v1alpha1.ExecutableNode
isErr bool
expectedPhase handler.EPhase
childPhase v1alpha1.NodePhase
upstreamNodeID string
}{
{"upstreamNodeExists", interfaces.NodeStatusPending, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"},
bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"},
{"childNodeError", interfaces.NodeStatusUndefined, fmt.Errorf("err"),
&mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""},
bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""},
{"childPending", interfaces.NodeStatusPending, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""},
bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""},
{"childStillRunning", interfaces.NodeStatusRunning, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""},
bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""},
{"childFailure", interfaces.NodeStatusFailed(expectedError), nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""},
bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""},
{"childComplete", interfaces.NodeStatusComplete, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""},
bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand Down Expand Up @@ -222,16 +223,16 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
).Return(test.ns, test.err)

childNodeStatus := &mocks2.ExecutableNodeStatus{}
if mockNodeLookup != nil {
childNodeStatus.OnGetOutputDir().Return("parent-output-dir")
test.nodeStatus.OnGetDataDir().Return("parent-data-dir")
test.nodeStatus.OnGetOutputDir().Return("parent-output-dir")
mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus)
childNodeStatus.On("SetDataDir", storage.DataReference("parent-data-dir")).Once()
childNodeStatus.On("SetOutputDir", storage.DataReference("parent-output-dir")).Once()
childNodeStatus.OnGetAttempts().Return(0)
childNodeStatus.On("SetDataDir", storage.DataReference("/output-dir/child")).Once()
childNodeStatus.On("SetOutputDir", storage.DataReference("/output-dir/child/0")).Once()
mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus)
if test.childPhase == v1alpha1.NodePhaseSucceeded {
_ = nCtx.DataStore().WriteRaw(ctx, storage.DataReference("/output-dir/child/0/outputs.pb"), 0, storage.Options{}, bytes.NewReader([]byte{}))
}

branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()).(*branchHandler)
h, err := branch.recurseDownstream(ctx, nCtx, test.nodeStatus, test.branchTakenNode)
h, err := branch.recurseDownstream(ctx, nCtx, test.branchTakenNode)
if test.isErr {
assert.Error(t, err)
} else {
Expand Down

0 comments on commit 97e17ab

Please sign in to comment.