Skip to content

Commit

Permalink
Use WithContext in all DB calls (#5538)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Parraga <[email protected]>
  • Loading branch information
Sovietaced authored Jul 4, 2024
1 parent e13bfe3 commit b63ce0e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion flyteadmin/pkg/repositories/gormimpl/named_entity_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (r *NamedEntityRepo) List(ctx context.Context, input interfaces.ListNamedEn
"Cannot list entity names for resource type: %v", input.ResourceType)
}

tx := getSubQueryJoin(r.db, tableName, input)
tx := getSubQueryJoin(r.db.WithContext(ctx), tableName, input)

// Apply filters
tx, err := applyScopedFilters(tx, input.InlineFilters, input.MapFilters)
Expand Down
8 changes: 4 additions & 4 deletions flyteadmin/pkg/repositories/gormimpl/signal_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type SignalRepo struct {
func (s *SignalRepo) Get(ctx context.Context, input models.SignalKey) (models.Signal, error) {
var signal models.Signal
timer := s.metrics.GetDuration.Start()
tx := s.db.Where(&models.Signal{
tx := s.db.WithContext(ctx).Where(&models.Signal{
SignalKey: input,
}).Take(&signal)
timer.Stop()
Expand All @@ -41,7 +41,7 @@ func (s *SignalRepo) Get(ctx context.Context, input models.SignalKey) (models.Si
// GetOrCreate returns a signal if it already exists, if not it creates a new one given the input
func (s *SignalRepo) GetOrCreate(ctx context.Context, input *models.Signal) error {
timer := s.metrics.CreateDuration.Start()
tx := s.db.FirstOrCreate(&input, input)
tx := s.db.WithContext(ctx).FirstOrCreate(&input, input)
timer.Stop()
if tx.Error != nil {
return s.errorTransformer.ToFlyteAdminError(tx.Error)
Expand All @@ -56,7 +56,7 @@ func (s *SignalRepo) List(ctx context.Context, input interfaces.ListResourceInpu
return nil, err
}
var signals []models.Signal
tx := s.db.Limit(input.Limit).Offset(input.Offset)
tx := s.db.WithContext(ctx).Limit(input.Limit).Offset(input.Offset)

// Apply filters
tx, err := applyFilters(tx, input.InlineFilters, input.MapFilters)
Expand Down Expand Up @@ -85,7 +85,7 @@ func (s *SignalRepo) Update(ctx context.Context, input models.SignalKey, value [
}

timer := s.metrics.GetDuration.Start()
tx := s.db.Model(&signal).Select("value").Updates(signal)
tx := s.db.WithContext(ctx).Model(&signal).Select("value").Updates(signal)
timer.Stop()
if tx.Error != nil {
return s.errorTransformer.ToFlyteAdminError(tx.Error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (r *TaskExecutionRepo) Get(ctx context.Context, input interfaces.GetTaskExe

func (r *TaskExecutionRepo) Update(ctx context.Context, execution models.TaskExecution) error {
timer := r.metrics.UpdateDuration.Start()
tx := r.db.WithContext(ctx).WithContext(ctx).Save(&execution) // TODO @hmaersaw - need to add WithContext to all db calls to link otel spans
tx := r.db.WithContext(ctx).WithContext(ctx).Save(&execution)
timer.Stop()

if err := tx.Error; err != nil {
Expand Down

0 comments on commit b63ce0e

Please sign in to comment.