Skip to content

Commit

Permalink
Get NIOEmbedded clean under strict concurrency
Browse files Browse the repository at this point in the history
Motivation:

NIOEmbedded is used all over NIO-land for testing various pieces
of the infrastructure, and so requires a substantial
audit for strict concurrency.

Modifications:

- Mark a few things Sendable.
- Fix the tests, which actually did have some nasty bugs

Result:

Sendable-clean NIOEmbedded
  • Loading branch information
Lukasa committed Dec 16, 2024
1 parent 1e4fde1 commit dce7351
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 105 deletions.
6 changes: 4 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ let package = Package(
"_NIODataStructures",
swiftAtomics,
swiftCollections,
]
],
swiftSettings: strictConcurrencySettings
),
.target(
name: "NIOPosix",
Expand Down Expand Up @@ -429,7 +430,8 @@ let package = Package(
"NIOConcurrencyHelpers",
"NIOCore",
"NIOEmbedded",
]
],
swiftSettings: strictConcurrencySettings
),
.testTarget(
name: "NIOPosixTests",
Expand Down
53 changes: 29 additions & 24 deletions Sources/NIOEmbedded/AsyncTestingChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -199,24 +199,23 @@ public final class NIOAsyncTestingChannel: Channel {
/// `nil` because ``NIOAsyncTestingChannel``s don't have parents.
public let parent: Channel? = nil

// This is only written once, from a single thread, and never written again, so it's _technically_ thread-safe. Most methods cannot safely
// These two variables are only written once, from a single thread, and never written again, so they're _technically_ thread-safe. Most methods cannot safely
// be used from multiple threads, but `isActive`, `isOpen`, `eventLoop`, and `closeFuture` can all safely be used from any thread. Just.
@usableFromInline
var channelcore: EmbeddedChannelCore!
nonisolated(unsafe) var channelcore: EmbeddedChannelCore!

/// Guards any of the getters/setters that can be accessed from any thread.
private let stateLock: NIOLock = NIOLock()

// Guarded by `stateLock`
private var _isWritable: Bool = true

// Guarded by `stateLock`
private var _localAddress: SocketAddress? = nil
nonisolated(unsafe) private var _pipeline: ChannelPipeline!

// Guarded by `stateLock`
private var _remoteAddress: SocketAddress? = nil
private struct State {
var isWritable: Bool
var localAddress: SocketAddress?
var remoteAddress: SocketAddress?
}

private var _pipeline: ChannelPipeline!
/// Guards any of the getters/setters that can be accessed from any thread.
private let stateLock = NIOLockedValueBox(
State(isWritable: true, localAddress: nil, remoteAddress: nil)
)

/// - see: `Channel._channelCore`
public var _channelCore: ChannelCore {
Expand All @@ -231,35 +230,35 @@ public final class NIOAsyncTestingChannel: Channel {
/// - see: `Channel.isWritable`
public var isWritable: Bool {
get {
self.stateLock.withLock { self._isWritable }
self.stateLock.withLockedValue { $0.isWritable }
}
set {
self.stateLock.withLock { () -> Void in
self._isWritable = newValue
self.stateLock.withLockedValue {
$0.isWritable = newValue
}
}
}

/// - see: `Channel.localAddress`
public var localAddress: SocketAddress? {
get {
self.stateLock.withLock { self._localAddress }
self.stateLock.withLockedValue { $0.localAddress }
}
set {
self.stateLock.withLock { () -> Void in
self._localAddress = newValue
self.stateLock.withLockedValue {
$0.localAddress = newValue
}
}
}

/// - see: `Channel.remoteAddress`
public var remoteAddress: SocketAddress? {
get {
self.stateLock.withLock { self._remoteAddress }
self.stateLock.withLockedValue { $0.remoteAddress }
}
set {
self.stateLock.withLock { () -> Void in
self._remoteAddress = newValue
self.stateLock.withLockedValue {
$0.remoteAddress = newValue
}
}
}
Expand All @@ -283,7 +282,8 @@ public final class NIOAsyncTestingChannel: Channel {
/// - Parameters:
/// - handler: The `ChannelHandler` to add to the `ChannelPipeline` before register.
/// - loop: The ``NIOAsyncTestingEventLoop`` to use.
public convenience init(handler: ChannelHandler, loop: NIOAsyncTestingEventLoop = NIOAsyncTestingEventLoop()) async
@preconcurrency
public convenience init(handler: ChannelHandler & Sendable, loop: NIOAsyncTestingEventLoop = NIOAsyncTestingEventLoop()) async
{
await self.init(handlers: [handler], loop: loop)
}
Expand All @@ -295,8 +295,9 @@ public final class NIOAsyncTestingChannel: Channel {
/// - Parameters:
/// - handlers: The `ChannelHandler`s to add to the `ChannelPipeline` before register.
/// - loop: The ``NIOAsyncTestingEventLoop`` to use.
@preconcurrency
public convenience init(
handlers: [ChannelHandler],
handlers: [ChannelHandler & Sendable],
loop: NIOAsyncTestingEventLoop = NIOAsyncTestingEventLoop()
) async {
self.init(loop: loop)
Expand Down Expand Up @@ -671,3 +672,7 @@ extension NIOAsyncTestingChannel.LeftOverState: @unchecked Sendable {}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension NIOAsyncTestingChannel.BufferState: @unchecked Sendable {}
#endif

// Synchronous options are never Sendable.
@available(*, unavailable)
extension NIOAsyncTestingChannel.SynchronousOptions: Sendable { }
14 changes: 9 additions & 5 deletions Sources/NIOEmbedded/AsyncTestingEventLoop.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
self.scheduledTasks.removeFirst { $0.id == taskID }
}

private func insertTask<ReturnType>(
private func insertTask<ReturnType: Sendable>(
taskID: UInt64,
deadline: NIODeadline,
promise: EventLoopPromise<ReturnType>,
Expand All @@ -152,7 +152,8 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {

/// - see: `EventLoop.scheduleTask(deadline:_:)`
@discardableResult
public func scheduleTask<T>(deadline: NIODeadline, _ task: @escaping () throws -> T) -> Scheduled<T> {
@preconcurrency
public func scheduleTask<T: Sendable>(deadline: NIODeadline, _ task: @escaping @Sendable () throws -> T) -> Scheduled<T> {
let promise: EventLoopPromise<T> = self.makePromise()
let taskID = self.scheduledTaskCounter.loadThenWrappingIncrement(ordering: .relaxed)

Expand Down Expand Up @@ -190,7 +191,8 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {

/// - see: `EventLoop.scheduleTask(in:_:)`
@discardableResult
public func scheduleTask<T>(in: TimeAmount, _ task: @escaping () throws -> T) -> Scheduled<T> {
@preconcurrency
public func scheduleTask<T: Sendable>(in: TimeAmount, _ task: @escaping @Sendable () throws -> T) -> Scheduled<T> {
self.scheduleTask(deadline: self.now + `in`, task)
}

Expand Down Expand Up @@ -230,7 +232,8 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {

/// On an `NIOAsyncTestingEventLoop`, `execute` will simply use `scheduleTask` with a deadline of _now_. Unlike with the other operations, this will
/// immediately execute, to eliminate a common class of bugs.
public func execute(_ task: @escaping () -> Void) {
@preconcurrency
public func execute(_ task: @escaping @Sendable () -> Void) {
if self.inEventLoop {
self.scheduleTask(deadline: self.now, task)
} else {
Expand Down Expand Up @@ -359,7 +362,8 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
}

/// - see: `EventLoop.shutdownGracefully`
public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) {
@preconcurrency
public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping @Sendable (Error?) -> Void) {
self.queue.async {
self._shutdownGracefully()
queue.async {
Expand Down
29 changes: 26 additions & 3 deletions Sources/NIOEmbedded/Embedded.swift
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public final class EmbeddedEventLoop: EventLoop, CustomStringConvertible {
insertOrder: self.nextTaskNumber(),
task: {
do {
promise.succeed(try task())
promise.assumeIsolated().succeed(try task())
} catch let err {
promise.fail(err)
}
Expand Down Expand Up @@ -365,6 +365,11 @@ public final class EmbeddedEventLoop: EventLoop, CustomStringConvertible {
}()
}

// EmbeddedEventLoop is extremely _not_ Sendable. However, the EventLoop protocol
// requires it to be. We are doing some runtime enforcement of correct use, but
// ultimately we can't have the compiler validating this usage.
extension EmbeddedEventLoop: @unchecked Sendable { }

@usableFromInline
class EmbeddedChannelCore: ChannelCore {
var isOpen: Bool {
Expand Down Expand Up @@ -484,8 +489,11 @@ class EmbeddedChannelCore: ChannelCore {
self.pipeline.syncOperations.fireChannelInactive()
self.pipeline.syncOperations.fireChannelUnregistered()

let loopBoundSelf = NIOLoopBound(self, eventLoop: self.eventLoop)

eventLoop.execute {
// ensure this is executed in a delayed fashion as the users code may still traverse the pipeline
let `self` = loopBoundSelf.value
self.removeHandlers(pipeline: self.pipeline)
self.closePromise.succeed(())
}
Expand Down Expand Up @@ -583,6 +591,10 @@ class EmbeddedChannelCore: ChannelCore {
}
}

// ChannelCores are basically never Sendable.
@available(*, unavailable)
extension EmbeddedChannelCore: Sendable { }

/// `EmbeddedChannel` is a `Channel` implementation that does neither any
/// actual IO nor has a proper eventing mechanism. The prime use-case for
/// `EmbeddedChannel` is in unit tests when you want to feed the inbound events
Expand Down Expand Up @@ -867,8 +879,8 @@ public final class EmbeddedChannel: Channel {
@inlinable
@discardableResult public func writeInbound<T>(_ data: T) throws -> BufferState {
self.embeddedEventLoop.checkCorrectThread()
self.pipeline.fireChannelRead(data)
self.pipeline.fireChannelReadComplete()
self.pipeline.syncOperations.fireChannelRead(NIOAny(data))
self.pipeline.syncOperations.fireChannelReadComplete()
try self.throwIfErrorCaught()
return self.channelcore.inboundBuffer.isEmpty ? .empty : .full(Array(self.channelcore.inboundBuffer))
}
Expand Down Expand Up @@ -1086,5 +1098,16 @@ extension EmbeddedChannel {
}
}

// EmbeddedChannel is extremely _not_ Sendable. However, the Channel protocol
// requires it to be. We are doing some runtime enforcement of correct use, but
// ultimately we can't have the compiler validating this usage.
extension EmbeddedChannel: @unchecked Sendable { }

@available(*, unavailable)
extension EmbeddedChannel.LeftOverState: @unchecked Sendable {}

@available(*, unavailable)
extension EmbeddedChannel.BufferState: @unchecked Sendable {}

@available(*, unavailable)
extension EmbeddedChannel.SynchronousOptions: Sendable {}
2 changes: 1 addition & 1 deletion Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class AsyncTestingChannelTests: XCTestCase {
}

let channel = NIOAsyncTestingChannel()
XCTAssertThrowsError(try channel.pipeline.handler(type: Handler.self).wait()) { e in
XCTAssertThrowsError(try channel.pipeline.handler(type: Handler.self).map { _ in }.wait()) { e in
XCTAssertEqual(e as? ChannelPipelineError, .notFound)
}

Expand Down
53 changes: 28 additions & 25 deletions Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import XCTest

@testable import NIOEmbedded

private class EmbeddedTestError: Error {}
private final class EmbeddedTestError: Error {}

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
final class NIOAsyncTestingEventLoopTests: XCTestCase {
Expand Down Expand Up @@ -336,10 +336,12 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase {
// advanceTime(by:) is the same as on MultiThreadedEventLoopGroup: specifically, that tasks run via
// schedule that expire "now" all run at the same time, and that any work they schedule is run
// after all such tasks expire.
struct TestState {
var firstScheduled: Scheduled<Void>?
var secondScheduled: Scheduled<Void>?
}
let loop = NIOAsyncTestingEventLoop()
let lock = NIOLock()
var firstScheduled: Scheduled<Void>? = nil
var secondScheduled: Scheduled<Void>? = nil
let lock = NIOLockedValueBox(TestState())
let orderingCounter = ManagedAtomic(0)

// Here's the setup. First, we'll set up two scheduled tasks to fire in 5 nanoseconds. Each of these
Expand All @@ -356,13 +358,13 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase {
//
// To validate the ordering, we'll use a counter.

lock.withLock { () -> Void in
firstScheduled = loop.scheduleTask(in: .nanoseconds(5)) {
let second = lock.withLock { () -> Scheduled<Void>? in
XCTAssertNotNil(firstScheduled)
firstScheduled = nil
XCTAssertNotNil(secondScheduled)
return secondScheduled
lock.withLockedValue {
$0.firstScheduled = loop.scheduleTask(in: .nanoseconds(5)) {
let second = lock.withLockedValue {
XCTAssertNotNil($0.firstScheduled)
$0.firstScheduled = nil
XCTAssertNotNil($0.secondScheduled)
return $0.secondScheduled
}

if let partner = second {
Expand All @@ -379,11 +381,11 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase {
}
}

secondScheduled = loop.scheduleTask(in: .nanoseconds(5)) {
lock.withLock { () -> Void in
secondScheduled = nil
XCTAssertNil(firstScheduled)
XCTAssertNil(secondScheduled)
$0.secondScheduled = loop.scheduleTask(in: .nanoseconds(5)) {
lock.withLockedValue {
$0.secondScheduled = nil
XCTAssertNil($0.firstScheduled)
XCTAssertNil($0.secondScheduled)
}

XCTAssertCompareAndSwapSucceeds(storage: orderingCounter, expected: 2, desired: 3)
Expand Down Expand Up @@ -482,6 +484,7 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase {
let eventLoop = NIOAsyncTestingEventLoop()
let tasksRun = ManagedAtomic(0)

@Sendable
func scheduleRecursiveTask(
at taskStartTime: NIODeadline,
andChildTaskAfter childTaskStartDelay: TimeAmount
Expand Down Expand Up @@ -514,29 +517,29 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase {

func testShutdownCancelsRemainingScheduledTasks() async {
let eventLoop = NIOAsyncTestingEventLoop()
var tasksRun = 0
let tasksRun = ManagedAtomic(0)

let a = eventLoop.scheduleTask(in: .seconds(1)) { tasksRun += 1 }
let b = eventLoop.scheduleTask(in: .seconds(2)) { tasksRun += 1 }
let a = eventLoop.scheduleTask(in: .seconds(1)) { tasksRun.wrappingIncrement(ordering: .sequentiallyConsistent) }
let b = eventLoop.scheduleTask(in: .seconds(2)) { tasksRun.wrappingIncrement(ordering: .sequentiallyConsistent) }

XCTAssertEqual(tasksRun, 0)
XCTAssertEqual(tasksRun.load(ordering: .sequentiallyConsistent), 0)

await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun, 1)
XCTAssertEqual(tasksRun.load(ordering: .sequentiallyConsistent), 1)

XCTAssertNoThrow(try eventLoop.syncShutdownGracefully())
XCTAssertEqual(tasksRun, 1)
XCTAssertEqual(tasksRun.load(ordering: .sequentiallyConsistent), 1)

await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun, 1)
XCTAssertEqual(tasksRun.load(ordering: .sequentiallyConsistent), 1)

await eventLoop.advanceTime(to: .distantFuture)
XCTAssertEqual(tasksRun, 1)
XCTAssertEqual(tasksRun.load(ordering: .sequentiallyConsistent), 1)

XCTAssertNoThrow(try a.futureResult.wait())
await XCTAssertThrowsError(try await b.futureResult.get()) { error in
XCTAssertEqual(error as? EventLoopError, .cancelled)
XCTAssertEqual(tasksRun, 1)
XCTAssertEqual(tasksRun.load(ordering: .sequentiallyConsistent), 1)
}
}

Expand Down
Loading

0 comments on commit dce7351

Please sign in to comment.