Skip to content

Commit

Permalink
enable foreign keys (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanner0101 authored Jan 23, 2020
1 parent 01e9cd2 commit f098081
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
17 changes: 11 additions & 6 deletions Sources/SQLiteKit/SQLiteConnectionSource.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ public struct SQLiteConnectionSource: ConnectionPoolSource {
threadPool: self.threadPool,
logger: logger,
on: eventLoop
)
).flatMap { conn in
if self.configuration.enableForeignKeys {
return conn.query("PRAGMA foreign_keys = ON")
.map { _ in conn }
} else {
return eventLoop.makeSucceededFuture(conn)
}
}
}
}

Expand All @@ -46,13 +53,11 @@ public struct SQLiteConfiguration {
}

public var storage: Storage

public init(file: String) {
self.init(storage: .file(path: file))
}
public var enableForeignKeys: Bool

public init(storage: Storage) {
public init(storage: Storage, enableForeignKeys: Bool = true) {
self.storage = storage
self.enableForeignKeys = enableForeignKeys
}
}

Expand Down
14 changes: 9 additions & 5 deletions Tests/SQLiteKitTests/SQLKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ class SQLiteTests: XCTestCase {
.run().wait()
}

func testForeignKeysEnabled() throws {
let res = try self.connection.query("PRAGMA foreign_keys").wait()
XCTAssertEqual(res[0].column("foreign_keys"), .integer(1))
}

var db: SQLDatabase {
self.connection.sql()
}
Expand All @@ -101,11 +106,10 @@ class SQLiteTests: XCTestCase {
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2)
self.threadPool = NIOThreadPool(numberOfThreads: 2)
self.threadPool.start()
self.connection = try! SQLiteConnection.open(
storage: .memory,
threadPool: self.threadPool,
on: self.eventLoopGroup.next()
).wait()
self.connection = try! SQLiteConnectionSource(
configuration: .init(storage: .memory, enableForeignKeys: true),
threadPool: self.threadPool
).makeConnection(logger: .init(label: "test"), on: self.eventLoopGroup.next()).wait()
}

override func tearDown() {
Expand Down

0 comments on commit f098081

Please sign in to comment.