Skip to content

Commit

Permalink
Use child dir for branch taken
Browse files Browse the repository at this point in the history
We reuse the same data directory for the branch node and the branch taken (child). This change adds an additional set of directories to decouple these. It makes sure to copy the child output up to the branch node output like we do in subworkflow

Example
```
from dataclasses import dataclass
from flytekit import conditional, task, workflow

@DataClass
class MyData:
    foo: str

@workflow
def root_wf(data: MyData) -> str:
    return sub_wf(data=data)

@workflow
def sub_wf(data: MyData) -> str:
    check = always_true()
    return conditional("decision").if_(check.is_true()).then(conditional_wf(data=data)).else_().fail("not done")

@task
def always_true() -> bool:
    return True

@workflow
def conditional_wf(data: MyData) -> str:
    return done(data)

@task
def done(data: MyData) -> str:
    return f"done ({data.foo})"
```

- [x] Add unittests
- [x] Run locally using sandbox and verify behavior with working example. Inspect paths
  • Loading branch information
andrewwdye committed Dec 20, 2024
1 parent 27b2f3a commit 2132483
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 24 deletions.
31 changes: 22 additions & 9 deletions flytepropeller/pkg/controller/nodes/branch/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package branch
import (
"context"
"fmt"
"strconv"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
Expand All @@ -15,6 +16,7 @@ import (
stdErrors "github.com/flyteorg/flyte/flytestdlib/errors"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/flyteorg/flyte/flytestdlib/storage"
)

type metrics struct {
Expand Down Expand Up @@ -74,8 +76,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha
childNodeStatus.SetParentNodeID(&i)

logger.Debugf(ctx, "Recursively executing branchNode's chosen path")
nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID())
return b.recurseDownstream(ctx, nCtx, nodeStatus, finalNode)
return b.recurseDownstream(ctx, nCtx, finalNode)
}

// If the branchNodestatus was already evaluated i.e, Node is in Running status
Expand All @@ -99,8 +100,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha
}

// Recurse downstream
nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID())
return b.recurseDownstream(ctx, nCtx, nodeStatus, branchTakenNode)
return b.recurseDownstream(ctx, nCtx, branchTakenNode)
}

func (b *branchHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) {
Expand All @@ -123,7 +123,7 @@ func (b *branchHandler) getExecutionContextForDownstream(nCtx interfaces.NodeExe
return executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), nil
}

func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) {
func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) {
// TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time
// There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc
// The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node.
Expand All @@ -134,8 +134,16 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.N
}

childNodeStatus := nl.GetNodeExecutionStatus(ctx, branchTakenNode.GetID())
childNodeStatus.SetDataDir(nodeStatus.GetDataDir())
childNodeStatus.SetOutputDir(nodeStatus.GetOutputDir())
childDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), branchTakenNode.GetID())
if err != nil {
return handler.UnknownTransition, err
}
childOutputDir, err := nCtx.DataStore().ConstructReference(ctx, childDataDir, strconv.Itoa(int(childNodeStatus.GetAttempts())))
if err != nil {
return handler.UnknownTransition, err
}
childNodeStatus.SetDataDir(childDataDir)
childNodeStatus.SetOutputDir(childOutputDir)
upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID())
if err != nil {
return handler.UnknownTransition, err
Expand All @@ -151,9 +159,14 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.N
}

if downstreamStatus.IsComplete() {
// For branch node we set the output node to be the same as the child nodes output
childOutputsPath := v1alpha1.GetOutputsFile(childOutputDir)
outputsPath := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir())
if err := nCtx.DataStore().CopyRaw(ctx, childOutputsPath, outputsPath, storage.Options{}); err != nil {
errMsg := fmt.Sprintf("Failed to copy child node outputs from [%v] to [%v]", childOutputsPath, outputsPath)
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.OutputsNotFoundError, errMsg, nil)), nil
}
phase := handler.PhaseInfoSuccess(&handler.ExecutionInfo{
OutputInfo: &handler.OutputInfo{OutputURI: v1alpha1.GetOutputsFile(childNodeStatus.GetOutputDir())},
OutputInfo: &handler.OutputInfo{OutputURI: outputsPath},
})

return handler.DoTransition(handler.TransitionTypeEphemeral, phase), nil
Expand Down
31 changes: 16 additions & 15 deletions flytepropeller/pkg/controller/nodes/branch/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package branch

import (
"bytes"
"context"
"fmt"
"testing"
Expand Down Expand Up @@ -110,6 +111,7 @@ func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.Nod

ns := &mocks2.ExecutableNodeStatus{}
ns.OnGetDataDir().Return(storage.DataReference("data-dir"))
ns.OnGetOutputDir().Return(storage.DataReference("output-dir"))
ns.OnGetPhase().Return(v1alpha1.NodePhaseNotYetStarted)

ir := &mocks3.InputReader{}
Expand Down Expand Up @@ -162,25 +164,24 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
name string
ns interfaces.NodeStatus
err error
nodeStatus *mocks2.ExecutableNodeStatus
branchTakenNode v1alpha1.ExecutableNode
isErr bool
expectedPhase handler.EPhase
childPhase v1alpha1.NodePhase
upstreamNodeID string
}{
{"upstreamNodeExists", interfaces.NodeStatusPending, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"},
bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"},
{"childNodeError", interfaces.NodeStatusUndefined, fmt.Errorf("err"),
&mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""},
bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""},
{"childPending", interfaces.NodeStatusPending, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""},
bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""},
{"childStillRunning", interfaces.NodeStatusRunning, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""},
bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""},
{"childFailure", interfaces.NodeStatusFailed(expectedError), nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""},
bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""},
{"childComplete", interfaces.NodeStatusComplete, nil,
&mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""},
bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand Down Expand Up @@ -222,16 +223,16 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) {
).Return(test.ns, test.err)

childNodeStatus := &mocks2.ExecutableNodeStatus{}
if mockNodeLookup != nil {
childNodeStatus.OnGetOutputDir().Return("parent-output-dir")
test.nodeStatus.OnGetDataDir().Return("parent-data-dir")
test.nodeStatus.OnGetOutputDir().Return("parent-output-dir")
mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus)
childNodeStatus.On("SetDataDir", storage.DataReference("parent-data-dir")).Once()
childNodeStatus.On("SetOutputDir", storage.DataReference("parent-output-dir")).Once()
childNodeStatus.OnGetAttempts().Return(0)
childNodeStatus.On("SetDataDir", storage.DataReference("/output-dir/child")).Once()
childNodeStatus.On("SetOutputDir", storage.DataReference("/output-dir/child/0")).Once()
mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus)
if test.childPhase == v1alpha1.NodePhaseSucceeded {
_ = nCtx.DataStore().WriteRaw(ctx, storage.DataReference("/output-dir/child/0/outputs.pb"), 0, storage.Options{}, bytes.NewReader([]byte{}))
}

branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()).(*branchHandler)
h, err := branch.recurseDownstream(ctx, nCtx, test.nodeStatus, test.branchTakenNode)
h, err := branch.recurseDownstream(ctx, nCtx, test.branchTakenNode)
if test.isErr {
assert.Error(t, err)
} else {
Expand Down

0 comments on commit 2132483

Please sign in to comment.