Skip to content

Commit

Permalink
TUN-8748: Migrated datagram V3 flows to use migrated context
Browse files Browse the repository at this point in the history
Previously, during local flow migration the current connection context
was not part of the migration and would cause the flow to still be listening
on the connection context of the old connection (before the migration).
This meant that if a flow was migrated from connection 0 to
connection 1, and connection 0 goes away, the flow would be early
terminated incorrectly with the context lifetime of connection 0.

The new connection context is provided during migration of a flow
and will trigger the observe loop for the flow lifetime to be rebound
to this provided context.
Closes TUN-8748
  • Loading branch information
DevinCarr committed Nov 21, 2024
1 parent c59d56c commit d779394
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 15 deletions.
2 changes: 1 addition & 1 deletion quic/v3/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func (c *datagramConn) handleSessionMigration(requestID RequestID, logger *zerol

// Migrate the session to use this edge connection instead of the currently running one.
// We also pass in this connection's logger to override the existing logger for the session.
session.Migrate(c, c.logger)
session.Migrate(c, c.conn.Context(), c.logger)

// Send another registration response since the session is already active
err = c.SendUDPSessionResponse(requestID, ResponseOk)
Expand Down
14 changes: 8 additions & 6 deletions quic/v3/muxer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,14 @@ func newMockSession() mockSession {
}
}

func (m *mockSession) ID() v3.RequestID { return testRequestID }
func (m *mockSession) RemoteAddr() net.Addr { return testOriginAddr }
func (m *mockSession) LocalAddr() net.Addr { return testLocalAddr }
func (m *mockSession) ConnectionID() uint8 { return 0 }
func (m *mockSession) Migrate(conn v3.DatagramConn, log *zerolog.Logger) { m.migrated <- conn.ID() }
func (m *mockSession) ResetIdleTimer() {}
func (m *mockSession) ID() v3.RequestID { return testRequestID }
func (m *mockSession) RemoteAddr() net.Addr { return testOriginAddr }
func (m *mockSession) LocalAddr() net.Addr { return testLocalAddr }
func (m *mockSession) ConnectionID() uint8 { return 0 }
func (m *mockSession) Migrate(conn v3.DatagramConn, ctx context.Context, log *zerolog.Logger) {
m.migrated <- conn.ID()
}
func (m *mockSession) ResetIdleTimer() {}

func (m *mockSession) Serve(ctx context.Context) error {
close(m.served)
Expand Down
23 changes: 17 additions & 6 deletions quic/v3/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type Session interface {
RemoteAddr() net.Addr
LocalAddr() net.Addr
ResetIdleTimer()
Migrate(eyeball DatagramConn, logger *zerolog.Logger)
Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger)
// Serve starts the event loop for processing UDP packets
Serve(ctx context.Context) error
}
Expand All @@ -70,6 +70,7 @@ type session struct {
// activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time
closeChan chan error
contextChan chan context.Context
metrics Metrics
log *zerolog.Logger
}
Expand All @@ -96,8 +97,10 @@ func NewSession(
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 1),
closeChan: make(chan error, 1),
metrics: metrics,
log: &logger,
// contextChan is an unbounded channel to help enforce one active migration of a session at a time.
contextChan: make(chan context.Context),
metrics: metrics,
log: &logger,
}
session.eyeball.Store(&eyeball)
return session
Expand All @@ -120,11 +123,12 @@ func (s *session) ConnectionID() uint8 {
return eyeball.ID()
}

func (s *session) Migrate(eyeball DatagramConn, logger *zerolog.Logger) {
func (s *session) Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger) {
current := *(s.eyeball.Load())
// Only migrate if the connection ids are different.
if current.ID() != eyeball.ID() {
s.eyeball.Store(&eyeball)
s.contextChan <- ctx
log := logger.With().Str(logFlowID, s.id.String()).Logger()
s.log = &log
}
Expand Down Expand Up @@ -225,6 +229,7 @@ func (s *session) Close() error {
}

func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {
connCtx := ctx
// Closing the session at the end cancels read so Serve() can return
defer s.Close()
if closeAfterIdle == 0 {
Expand All @@ -237,8 +242,14 @@ func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-connCtx.Done():
return connCtx.Err()
case newContext := <-s.contextChan:
// During migration of a session, we need to make sure that the context of the new connection is used instead
// of the old connection context. This will ensure that when the old connection goes away, this session will
// still be active on the existing connection.
connCtx = newContext
continue
case reason := <-s.closeChan:
return reason
case <-checkIdleTimer.C:
Expand Down
80 changes: 78 additions & 2 deletions quic/v3/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,28 @@ func TestSessionServe_Migrate(t *testing.T) {
defer session.Close()

done := make(chan error)
eyeball1Ctx, cancel := context.WithCancelCause(context.Background())
go func() {
done <- session.Serve(context.Background())
done <- session.Serve(eyeball1Ctx)
}()

// Migrate the session to a new connection before origin sends data
eyeball2 := newMockEyeball()
eyeball2.connID = 1
session.Migrate(&eyeball2, &log)
eyeball2Ctx := context.Background()
session.Migrate(&eyeball2, eyeball2Ctx, &log)

// Cancel the origin eyeball context; this should not cancel the session
contextCancelErr := errors.New("context canceled for first eyeball connection")
cancel(contextCancelErr)
select {
case <-done:
t.Fatalf("expected session to still be running")
default:
}
if context.Cause(eyeball1Ctx) != contextCancelErr {
t.Fatalf("first eyeball context should be cancelled manually: %+v", context.Cause(eyeball1Ctx))
}

// Origin sends data
payload2 := []byte{0xde}
Expand All @@ -166,6 +180,68 @@ func TestSessionServe_Migrate(t *testing.T) {
if !errors.Is(err, v3.SessionIdleErr{}) {
t.Error(err)
}
if eyeball2Ctx.Err() != nil {
t.Fatalf("second eyeball context should be not be cancelled")
}
}

func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
log := zerolog.Nop()
eyeball := newMockEyeball()
pipe1, pipe2 := net.Pipe()
session := v3.NewSession(testRequestID, 2*time.Second, pipe2, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
defer session.Close()

done := make(chan error)
eyeball1Ctx, cancel := context.WithCancelCause(context.Background())
go func() {
done <- session.Serve(eyeball1Ctx)
}()

// Migrate the session to a new connection before origin sends data
eyeball2 := newMockEyeball()
eyeball2.connID = 1
eyeball2Ctx, cancel2 := context.WithCancelCause(context.Background())
session.Migrate(&eyeball2, eyeball2Ctx, &log)

// Cancel the origin eyeball context; this should not cancel the session
contextCancelErr := errors.New("context canceled for first eyeball connection")
cancel(contextCancelErr)
select {
case <-done:
t.Fatalf("expected session to still be running")
default:
}
if context.Cause(eyeball1Ctx) != contextCancelErr {
t.Fatalf("first eyeball context should be cancelled manually: %+v", context.Cause(eyeball1Ctx))
}

// Origin sends data
payload2 := []byte{0xde}
pipe1.Write(payload2)

// Expect write to eyeball2
data := <-eyeball2.recvData
if len(data) <= 17 || !slices.Equal(payload2, data[17:]) {
t.Fatalf("expected data to write to eyeball2 after migration: %+v", data)
}

select {
case data := <-eyeball.recvData:
t.Fatalf("expected no data to write to eyeball1 after migration: %+v", data)
default:
}

// Close the connection2 context manually
contextCancel2Err := errors.New("context canceled for second eyeball connection")
cancel2(contextCancel2Err)
err := <-done
if err != context.Canceled {
t.Fatalf("session Serve should be done: %+v", err)
}
if context.Cause(eyeball2Ctx) != contextCancel2Err {
t.Fatalf("second eyeball context should have been cancelled manually: %+v", context.Cause(eyeball2Ctx))
}
}

func TestSessionClose_Multiple(t *testing.T) {
Expand Down

0 comments on commit d779394

Please sign in to comment.