From 97e17abf65e68789c2eee15e313b1cd3bc2cb401 Mon Sep 17 00:00:00 2001 From: Andrew Dye Date: Thu, 19 Dec 2024 08:00:03 -0800 Subject: [PATCH] Use child dir for branch taken We reuse the same data directory for the branch node and the branch taken (child). This change adds an additional set of directories to decouple these. It makes sure to copy the child output up to the branch node output like we do in subworkflow Example ``` from dataclasses import dataclass from flytekit import conditional, task, workflow @dataclass class MyData: foo: str @workflow def root_wf(data: MyData) -> str: return sub_wf(data=data) @workflow def sub_wf(data: MyData) -> str: check = always_true() return conditional("decision").if_(check.is_true()).then(conditional_wf(data=data)).else_().fail("not done") @task def always_true() -> bool: return True @workflow def conditional_wf(data: MyData) -> str: return done(data) @task def done(data: MyData) -> str: return f"done ({data.foo})" ``` - [x] Add unittests - [x] Run locally using sandbox and verify behavior with working example. Inspect paths Signed-off-by: Andrew Dye --- .../pkg/controller/nodes/branch/handler.go | 31 +++++++++++++------ .../controller/nodes/branch/handler_test.go | 31 ++++++++++--------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/branch/handler.go b/flytepropeller/pkg/controller/nodes/branch/handler.go index 9789b65c22..a869569680 100644 --- a/flytepropeller/pkg/controller/nodes/branch/handler.go +++ b/flytepropeller/pkg/controller/nodes/branch/handler.go @@ -3,6 +3,7 @@ package branch import ( "context" "fmt" + "strconv" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -15,6 +16,7 @@ import ( stdErrors "github.com/flyteorg/flyte/flytestdlib/errors" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" + "github.com/flyteorg/flyte/flytestdlib/storage" ) type metrics struct { @@ -74,8 +76,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha childNodeStatus.SetParentNodeID(&i) logger.Debugf(ctx, "Recursively executing branchNode's chosen path") - nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID()) - return b.recurseDownstream(ctx, nCtx, nodeStatus, finalNode) + return b.recurseDownstream(ctx, nCtx, finalNode) } // If the branchNodestatus was already evaluated i.e, Node is in Running status @@ -99,8 +100,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha } // Recurse downstream - nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID()) - return b.recurseDownstream(ctx, nCtx, nodeStatus, branchTakenNode) + return b.recurseDownstream(ctx, nCtx, branchTakenNode) } func (b *branchHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { @@ -123,7 +123,7 @@ func (b *branchHandler) getExecutionContextForDownstream(nCtx interfaces.NodeExe return executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), nil } -func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) { +func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) { // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. @@ -134,8 +134,16 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.N } childNodeStatus := nl.GetNodeExecutionStatus(ctx, branchTakenNode.GetID()) - childNodeStatus.SetDataDir(nodeStatus.GetDataDir()) - childNodeStatus.SetOutputDir(nodeStatus.GetOutputDir()) + childDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), branchTakenNode.GetID()) + if err != nil { + return handler.UnknownTransition, err + } + childOutputDir, err := nCtx.DataStore().ConstructReference(ctx, childDataDir, strconv.Itoa(int(childNodeStatus.GetAttempts()))) + if err != nil { + return handler.UnknownTransition, err + } + childNodeStatus.SetDataDir(childDataDir) + childNodeStatus.SetOutputDir(childOutputDir) upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID()) if err != nil { return handler.UnknownTransition, err @@ -151,9 +159,14 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.N } if downstreamStatus.IsComplete() { - // For branch node we set the output node to be the same as the child nodes output + childOutputsPath := v1alpha1.GetOutputsFile(childOutputDir) + outputsPath := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + if err := nCtx.DataStore().CopyRaw(ctx, childOutputsPath, outputsPath, storage.Options{}); err != nil { + errMsg := fmt.Sprintf("Failed to copy child node outputs from [%v] to [%v]", childOutputsPath, outputsPath) + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.OutputsNotFoundError, errMsg, nil)), nil + } phase := handler.PhaseInfoSuccess(&handler.ExecutionInfo{ - OutputInfo: &handler.OutputInfo{OutputURI: v1alpha1.GetOutputsFile(childNodeStatus.GetOutputDir())}, + OutputInfo: &handler.OutputInfo{OutputURI: outputsPath}, }) return handler.DoTransition(handler.TransitionTypeEphemeral, phase), nil diff --git a/flytepropeller/pkg/controller/nodes/branch/handler_test.go b/flytepropeller/pkg/controller/nodes/branch/handler_test.go index a48344020d..96d20b1710 100644 --- a/flytepropeller/pkg/controller/nodes/branch/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/branch/handler_test.go @@ -1,6 +1,7 @@ package branch import ( + "bytes" "context" "fmt" "testing" @@ -110,6 +111,7 @@ func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.Nod ns := &mocks2.ExecutableNodeStatus{} ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) ns.OnGetPhase().Return(v1alpha1.NodePhaseNotYetStarted) ir := &mocks3.InputReader{} @@ -162,7 +164,6 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { name string ns interfaces.NodeStatus err error - nodeStatus *mocks2.ExecutableNodeStatus branchTakenNode v1alpha1.ExecutableNode isErr bool expectedPhase handler.EPhase @@ -170,17 +171,17 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { upstreamNodeID string }{ {"upstreamNodeExists", interfaces.NodeStatusPending, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"}, + bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"}, {"childNodeError", interfaces.NodeStatusUndefined, fmt.Errorf("err"), - &mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""}, + bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""}, {"childPending", interfaces.NodeStatusPending, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""}, + bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""}, {"childStillRunning", interfaces.NodeStatusRunning, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""}, + bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""}, {"childFailure", interfaces.NodeStatusFailed(expectedError), nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""}, + bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""}, {"childComplete", interfaces.NodeStatusComplete, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""}, + bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -222,16 +223,16 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { ).Return(test.ns, test.err) childNodeStatus := &mocks2.ExecutableNodeStatus{} - if mockNodeLookup != nil { - childNodeStatus.OnGetOutputDir().Return("parent-output-dir") - test.nodeStatus.OnGetDataDir().Return("parent-data-dir") - test.nodeStatus.OnGetOutputDir().Return("parent-output-dir") - mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus) - childNodeStatus.On("SetDataDir", storage.DataReference("parent-data-dir")).Once() - childNodeStatus.On("SetOutputDir", storage.DataReference("parent-output-dir")).Once() + childNodeStatus.OnGetAttempts().Return(0) + childNodeStatus.On("SetDataDir", storage.DataReference("/output-dir/child")).Once() + childNodeStatus.On("SetOutputDir", storage.DataReference("/output-dir/child/0")).Once() + mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus) + if test.childPhase == v1alpha1.NodePhaseSucceeded { + _ = nCtx.DataStore().WriteRaw(ctx, storage.DataReference("/output-dir/child/0/outputs.pb"), 0, storage.Options{}, bytes.NewReader([]byte{})) } + branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()).(*branchHandler) - h, err := branch.recurseDownstream(ctx, nCtx, test.nodeStatus, test.branchTakenNode) + h, err := branch.recurseDownstream(ctx, nCtx, test.branchTakenNode) if test.isErr { assert.Error(t, err) } else {