diff --git a/Sources/WhoopDIKit/Container/Container.swift b/Sources/WhoopDIKit/Container/Container.swift index 10dfbcc..df12093 100644 --- a/Sources/WhoopDIKit/Container/Container.swift +++ b/Sources/WhoopDIKit/Container/Container.swift @@ -1,4 +1,5 @@ import Foundation + public final class Container { private let localDependencyGraph: ThreadSafeDependencyGraph private var isLocalInjectActive: Bool = false @@ -60,6 +61,19 @@ public final class Container { public func inject(_ name: String? = nil, params: Any? = nil, _ localDefinition: (DependencyModule) -> Void) -> T { + if options.isOptionEnabled(.taskLocalInject) { + let localModule = DependencyModule() + localDefinition(localModule) + return ServiceDictionaryTaskLocal.dictionary.withDependencyModuleUpdates(dependencyModule: localModule) { + do { + return try get(name, params) + } catch { + print("Inject failed with stack trace:") + Thread.callStackSymbols.forEach { print($0) } + fatalError("WhoopDI inject failed with error: \(error)") + } + } + } return localDependencyGraph.acquireDependencyGraph { localServiceDict in // Nested local injects are not currently supported. Fail fast here. guard !isLocalInjectActive else { @@ -113,8 +127,12 @@ public final class Container { } private func getDefinition(_ serviceKey: ServiceKey) -> DependencyDefinition? { - localDependencyGraph.acquireDependencyGraph { localServiceDict in - return localServiceDict[serviceKey] ?? serviceDict[serviceKey] + if options.isOptionEnabled(.taskLocalInject) { + ServiceDictionaryTaskLocal.dictionary.getDependencyModule()?[serviceKey] ?? serviceDict[serviceKey] + } else { + localDependencyGraph.acquireDependencyGraph { localServiceDict in + return localServiceDict[serviceKey] ?? serviceDict[serviceKey] + } } } diff --git a/Sources/WhoopDIKit/Container/ServiceDictionaryTaskLocal.swift b/Sources/WhoopDIKit/Container/ServiceDictionaryTaskLocal.swift new file mode 100644 index 0000000..c751175 --- /dev/null +++ b/Sources/WhoopDIKit/Container/ServiceDictionaryTaskLocal.swift @@ -0,0 +1,26 @@ +enum ServiceDictionaryTaskLocal { + @TaskLocal + static var dictionary = ServiceDictionaryTaskLocalWrapper() +} + +// This always returns copies and mutates copies, so there is no sendability worry here +struct ServiceDictionaryTaskLocalWrapper: @unchecked Sendable { + private let serviceDictionary: ServiceDictionary? + + init(serviceDictionary: ServiceDictionary? = nil) { + self.serviceDictionary = serviceDictionary + } + + func withDependencyModuleUpdates(dependencyModule: DependencyModule, perform: () throws -> T) rethrows -> T { + let dictionaryCopy = serviceDictionary?.copy() ?? ServiceDictionary() + dependencyModule.addToServiceDictionary(serviceDict: dictionaryCopy) + return try ServiceDictionaryTaskLocal.$dictionary.withValue(ServiceDictionaryTaskLocalWrapper(serviceDictionary: dictionaryCopy)) { + return try perform() + } + + } + + func getDependencyModule() -> ServiceDictionary? { + return serviceDictionary?.copy() + } +} diff --git a/Sources/WhoopDIKit/Options/WhoopDIOption.swift b/Sources/WhoopDIKit/Options/WhoopDIOption.swift index d453141..f704461 100644 --- a/Sources/WhoopDIKit/Options/WhoopDIOption.swift +++ b/Sources/WhoopDIKit/Options/WhoopDIOption.swift @@ -1,4 +1,5 @@ /// Options for WhoopDI. These are typically experimental features which may be enabled or disabled. public enum WhoopDIOption: Sendable { case threadSafeLocalInject + case taskLocalInject } diff --git a/Sources/WhoopDIKit/Service/ServiceDictionary.swift b/Sources/WhoopDIKit/Service/ServiceDictionary.swift index 6e7f864..b33cc30 100644 --- a/Sources/WhoopDIKit/Service/ServiceDictionary.swift +++ b/Sources/WhoopDIKit/Service/ServiceDictionary.swift @@ -26,7 +26,11 @@ public final class ServiceDictionary { valuesByType[key] = newValue } } - + + public func copy() -> Self { + return Self(valuesByType: self.valuesByType) + } + public func allKeys() -> Set { Set(valuesByType.keys) } diff --git a/Tests/WhoopDIKitTests/Container/ContainerTests.swift b/Tests/WhoopDIKitTests/Container/ContainerTests.swift index 0a8e2d9..13e11a0 100644 --- a/Tests/WhoopDIKitTests/Container/ContainerTests.swift +++ b/Tests/WhoopDIKitTests/Container/ContainerTests.swift @@ -46,20 +46,22 @@ class ContainerTests: @unchecked Sendable { func inject_localDefinition_concurrency() async { container.registerModules(modules: [GoodTestModule()]) // Run many times to try and capture race condition - for _ in 0..<500 { - let taskA = Task.detached { + let taskA = Task.detached { + for _ in 0..<500 { let _: Dependency = self.container.inject("C_Factory") { module in module.factory(name: "C_Factory") { DependencyA() as Dependency } } } + } - let taskB = Task.detached { + let taskB = Task.detached { + for _ in 0..<500 { let _: DependencyA = self.container.inject() } + } - for task in [taskA, taskB] { - let _ = await task.result - } + for task in [taskA, taskB] { + let _ = await task.result } } diff --git a/Tests/WhoopDIKitTests/Container/TaskLocalContainerTests.swift b/Tests/WhoopDIKitTests/Container/TaskLocalContainerTests.swift new file mode 100644 index 0000000..51a5542 --- /dev/null +++ b/Tests/WhoopDIKitTests/Container/TaskLocalContainerTests.swift @@ -0,0 +1,120 @@ +import WhoopDIKit +import Testing + +class TaskLocalContainerTests: @unchecked Sendable { + private let container: Container + + init() { + let options = MockOptionProvider(options: [.taskLocalInject: true]) + container = Container(options: options) + } + + @Test + func inject() { + container.registerModules(modules: [GoodTestModule()]) + let dependency: Dependency = container.inject("C_Factory", "param") + #expect(dependency is DependencyC) + } + + @Test + func inject_generic_integer() { + container.registerModules(modules: [GoodTestModule()]) + let dependency: GenericDependency = container.inject() + #expect(42 == dependency.value) + } + + @Test + func inject_generic_string() { + container.registerModules(modules: [GoodTestModule()]) + let dependency: GenericDependency = container.inject() + #expect("string" == dependency.value) + } + + @Test + func inject_localDefinition() { + container.registerModules(modules: [GoodTestModule()]) + let dependency: Dependency = container.inject("C_Factory") { module in + // Typically you'd override or provide a transient dependency. I'm using the top level dependency here + // for the sake of simplicity. + module.factory(name: "C_Factory") { DependencyA() as Dependency } + } + #expect(dependency is DependencyA) + } + + @Test + func inject_localDefinition_recursive() { + container.registerModules(modules: [GoodTestModule()]) + let dependency: Dependency = container.inject("C_Factory") { module in + // Typically you'd override or provide a transient dependency. I'm using the top level dependency here + // for the sake of simplicity. + module.factory(name: "C_Factory") { self.container.inject() as DependencyA as Dependency } + } + #expect(dependency is DependencyA) + } + + @Test + func inject_localDefinition_inside_localDefinition() async throws { + let dependency: Dependency = container.inject { module in + module.factory { + DependencyB(self.container.inject { innerModule in + innerModule.factory { "test_inner_module" } + }) as Dependency + } + } + #expect(dependency is DependencyB) + } + + @Test(.bug("https://github.com/WhoopInc/WhoopDI/issues/13")) + func inject_localDefinition_concurrency() async { + container.registerModules(modules: [GoodTestModule()]) + // Run many times to try and capture race condition + + let taskA = Task.detached { + for _ in 0..<500 { + let _: Dependency = self.container.inject("C_Factory") { module in + module.factory(name: "C_Factory") { DependencyA() as Dependency } + } + } + } + + let taskB = Task.detached { + for _ in 0..<500 { + let _: DependencyA = self.container.inject() + } + } + + for task in [taskA, taskB] { + let _ = await task.result + } + } + + @Test + func inject_localDefinition_noOverride() { + container.registerModules(modules: [GoodTestModule()]) + let dependency: Dependency = container.inject("C_Factory", params: "params") { _ in } + #expect(dependency is DependencyC) + } + + @Test + func inject_localDefinition_withParams() { + container.registerModules(modules: [GoodTestModule()]) + let dependency: Dependency = container.inject("C_Factory", params: "params") { module in + module.factoryWithParams(name: "C_Factory") { params in DependencyB(params) as Dependency } + } + #expect(dependency is DependencyB) + } + + @Test + func injectableWithDependency() throws { + container.registerModules(modules: [FakeTestModuleForInjecting()]) + let testInjecting: InjectableWithDependency = container.inject() + #expect(testInjecting == InjectableWithDependency(dependency: DependencyA())) + } + + @Test + func injectableWithNamedDependency() throws { + container.registerModules(modules: [FakeTestModuleForInjecting()]) + let testInjecting: InjectableWithNamedDependency = container.inject() + #expect(testInjecting == InjectableWithNamedDependency(name: 1)) + } +}