Skip to content

Commit

Permalink
#minor Intratask checkpointing in FlytePropeller (flyteorg#360)
Browse files Browse the repository at this point in the history
* Intratask checkpointing in FlytePropeller

Signed-off-by: Ketan Umare <[email protected]>

* fixed tests

Signed-off-by: Ketan Umare <[email protected]>

* addressed comments

Signed-off-by: Ketan Umare <[email protected]>

* merged from master

Signed-off-by: Ketan Umare <[email protected]>

* updated config

Signed-off-by: Ketan Umare <[email protected]>

* plugins updated

Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored Nov 24, 2021
1 parent 4d3a602 commit dc5d314
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 33 deletions.
2 changes: 1 addition & 1 deletion flytepropeller/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1
github.com/fatih/color v1.10.0
github.com/flyteorg/flyteidl v0.21.4
github.com/flyteorg/flyteplugins v0.7.5
github.com/flyteorg/flyteplugins v0.8.0
github.com/flyteorg/flytestdlib v0.4.4
github.com/ghodss/yaml v1.0.0
github.com/go-redis/redis v6.15.7+incompatible
Expand Down
4 changes: 2 additions & 2 deletions flytepropeller/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGE
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/flyteorg/flyteidl v0.21.4 h1:gtJK5rX2ydLAo2xLRHHznOSLuLHrRRdXDbpEAlxluhk=
github.com/flyteorg/flyteidl v0.21.4/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flyteplugins v0.7.5 h1:mu9agOeSRKKdZDdV0OrJ9fZzrAaKhZLXt4sbqRYWPvg=
github.com/flyteorg/flyteplugins v0.7.5/go.mod h1:kOiuXk1ddIEVSPoHcc4kBfVQcLuyf8jw3vWJT2Was90=
github.com/flyteorg/flyteplugins v0.8.0 h1:Jiy7Ugm9olGmm5OFAbbxv/VfVmYib3JqGdeytyoiwnU=
github.com/flyteorg/flyteplugins v0.8.0/go.mod h1:kOiuXk1ddIEVSPoHcc4kBfVQcLuyf8jw3vWJT2Was90=
github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220=
github.com/flyteorg/flytestdlib v0.3.36/go.mod h1:7cDWkY3v7xsoesFcDdu6DSW5Q2U2W5KlHUbUHSwBG1Q=
github.com/flyteorg/flytestdlib v0.4.4 h1:oPADei4KEjxtUqkTwrIjXB1nuH+JEKjwmwF92DSO3NM=
Expand Down
3 changes: 2 additions & 1 deletion flytepropeller/pkg/controller/nodes/dynamic/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, n
logger.Infof(ctx, "dynamic workflow node has succeeded, will call on success handler for parent node [%s]", nCtx.NodeID())
// These outputPaths only reads the output metadata. So the sandbox is completely optional here and hence it is nil.
// The sandbox creation as it uses hashing can be expensive and we skip that expense.
outputPaths := ioutils.NewRemoteFileOutputPaths(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetOutputDir(), nil)
outputPaths := ioutils.NewReadOnlyOutputFilePaths(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetOutputDir())
execID := task.GetTaskExecutionIdentifier(nCtx)
outputReader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, nCtx.MaxDatasetSizeBytes())
status, ee, err := d.TaskNodeHandler.ValidateOutputAndCacheAdd(ctx, nCtx.NodeID(), nCtx.InputReader(),
Expand All @@ -171,6 +171,7 @@ func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, n
return trns, newState, nil
}

// Handle method is the entry point for handling a node, which may or maynot be a Dynamic node
// The State machine for a dynamic node is as follows
// DynamicNodePhaseNone: The parent node is being handled
// DynamicNodePhaseParentFinalizing: The parent node has completes successfully and sub-nodes exist (futures file found). Parent node is being finalized.
Expand Down
65 changes: 44 additions & 21 deletions flytepropeller/pkg/controller/nodes/task/taskexec_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,22 @@ import (
"context"
"strconv"

"github.com/flyteorg/flytepropeller/pkg/utils"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/encoding"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"

"github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"

"github.com/flyteorg/flytepropeller/pkg/controller/nodes/common"

"github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager"

"github.com/flyteorg/flytestdlib/logger"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"

pluginCatalog "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog"
pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/encoding"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"

"github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/common"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager"
"github.com/flyteorg/flytepropeller/pkg/utils"
"github.com/flyteorg/flytestdlib/logger"
"github.com/flyteorg/flytestdlib/storage"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
)

var (
Expand Down Expand Up @@ -193,6 +186,36 @@ func convertTaskResourcesToRequirements(taskResources v1alpha1.TaskResources) *v

}

// ComputeRawOutputPrefix constructs the output directory, where raw outputs of a task can be stored by the task. FlytePropeller may not have
// access to this location and can be passed in per execution.
// the function also returns the uniqueID generated
func ComputeRawOutputPrefix(ctx context.Context, length int, nCtx handler.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (io.RawOutputPaths, string, error) {
uniqueID, err := encoding.FixedLengthUniqueIDForParts(length, nCtx.NodeExecutionMetadata().GetOwnerID().Name, currentNodeUniqueID, strconv.Itoa(int(currentAttempt)))
if err != nil {
// SHOULD never really happen
return nil, uniqueID, err
}

rawOutputPrefix, err := ioutils.NewShardedRawOutputPath(ctx, nCtx.OutputShardSelector(), nCtx.RawOutputPrefix(), uniqueID, nCtx.DataStore())
if err != nil {
return nil, uniqueID, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "failed to create output sandbox for node execution")
}
return rawOutputPrefix, uniqueID, nil
}

// ComputePreviousCheckpointPath returns the checkpoint path for the previous attempt, if this is the first attempt then returns an empty path
func ComputePreviousCheckpointPath(ctx context.Context, length int, nCtx handler.NodeExecutionContext, currentNodeUniqueID v1alpha1.NodeID, currentAttempt uint32) (storage.DataReference, error) {
if currentAttempt == 0 {
return "", nil
}
prevAttempt := currentAttempt - 1
prevRawOutputPrefix, _, err := ComputeRawOutputPrefix(ctx, length, nCtx, currentNodeUniqueID, prevAttempt)
if err != nil {
return "", err
}
return ioutils.ConstructCheckpointPath(nCtx.DataStore(), prevRawOutputPrefix.GetRawOutputPrefix()), nil
}

func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx handler.NodeExecutionContext, plugin pluginCore.Plugin) (*taskExecutionContext, error) {
id := GetTaskExecutionIdentifier(nCtx)

Expand All @@ -210,17 +233,17 @@ func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx handler.Node
length = *l
}

uniqueID, err := encoding.FixedLengthUniqueIDForParts(length, nCtx.NodeExecutionMetadata().GetOwnerID().Name, currentNodeUniqueID, strconv.Itoa(int(id.RetryAttempt)))
rawOutputPrefix, uniqueID, err := ComputeRawOutputPrefix(ctx, length, nCtx, currentNodeUniqueID, id.RetryAttempt)
if err != nil {
// SHOULD never really happen
return nil, err
}

outputSandbox, err := ioutils.NewShardedRawOutputPath(ctx, nCtx.OutputShardSelector(), nCtx.RawOutputPrefix(), uniqueID, nCtx.DataStore())
prevCheckpointPath, err := ComputePreviousCheckpointPath(ctx, length, nCtx, currentNodeUniqueID, id.RetryAttempt)
if err != nil {
return nil, errors.Wrapf(errors.StorageError, nCtx.NodeID(), err, "failed to create output sandbox for node execution")
return nil, err
}
ow := ioutils.NewBufferedOutputWriter(ctx, ioutils.NewRemoteFileOutputPaths(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetOutputDir(), outputSandbox))

ow := ioutils.NewBufferedOutputWriter(ctx, ioutils.NewCheckpointRemoteFilePaths(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetOutputDir(), rawOutputPrefix, prevCheckpointPath))
ts := nCtx.NodeStateReader().GetTaskNodeState()
var b *bytes.Buffer
if ts.PluginState != nil {
Expand Down
86 changes: 82 additions & 4 deletions flytepropeller/pkg/controller/nodes/task/taskexec_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func TestHandler_newTaskExecutionContext(t *testing.T) {
tr.OnGetTaskID().Return(taskID)

ns := &flyteMocks.ExecutableNodeStatus{}
ns.OnGetDataDir().Return(storage.DataReference("data-dir"))
ns.OnGetOutputDir().Return(storage.DataReference("output-dir"))
ns.OnGetDataDir().Return("data-dir")
ns.OnGetOutputDir().Return("output-dir")

res := &corev1.ResourceRequirements{
Requests: make(corev1.ResourceList),
Expand Down Expand Up @@ -180,13 +180,24 @@ func TestHandler_newTaskExecutionContext(t *testing.T) {
// assert.Equal(t, got.InputReader(), ir)

anotherPlugin := &pluginCoreMocks.Plugin{}
anotherPlugin.On("GetID").Return("plugin2")
anotherPlugin.OnGetID().Return("plugin2")
maxLength := 8
anotherPlugin.OnGetProperties().Return(pluginCore.PluginProperties{
GeneratedNameMaxLength: &maxLength,
})
anotherTaskExecCtx, _ := tk.newTaskExecutionContext(context.TODO(), nCtx, anotherPlugin)
anotherTaskExecCtx, err := tk.newTaskExecutionContext(context.TODO(), nCtx, anotherPlugin)
assert.NoError(t, err)
assert.Equal(t, anotherTaskExecCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), "fpmmhh6q")
assert.NotNil(t, anotherTaskExecCtx.ow)
assert.Equal(t, storage.DataReference("s3://sandbox/x/fpmmhh6q"), anotherTaskExecCtx.ow.GetRawOutputPrefix())
assert.Equal(t, storage.DataReference("s3://sandbox/x/fpmmhh6q/_flytecheckpoints"), anotherTaskExecCtx.ow.GetCheckpointPrefix())
assert.Equal(t, storage.DataReference("s3://sandbox/x/fpqmhlei/_flytecheckpoints"), anotherTaskExecCtx.ow.GetPreviousCheckpointsPrefix())
assert.NotNil(t, anotherTaskExecCtx.psm)
assert.NotNil(t, anotherTaskExecCtx.ber)
assert.NotNil(t, anotherTaskExecCtx.rm)
assert.NotNil(t, anotherTaskExecCtx.sm)
assert.NotNil(t, anotherTaskExecCtx.tm)
assert.NotNil(t, anotherTaskExecCtx.tr)
}

func TestAssignResource(t *testing.T) {
Expand Down Expand Up @@ -340,3 +351,70 @@ func TestConvertTaskResourcesToRequirements(t *testing.T) {
},
}, resourceRequirements)
}

func TestComputeRawOutputPrefix(t *testing.T) {

nCtx := &nodeMocks.NodeExecutionContext{}
nm := &nodeMocks.NodeExecutionMetadata{}
nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"})
nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"}))
nCtx.OnRawOutputPrefix().Return("s3://sandbox/")
ds, err := storage.NewDataStore(
&storage.Config{
Type: storage.TypeMemory,
},
promutils.NewTestScope(),
)
assert.NoError(t, err)
nCtx.OnDataStore().Return(ds)
nCtx.OnNodeExecutionMetadata().Return(nm)

pre, uid, err := ComputeRawOutputPrefix(context.TODO(), 100, nCtx, "n1", 0)
assert.NoError(t, err)
assert.Equal(t, "name-n1-0", uid)
assert.NotNil(t, pre)
assert.Equal(t, storage.DataReference("s3://sandbox/x/name-n1-0"), pre.GetRawOutputPrefix())

pre, uid, err = ComputeRawOutputPrefix(context.TODO(), 8, nCtx, "n1", 0)
assert.NoError(t, err)
assert.Equal(t, "fpqmhlei", uid)
assert.NotNil(t, pre)
assert.Equal(t, storage.DataReference("s3://sandbox/x/fpqmhlei"), pre.GetRawOutputPrefix())

_, _, err = ComputeRawOutputPrefix(context.TODO(), 5, nCtx, "n1", 0)
assert.Error(t, err)
}

func TestComputePreviousCheckpointPath(t *testing.T) {
t.Run("attempt-0", func(t *testing.T) {
c, err := ComputePreviousCheckpointPath(context.TODO(), 100, nil, "n1", 0)
assert.NoError(t, err)
assert.Equal(t, storage.DataReference(""), c)
})

nCtx := &nodeMocks.NodeExecutionContext{}
nm := &nodeMocks.NodeExecutionMetadata{}
nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"})
nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"}))
nCtx.OnRawOutputPrefix().Return("s3://sandbox/")
ds, err := storage.NewDataStore(
&storage.Config{
Type: storage.TypeMemory,
},
promutils.NewTestScope(),
)
assert.NoError(t, err)
nCtx.OnDataStore().Return(ds)
nCtx.OnNodeExecutionMetadata().Return(nm)
t.Run("attempt-0-nCtx", func(t *testing.T) {
c, err := ComputePreviousCheckpointPath(context.TODO(), 100, nCtx, "n1", 0)
assert.NoError(t, err)
assert.Equal(t, storage.DataReference(""), c)
})

t.Run("attempt-1-nCtx", func(t *testing.T) {
c, err := ComputePreviousCheckpointPath(context.TODO(), 100, nCtx, "n1", 1)
assert.NoError(t, err)
assert.Equal(t, storage.DataReference("s3://sandbox/x/name-n1-0/_flytecheckpoints"), c)
})
}
10 changes: 6 additions & 4 deletions flytepropeller/propeller-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ plugins:
- FLYTE_AWS_SECRET_ACCESS_KEY: miniostorage
co-pilot:
name: "flyte-copilot-"
image: "ghcr.io/flyteorg/flytecopilot:v0.5.28"
image: "ghcr.io/flyteorg/flytecopilot:v0.0.15"
start-timeout: "5s"
sagemaker:
roleArn: "arn:aws:iam::123456789012:role/test-development"
Expand All @@ -81,15 +81,17 @@ plugins:
kubernetes-enabled: true
kubernetes-url: "http://localhost:30082"
storage:
type: minio
container: "my-s3-bucket"
connection:
access-key: minio
auth-type: accesskey
secret-key: miniostorage
disable-ssl: true
endpoint: http://localhost:30084
region: us-east-1
secret-key: miniostorage
type: minio
container: "my-s3-bucket"
limits:
maxDownloadMBs: 10
event:
type: admin
rate: 500
Expand Down

0 comments on commit dc5d314

Please sign in to comment.