Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Task Local Inject Setup behind new option #20

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions Sources/WhoopDIKit/Container/Container.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Foundation

public final class Container {
private let localDependencyGraph: ThreadSafeDependencyGraph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be removed now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove it if we are certain we want task local to be the final way (then we can even remove the options in total or make it empty for now)

private var isLocalInjectActive: Bool = false
Expand Down Expand Up @@ -60,6 +61,19 @@ public final class Container {
public func inject<T>(_ name: String? = nil,
params: Any? = nil,
_ localDefinition: (DependencyModule) -> Void) -> T {
if options.isOptionEnabled(.taskLocalInject) {
let localModule = DependencyModule()
localDefinition(localModule)
return ServiceDictionaryTaskLocal.dictionary.withDependencyModuleUpdates(dependencyModule: localModule) {
do {
return try get(name, params)
} catch {
print("Inject failed with stack trace:")
Thread.callStackSymbols.forEach { print($0) }
fatalError("WhoopDI inject failed with error: \(error)")
}
}
}
return localDependencyGraph.acquireDependencyGraph { localServiceDict in
// Nested local injects are not currently supported. Fail fast here.
guard !isLocalInjectActive else {
Expand Down Expand Up @@ -113,8 +127,12 @@ public final class Container {
}

private func getDefinition(_ serviceKey: ServiceKey) -> DependencyDefinition? {
localDependencyGraph.acquireDependencyGraph { localServiceDict in
return localServiceDict[serviceKey] ?? serviceDict[serviceKey]
if options.isOptionEnabled(.taskLocalInject) {
ServiceDictionaryTaskLocal.dictionary.getDependencyModule()?[serviceKey] ?? serviceDict[serviceKey]
} else {
localDependencyGraph.acquireDependencyGraph { localServiceDict in
return localServiceDict[serviceKey] ?? serviceDict[serviceKey]
}
}
}

Expand Down
26 changes: 26 additions & 0 deletions Sources/WhoopDIKit/Container/ServiceDictionaryTaskLocal.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
enum ServiceDictionaryTaskLocal {
@TaskLocal
static var dictionary = ServiceDictionaryTaskLocalWrapper()
}

// This always returns copies and mutates copies, so there is no sendability worry here
struct ServiceDictionaryTaskLocalWrapper: @unchecked Sendable {
private let serviceDictionary: ServiceDictionary<DependencyDefinition>?

init(serviceDictionary: ServiceDictionary<DependencyDefinition>? = nil) {
self.serviceDictionary = serviceDictionary
}

func withDependencyModuleUpdates<T>(dependencyModule: DependencyModule, perform: () throws -> T) rethrows -> T {
let dictionaryCopy = serviceDictionary?.copy() ?? ServiceDictionary()
dependencyModule.addToServiceDictionary(serviceDict: dictionaryCopy)
return try ServiceDictionaryTaskLocal.$dictionary.withValue(ServiceDictionaryTaskLocalWrapper(serviceDictionary: dictionaryCopy)) {
return try perform()
}

}

func getDependencyModule() -> ServiceDictionary<DependencyDefinition>? {
return serviceDictionary?.copy()
}
}
1 change: 1 addition & 0 deletions Sources/WhoopDIKit/Options/WhoopDIOption.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/// Options for WhoopDI. These are typically experimental features which may be enabled or disabled.
public enum WhoopDIOption: Sendable {
case threadSafeLocalInject
case taskLocalInject
}
6 changes: 5 additions & 1 deletion Sources/WhoopDIKit/Service/ServiceDictionary.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ public final class ServiceDictionary<Value> {
valuesByType[key] = newValue
}
}


public func copy() -> Self {
return Self(valuesByType: self.valuesByType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would want to create a new instance of the dictionary here which takes the values in, otherwise we have the same reference to a dict under the hood which could have issues with races I think (since this is a class not a struct)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are passing the dictionary into a new class, we should have a copy. Technically copy on write will make it the same instance, but it should be fine since those updates are atomic. What would be dangerous would be passing the ServiceDictionary itself, since that has a pointer to the actual same dictionary in multiple places and is definitely not sendable

}

public func allKeys() -> Set<ServiceKey> {
Set(valuesByType.keys)
}
Expand Down
14 changes: 8 additions & 6 deletions Tests/WhoopDIKitTests/Container/ContainerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,22 @@ class ContainerTests: @unchecked Sendable {
func inject_localDefinition_concurrency() async {
container.registerModules(modules: [GoodTestModule()])
// Run many times to try and capture race condition
for _ in 0..<500 {
let taskA = Task.detached {
let taskA = Task.detached {
for _ in 0..<500 {
let _: Dependency = self.container.inject("C_Factory") { module in
module.factory(name: "C_Factory") { DependencyA() as Dependency }
}
}
}

let taskB = Task.detached {
let taskB = Task.detached {
for _ in 0..<500 {
let _: DependencyA = self.container.inject()
}
}

for task in [taskA, taskB] {
let _ = await task.result
}
for task in [taskA, taskB] {
let _ = await task.result
}
}

Expand Down
120 changes: 120 additions & 0 deletions Tests/WhoopDIKitTests/Container/TaskLocalContainerTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import WhoopDIKit
import Testing

class TaskLocalContainerTests: @unchecked Sendable {
private let container: Container

init() {
let options = MockOptionProvider(options: [.taskLocalInject: true])
container = Container(options: options)
}

@Test
func inject() {
container.registerModules(modules: [GoodTestModule()])
let dependency: Dependency = container.inject("C_Factory", "param")
#expect(dependency is DependencyC)
}

@Test
func inject_generic_integer() {
container.registerModules(modules: [GoodTestModule()])
let dependency: GenericDependency<Int> = container.inject()
#expect(42 == dependency.value)
}

@Test
func inject_generic_string() {
container.registerModules(modules: [GoodTestModule()])
let dependency: GenericDependency<String> = container.inject()
#expect("string" == dependency.value)
}

@Test
func inject_localDefinition() {
container.registerModules(modules: [GoodTestModule()])
let dependency: Dependency = container.inject("C_Factory") { module in
// Typically you'd override or provide a transient dependency. I'm using the top level dependency here
// for the sake of simplicity.
module.factory(name: "C_Factory") { DependencyA() as Dependency }
}
#expect(dependency is DependencyA)
}

@Test
func inject_localDefinition_recursive() {
container.registerModules(modules: [GoodTestModule()])
let dependency: Dependency = container.inject("C_Factory") { module in
// Typically you'd override or provide a transient dependency. I'm using the top level dependency here
// for the sake of simplicity.
module.factory(name: "C_Factory") { self.container.inject() as DependencyA as Dependency }
}
#expect(dependency is DependencyA)
}

@Test
func inject_localDefinition_inside_localDefinition() async throws {
let dependency: Dependency = container.inject { module in
module.factory {
DependencyB(self.container.inject { innerModule in
innerModule.factory { "test_inner_module" }
}) as Dependency
}
}
#expect(dependency is DependencyB)
}

@Test(.bug("https://github.com/WhoopInc/WhoopDI/issues/13"))
func inject_localDefinition_concurrency() async {
container.registerModules(modules: [GoodTestModule()])
// Run many times to try and capture race condition

let taskA = Task.detached {
for _ in 0..<500 {
let _: Dependency = self.container.inject("C_Factory") { module in
module.factory(name: "C_Factory") { DependencyA() as Dependency }
}
}
}

let taskB = Task.detached {
for _ in 0..<500 {
let _: DependencyA = self.container.inject()
}
}

for task in [taskA, taskB] {
let _ = await task.result
}
}

@Test
func inject_localDefinition_noOverride() {
container.registerModules(modules: [GoodTestModule()])
let dependency: Dependency = container.inject("C_Factory", params: "params") { _ in }
#expect(dependency is DependencyC)
}

@Test
func inject_localDefinition_withParams() {
container.registerModules(modules: [GoodTestModule()])
let dependency: Dependency = container.inject("C_Factory", params: "params") { module in
module.factoryWithParams(name: "C_Factory") { params in DependencyB(params) as Dependency }
}
#expect(dependency is DependencyB)
}

@Test
func injectableWithDependency() throws {
container.registerModules(modules: [FakeTestModuleForInjecting()])
let testInjecting: InjectableWithDependency = container.inject()
#expect(testInjecting == InjectableWithDependency(dependency: DependencyA()))
}

@Test
func injectableWithNamedDependency() throws {
container.registerModules(modules: [FakeTestModuleForInjecting()])
let testInjecting: InjectableWithNamedDependency = container.inject()
#expect(testInjecting == InjectableWithNamedDependency(name: 1))
}
}
Loading