From 61838b47aa0eb97603842bbd875c2953d444b596 Mon Sep 17 00:00:00 2001 From: Alex Wu <66259759+popojk@users.noreply.github.com> Date: Sat, 28 Dec 2024 04:18:19 +0800 Subject: [PATCH] Flyteadmin digest comparison should rely on database semantics (#6058) * to TaskManager CreateTask in transactional way Signed-off-by: Alex Wu * amend TaskRepo Create method to create task before description to prevent TaskManager CreateTask method Task not found isue Signed-off-by: Alex Wu * add unit test Signed-off-by: Alex Wu * fix lint error Signed-off-by: Alex Wu --------- Signed-off-by: Alex Wu --- flyteadmin/pkg/manager/impl/task_manager.go | 38 +++++++----- .../pkg/manager/impl/task_manager_test.go | 61 +++++++++++++++++++ .../pkg/repositories/gormimpl/task_repo.go | 4 +- 3 files changed, 85 insertions(+), 18 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/task_manager.go b/flyteadmin/pkg/manager/impl/task_manager.go index 7d903e98fb..c5735fe406 100644 --- a/flyteadmin/pkg/manager/impl/task_manager.go +++ b/flyteadmin/pkg/manager/impl/task_manager.go @@ -88,19 +88,7 @@ func (t *TaskManager) CreateTask( logger.Errorf(ctx, "failed to compute task digest with err %v", err) return nil, err } - // See if a task exists and confirm whether it's an identical task or one that with a separate definition. - existingTaskModel, err := util.GetTaskModel(ctx, t.db, request.GetSpec().GetTemplate().GetId()) - if err == nil { - if bytes.Equal(taskDigest, existingTaskModel.Digest) { - return nil, errors.NewTaskExistsIdenticalStructureError(ctx, request) - } - existingTask, transformerErr := transformers.FromTaskModel(*existingTaskModel) - if transformerErr != nil { - logger.Errorf(ctx, "failed to transform task from task model") - return nil, transformerErr - } - return nil, errors.NewTaskExistsDifferentStructureError(ctx, request, existingTask.GetClosure().GetCompiledTask(), compiledTask) - } + // Create Task in DB taskModel, err := transformers.CreateTaskModel(finalizedRequest, &admin.TaskClosure{ CompiledTask: compiledTask, CreatedAt: createdAt, @@ -110,7 +98,6 @@ func (t *TaskManager) CreateTask( "Failed to transform task model [%+v] with err: %v", finalizedRequest, err) return nil, err } - descriptionModel, err := transformers.CreateDescriptionEntityModel(request.GetSpec().GetDescription(), request.GetId()) if err != nil { logger.Errorf(ctx, @@ -122,8 +109,27 @@ func (t *TaskManager) CreateTask( } err = t.db.TaskRepo().Create(ctx, taskModel, descriptionModel) if err != nil { - logger.Debugf(ctx, "Failed to create task model with id [%+v] with err %v", request.GetId(), err) - return nil, err + // See if an identical task already exists by checking the error code + flyteErr, ok := err.(errors.FlyteAdminError) + if !ok || flyteErr.Code() != codes.AlreadyExists { + logger.Errorf(ctx, "Failed to create task model with id [%+v] with err %v", request.GetId(), err) + return nil, err + } + // An identical task already exists. Fetch the existing task to verify if it has a different digest + existingTaskModel, fetchErr := util.GetTaskModel(ctx, t.db, request.GetSpec().GetTemplate().GetId()) + if fetchErr != nil { + logger.Errorf(ctx, "Failed to fetch existing task model for id [%+v] with err %v", request.GetId(), fetchErr) + return nil, fetchErr + } + if bytes.Equal(taskDigest, existingTaskModel.Digest) { + return nil, errors.NewTaskExistsIdenticalStructureError(ctx, request) + } + existingTask, transformerErr := transformers.FromTaskModel(*existingTaskModel) + if transformerErr != nil { + logger.Errorf(ctx, "Failed to transform task from task model for id [%+v]", request.GetId()) + return nil, transformerErr + } + return nil, errors.NewTaskExistsDifferentStructureError(ctx, request, existingTask.GetClosure().GetCompiledTask(), compiledTask) } t.metrics.ClosureSizeBytes.Observe(float64(len(taskModel.Closure))) if finalizedRequest.GetSpec().GetTemplate().GetMetadata() != nil { diff --git a/flyteadmin/pkg/manager/impl/task_manager_test.go b/flyteadmin/pkg/manager/impl/task_manager_test.go index 1301444ceb..50fc2c3234 100644 --- a/flyteadmin/pkg/manager/impl/task_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_manager_test.go @@ -99,6 +99,67 @@ func TestCreateTask(t *testing.T) { assert.NotNil(t, response) } +func TestCreateTask_DuplicateTaskRegistration(t *testing.T) { + mockRepository := getMockTaskRepository() + mockRepository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetGetCallback( + func(input interfaces.Identifier) (models.Task, error) { + return models.Task{ + TaskKey: models.TaskKey{ + Project: taskIdentifier.GetProject(), + Domain: taskIdentifier.GetDomain(), + Name: taskIdentifier.GetName(), + Version: taskIdentifier.GetVersion(), + }, + Digest: []byte{ + 0xbf, 0x79, 0x61, 0x1c, 0xf5, 0xc1, 0xfb, 0x4c, 0xf8, 0xf4, 0xc4, 0x53, 0x5f, 0x8f, 0x73, 0xe2, 0x26, 0x5a, + 0x18, 0x4a, 0xb7, 0x66, 0x98, 0x3c, 0xab, 0x2, 0x6c, 0x9, 0x9b, 0x90, 0xec, 0x8f}, + }, nil + }) + mockRepository.TaskRepo().(*repositoryMocks.MockTaskRepo).SetCreateCallback(func(input models.Task, descriptionEntity *models.DescriptionEntity) error { + return adminErrors.NewFlyteAdminErrorf(codes.AlreadyExists, "task already exists") + }) + taskManager := NewTaskManager(mockRepository, getMockConfigForTaskTest(), getMockTaskCompiler(), + mockScope.NewTestScope()) + request := testutils.GetValidTaskRequest() + _, err := taskManager.CreateTask(context.Background(), request) + assert.Error(t, err) + flyteErr, ok := err.(adminErrors.FlyteAdminError) + assert.True(t, ok, "Error should be of type FlyteAdminError") + assert.Equal(t, codes.AlreadyExists, flyteErr.Code(), "Error code should be AlreadyExists") + assert.Contains(t, flyteErr.Error(), "task with identical structure already exists") + differentTemplate := &core.TaskTemplate{ + Id: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + }, + Type: "type", + Metadata: &core.TaskMetadata{ + Runtime: &core.RuntimeMetadata{ + Version: "runtime version 2", + }, + }, + Interface: &core.TypedInterface{}, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "image", + Command: []string{ + "command", + }, + }, + }, + } + request.Spec.Template = differentTemplate + _, err = taskManager.CreateTask(context.Background(), request) + assert.Error(t, err) + flyteErr, ok = err.(adminErrors.FlyteAdminError) + assert.True(t, ok, "Error should be of type FlyteAdminError") + assert.Equal(t, codes.InvalidArgument, flyteErr.Code(), "Error code should be InvalidArgument") + assert.Contains(t, flyteErr.Error(), "name task with different structure already exists.") +} + func TestCreateTask_ValidationError(t *testing.T) { mockRepository := getMockTaskRepository() taskManager := NewTaskManager(mockRepository, getMockConfigForTaskTest(), getMockTaskCompiler(), diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index 1b42756b7a..3f99172224 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -30,12 +30,12 @@ func (r *TaskRepo) Create(ctx context.Context, input models.Task, descriptionEnt } return nil } - tx := r.db.WithContext(ctx).Omit("id").Create(descriptionEntity) + tx := r.db.WithContext(ctx).Omit("id").Create(&input) if tx.Error != nil { return r.errorTransformer.ToFlyteAdminError(tx.Error) } - tx = r.db.WithContext(ctx).Omit("id").Create(&input) + tx = r.db.WithContext(ctx).Omit("id").Create(descriptionEntity) if tx.Error != nil { return r.errorTransformer.ToFlyteAdminError(tx.Error) }