Skip to content

Commit

Permalink
Recreate SecureVault database on corruption (#60)
Browse files Browse the repository at this point in the history
* Recreate DB file on corruption

* Safe database recreation with backup

* db recreation test

* fix linter issues

* VaultFactory error reporting
  • Loading branch information
mallexxx authored Feb 20, 2022
1 parent b86c166 commit 56d011b
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Sources/BrowserServicesKit/Resources/duckduckgo-autofill
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import Foundation
import GRDB
import os.log

protocol SecureVaultDatabaseProvider {

Expand Down Expand Up @@ -52,22 +53,32 @@ protocol SecureVaultDatabaseProvider {
final class DefaultDatabaseProvider: SecureVaultDatabaseProvider {

enum DbError: Error {
case nonRecoverable(DatabaseError)

case unableToDetermineStorageDirectory
case unableToGetDatabaseKey

var databaseError: DatabaseError {
switch self {
case .nonRecoverable(let dbError): return dbError
}
}
}

let db: DatabaseQueue

init(key: Data) throws {
init(file: URL = DefaultDatabaseProvider.dbFile(), key: Data) throws {
var config = Configuration()
config.prepareDatabase {
try $0.usePassphrase(key)
}

let file = try Self.dbFile()
db = try DatabaseQueue(path: file.path, configuration: config)
do {
db = try DatabaseQueue(path: file.path, configuration: config)
} catch let error as DatabaseError where [.SQLITE_NOTADB, .SQLITE_CORRUPT].contains(error.resultCode) {
os_log("database corrupt: %{public}s", type: .error, error.message ?? "")
throw DbError.nonRecoverable(error)
} catch {
os_log("database initialization failed with %{public}s", type: .error, error.localizedDescription)
throw error
}

var migrator = DatabaseMigrator()
migrator.registerMigration("v1", migrate: Self.migrateV1(database:))
Expand All @@ -80,11 +91,34 @@ final class DefaultDatabaseProvider: SecureVaultDatabaseProvider {
do {
try migrator.migrate(db)
} catch {
print(error)
os_log("database migration error: %{public}s", type: .error, error.localizedDescription)
throw error
}
}

static func recreateDatabase(withKey key: Data) throws -> DefaultDatabaseProvider {
let dbFile = self.dbFile()

guard FileManager.default.fileExists(atPath: dbFile.path) else {
return try Self(file: dbFile, key: key)
}

// make sure we can create an empty db first and release it then
let newDbFile = self.nonExistingDBFile(withExtension: dbFile.pathExtension)
try autoreleasepool {
try _=Self(file: newDbFile, key: key)
}

// backup old db file
let backupFile = self.nonExistingDBFile(withExtension: dbFile.pathExtension + ".bak")
try FileManager.default.moveItem(at: dbFile, to: backupFile)

// place just created new db in place of dbFile
try FileManager.default.moveItem(at: newDbFile, to: dbFile)

return try Self(file: dbFile, key: key)
}

func accounts() throws -> [SecureVaultModels.WebsiteAccount] {
return try db.read {
return try SecureVaultModels.WebsiteAccount
Expand Down Expand Up @@ -616,7 +650,7 @@ struct MigrationUtility {

extension DefaultDatabaseProvider {

static internal func dbFile() throws -> URL {
static internal func dbFile() -> URL {

let fm = FileManager.default

Expand Down Expand Up @@ -649,6 +683,24 @@ extension DefaultDatabaseProvider {
return subDir.appendingPathComponent("Vault.db")
}

static internal func nonExistingDBFile(withExtension ext: String) -> URL {
let originalPath = Self.dbFile().deletingPathExtension().path

for i in 0... {
var path = originalPath
if i > 0 {
path += "_\(i)"
}
path += "." + ext

if !FileManager.default.fileExists(atPath: path) {
return URL(fileURLWithPath: path)
}
}

fatalError()
}

}

// MARK: - Database records
Expand Down
50 changes: 50 additions & 0 deletions Sources/BrowserServicesKit/SecureVault/SecureVaultError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//

import Foundation
import GRDB

public enum SecureVaultError: Error {

Expand All @@ -34,3 +35,52 @@ public enum SecureVaultError: Error {
case generalCryptoError

}

extension SecureVaultError: CustomNSError {

public static var errorDomain: String { "SecureVaultError" }

public var errorCode: Int {
switch self {
case .initFailed: return 1
case .authRequired: return 2
case .invalidPassword: return 3
case .noL1Key: return 4
case .noL2Key: return 5
case .authError: return 6
case .failedToOpenDatabase: return 7
case .databaseError: return 8
case .duplicateRecord: return 9
case .keystoreError: return 10
case .secError: return 11
case .generalCryptoError: return 12
}
}

public var errorUserInfo: [String : Any] {
var errorUserInfo = [String : Any]()
switch self {
case .initFailed(cause: let error), .authError(cause: let error),
.failedToOpenDatabase(cause: let error), .databaseError(cause: let error):
if let secureVaultError = error as? SecureVaultError {
return secureVaultError.errorUserInfo
}

errorUserInfo["NSUnderlyingError"] = error as NSError
if let sqliteError = error as? DatabaseError ?? (error as? DefaultDatabaseProvider.DbError)?.databaseError {
errorUserInfo["SQLiteResultCode"] = NSNumber(value: sqliteError.resultCode.rawValue)
errorUserInfo["SQLiteExtendedResultCode"] = NSNumber(value: sqliteError.extendedResultCode.rawValue)
}

case .keystoreError(status: let code):
errorUserInfo["NSUnderlyingError"] = NSError(domain: "keystoreError", code: Int(code), userInfo: nil)
case .secError(status: let code):
errorUserInfo["NSUnderlyingError"] = NSError(domain: "secError", code: Int(code), userInfo: nil)

case .authRequired, .invalidPassword, .noL1Key, .noL2Key, .duplicateRecord, .generalCryptoError:
errorUserInfo["NSUnderlyingError"] = NSError(domain: "\(self)", code: 0, userInfo: nil)
}
return errorUserInfo
}

}
35 changes: 26 additions & 9 deletions Sources/BrowserServicesKit/SecureVault/SecureVaultFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

import Foundation

public protocol SecureVaultErrorReporting: AnyObject {
func secureVaultInitFailed(_ error: SecureVaultError)
}

/// Can make a SecureVault instance with given specification. May return previously created instance if specification is unchanged.
public class SecureVaultFactory {

Expand All @@ -41,7 +45,8 @@ public class SecureVaultFactory {
/// * Generates a secret key for L2 encryption
/// * Generates a user password to encrypt the L2 key with
/// * Stores encrypted L2 key in Keychain
public func makeVault(authExpiration: TimeInterval = 60 * 60 * 24 * 72) throws -> SecureVault {
public func makeVault(errorReporter: SecureVaultErrorReporting?,
authExpiration: TimeInterval = 60 * 60 * 24 * 72) throws -> SecureVault {

if let vault = self.vault, authExpiration == vault.authExpiry {
return vault
Expand All @@ -59,28 +64,40 @@ public class SecureVaultFactory {

return vault

} catch let error as SecureVaultError {
errorReporter?.secureVaultInitFailed(error)
throw error
} catch {
errorReporter?.secureVaultInitFailed(SecureVaultError.initFailed(cause: error))
throw SecureVaultError.initFailed(cause: error)
}
}

}

internal func makeSecureVaultProviders() throws -> SecureVaultProviders {
let (cryptoProvider, keystoreProvider): (SecureVaultCryptoProvider, SecureVaultKeyStoreProvider)
do {
(cryptoProvider, keystoreProvider) = try createAndInitializeEncryptionProviders()
} catch {
throw SecureVaultError.initFailed(cause: error)
}
guard let existingL1Key = try keystoreProvider.l1Key() else { throw SecureVaultError.noL1Key }

let databaseProvider: SecureVaultDatabaseProvider
do {
let (cryptoProvider, keystoreProvider) = try createAndInitializeEncryptionProviders()
let databaseProvider: SecureVaultDatabaseProvider

if let existingL1Key = try keystoreProvider.l1Key() {
do {
databaseProvider = try DefaultDatabaseProvider(key: existingL1Key)
} else {
throw SecureVaultError.noL1Key
} catch DefaultDatabaseProvider.DbError.nonRecoverable {
databaseProvider = try DefaultDatabaseProvider.recreateDatabase(withKey: existingL1Key)
}

return SecureVaultProviders(crypto: cryptoProvider, database: databaseProvider, keystore: keystoreProvider)

} catch {
throw SecureVaultError.initFailed(cause: error)
throw SecureVaultError.failedToOpenDatabase(cause: error)
}

return SecureVaultProviders(crypto: cryptoProvider, database: databaseProvider, keystore: keystoreProvider)
}

internal func makeCryptoProvider() -> SecureVaultCryptoProvider {
Expand Down
20 changes: 13 additions & 7 deletions Sources/BrowserServicesKit/SecureVault/SecureVaultManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public enum AutofillType {
case identity
}

public protocol SecureVaultManagerDelegate: AnyObject {
public protocol SecureVaultManagerDelegate: SecureVaultErrorReporting {

func secureVaultManager(_: SecureVaultManager,
promptUserToStoreCredentials credentials: SecureVaultModels.WebsiteCredentials)
Expand All @@ -53,7 +53,7 @@ extension SecureVaultManager: AutofillSecureVaultDelegate {
[SecureVaultModels.CreditCard]) -> Void) {

do {
let vault = try SecureVaultFactory.default.makeVault()
let vault = try SecureVaultFactory.default.makeVault(errorReporter: self.delegate)
let accounts = try vault.accountsFor(domain: domain)
let identities = try vault.identities()
let cards = try vault.creditCards()
Expand All @@ -74,7 +74,9 @@ extension SecureVaultManager: AutofillSecureVaultDelegate {

do {

if let account = try SecureVaultFactory.default.makeVault().accountsFor(domain: domain).first(where: { $0.username == username }) {
if let account = try SecureVaultFactory.default.makeVault(errorReporter: self.delegate)
.accountsFor(domain: domain)
.first(where: { $0.username == username }) {

let credentials = SecureVaultModels.WebsiteCredentials(account: account, password: passwordData)
delegate?.secureVaultManager(self, promptUserToStoreCredentials: credentials)
Expand All @@ -97,7 +99,8 @@ extension SecureVaultManager: AutofillSecureVaultDelegate {
completionHandler: @escaping ([SecureVaultModels.WebsiteAccount]) -> Void) {

do {
completionHandler(try SecureVaultFactory.default.makeVault().accountsFor(domain: domain))
completionHandler(try SecureVaultFactory.default.makeVault(errorReporter: self.delegate)
.accountsFor(domain: domain))
} catch {
os_log(.error, "Error requesting accounts: %{public}@", error.localizedDescription)
completionHandler([])
Expand All @@ -110,7 +113,8 @@ extension SecureVaultManager: AutofillSecureVaultDelegate {
completionHandler: @escaping (SecureVaultModels.WebsiteCredentials?) -> Void) {

do {
completionHandler(try SecureVaultFactory.default.makeVault().websiteCredentialsFor(accountId: accountId))
completionHandler(try SecureVaultFactory.default.makeVault(errorReporter: self.delegate)
.websiteCredentialsFor(accountId: accountId))
delegate?.secureVaultManager(self, didAutofill: .password, withObjectId: accountId)
} catch {
os_log(.error, "Error requesting credentials: %{public}@", error.localizedDescription)
Expand All @@ -123,7 +127,8 @@ extension SecureVaultManager: AutofillSecureVaultDelegate {
didRequestCreditCardWithId creditCardId: Int64,
completionHandler: @escaping (SecureVaultModels.CreditCard?) -> Void) {
do {
completionHandler(try SecureVaultFactory.default.makeVault().creditCardFor(id: creditCardId))
completionHandler(try SecureVaultFactory.default.makeVault(errorReporter: self.delegate)
.creditCardFor(id: creditCardId))
delegate?.secureVaultManager(self, didAutofill: .card, withObjectId: creditCardId)
} catch {
os_log(.error, "Error requesting credit card: %{public}@", error.localizedDescription)
Expand All @@ -135,7 +140,8 @@ extension SecureVaultManager: AutofillSecureVaultDelegate {
didRequestIdentityWithId identityId: Int64,
completionHandler: @escaping (SecureVaultModels.Identity?) -> Void) {
do {
completionHandler(try SecureVaultFactory.default.makeVault().identityFor(id: identityId))
completionHandler(try SecureVaultFactory.default.makeVault(errorReporter: self.delegate)
.identityFor(id: identityId))
delegate?.secureVaultManager(self, didAutofill: .identity, withObjectId: identityId)
} catch {
os_log(.error, "Error requesting identity: %{public}@", error.localizedDescription)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ final class SuggestionProcessing {
case .historyEntry:
// If there is a historyEntry and bookmark with the same URL, suggest the bookmark
newSuggestion = findBookmarkDuplicate(to: suggestion, nakedUrl: suggestionNakedUrl, from: suggestions)
case .bookmark(title: let title, url: let url, isFavorite: let isFavorite, allowedInTopHits: _):
case .bookmark(title: _, url: _, isFavorite: _, allowedInTopHits: _):
newSuggestion = findAndMergeHistoryDuplicate(with: suggestion, nakedUrl: suggestionNakedUrl, from: suggestions)
case .phrase, .website, .unknown:
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ class DatabaseProviderTests: XCTestCase {

private func deleteDbFile() throws {
do {
try FileManager.default.removeItem(atPath: (try DefaultDatabaseProvider.dbFile()).path)
let dbFile = DefaultDatabaseProvider.dbFile()
let dbFileContainer = dbFile.deletingLastPathComponent()
for file in try FileManager.default.contentsOfDirectory(atPath: dbFileContainer.path) {
guard ["db", "bak"].contains((file as NSString).pathExtension) else { continue }
try FileManager.default.removeItem(atPath: dbFileContainer.appendingPathComponent(file).path)
}

} catch let error as NSError {
// File not found
if error.domain != NSCocoaErrorDomain || error.code != 4 {
Expand Down Expand Up @@ -179,6 +185,24 @@ class DatabaseProviderTests: XCTestCase {
XCTAssertTrue(results.isEmpty)
}

func test_when_database_is_corrupt_then_it_can_be_recreated_with_backup() throws {
do {
try! "asdf".data(using: .utf8)!.write(to: DefaultDatabaseProvider.dbFile())
try _=DefaultDatabaseProvider(key: simpleL1Key) as SecureVaultDatabaseProvider
XCTFail("should throw an error at this point")
} catch {
let database = try DefaultDatabaseProvider.recreateDatabase(withKey: simpleL1Key)
let backupURL = DefaultDatabaseProvider.dbFile().appendingPathExtension("bak")
XCTAssertEqual(try! Data(contentsOf: backupURL), "asdf".data(using: .utf8))

let account = SecureVaultModels.WebsiteAccount(username: "brindy", domain: "example.com")
let credentials = SecureVaultModels.WebsiteCredentials(account: account, password: "password".data(using: .utf8)!)
try database.storeWebsiteCredentials(credentials)

XCTAssertEqual(1, try database.accounts().count)
}
}

func test_when_credentials_are_deleted_then_they_are_removed_from_the_database() throws {
let database = try DefaultDatabaseProvider(key: simpleL1Key) as SecureVaultDatabaseProvider

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class VaultFactoryTests: XCTestCase {
func test() throws {
let testHarness = VaultFactoryTestHarness()
testHarness.mockKeystoreProvider._l1Key = "samplekey".data(using: .utf8)
_ = try testHarness.makeVault()
_ = try testHarness.makeVault(errorReporter: nil)
}

}

0 comments on commit 56d011b

Please sign in to comment.