diff --git a/.swiftpm/BrowserServicesKit-Package.xctestplan b/.swiftpm/BrowserServicesKit-Package.xctestplan new file mode 100644 index 000000000..c5e26f99c --- /dev/null +++ b/.swiftpm/BrowserServicesKit-Package.xctestplan @@ -0,0 +1,196 @@ +{ + "configurations" : [ + { + "id" : "2EA622A1-B72B-456A-A84F-B3979C987FE3", + "name" : "Test Scheme Action", + "options" : { + + } + } + ], + "defaultOptions" : { + "targetForVariableExpansion" : { + "containerPath" : "container:", + "identifier" : "BookmarksTestDBBuilder", + "name" : "BookmarksTestDBBuilder" + } + }, + "testTargets" : [ + { + "target" : { + "containerPath" : "container:", + "identifier" : "BookmarksTests", + "name" : "BookmarksTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "BrowserServicesKitTests", + "name" : "BrowserServicesKitTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "CommonTests", + "name" : "CommonTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "ConfigurationTests", + "name" : "ConfigurationTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "CrashesTests", + "name" : "CrashesTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "DDGSyncCryptoTests", + "name" : "DDGSyncCryptoTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "DDGSyncTests", + "name" : "DDGSyncTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "HistoryTests", + "name" : "HistoryTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "NavigationTests", + "name" : "NavigationTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "NetworkProtectionTests", + "name" : "NetworkProtectionTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "NetworkingTests", + "name" : "NetworkingTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PersistenceTests", + "name" : "PersistenceTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PixelKitTests", + "name" : "PixelKitTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PrivacyDashboardTests", + "name" : "PrivacyDashboardTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "RemoteMessagingTests", + "name" : "RemoteMessagingTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SecureStorageTests", + "name" : "SecureStorageTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SubscriptionTests", + "name" : "SubscriptionTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SuggestionsTests", + "name" : "SuggestionsTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SyncDataProvidersTests", + "name" : "SyncDataProvidersTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "UserScriptTests", + "name" : "UserScriptTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "DuckPlayerTests", + "name" : "DuckPlayerTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "OnboardingTests", + "name" : "OnboardingTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SpecialErrorPagesTests", + "name" : "SpecialErrorPagesTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PageRefreshMonitorTests", + "name" : "PageRefreshMonitorTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "BrokenSitePromptTests", + "name" : "BrokenSitePromptTests" + } + } + ], + "version" : 1 +} diff --git a/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDETemplateMacros.plist b/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDETemplateMacros.plist index 6bebd560c..c4fc4eaa6 100644 --- a/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDETemplateMacros.plist +++ b/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDETemplateMacros.plist @@ -2,21 +2,20 @@ - FILEHEADER - + FILEHEADER + // ___FILENAME___ -// DuckDuckGo // // Copyright © ___YEAR___ DuckDuckGo. All rights reserved. // -// Licensed under the Apache License, Version 2.0 (the "License"); +// Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, +// distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/BookmarksTestDBBuilder.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/BookmarksTestDBBuilder.xcscheme index 903ba019f..f23bed1fa 100644 --- a/.swiftpm/xcode/xcshareddata/xcschemes/BookmarksTestDBBuilder.xcscheme +++ b/.swiftpm/xcode/xcshareddata/xcschemes/BookmarksTestDBBuilder.xcscheme @@ -1,6 +1,6 @@ + shouldUseLaunchSchemeArgsEnv = "YES"> + + + + + skipped = "NO" + parallelizable = "YES"> + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/NetworkingTests.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/NetworkingTests.xcscheme new file mode 100644 index 000000000..d5063487f --- /dev/null +++ b/.swiftpm/xcode/xcshareddata/xcschemes/NetworkingTests.xcscheme @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/SubscriptionTests.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/SubscriptionTests.xcscheme index 63c498679..e00698aec 100644 --- a/.swiftpm/xcode/xcshareddata/xcschemes/SubscriptionTests.xcscheme +++ b/.swiftpm/xcode/xcshareddata/xcschemes/SubscriptionTests.xcscheme @@ -1,6 +1,6 @@ (from input: Input) -> T? { + do { + let json = try JSONSerialization.data(withJSONObject: input) + return try JSONDecoder().decode(T.self, from: json) + } catch { + Logger.general.error("Error decoding input: \(error.localizedDescription, privacy: .public)") + return nil + } + } + + public static func decode(jsonData: Data) -> T? { + do { + return try JSONDecoder().decode(T.self, from: jsonData) + } catch { + Logger.general.error("Error decoding input: \(error.localizedDescription, privacy: .public)") + } + return nil + } + + public static func encode(_ object: T) -> Data? { + do { + let encoder = JSONEncoder() + encoder.outputFormatting = .prettyPrinted + return try encoder.encode(object) + } catch let error { + Logger.general.error("Error encoding input: \(error.localizedDescription, privacy: .public)") + } + return nil + } +} + +public typealias DecodableHelper = CodableHelper diff --git a/Sources/Common/Extensions/DateExtension.swift b/Sources/Common/Extensions/DateExtension.swift index 4ae7be7f6..c66842351 100644 --- a/Sources/Common/Extensions/DateExtension.swift +++ b/Sources/Common/Extensions/DateExtension.swift @@ -25,122 +25,190 @@ public extension Date { public let index: Int } + /// Extracts day, month, and year components from the date. var components: DateComponents { - return Calendar.current.dateComponents([.day, .year, .month], from: self) + Calendar.current.dateComponents([.day, .year, .month], from: self) } + /// Returns the date exactly one week ago. static var weekAgo: Date { - return Calendar.current.date(byAdding: .weekOfMonth, value: -1, to: Date())! + guard let date = Calendar.current.date(byAdding: .weekOfMonth, value: -1, to: Date()) else { + fatalError("Unable to calculate a week ago date.") + } + return date } - static var monthAgo: Date! { - return Calendar.current.date(byAdding: .month, value: -1, to: Date())! + /// Returns the date exactly one month ago. + static var monthAgo: Date { + guard let date = Calendar.current.date(byAdding: .month, value: -1, to: Date()) else { + fatalError("Unable to calculate a month ago date.") + } + return date } - static var yearAgo: Date! { - return Calendar.current.date(byAdding: .year, value: -1, to: Date())! + /// Returns the date exactly one year ago. + static var yearAgo: Date { + guard let date = Calendar.current.date(byAdding: .year, value: -1, to: Date()) else { + fatalError("Unable to calculate a year ago date.") + } + return date } - static var aYearFromNow: Date! { - return Calendar.current.date(byAdding: .year, value: 1, to: Date())! + /// Returns the date exactly one year from now. + static var aYearFromNow: Date { + guard let date = Calendar.current.date(byAdding: .year, value: 1, to: Date()) else { + fatalError("Unable to calculate a year from now date.") + } + return date } - static func daysAgo(_ days: Int) -> Date! { - return Calendar.current.date(byAdding: .day, value: -days, to: Date())! + /// Returns the date a specific number of days ago. + static func daysAgo(_ days: Int) -> Date { + guard let date = Calendar.current.date(byAdding: .day, value: -days, to: Date()) else { + fatalError("Unable to calculate \(days) days ago date.") + } + return date } + /// Checks if two dates fall on the same calendar day. static func isSameDay(_ date1: Date, _ date2: Date?) -> Bool { guard let date2 = date2 else { return false } return Calendar.current.isDate(date1, inSameDayAs: date2) } + /// Returns the start of tomorrow's day. static var startOfDayTomorrow: Date { let tomorrow = Calendar.current.date(byAdding: .day, value: 1, to: Date())! return Calendar.current.startOfDay(for: tomorrow) } + /// Returns the start of today's day. static var startOfDayToday: Date { - return Calendar.current.startOfDay(for: Date()) + Calendar.current.startOfDay(for: Date()) } + /// Returns the start of the day for this date instance. var startOfDay: Date { - return Calendar.current.startOfDay(for: self) + Calendar.current.startOfDay(for: self) } + /// Returns the date a specific number of days ago from this date instance. func daysAgo(_ days: Int) -> Date { - Calendar.current.date(byAdding: .day, value: -days, to: self)! + guard let date = Calendar.current.date(byAdding: .day, value: -days, to: self) else { + fatalError("Unable to calculate \(days) days ago date from this instance.") + } + return date } + /// Returns the start of the current minute. static var startOfMinuteNow: Date { - let date = Calendar.current.date(bySetting: .second, value: 0, of: Date())! - let start = Calendar.current.date(byAdding: .minute, value: -1, to: date)! + guard let date = Calendar.current.date(bySetting: .second, value: 0, of: Date()), + let start = Calendar.current.date(byAdding: .minute, value: -1, to: date) else { + fatalError("Unable to calculate the start of the current minute.") + } return start } + /// Provides a list of months with their names and indices. static var monthsWithIndex: [IndexedMonth] { - let months = Calendar.current.monthSymbols - - return months.enumerated().map { index, month in - return IndexedMonth(name: month, index: index + 1) + Calendar.current.monthSymbols.enumerated().map { index, month in + IndexedMonth(name: month, index: index + 1) } } - static var daysInMonth: [Int] = { - return Array(1...31) - }() - - static var nextTenYears: [Int] = { - let offsetComponents = DateComponents(year: 1) - - var years = [Int]() - var currentDate = Date() - - for _ in 0...10 { - let currentYear = Calendar.current.component(.year, from: currentDate) - years.append(currentYear) + /// Provides a list of days in a month (1 through 31). + static let daysInMonth = Array(1...31) - currentDate = Calendar.current.date(byAdding: offsetComponents, to: currentDate)! - } - - return years - }() - - static var lastHundredYears: [Int] = { - let offsetComponents = DateComponents(year: -1) - - var years = [Int]() - var currentDate = Date() - - for _ in 0...100 { - let currentYear = Calendar.current.component(.year, from: currentDate) - years.append(currentYear) - - currentDate = Calendar.current.date(byAdding: offsetComponents, to: currentDate)! - } + /// Provides a list of the next ten years including the current year. + static var nextTenYears: [Int] { + let currentYear = Calendar.current.component(.year, from: Date()) + return (0...10).map { currentYear + $0 } + } - return years - }() + /// Provides a list of the last hundred years including the current year. + static var lastHundredYears: [Int] { + let currentYear = Calendar.current.component(.year, from: Date()) + return (0...100).map { currentYear - $0 } + } + /// Returns the number of whole days since the reference date (January 1, 2001). var daySinceReferenceDate: Int { Int(self.timeIntervalSinceReferenceDate / TimeInterval.day) } - @inlinable + /// Adds a specific time interval to this date. func adding(_ timeInterval: TimeInterval) -> Date { addingTimeInterval(timeInterval) } + /// Checks if this date falls on the same calendar day as another date. func isSameDay(_ otherDate: Date?) -> Bool { guard let otherDate = otherDate else { return false } return Calendar.current.isDate(self, inSameDayAs: otherDate) } + /// Checks if this date is within a certain number of days ago. func isLessThan(daysAgo days: Int) -> Bool { - self > Date().addingTimeInterval(Double(-days) * 24 * 60 * 60) + self > Date().addingTimeInterval(Double(-days) * TimeInterval.day) } + /// Checks if this date is within a certain number of minutes ago. func isLessThan(minutesAgo minutes: Int) -> Bool { self > Date().addingTimeInterval(Double(-minutes) * 60) } + /// Returns a new date a specific number of seconds from now. + static func secondsFromNow(_ seconds: Int) -> Date { + Calendar.current.date(byAdding: .second, value: -seconds, to: Date())! + } + + /// Returns a new date a specific number of minutes from now. + static func minutesFromNow(_ minutes: Int) -> Date { + Calendar.current.date(byAdding: .minute, value: -minutes, to: Date())! + } + + /// Returns a new date a specific number of hours from now. + static func hoursFromNow(_ hours: Int) -> Date { + Calendar.current.date(byAdding: .hour, value: -hours, to: Date())! + } + + /// Returns a new date a specific number of days from now. + static func daysFromNow(_ days: Int) -> Date { + Calendar.current.date(byAdding: .day, value: -days, to: Date())! + } + + /// Returns a new date a specific number of months from now. + static func monthsFromNow(_ months: Int) -> Date { + Calendar.current.date(byAdding: .month, value: -months, to: Date())! + } + + /// Returns the number of seconds since this date until now. + func secondsSinceNow() -> Int { + Int(Date().timeIntervalSince(self)) + } + + /// Returns the number of minutes since this date until now. + func minutesSinceNow() -> Int { + secondsSinceNow() / 60 + } + + /// Returns the number of hours since this date until now. + func hoursSinceNow() -> Int { + minutesSinceNow() / 60 + } + + /// Returns the number of days since this date until now. + func daysSinceNow() -> Int { + hoursSinceNow() / 24 + } + + /// Returns the number of months since this date until now. + func monthsSinceNow() -> Int { + Calendar.current.dateComponents([.month], from: self, to: Date()).month ?? 0 + } + + /// Returns the number of years since this date until now. + func yearsSinceNow() -> Int { + Calendar.current.dateComponents([.year], from: self, to: Date()).year ?? 0 + } } diff --git a/Sources/NetworkProtection/Keychain/KeychainType.swift b/Sources/Common/KeychainType.swift similarity index 87% rename from Sources/NetworkProtection/Keychain/KeychainType.swift rename to Sources/Common/KeychainType.swift index 0890501e3..caade87ac 100644 --- a/Sources/NetworkProtection/Keychain/KeychainType.swift +++ b/Sources/Common/KeychainType.swift @@ -1,7 +1,7 @@ // // KeychainType.swift // -// Copyright © 2023 DuckDuckGo. All rights reserved. +// Copyright © 2024 DuckDuckGo. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,20 +19,18 @@ import Foundation /// A convenience enum to unify the logic for selecting the right keychain through the query attributes. -/// public enum KeychainType { case dataProtection(_ accessGroup: AccessGroup) - /// Uses the system keychain. - /// case system + case fileBased public enum AccessGroup { case unspecified case named(_ name: String) } - func queryAttributes() -> [CFString: Any] { + public func queryAttributes() -> [CFString: Any] { switch self { case .dataProtection(let accessGroup): switch accessGroup { @@ -46,6 +44,8 @@ public enum KeychainType { } case .system: return [kSecUseDataProtectionKeychain: false] + case .fileBased: + return [kSecUseDataProtectionKeychain: false] } } } diff --git a/Sources/Common/UserDefaultsCache.swift b/Sources/Common/UserDefaultsCache.swift index aba17b027..27561f9c2 100644 --- a/Sources/Common/UserDefaultsCache.swift +++ b/Sources/Common/UserDefaultsCache.swift @@ -44,7 +44,7 @@ public class UserDefaultsCache { let object: ObjectType } - let logger = { Logger(subsystem: Bundle.main.bundleIdentifier ?? "DuckDuckGo", category: "UserDefaultsCache") }() + let logger = { Logger(subsystem: "UserDefaultsCache", category: "") }() private var userDefaults: UserDefaults public private(set) var settings: UserDefaultsCacheSettings @@ -65,8 +65,9 @@ public class UserDefaultsCache { do { let data = try encoder.encode(cacheObject) userDefaults.set(data, forKey: key.rawValue) - logger.debug("Cache Set: \(String(describing: cacheObject))") + logger.debug("Cache Set: \(String(describing: cacheObject), privacy: .public)") } catch { + logger.fault("Failed to encode CacheObject: \(error, privacy: .public)") assertionFailure("Failed to encode CacheObject: \(error)") } } @@ -77,21 +78,21 @@ public class UserDefaultsCache { do { let cacheObject = try decoder.decode(CacheObject.self, from: data) if cacheObject.expires > Date() { - logger.debug("Cache Hit: \(ObjectType.self)") + logger.debug("Cache Hit: \(ObjectType.self, privacy: .public)") return cacheObject.object } else { - logger.debug("Cache Miss: \(ObjectType.self)") + logger.debug("Cache Miss: \(ObjectType.self, privacy: .public)") reset() // Clear expired data return nil } } catch let error { - logger.error("Cache Decode Error: \(error)") + logger.fault("Cache Decode Error: \(error, privacy: .public)") return nil } } public func reset() { - logger.debug("Cache Clean: \(ObjectType.self)") + logger.debug("Cache Clean: \(ObjectType.self, privacy: .public)") userDefaults.removeObject(forKey: key.rawValue) } } diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift index 6bf0319f4..8c7fed556 100644 --- a/Sources/MaliciousSiteProtection/API/APIClient.swift +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -29,6 +29,7 @@ extension APIClient { extension APIClient: APIClient.Mockable {} public protocol APIClientEnvironment { + func queryItems(for requestType: APIRequestType) -> QueryItems func headers(for requestType: APIRequestType) -> APIRequestV2.HeadersV2 func url(for requestType: APIRequestType) -> URL func timeout(for requestType: APIRequestType) -> TimeInterval? @@ -67,20 +68,27 @@ public extension MaliciousSiteDetector { static let hashPrefix = "hashPrefix" } - public func url(for requestType: APIRequestType) -> URL { + public func queryItems(for requestType: APIRequestType) -> QueryItems { switch requestType { case .hashPrefixSet(let configuration): - endpoint.appendingPathComponent(APIPath.hashPrefix).appendingParameters([ - QueryParameter.category: configuration.threatKind.rawValue, - QueryParameter.revision: (configuration.revision ?? 0).description, - ]) + return [QueryParameter.category: configuration.threatKind.rawValue, + QueryParameter.revision: (configuration.revision ?? 0).description] case .filterSet(let configuration): - endpoint.appendingPathComponent(APIPath.filterSet).appendingParameters([ - QueryParameter.category: configuration.threatKind.rawValue, - QueryParameter.revision: (configuration.revision ?? 0).description, - ]) + return [QueryParameter.category: configuration.threatKind.rawValue, + QueryParameter.revision: (configuration.revision ?? 0).description] case .matches(let configuration): - endpoint.appendingPathComponent(APIPath.matches).appendingParameter(name: QueryParameter.hashPrefix, value: configuration.hashPrefix) + return [QueryParameter.hashPrefix: configuration.hashPrefix] + } + } + + public func url(for requestType: APIRequestType) -> URL { + switch requestType { + case .hashPrefixSet: + endpoint.appendingPathComponent(APIPath.hashPrefix) + case .filterSet: + endpoint.appendingPathComponent(APIPath.filterSet) + case .matches: + endpoint.appendingPathComponent(APIPath.matches) } } @@ -105,9 +113,9 @@ struct APIClient { let requestType = requestConfig.requestType let headers = environment.headers(for: requestType) let url = environment.url(for: requestType) + let queryItems = environment.queryItems(for: requestType) let timeout = environment.timeout(for: requestType) ?? requestConfig.defaultTimeout ?? 60 - - let apiRequest = APIRequestV2(url: url, method: .get, headers: headers, timeoutInterval: timeout) + let apiRequest = APIRequestV2(url: url, queryItems: queryItems, headers: headers, timeoutInterval: timeout)! let response = try await service.fetch(request: apiRequest) let result: R.Response = try response.decodeBody() diff --git a/Sources/NetworkProtection/Diagnostics/NetworkProtectionConnectionTester.swift b/Sources/NetworkProtection/Diagnostics/NetworkProtectionConnectionTester.swift index 382532893..a5637bd08 100644 --- a/Sources/NetworkProtection/Diagnostics/NetworkProtectionConnectionTester.swift +++ b/Sources/NetworkProtection/Diagnostics/NetworkProtectionConnectionTester.swift @@ -123,7 +123,7 @@ final class NetworkProtectionConnectionTester { } func stop() { - Logger.networkProtectionConnectionTester.log("🔴 Stopping connection tester") + Logger.networkProtectionConnectionTester.log("🟢 Stopping connection tester") stopScheduledTimer() isRunning = false } @@ -216,7 +216,7 @@ final class NetworkProtectionConnectionTester { Logger.networkProtectionConnectionTester.log("👎 VPN is DOWN") handleDisconnected() } else { - Logger.networkProtectionConnectionTester.log("👍 VPN: \(vpnIsConnected ? "UP" : "DOWN") local: \(localIsConnected ? "UP" : "DOWN")") + Logger.networkProtectionConnectionTester.log("👍 VPN: \(vpnIsConnected ? "UP" : "DOWN", privacy: .public) local: \(localIsConnected ? "UP" : "DOWN", privacy: .public)") handleConnected() } } diff --git a/Sources/NetworkProtection/Diagnostics/NetworkProtectionServerStatusMonitor.swift b/Sources/NetworkProtection/Diagnostics/NetworkProtectionServerStatusMonitor.swift index d9898b0ed..26872650f 100644 --- a/Sources/NetworkProtection/Diagnostics/NetworkProtectionServerStatusMonitor.swift +++ b/Sources/NetworkProtection/Diagnostics/NetworkProtectionServerStatusMonitor.swift @@ -21,6 +21,7 @@ import Network import Common import Combine import os.log +import Subscription public actor NetworkProtectionServerStatusMonitor { @@ -49,13 +50,14 @@ public actor NetworkProtectionServerStatusMonitor { } private let networkClient: NetworkProtectionClient - private let tokenStore: NetworkProtectionTokenStore + private let tokenProvider: any SubscriptionTokenProvider // MARK: - Init & deinit - init(networkClient: NetworkProtectionClient, tokenStore: NetworkProtectionTokenStore) { + init(networkClient: NetworkProtectionClient, + tokenProvider: any SubscriptionTokenProvider) { self.networkClient = networkClient - self.tokenStore = tokenStore + self.tokenProvider = tokenProvider Logger.networkProtectionMemory.debug("[+] \(String(describing: self), privacy: .public)") } @@ -99,11 +101,11 @@ public actor NetworkProtectionServerStatusMonitor { // MARK: - Server Status Check private func checkServerStatus(for serverName: String) async -> Result { - guard let accessToken = try? tokenStore.fetchToken() else { + guard let accessToken = try? await VPNAuthTokenBuilder.getVPNAuthToken(from: tokenProvider, policy: .local) else { + Logger.networkProtection.fault("Failed to check server status due to lack of access token") assertionFailure("Failed to check server status due to lack of access token") return .failure(.invalidAuthToken) } - return await networkClient.getServerStatus(authToken: accessToken, serverName: serverName) } diff --git a/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeyStore.swift b/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeyStore.swift index 0df61e950..9091961f5 100644 --- a/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeyStore.swift +++ b/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeyStore.swift @@ -218,6 +218,8 @@ public final class NetworkProtectionKeychainKeyStore: NetworkProtectionKeyStore // MARK: - EventMapping private func handle(_ error: Error) { + Logger.networkProtectionKeyManagement.error("Failed to perform operation: \(error, privacy: .public)") + guard let error = error as? NetworkProtectionKeychainStoreError else { assertionFailure("Failed to cast Network Protection Keychain store error") errorEvents?.fire(NetworkProtectionError.unhandledError(function: #function, line: #line, error: error)) diff --git a/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeychainStore.swift b/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeychainStore.swift index e3abda105..68a6bddc6 100644 --- a/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeychainStore.swift +++ b/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeychainStore.swift @@ -20,7 +20,7 @@ import Foundation import Common import os.log -enum NetworkProtectionKeychainStoreError: Error, NetworkProtectionErrorConvertible { +public enum NetworkProtectionKeychainStoreError: Error, NetworkProtectionErrorConvertible { case failedToCastKeychainValueToData(field: String) case keychainReadError(field: String, status: Int32) case keychainWriteError(field: String, status: Int32) @@ -39,14 +39,14 @@ enum NetworkProtectionKeychainStoreError: Error, NetworkProtectionErrorConvertib } /// General Keychain access helper class for the NetworkProtection module. Should be used for specific KeychainStore types. -final class NetworkProtectionKeychainStore { +public final class NetworkProtectionKeychainStore { private let label: String private let serviceName: String private let keychainType: KeychainType - init(label: String, - serviceName: String, - keychainType: KeychainType) { + public init(label: String, + serviceName: String, + keychainType: KeychainType) { self.label = label self.serviceName = serviceName @@ -55,7 +55,8 @@ final class NetworkProtectionKeychainStore { // MARK: - Keychain Interaction - func readData(named name: String) throws -> Data? { + public func readData(named name: String) throws -> Data? { + Logger.networkProtectionKeyManagement.debug("Reading key \(name, privacy: .public) from keychain") var query = defaultAttributes() query[kSecAttrAccount] = name query[kSecReturnData] = true @@ -78,7 +79,8 @@ final class NetworkProtectionKeychainStore { } } - func writeData(_ data: Data, named name: String) throws { + public func writeData(_ data: Data, named name: String) throws { + Logger.networkProtectionKeyManagement.debug("Writing key \(name, privacy: .public) to keychain") var query = defaultAttributes() query[kSecAttrAccount] = name query[kSecAttrAccessible] = kSecAttrAccessibleAfterFirstUnlock @@ -101,18 +103,20 @@ final class NetworkProtectionKeychainStore { } private func updateData(_ data: Data, named name: String) -> OSStatus { + Logger.networkProtectionKeyManagement.debug("Updating key \(name, privacy: .public) in keychain") var query = defaultAttributes() query[kSecAttrAccount] = name let newAttributes = [ - kSecValueData: data, - kSecAttrAccessible: kSecAttrAccessibleAfterFirstUnlock + kSecValueData: data, + kSecAttrAccessible: kSecAttrAccessibleAfterFirstUnlock ] as [CFString: Any] return SecItemUpdate(query as CFDictionary, newAttributes as CFDictionary) } - func deleteAll() throws { + public func deleteAll() throws { + Logger.networkProtectionKeyManagement.debug("Deleting all keys from keychain") var query = defaultAttributes() #if os(macOS) // This line causes the delete to error with status -50 on iOS. Needs investigation but, for now, just delete the first item @@ -125,6 +129,7 @@ final class NetworkProtectionKeychainStore { case errSecItemNotFound, errSecSuccess: break default: + Logger.networkProtectionKeyManagement.error("🔴 Failed to delete all keys, SecItemDelete status \(String(describing: status), privacy: .public)") throw NetworkProtectionKeychainStoreError.keychainDeleteError(status: status) } } diff --git a/Sources/NetworkProtection/KeyManagement/NetworkProtectionTokenStore.swift b/Sources/NetworkProtection/KeyManagement/NetworkProtectionTokenStore.swift deleted file mode 100644 index 4510dc1b6..000000000 --- a/Sources/NetworkProtection/KeyManagement/NetworkProtectionTokenStore.swift +++ /dev/null @@ -1,151 +0,0 @@ -// -// NetworkProtectionTokenStore.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common - -public protocol NetworkProtectionTokenStore { - /// Store an auth token. - /// - @available(iOS, deprecated, message: "[NetP Subscription] Use subscription access token instead") - func store(_ token: String) throws - - /// Obtain the current auth token. - /// - func fetchToken() throws -> String? - - /// Delete the stored auth token. - /// - @available(iOS, deprecated, message: "[NetP Subscription] Use subscription access token instead") - func deleteToken() throws -} - -#if os(macOS) - -/// Store an auth token for NetworkProtection on behalf of the user. This key is then used to authenticate requests for registration and server fetches from the Network Protection backend servers. -/// Writing a new auth token will replace the old one. -public final class NetworkProtectionKeychainTokenStore: NetworkProtectionTokenStore { - private let keychainStore: NetworkProtectionKeychainStore - private let errorEvents: EventMapping? - private let useAccessTokenProvider: Bool - public typealias AccessTokenProvider = () -> String? - private let accessTokenProvider: AccessTokenProvider - - public static var authTokenPrefix: String { "ddg:" } - - public struct Defaults { - static let tokenStoreEntryLabel = "DuckDuckGo Network Protection Auth Token" - public static let tokenStoreService = "com.duckduckgo.networkprotection.authToken" - static let tokenStoreName = "com.duckduckgo.networkprotection.token" - } - - /// - isSubscriptionEnabled: Controls whether the subscription access token is used to authenticate with the NetP backend - /// - accessTokenProvider: Defines how to actually retrieve the subscription access token - public init(keychainType: KeychainType, - serviceName: String = Defaults.tokenStoreService, - errorEvents: EventMapping?, - useAccessTokenProvider: Bool, - accessTokenProvider: @escaping AccessTokenProvider) { - keychainStore = NetworkProtectionKeychainStore(label: Defaults.tokenStoreEntryLabel, - serviceName: serviceName, - keychainType: keychainType) - self.errorEvents = errorEvents - self.useAccessTokenProvider = useAccessTokenProvider - self.accessTokenProvider = accessTokenProvider - } - - public func store(_ token: String) throws { - let data = token.data(using: .utf8)! - do { - try keychainStore.writeData(data, named: Defaults.tokenStoreName) - } catch { - handle(error) - throw error - } - } - - private func makeToken(from subscriptionAccessToken: String) -> String { - Self.authTokenPrefix + subscriptionAccessToken - } - - public func fetchToken() throws -> String? { - if useAccessTokenProvider { - return accessTokenProvider().map { makeToken(from: $0) } - } - - do { - return try keychainStore.readData(named: Defaults.tokenStoreName).flatMap { - String(data: $0, encoding: .utf8) - } - } catch { - handle(error) - throw error - } - } - - public func deleteToken() throws { - do { - try keychainStore.deleteAll() - } catch { - handle(error) - throw error - } - } - - // MARK: - EventMapping - - private func handle(_ error: Error) { - guard let error = error as? NetworkProtectionKeychainStoreError else { - assertionFailure("Failed to cast Network Protection Token store error") - errorEvents?.fire(NetworkProtectionError.unhandledError(function: #function, line: #line, error: error)) - return - } - - errorEvents?.fire(error.networkProtectionError) - } -} - -#else - -public final class NetworkProtectionKeychainTokenStore: NetworkProtectionTokenStore { - private let accessTokenProvider: () -> String? - - public static var authTokenPrefix: String { "ddg:" } - - public init(accessTokenProvider: @escaping () -> String?) { - self.accessTokenProvider = accessTokenProvider - } - - public func store(_ token: String) throws { - assertionFailure("Unsupported operation") - } - - public func fetchToken() throws -> String? { - accessTokenProvider().map { makeToken(from: $0) } - } - - public func deleteToken() throws { - assertionFailure("Unsupported operation") - } - - private func makeToken(from subscriptionAccessToken: String) -> String { - Self.authTokenPrefix + subscriptionAccessToken - } -} - -#endif diff --git a/Sources/NetworkProtection/NetworkProtectionDeviceManager.swift b/Sources/NetworkProtection/NetworkProtectionDeviceManager.swift index a599e150a..36ff5b4f1 100644 --- a/Sources/NetworkProtection/NetworkProtectionDeviceManager.swift +++ b/Sources/NetworkProtection/NetworkProtectionDeviceManager.swift @@ -20,6 +20,7 @@ import Foundation import Common import NetworkExtension import os.log +import Subscription public enum NetworkProtectionServerSelectionMethod: CustomDebugStringConvertible { public var debugDescription: String { @@ -73,27 +74,27 @@ public protocol NetworkProtectionDeviceManagement { public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement { private let networkClient: NetworkProtectionClient - private let tokenStore: NetworkProtectionTokenStore + private let tokenProvider: any SubscriptionTokenProvider private let keyStore: NetworkProtectionKeyStore private let errorEvents: EventMapping? public init(environment: VPNSettings.SelectedEnvironment, - tokenStore: NetworkProtectionTokenStore, + tokenProvider: any SubscriptionTokenProvider, keyStore: NetworkProtectionKeyStore, errorEvents: EventMapping?) { self.init(networkClient: NetworkProtectionBackendClient(environment: environment), - tokenStore: tokenStore, + tokenProvider: tokenProvider, keyStore: keyStore, errorEvents: errorEvents) } init(networkClient: NetworkProtectionClient, - tokenStore: NetworkProtectionTokenStore, + tokenProvider: any SubscriptionTokenProvider, keyStore: NetworkProtectionKeyStore, errorEvents: EventMapping?) { self.networkClient = networkClient - self.tokenStore = tokenStore + self.tokenProvider = tokenProvider self.keyStore = keyStore self.errorEvents = errorEvents } @@ -102,9 +103,7 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement { /// This method will return the remote server list if available, or the local server list if there was a problem with the service call. /// public func refreshServerList() async throws -> [NetworkProtectionServer] { - guard let token = try? tokenStore.fetchToken() else { - throw NetworkProtectionError.noAuthTokenFound - } + let token = try await VPNAuthTokenBuilder.getVPNAuthToken(from: tokenProvider, policy: .localValid) let result = await networkClient.getServers(authToken: token) let completeServerList: [NetworkProtectionServer] @@ -189,7 +188,7 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement { selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (server: NetworkProtectionServer, newExpiration: Date?) { - guard let token = try? tokenStore.fetchToken() else { throw NetworkProtectionError.noAuthTokenFound } + let token = try await VPNAuthTokenBuilder.getVPNAuthToken(from: tokenProvider, policy: .localValid) let serverSelection: RegisterServerSelection let excludedServerName: String? @@ -313,11 +312,11 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement { } private func handle(clientError: NetworkProtectionClientError) { -#if os(macOS) + #if os(macOS) if case .invalidAuthToken = clientError { - try? tokenStore.deleteToken() + tokenProvider.removeTokenContainer() } -#endif + #endif errorEvents?.fire(clientError.networkProtectionError) } diff --git a/Sources/NetworkProtection/NetworkProtectionOptionKey.swift b/Sources/NetworkProtection/NetworkProtectionOptionKey.swift index 660b6368f..9f837ca45 100644 --- a/Sources/NetworkProtection/NetworkProtectionOptionKey.swift +++ b/Sources/NetworkProtection/NetworkProtectionOptionKey.swift @@ -25,7 +25,7 @@ public enum NetworkProtectionOptionKey { public static let selectedLocation = "selectedLocation" public static let dnsSettings = "dnsSettings" public static let excludeLocalNetworks = "excludeLocalNetworks" - public static let authToken = "authToken" + public static let tokenContainer = "tokenContainer" public static let isOnDemand = "is-on-demand" public static let activationAttemptId = "activationAttemptId" public static let tunnelFailureSimulation = "tunnelFailureSimulation" diff --git a/Sources/NetworkProtection/Networking/NetworkProtectionClient.swift b/Sources/NetworkProtection/Networking/NetworkProtectionClient.swift index 2e59d67e9..5116258ee 100644 --- a/Sources/NetworkProtection/Networking/NetworkProtectionClient.swift +++ b/Sources/NetworkProtection/Networking/NetworkProtectionClient.swift @@ -89,6 +89,26 @@ public enum NetworkProtectionClientError: CustomNSError, NetworkProtectionErrorC return [:] } } + + public var errorDescription: String? { + switch self { + case .failedToFetchLocationList: return "Failed to fetch location list" + case .failedToParseLocationListResponse: return "Failed to parse location list response" + case .failedToFetchServerList: return "Failed to fetch server list" + case .failedToParseServerListResponse: return "Failed to parse server list response" + case .failedToEncodeRegisterKeyRequest: return "Failed to encode register key request" + case .failedToFetchServerStatus(let error): + return "Failed to fetch server status: \(error)" + case .failedToParseServerStatusResponse(let error): + return "Failed to parse server status response: \(error)" + case .failedToFetchRegisteredServers(let error): + return "Failed to fetch registered servers: \(error)" + case .failedToParseRegisteredServersResponse(let error): + return "Failed to parse registered servers response: \(error)" + case .invalidAuthToken: return "Invalid auth token" + case .accessDenied: return "Access denied" + } + } } struct RegisterKeyRequestBody: Encodable { @@ -175,14 +195,12 @@ final class NetworkProtectionBackendClient: NetworkProtectionClient { } private let decoder: JSONDecoder = { - let formatter = ISO8601DateFormatter() - formatter.formatOptions = [.withFullDate, .withFullTime, .withFractionalSeconds] - let decoder = JSONDecoder() decoder.dateDecodingStrategy = .custom({ decoder in let container = try decoder.singleValueContainer() let dateString = try container.decode(String.self) - + let formatter = ISO8601DateFormatter() + formatter.formatOptions = [.withFullDate, .withFullTime, .withFractionalSeconds] guard let date = formatter.date(from: dateString) else { throw DecoderError.failedToDecode(key: container.codingPath.last?.stringValue ?? String(describing: container.codingPath)) } diff --git a/Sources/NetworkProtection/PacketTunnelProvider.swift b/Sources/NetworkProtection/PacketTunnelProvider.swift index 7e82291d7..596c3afeb 100644 --- a/Sources/NetworkProtection/PacketTunnelProvider.swift +++ b/Sources/NetworkProtection/PacketTunnelProvider.swift @@ -25,6 +25,7 @@ import Foundation import NetworkExtension import UserNotifications import os.log +import Subscription open class PacketTunnelProvider: NEPacketTunnelProvider { @@ -232,7 +233,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { private lazy var serverSelectionResolver: VPNServerSelectionResolving = { let locationRepository = NetworkProtectionLocationListCompositeRepository( environment: settings.selectedEnvironment, - tokenStore: tokenStore, + tokenProvider: tokenProvider, errorEvents: debugEvents ) return VPNServerSelectionResolver(locationListRepository: locationRepository, vpnSettings: settings) @@ -261,7 +262,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { private lazy var keyStore = NetworkProtectionKeychainKeyStore(keychainType: keychainType, errorEvents: debugEvents) - private let tokenStore: NetworkProtectionTokenStore + private let tokenProvider: any SubscriptionTokenProvider private func resetRegistrationKey() { Logger.networkProtectionKeyManagement.log("Resetting the current registration key") @@ -415,7 +416,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { private lazy var deviceManager: NetworkProtectionDeviceManagement = NetworkProtectionDeviceManager( environment: self.settings.selectedEnvironment, - tokenStore: self.tokenStore, + tokenProvider: self.tokenProvider, keyStore: self.keyStore, errorEvents: self.debugEvents ) @@ -426,7 +427,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { public lazy var entitlementMonitor = NetworkProtectionEntitlementMonitor() public lazy var serverStatusMonitor = NetworkProtectionServerStatusMonitor( networkClient: NetworkProtectionBackendClient(environment: self.settings.selectedEnvironment), - tokenStore: self.tokenStore + tokenProvider: self.tokenProvider ) private var lastTestFailed = false @@ -455,7 +456,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { snoozeTimingStore: NetworkProtectionSnoozeTimingStore, wireGuardInterface: WireGuardInterface, keychainType: KeychainType, - tokenStore: NetworkProtectionTokenStore, + tokenProvider: any SubscriptionTokenProvider, debugEvents: EventMapping, providerEvents: EventMapping, settings: VPNSettings, @@ -465,7 +466,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { self.notificationsPresenter = notificationsPresenter self.keychainType = keychainType - self.tokenStore = tokenStore + self.tokenProvider = tokenProvider self.debugEvents = debugEvents self.providerEvents = providerEvents self.tunnelHealth = tunnelHealthStore @@ -515,7 +516,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { } } - open func load(options: StartupOptions) throws { + open func load(options: StartupOptions) async throws { loadKeyValidity(from: options) loadSelectedEnvironment(from: options) loadSelectedServer(from: options) @@ -523,7 +524,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { loadDNSSettings(from: options) loadTesterEnabled(from: options) #if os(macOS) - try loadAuthToken(from: options) + try await loadAuthToken(from: options) #endif } @@ -598,22 +599,24 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { } #if os(macOS) - private func loadAuthToken(from options: StartupOptions) throws { - switch options.authToken { - case .set(let newAuthToken): - if let currentAuthToken = try? tokenStore.fetchToken(), currentAuthToken == newAuthToken { - return - } - - try tokenStore.store(newAuthToken) + private func loadAuthToken(from options: StartupOptions) async throws { + Logger.networkProtection.log("Loading token \(options.tokenContainer.description, privacy: .public)") + switch options.tokenContainer { + case .set(let newTokenContainer): + try await tokenProvider.adopt(tokenContainer: newTokenContainer) + // Important: Here we force the token refresh in order to immediately branch the system extension token from the main app one. + // See discussion https://app.asana.com/0/1199230911884351/1208785842165508/f + try await tokenProvider.getTokenContainer(policy: .localForceRefresh) case .useExisting: - guard try tokenStore.fetchToken() != nil else { + do { + try await tokenProvider.getTokenContainer(policy: .local) + } catch { throw TunnelError.startingTunnelWithoutAuthToken } case .reset: // This case should in theory not be possible, but it's ideal to have this in place // in case an error in the controller on the client side allows it. - try tokenStore.deleteToken() + tokenProvider.removeTokenContainer() throw TunnelError.startingTunnelWithoutAuthToken } } @@ -677,11 +680,8 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { self.snoozeTimingStore.reset() do { - try load(options: startupOptions) - - if (try? tokenStore.fetchToken()) == nil { - throw TunnelError.startingTunnelWithoutAuthToken - } + try await load(options: startupOptions) + Logger.networkProtection.log("Startup options loaded correctly") } catch { if startupOptions.startupMethod == .automaticOnDemand { // If the VPN was started by on-demand without the basic prerequisites for @@ -698,7 +698,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { providerEvents.fire(.tunnelStartAttempt(.failure(error))) } - Logger.networkProtection.log("🔴 Stopping VPN due to no auth token") + Logger.networkProtection.error("🔴 Stopping VPN due to no auth token") await cancelTunnel(with: TunnelError.startingTunnelWithoutAuthToken) // Check that the error is valid and able to be re-thrown to the OS before shutting the tunnel down @@ -723,6 +723,8 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { providerEvents.fire(.tunnelStartAttempt(.success)) } catch { + Logger.networkProtection.error("🔴 Failed to start tunnel \(error.localizedDescription, privacy: .public)") + if startupOptions.startupMethod == .automaticOnDemand { // We add a delay when the VPN is started by // on-demand and there's an error, to avoid frenetic ON/OFF @@ -786,6 +788,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { try await startTunnel(with: tunnelConfiguration, onDemand: onDemand) Logger.networkProtection.log("Done generating tunnel config") } catch { + Logger.networkProtection.error("Failed to start tunnel on demand: \(error.localizedDescription, privacy: .public)") controllerErrorStore.lastErrorMessage = error.localizedDescription throw error } @@ -1207,7 +1210,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { resetRegistrationKey() #if os(macOS) - try? tokenStore.deleteToken() + tokenProvider.removeTokenContainer() #endif Task { @@ -1570,9 +1573,9 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { @MainActor private func attemptShutdownDueToRevokedAccess() async { let cancelTunnel = { -#if os(macOS) - try? self.tokenStore.deleteToken() -#endif + #if os(macOS) + self.tokenProvider.removeTokenContainer() + #endif self.cancelTunnelWithError(TunnelError.vpnAccessRevoked) } diff --git a/Sources/NetworkProtection/Repositories/NetworkProtectionLocationListRepository.swift b/Sources/NetworkProtection/Repositories/NetworkProtectionLocationListRepository.swift index e4939a6ee..41f2faee3 100644 --- a/Sources/NetworkProtection/Repositories/NetworkProtectionLocationListRepository.swift +++ b/Sources/NetworkProtection/Repositories/NetworkProtectionLocationListRepository.swift @@ -18,6 +18,8 @@ import Foundation import Common +import Subscription +import Networking public enum NetworkProtectionLocationListCachePolicy { case returnCacheElseLoad @@ -36,24 +38,24 @@ final public class NetworkProtectionLocationListCompositeRepository: NetworkProt @MainActor private static var cacheTimestamp = Date() private static let cacheValidity = TimeInterval(60) // Refreshes at most once per minute private let client: NetworkProtectionClient - private let tokenStore: NetworkProtectionTokenStore + private let tokenProvider: any SubscriptionTokenProvider private let errorEvents: EventMapping convenience public init(environment: VPNSettings.SelectedEnvironment, - tokenStore: NetworkProtectionTokenStore, + tokenProvider: any SubscriptionTokenProvider, errorEvents: EventMapping) { self.init( client: NetworkProtectionBackendClient(environment: environment), - tokenStore: tokenStore, + tokenProvider: tokenProvider, errorEvents: errorEvents ) } init(client: NetworkProtectionClient, - tokenStore: NetworkProtectionTokenStore, + tokenProvider: any SubscriptionTokenProvider, errorEvents: EventMapping) { self.client = client - self.tokenStore = tokenStore + self.tokenProvider = tokenProvider self.errorEvents = errorEvents } @@ -87,9 +89,7 @@ final public class NetworkProtectionLocationListCompositeRepository: NetworkProt @discardableResult func fetchLocationListFromRemote() async throws -> [NetworkProtectionLocation] { do { - guard let authToken = try tokenStore.fetchToken() else { - throw NetworkProtectionError.noAuthTokenFound - } + let authToken = try await VPNAuthTokenBuilder.getVPNAuthToken(from: tokenProvider, policy: .localValid) Self.locationList = try await client.getLocations(authToken: authToken).get() Self.cacheTimestamp = Date() } catch let error as NetworkProtectionErrorConvertible { @@ -98,6 +98,10 @@ final public class NetworkProtectionLocationListCompositeRepository: NetworkProt } catch let error as NetworkProtectionError { errorEvents.fire(error) throw error + } catch Networking.OAuthClientError.missingTokens { + let newError = NetworkProtectionError.noAuthTokenFound + errorEvents.fire(newError) + throw newError } catch { let unhandledError = NetworkProtectionError.unhandledError(function: #function, line: #line, error: error) errorEvents.fire(unhandledError) diff --git a/Sources/NetworkProtection/Settings/Extensions/UserDefaults+subscriptionOverrideEnabled.swift b/Sources/NetworkProtection/Settings/Extensions/UserDefaults+subscriptionOverrideEnabled.swift index 0e2d86ffc..cc123bfc9 100644 --- a/Sources/NetworkProtection/Settings/Extensions/UserDefaults+subscriptionOverrideEnabled.swift +++ b/Sources/NetworkProtection/Settings/Extensions/UserDefaults+subscriptionOverrideEnabled.swift @@ -34,7 +34,7 @@ extension UserDefaults { } } - public func resetsubscriptionOverrideEnabled() { + public func resetSubscriptionOverrideEnabled() { removeObject(forKey: subscriptionOverrideEnabledKey) } } diff --git a/Sources/NetworkProtection/StartupOptions.swift b/Sources/NetworkProtection/StartupOptions.swift index dcc9ef4c6..c9c2baa28 100644 --- a/Sources/NetworkProtection/StartupOptions.swift +++ b/Sources/NetworkProtection/StartupOptions.swift @@ -18,6 +18,8 @@ import Foundation import Common +import Networking +import os.log /// This class handles the proper parsing of the startup options for our tunnel. /// @@ -110,7 +112,7 @@ public struct StartupOptions { let dnsSettings: StoredOption public let excludeLocalNetworks: StoredOption #if os(macOS) - let authToken: StoredOption + let tokenContainer: StoredOption #endif let enableTester: StoredOption @@ -133,7 +135,7 @@ public struct StartupOptions { let resetStoredOptionsIfNil = startupMethod == .manualByMainApp #if os(macOS) - authToken = Self.readAuthToken(from: options, resetIfNil: resetStoredOptionsIfNil) + tokenContainer = Self.readAuthToken(from: options, resetIfNil: resetStoredOptionsIfNil) #endif enableTester = Self.readEnableTester(from: options, resetIfNil: resetStoredOptionsIfNil) keyValidity = Self.readKeyValidity(from: options, resetIfNil: resetStoredOptionsIfNil) @@ -165,14 +167,14 @@ public struct StartupOptions { // MARK: - Helpers for reading stored options #if os(macOS) - private static func readAuthToken(from options: [String: Any], resetIfNil: Bool) -> StoredOption { + private static func readAuthToken(from options: [String: Any], resetIfNil: Bool) -> StoredOption { StoredOption(resetIfNil: resetIfNil) { - guard let authToken = options[NetworkProtectionOptionKey.authToken] as? String, - !authToken.isEmpty else { + guard let data = options[NetworkProtectionOptionKey.tokenContainer] as? NSData, + let tokenContainer = try? TokenContainer(with: data) else { + Logger.networkProtection.error("`tokenContainer` is missing or invalid") return nil } - - return authToken + return tokenContainer } } #endif diff --git a/Sources/NetworkProtection/Status/ControlllerErrorMessageObserver/ControllerErrorMesssageObserverThroughDistributedNotifications.swift b/Sources/NetworkProtection/Status/ControlllerErrorMessageObserver/ControllerErrorMesssageObserverThroughDistributedNotifications.swift index 4263bfdee..2f60e8795 100644 --- a/Sources/NetworkProtection/Status/ControlllerErrorMessageObserver/ControllerErrorMesssageObserverThroughDistributedNotifications.swift +++ b/Sources/NetworkProtection/Status/ControlllerErrorMessageObserver/ControllerErrorMesssageObserverThroughDistributedNotifications.swift @@ -61,6 +61,8 @@ public class ControllerErrorMesssageObserverThroughDistributedNotifications: Con let errorMessage = notification.object as? String logErrorChanged(isShowingError: errorMessage != nil) + Logger.networkProtectionStatusReporter.debug("Received error message") + subject.send(errorMessage) } diff --git a/Sources/NetworkProtection/VPNAuthTokenBuilder.swift b/Sources/NetworkProtection/VPNAuthTokenBuilder.swift new file mode 100644 index 000000000..3de3d0a3e --- /dev/null +++ b/Sources/NetworkProtection/VPNAuthTokenBuilder.swift @@ -0,0 +1,33 @@ +// +// VPNAuthTokenBuilder.swift +// +// Copyright © 2023 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import Subscription +import Networking + +public struct VPNAuthTokenBuilder { + + public static func getVPNAuthToken(from tokenProvider: SubscriptionTokenProvider, policy: TokensCachePolicy) async throws -> String { + let token = try await tokenProvider.getTokenContainer(policy: policy).accessToken + return "ddg:\(token)" + } + + public static func getVPNAuthToken(from originalToken: String) -> String{ + return "ddg:\(originalToken)" + } +} diff --git a/Sources/NetworkProtectionTestUtils/KeyManagement/MockNetworkProtectionTokenStore.swift b/Sources/NetworkProtectionTestUtils/KeyManagement/MockNetworkProtectionTokenStore.swift deleted file mode 100644 index 3022b83c4..000000000 --- a/Sources/NetworkProtectionTestUtils/KeyManagement/MockNetworkProtectionTokenStore.swift +++ /dev/null @@ -1,54 +0,0 @@ -// -// MockNetworkProtectionTokenStore.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import NetworkProtection - -public final class MockNetworkProtectionTokenStorage: NetworkProtectionTokenStore { - - public init() { - - } - - var spyToken: String? - var storeError: Error? - - public func store(_ token: String) throws { - if let storeError { - throw storeError - } - spyToken = token - } - - var stubFetchToken: String? - - public func fetchToken() throws -> String? { - return stubFetchToken - } - - var didCallDeleteToken: Bool = false - - public func deleteToken() throws { - didCallDeleteToken = true - } - - public func fetchSubscriptionToken() throws -> String? { - try fetchToken() - } - -} diff --git a/Sources/Networking/OAuth/Logger+OAuth.swift b/Sources/Networking/OAuth/Logger+OAuth.swift new file mode 100644 index 000000000..9d1248ab9 --- /dev/null +++ b/Sources/Networking/OAuth/Logger+OAuth.swift @@ -0,0 +1,25 @@ +// +// Logger+OAuth.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import os.log + +public extension Logger { + static var OAuth = { Logger(subsystem: "Networking", category: "OAuth") }() + static var OAuthClient = { Logger(subsystem: "Networking", category: "OAuthClient") }() +} diff --git a/Sources/Networking/OAuth/OAuthClient.swift b/Sources/Networking/OAuth/OAuthClient.swift new file mode 100644 index 000000000..0b107c968 --- /dev/null +++ b/Sources/Networking/OAuth/OAuthClient.swift @@ -0,0 +1,454 @@ +// +// OAuthClient.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import os.log + +public enum OAuthClientError: Error, LocalizedError, Equatable { + case internalError(String) + case missingTokens + case missingRefreshToken + case unauthenticated + /// When both access token and refresh token are expired + case deadToken + + public var errorDescription: String? { + switch self { + case .internalError(let error): + return "Internal error: \(error)" + case .missingTokens: + return "No token available" + case .missingRefreshToken: + return "No refresh token available, please re-authenticate" + case .unauthenticated: + return "The account is not authenticated, please re-authenticate" + case .deadToken: + return "The token can't be refreshed" + } + } +} + +/// Provides the locally stored tokens container +public protocol TokenStoring { + var tokenContainer: TokenContainer? { get set } +} + +/// Provides the legacy AuthToken V1 +public protocol LegacyTokenStoring { + var token: String? { get set } +} + +public enum TokensCachePolicy { + /// The locally stored one as it is, valid or not + case local + /// The locally stored one refreshed + case localValid + + /// The locally stored one and force the refresh + case localForceRefresh + + /// Local refreshed, if doesn't exist create a new one + case createIfNeeded + + public var description: String { + switch self { + case .local: + return "Local" + case .localValid: + return "Local valid" + case .localForceRefresh: + return "Local force refresh" + case .createIfNeeded: + return "Create if needed" + } + } +} + +public protocol OAuthClient { + + // MARK: - Public + + var isUserAuthenticated: Bool { get } + + var currentTokenContainer: TokenContainer? { get set } + + /// Returns a tokens container based on the policy + /// - `.local`: Returns what's in the storage, as it is, throws an error if no token is available + /// - `.localValid`: Returns what's in the storage, refreshes it if needed. throws an error if no token is available + /// - `.localForceRefresh`: Returns what's in the storage but forces a refresh first. throws an error if no refresh token is available. + /// - `.createIfNeeded`: Returns what's in the storage, if the stored token is expired refreshes it, if not token is available creates a new account/token + /// All options store new or refreshed tokens via the tokensStorage + func getTokens(policy: TokensCachePolicy) async throws -> TokenContainer + + // MARK: Activate + + /// Activate the account with a platform signature + /// - Parameter signature: The platform signature + /// - Returns: A container of tokens + func activate(withPlatformSignature signature: String) async throws -> TokenContainer + + // MARK: Exchange + + /// Exchange token v1 for tokens v2 + /// - Parameter accessTokenV1: The legacy auth token + /// - Returns: A TokenContainer with access and refresh tokens + func exchange(accessTokenV1: String) async throws -> TokenContainer + + // MARK: Logout + + /// Logout by invalidating the current access token + func logout() async throws + + /// Remove the tokens container stored locally + func removeLocalAccount() +} + +final public class DefaultOAuthClient: OAuthClient { + + private struct Constants { + /// https://app.asana.com/0/1205784033024509/1207979495854201/f + static let clientID = "f4311287-0121-40e6-8bbd-85c36daf1837" + static let redirectURI = "com.duckduckgo:/authcb" + static let availableScopes = [ "privacypro" ] + } + + // MARK: - + + private let authService: any OAuthService + private var tokenStorage: any TokenStoring + public var legacyTokenStorage: (any LegacyTokenStoring)? + + public init(tokensStorage: any TokenStoring, + legacyTokenStorage: (any LegacyTokenStoring)? = nil, + authService: OAuthService) { + self.tokenStorage = tokensStorage + self.authService = authService + } + + // MARK: - Internal + + @discardableResult + private func getTokens(authCode: String, codeVerifier: String) async throws -> TokenContainer { + Logger.OAuthClient.log("Getting tokens") + let getTokensResponse = try await authService.getAccessToken(clientID: Constants.clientID, + codeVerifier: codeVerifier, + code: authCode, + redirectURI: Constants.redirectURI) + return try await decode(accessToken: getTokensResponse.accessToken, refreshToken: getTokensResponse.refreshToken) + } + + private func getVerificationCodes() async throws -> (codeVerifier: String, codeChallenge: String) { + Logger.OAuthClient.log("Getting verification codes") + let codeVerifier = OAuthCodesGenerator.codeVerifier + guard let codeChallenge = OAuthCodesGenerator.codeChallenge(codeVerifier: codeVerifier) else { + Logger.OAuthClient.error("Failed to get verification codes") + throw OAuthClientError.internalError("Failed to generate code challenge") + } + return (codeVerifier, codeChallenge) + } + +#if DEBUG + var testingDecodedTokenContainer: TokenContainer? +#endif + private func decode(accessToken: String, refreshToken: String) async throws -> TokenContainer { + Logger.OAuthClient.log("Decoding tokens") + +#if DEBUG + if let testingDecodedTokenContainer { + return testingDecodedTokenContainer + } +#endif + + let jwtSigners = try await authService.getJWTSigners() + let decodedAccessToken = try jwtSigners.verify(accessToken, as: JWTAccessToken.self) + let decodedRefreshToken = try jwtSigners.verify(refreshToken, as: JWTRefreshToken.self) + + return TokenContainer(accessToken: accessToken, + refreshToken: refreshToken, + decodedAccessToken: decodedAccessToken, + decodedRefreshToken: decodedRefreshToken) + } + + // MARK: - Public + + public var isUserAuthenticated: Bool { + tokenStorage.tokenContainer != nil + } + + public var currentTokenContainer: TokenContainer? { + get { + tokenStorage.tokenContainer + } + set { + tokenStorage.tokenContainer = newValue + } + } + + public func getTokens(policy: TokensCachePolicy) async throws -> TokenContainer { + let localTokenContainer: TokenContainer? + // V1 to V2 tokens migration + if let migratedTokenContainer = await migrateLegacyTokenIfNeeded() { + localTokenContainer = migratedTokenContainer + } else { + localTokenContainer = tokenStorage.tokenContainer + } + + switch policy { + case .local: + if let localTokenContainer { + Logger.OAuthClient.debug("Local tokens found, expiry: \(localTokenContainer.decodedAccessToken.exp.value)") + return localTokenContainer + } else { + Logger.OAuthClient.debug("Tokens not found") + throw OAuthClientError.missingTokens + } + case .localValid: + if let localTokenContainer { + Logger.OAuthClient.debug("Local tokens found, expiry: \(localTokenContainer.decodedAccessToken.exp.value)") + if localTokenContainer.decodedAccessToken.isExpired() { + Logger.OAuthClient.debug("Local access token is expired, refreshing it") + return try await getTokens(policy: .localForceRefresh) + } else { + return localTokenContainer + } + } else { + Logger.OAuthClient.debug("Tokens not found") + throw OAuthClientError.missingTokens + } + case .localForceRefresh: + guard let refreshToken = localTokenContainer?.refreshToken else { + Logger.OAuthClient.debug("Refresh token not found") + throw OAuthClientError.missingRefreshToken + } + do { + let refreshTokenResponse = try await authService.refreshAccessToken(clientID: Constants.clientID, refreshToken: refreshToken) + let refreshedTokens = try await decode(accessToken: refreshTokenResponse.accessToken, refreshToken: refreshTokenResponse.refreshToken) + Logger.OAuthClient.debug("Tokens refreshed: \(refreshedTokens.debugDescription)") + tokenStorage.tokenContainer = refreshedTokens + return refreshedTokens + } catch OAuthServiceError.authAPIError(let code) where code == OAuthRequest.BodyErrorCode.invalidTokenRequest { + Logger.OAuthClient.error("Failed to refresh token") + throw OAuthClientError.deadToken + } catch OAuthServiceError.authAPIError(let code) { + Logger.OAuthClient.error("Failed to refresh token: \(code.rawValue, privacy: .public), \(code.description, privacy: .public)") + throw OAuthServiceError.authAPIError(code: code) + } + case .createIfNeeded: + do { + return try await getTokens(policy: .localValid) + } catch { + Logger.OAuthClient.debug("Local token not found, creating a new account") + let tokens = try await createAccount() + tokenStorage.tokenContainer = tokens + return tokens + } + } + } + + /// Tries to retrieve the v1 auth token stored locally, if present performs a migration to v2 and removes the old token + private func migrateLegacyTokenIfNeeded() async -> TokenContainer? { + guard var legacyTokenStorage, + let legacyToken = legacyTokenStorage.token else { + return nil + } + + Logger.OAuthClient.log("Migrating legacy token") + do { + let tokenContainer = try await exchange(accessTokenV1: legacyToken) + Logger.OAuthClient.log("Tokens migrated successfully, removing legacy token") + + // Remove old token + legacyTokenStorage.token = nil + + // Store new tokens + tokenStorage.tokenContainer = tokenContainer + + return tokenContainer + } catch { + Logger.OAuthClient.error("Failed to migrate legacy token: \(error, privacy: .public)") + return nil + } + } + + // MARK: Create + + /// Create an accounts, stores all tokens and returns them + private func createAccount() async throws -> TokenContainer { + Logger.OAuthClient.log("Creating new account") + let (codeVerifier, codeChallenge) = try await getVerificationCodes() + let authSessionID = try await authService.authorize(codeChallenge: codeChallenge) + let authCode = try await authService.createAccount(authSessionID: authSessionID) + let tokenContainer = try await getTokens(authCode: authCode, codeVerifier: codeVerifier) + Logger.OAuthClient.log("New account created successfully") + return tokenContainer + } + + // MARK: Activate + + /* + /// Helper, single use + public class EmailAccountActivator { + + private let oAuthClient: any OAuthClient + private var email: String? + private var authSessionID: String? + private var codeVerifier: String? + + public init(oAuthClient: any OAuthClient) { + self.oAuthClient = oAuthClient + } + + public func activateWith(email: String) async throws { + self.email = email + let (authSessionID, codeVerifier) = try await oAuthClient.requestOTP(email: email) + self.authSessionID = authSessionID + self.codeVerifier = codeVerifier + } + + public func confirm(otp: String) async throws { + guard let codeVerifier, let authSessionID, let email else { return } + try await oAuthClient.activate(withOTP: otp, email: email, codeVerifier: codeVerifier, authSessionID: authSessionID) + } + } + + public func requestOTP(email: String) async throws -> (authSessionID: String, codeVerifier: String) { + Logger.OAuthClient.log("Requesting OTP") + let (codeVerifier, codeChallenge) = try await getVerificationCodes() + let authSessionID = try await authService.authorize(codeChallenge: codeChallenge) + try await authService.requestOTP(authSessionID: authSessionID, emailAddress: email) + return (authSessionID, codeVerifier) // to be used in activate(withOTP or activate(withPlatformSignature + } + + public func activate(withOTP otp: String, email: String, codeVerifier: String, authSessionID: String) async throws { + Logger.OAuthClient.log("Activating with OTP") + let authCode = try await authService.login(withOTP: otp, authSessionID: authSessionID, email: email) + try await getTokens(authCode: authCode, codeVerifier: codeVerifier) + } + */ + + public func activate(withPlatformSignature signature: String) async throws -> TokenContainer { + Logger.OAuthClient.log("Activating with platform signature") + let (codeVerifier, codeChallenge) = try await getVerificationCodes() + let authSessionID = try await authService.authorize(codeChallenge: codeChallenge) + let authCode = try await authService.login(withSignature: signature, authSessionID: authSessionID) + let tokens = try await getTokens(authCode: authCode, codeVerifier: codeVerifier) + tokenStorage.tokenContainer = tokens + Logger.OAuthClient.log("Activation completed") + return tokens + } + + // MARK: Refresh + +// private func refreshTokens() async throws -> TokenContainer { +// Logger.OAuthClient.log("Refreshing tokens") +// guard let refreshToken = tokenStorage.tokenContainer?.refreshToken else { +// throw OAuthClientError.missingRefreshToken +// } +// +// do { +// let refreshTokenResponse = try await authService.refreshAccessToken(clientID: Constants.clientID, refreshToken: refreshToken) +// let refreshedTokens = try await decode(accessToken: refreshTokenResponse.accessToken, refreshToken: refreshTokenResponse.refreshToken) +// Logger.OAuthClient.log("Tokens refreshed: \(refreshedTokens.debugDescription)") +// tokenStorage.tokenContainer = refreshedTokens +// return refreshedTokens +// } catch OAuthServiceError.authAPIError(let code) { +// if code == OAuthRequest.BodyErrorCode.invalidTokenRequest { +// Logger.OAuthClient.error("Failed to refresh token") +// throw OAuthClientError.deadToken +// } else { +// Logger.OAuthClient.error("Failed to refresh token: \(code.rawValue, privacy: .public), \(code.description, privacy: .public)") +// throw OAuthServiceError.authAPIError(code: code) +// } +// } catch { +// Logger.OAuthClient.error("Failed to refresh token: \(error, privacy: .public)") +// throw error +// } +// } + + // MARK: Exchange V1 to V2 token + + public func exchange(accessTokenV1: String) async throws -> TokenContainer { + Logger.OAuthClient.log("Exchanging access token V1 to V2") + let (codeVerifier, codeChallenge) = try await getVerificationCodes() + let authSessionID = try await authService.authorize(codeChallenge: codeChallenge) + let authCode = try await authService.exchangeToken(accessTokenV1: accessTokenV1, authSessionID: authSessionID) + let tokenContainer = try await getTokens(authCode: authCode, codeVerifier: codeVerifier) + tokenStorage.tokenContainer = tokenContainer + return tokenContainer + } + + // MARK: Logout + + public func logout() async throws { + let existingToken = tokenStorage.tokenContainer?.accessToken + removeLocalAccount() + + if let existingToken { + Logger.OAuthClient.log("Logging out") + try await authService.logout(accessToken: existingToken) + } + } + + public func removeLocalAccount() { + Logger.OAuthClient.log("Removing local account") + tokenStorage.tokenContainer = nil + legacyTokenStorage?.token = nil + } + + /* MARK: Edit account + + /// Helper, single use + public class AccountEditor { + + private let oAuthClient: any OAuthClient + private var hashString: String? + private var email: String? + + public init(oAuthClient: any OAuthClient) { + self.oAuthClient = oAuthClient + } + + public func change(email: String?) async throws { + self.hashString = try await self.oAuthClient.changeAccount(email: email) + } + + public func send(otp: String) async throws { + guard let email, let hashString else { + throw OAuthClientError.internalError("Missing email or hashString") + } + try await oAuthClient.confirmChangeAccount(email: email, otp: otp, hash: hashString) + try await oAuthClient.refreshTokens() + } + } + + public func changeAccount(email: String?) async throws -> String { + guard let token = tokensStorage.tokenContainer?.accessToken else { + throw OAuthClientError.unauthenticated + } + let editAccountResponse = try await authService.editAccount(clientID: Constants.clientID, accessToken: token, email: email) + return editAccountResponse.hash + } + + public func confirmChangeAccount(email: String, otp: String, hash: String) async throws { + guard let token = tokensStorage.tokenContainer?.accessToken else { + throw OAuthClientError.unauthenticated + } + _ = try await authService.confirmEditAccount(accessToken: token, email: email, hash: hash, otp: otp) + } + */ +} diff --git a/Sources/Networking/OAuth/OAuthCodesGenerator.swift b/Sources/Networking/OAuth/OAuthCodesGenerator.swift new file mode 100644 index 000000000..5210f9387 --- /dev/null +++ b/Sources/Networking/OAuth/OAuthCodesGenerator.swift @@ -0,0 +1,55 @@ +// +// OAuthCodesGenerator.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import CommonCrypto + +/// Helper that generates codes used in the OAuth2 authentication process +struct OAuthCodesGenerator { + + /// https://auth0.com/docs/get-started/authentication-and-authorization-flow/authorization-code-flow-with-pkce/add-login-using-the-authorization-code-flow-with-pkce#create-code-verifier + static var codeVerifier: String { + var buffer = [UInt8](repeating: 0, count: 128) + _ = SecRandomCopyBytes(kSecRandomDefault, buffer.count, &buffer) + return Data(buffer).base64EncodedString().replacingInvalidCharacters() + } + + /// https://auth0.com/docs/get-started/authentication-and-authorization-flow/authorization-code-flow-with-pkce/add-login-using-the-authorization-code-flow-with-pkce#create-code-challenge + static func codeChallenge(codeVerifier: String) -> String? { + + guard let data = codeVerifier.data(using: .utf8) else { + assertionFailure("Failed to generate OAuth2 code challenge") + return nil + } + var buffer = [UInt8](repeating: 0, count: Int(CC_SHA256_DIGEST_LENGTH)) + _ = data.withUnsafeBytes { + CC_SHA256($0.baseAddress, CC_LONG(data.count), &buffer) + } + let hash = Data(buffer) + return hash.base64EncodedString().replacingInvalidCharacters() + } +} + +fileprivate extension String { + + func replacingInvalidCharacters() -> String { + self.replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") + } +} diff --git a/Sources/SubscriptionTestingUtilities/AccountStorage/SubscriptionTokenKeychainStorageMock.swift b/Sources/Networking/OAuth/OAuthEnvironment.swift similarity index 51% rename from Sources/SubscriptionTestingUtilities/AccountStorage/SubscriptionTokenKeychainStorageMock.swift rename to Sources/Networking/OAuth/OAuthEnvironment.swift index 7bb5b77a5..878b974ed 100644 --- a/Sources/SubscriptionTestingUtilities/AccountStorage/SubscriptionTokenKeychainStorageMock.swift +++ b/Sources/Networking/OAuth/OAuthEnvironment.swift @@ -1,5 +1,5 @@ // -// SubscriptionTokenKeychainStorageMock.swift +// OAuthEnvironment.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -17,28 +17,25 @@ // import Foundation -import Subscription -public final class SubscriptionTokenKeychainStorageMock: SubscriptionTokenStoring { +public enum OAuthEnvironment: String, Codable, CustomStringConvertible { + case production, staging - public var accessToken: String? - - public var removeAccessTokenCalled: Bool = false - - public init(accessToken: String? = nil) { - self.accessToken = accessToken - } - - public func getAccessToken() throws -> String? { - accessToken - } - - public func store(accessToken: String) throws { - self.accessToken = accessToken + public var description: String { + switch self { + case .production: + "Production" + case .staging: + "Staging" + } } - public func removeAccessToken() throws { - removeAccessTokenCalled = true - accessToken = nil + public var url: URL { + switch self { + case .production: + URL(string: "https://quack.duckduckgo.com")! + case .staging: + URL(string: "https://quackdev.duckduckgo.com")! + } } } diff --git a/Sources/Networking/OAuth/OAuthRequest.swift b/Sources/Networking/OAuth/OAuthRequest.swift new file mode 100644 index 000000000..849ab64aa --- /dev/null +++ b/Sources/Networking/OAuth/OAuthRequest.swift @@ -0,0 +1,380 @@ +// +// OAuthRequest.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import os.log +import Common + +/// Auth API v2 Endpoints: https://dub.duckduckgo.com/duckduckgo/ddg/blob/main/components/auth/docs/AuthAPIV2Documentation.md#auth-api-v2-endpoints +public struct OAuthRequest { + + public let apiRequest: APIRequestV2 + public let httpSuccessCode: HTTPStatusCode + public let httpErrorCodes: [HTTPStatusCode] + public var url: URL { + apiRequest.urlRequest.url! + } + + public enum BodyErrorCode: String, Decodable { + case invalidAuthorizationRequest = "invalid_authorization_request" + case authorizeFailed = "authorize_failed" + case invalidRequest = "invalid_request" + case accountCreateFailed = "account_create_failed" + case invalidEmailAddress = "invalid_email_address" + case invalidSessionId = "invalid_session_id" + case suspendedAccount = "suspended_account" + case emailSendingError = "email_sending_error" + case invalidLoginCredentials = "invalid_login_credentials" + case unknownAccount = "unknown_account" + case invalidTokenRequest = "invalid_token_request" + case unverifiedAccount = "unverified_account" + case emailAddressNotChanged = "email_address_not_changed" + case failedMxCheck = "failed_mx_check" + case accountEditFailed = "account_edit_failed" + case invalidLinkSignature = "invalid_link_signature" + case accountChangeEmailAddressFailed = "account_change_email_address_failed" + case invalidToken = "invalid_token" + case expiredToken = "expired_token" + + public var description: String { + switch self { + case .invalidAuthorizationRequest: + return "One or more of the required parameters are missing or any provided parameters have invalid values" + case .authorizeFailed: + return "Failed to create the authorization session, either because of a reused code challenge or internal server error" + case .invalidRequest: + return "The ddg_auth_session_id is missing or has already been used to log in to a different account" + case .accountCreateFailed: + return "Failed to create the account because of an internal server error" + case .invalidEmailAddress: + return "Provided email address is missing or of an invalid format" + case .invalidSessionId: + return "The session id is missing, invalid or has already been used for logging in" + case .suspendedAccount: + return "The account you are logging in to is suspended" + case .emailSendingError: + return "Failed to send the OTP to the email address provided" + case .invalidLoginCredentials: + return "One or more of the provided parameters is invalid" + case .unknownAccount: + return "The login credentials appear valid but do not link to a known account" + case .invalidTokenRequest: + return "One or more of the required parameters are missing or any provided parameters have invalid values" + case .unverifiedAccount: + return "The token is valid but is for an unverified account" + case .emailAddressNotChanged: + return "New email address is the same as the old email address" + case .failedMxCheck: + return "DNS check to see if email address domain is valid failed" + case .accountEditFailed: + return "Something went wrong and the edit was aborted" + case .invalidLinkSignature: + return "The hash is invalid or does not match the provided email address and account" + case .accountChangeEmailAddressFailed: + return "Something went wrong and the edit was aborted" + case .invalidToken: + return "Provided access token is missing or invalid" + case .expiredToken: + return "Provided access token is expired" + } + } + } + + struct BodyError: Decodable { + let error: BodyErrorCode + } + + static func ddgAuthSessionCookie(domain: String, path: String, authSessionID: String) -> HTTPCookie? { + return HTTPCookie(properties: [ + .domain: domain, + .path: path, + .name: "ddg_auth_session_id", + .value: authSessionID + ]) + } + + // MARK: - + + init(apiRequest: APIRequestV2, + httpSuccessCode: HTTPStatusCode = HTTPStatusCode.ok, + httpErrorCodes: [HTTPStatusCode] = [HTTPStatusCode.badRequest, HTTPStatusCode.internalServerError]) { + self.apiRequest = apiRequest + self.httpSuccessCode = httpSuccessCode + self.httpErrorCodes = httpErrorCodes + } + + // MARK: Authorize + + static func authorize(baseURL: URL, codeChallenge: String) -> OAuthRequest? { + guard codeChallenge.isEmpty == false else { return nil } + + let path = "/api/auth/v2/authorize" + let queryItems = [ + "response_type": "code", + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "client_id": "f4311287-0121-40e6-8bbd-85c36daf1837", + "redirect_uri": "com.duckduckgo:/authcb", + "scope": "privacypro" + ] + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .get, + queryItems: queryItems) else { + return nil + } + return OAuthRequest(apiRequest: request, httpSuccessCode: HTTPStatusCode.found) + } + + // MARK: Create account + + static func createAccount(baseURL: URL, authSessionID: String) -> OAuthRequest? { + guard authSessionID.isEmpty == false else { return nil } + + let path = "/api/auth/v2/account/create" + guard let domain = baseURL.host, + let cookie = Self.ddgAuthSessionCookie(domain: domain, path: path, authSessionID: authSessionID) + else { return nil } + + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .post, + headers: APIRequestV2.HeadersV2(cookies: [cookie])) else { + return nil + } + return OAuthRequest(apiRequest: request, httpSuccessCode: HTTPStatusCode.found) + } + + // MARK: Sent OTP + + /// Unused in the current implementation + static func requestOTP(baseURL: URL, authSessionID: String, emailAddress: String) -> OAuthRequest? { + guard authSessionID.isEmpty == false, + emailAddress.isEmpty == false else { return nil } + + let path = "/api/auth/v2/otp" + let queryItems = [ "email": emailAddress ] + guard let domain = baseURL.host, + let cookie = Self.ddgAuthSessionCookie(domain: domain, path: path, authSessionID: authSessionID) + else { return nil } + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .post, + queryItems: queryItems, + headers: APIRequestV2.HeadersV2(cookies: [cookie])) else { + return nil + } + return OAuthRequest(apiRequest: request) + } + + // MARK: Login + + static func login(baseURL: URL, authSessionID: String, method: OAuthLoginMethod) -> OAuthRequest? { + guard authSessionID.isEmpty == false else { return nil } + + let path = "/api/auth/v2/login" + var body: [String: String] + + guard let domain = baseURL.host, + let cookie = Self.ddgAuthSessionCookie(domain: domain, path: path, authSessionID: authSessionID) + else { + Logger.OAuth.fault("Failed to create cookie") + assertionFailure("Failed to create cookie") + return nil + } + + switch method.self { + case is OAuthLoginMethodOTP: + guard let otpMethod = method as? OAuthLoginMethodOTP else { + return nil + } + body = [ + "method": otpMethod.name, + "email": otpMethod.email, + "otp": otpMethod.otp + ] + case is OAuthLoginMethodSignature: + guard let signatureMethod = method as? OAuthLoginMethodSignature else { + return nil + } + body = [ + "method": signatureMethod.name, + "signature": signatureMethod.signature, + "source": signatureMethod.source + ] + default: + Logger.OAuth.fault("Unknown login method: \(String(describing: method))") + assertionFailure("Unknown login method: \(String(describing: method))") + return nil + } + + guard let jsonBody = CodableHelper.encode(body) else { + assertionFailure("Failed to encode body: \(body)") + return nil + } + + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .post, + headers: APIRequestV2.HeadersV2(cookies: [cookie], + contentType: .json), + body: jsonBody, + retryPolicy: APIRequestV2.RetryPolicy(maxRetries: 3, delay: 2)) else { + return nil + } + return OAuthRequest(apiRequest: request, httpSuccessCode: HTTPStatusCode.found) + } + + // MARK: Access Token + // Note: The API has a single endpoint for both getting a new token and refreshing an old one, but here I'll split the endpoint in 2 different calls for clarity + // https://dub.duckduckgo.com/duckduckgo/ddg/blob/main/components/auth/docs/AuthAPIV2Documentation.md#access-token + + static func getAccessToken(baseURL: URL, clientID: String, codeVerifier: String, code: String, redirectURI: String) -> OAuthRequest? { + guard clientID.isEmpty == false, + codeVerifier.isEmpty == false, + code.isEmpty == false, + redirectURI.isEmpty == false else { return nil } + + let path = "/api/auth/v2/token" + let queryItems = [ + "grant_type": "authorization_code", + "client_id": clientID, + "code_verifier": codeVerifier, + "code": code, + "redirect_uri": redirectURI + ] + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .get, + queryItems: queryItems) else { + return nil + } + + return OAuthRequest(apiRequest: request) + } + + static func refreshAccessToken(baseURL: URL, clientID: String, refreshToken: String) -> OAuthRequest? { + guard clientID.isEmpty == false, + refreshToken.isEmpty == false else { return nil } + + let path = "/api/auth/v2/token" + let queryItems = [ + "grant_type": "refresh_token", + "client_id": clientID, + "refresh_token": refreshToken, + ] + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .get, + queryItems: queryItems) else { + return nil + } + return OAuthRequest(apiRequest: request) + } + + // MARK: Edit Account + + /// Unused in the current implementation + static func editAccount(baseURL: URL, accessToken: String, email: String?) -> OAuthRequest? { + guard accessToken.isEmpty == false else { return nil } + + let path = "/api/auth/v2/account/edit" + var queryItems: [String: String] = [:] + if let email { + queryItems["email"] = email + } + + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .post, + queryItems: queryItems, + headers: APIRequestV2.HeadersV2( + authToken: accessToken)) else { + return nil + } + return OAuthRequest(apiRequest: request, httpErrorCodes: [.unauthorized, .internalServerError]) + } + + /// Unused in the current implementation + static func confirmEditAccount(baseURL: URL, accessToken: String, email: String, hash: String, otp: String) -> OAuthRequest? { + guard accessToken.isEmpty == false, + email.isEmpty == false, + hash.isEmpty == false, + otp.isEmpty == false else { return nil } + + let path = "/account/edit/confirm" + let queryItems = [ + "email": email, + "hash": hash, + "otp": otp, + ] + + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .get, + queryItems: queryItems, + headers: APIRequestV2.HeadersV2(authToken: accessToken)) else { + return nil + } + return OAuthRequest(apiRequest: request, httpErrorCodes: [.unauthorized, .internalServerError]) + } + + // MARK: Logout + + static func logout(baseURL: URL, accessToken: String) -> OAuthRequest? { + guard accessToken.isEmpty == false else { return nil } + + let path = "/api/auth/v2/logout" + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .post, + headers: APIRequestV2.HeadersV2(authToken: accessToken)) else { + return nil + } + return OAuthRequest(apiRequest: request, httpErrorCodes: [.unauthorized, .internalServerError]) + } + + // MARK: Exchange token + + static func exchangeToken(baseURL: URL, accessTokenV1: String, authSessionID: String) -> OAuthRequest? { + guard accessTokenV1.isEmpty == false, + authSessionID.isEmpty == false else { return nil } + + let path = "/api/auth/v2/exchange" + guard let domain = baseURL.host, + let cookie = Self.ddgAuthSessionCookie(domain: domain, path: path, authSessionID: authSessionID) + else { return nil } + + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .post, + headers: APIRequestV2.HeadersV2(cookies: [cookie], + authToken: accessTokenV1)) else { + return nil + } + return OAuthRequest(apiRequest: request, + httpSuccessCode: .found, + httpErrorCodes: [.badRequest, .internalServerError]) + } + + // MARK: JWKs + + /// This endpoint is where the Auth service will publish public keys for consuming services and clients to use to independently verify access tokens. Tokens should be downloaded and cached for an hour upon first use. When rotating private keys for signing JWTs, the Auth service will publish new public keys 24 hours in advance of starting to sign new JWTs with them. This should provide consuming services with plenty of time to invalidate their public key cache and have the new key available before they can expect to start receiving JWTs signed with the old key. The old key will remain published until the next key rotation, so there should generally be two public keys available through this endpoint. The response format is a standard JWKS response, as documented in RFC 7517. + static func jwks(baseURL: URL) -> OAuthRequest? { + let path = "/api/auth/v2/.well-known/jwks.json" + + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .get, + retryPolicy: APIRequestV2.RetryPolicy(maxRetries: 2, delay: 1)) else { + return nil + } + return OAuthRequest(apiRequest: request, + httpSuccessCode: .ok, + httpErrorCodes: [.internalServerError]) + } +} diff --git a/Sources/Networking/OAuth/OAuthService.swift b/Sources/Networking/OAuth/OAuthService.swift new file mode 100644 index 000000000..9cc7e8888 --- /dev/null +++ b/Sources/Networking/OAuth/OAuthService.swift @@ -0,0 +1,431 @@ +// +// OAuthService.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import JWTKit + +public protocol OAuthService { + + /// Authorizes a user with a given code challenge. + /// - Parameter codeChallenge: The code challenge for authorization. + /// - Returns: An OAuthSessionID. + /// - Throws: An error if the authorization fails. + func authorize(codeChallenge: String) async throws -> OAuthSessionID + + /// Creates a new account using the provided auth session ID. + /// - Parameter authSessionID: The authentication session ID. + /// - Returns: The authorization code needed for the Access Token request. + /// - Throws: An error if account creation fails. + func createAccount(authSessionID: String) async throws -> AuthorisationCode + + /// Logs in a user with a signature and auth session ID. + /// - Parameters: + /// - signature: The platform signature + /// - authSessionID: The authentication session ID. + /// - Returns: An OAuthRedirectionURI. + /// - Throws: An error if login fails. + func login(withSignature signature: String, authSessionID: String) async throws -> AuthorisationCode + + /// Retrieves an access token using the provided parameters. + /// - Parameters: + /// - clientID: The client ID. + /// - codeVerifier: The code verifier. + /// - code: The authorization code. + /// - redirectURI: The redirect URI. + /// - Returns: An OAuthTokenResponse. + /// - Throws: An error if token retrieval fails. + func getAccessToken(clientID: String, codeVerifier: String, code: String, redirectURI: String) async throws -> OAuthTokenResponse + + /// Refreshes an access token using the provided client ID and refresh token. + /// - Parameters: + /// - clientID: The client ID. + /// - refreshToken: The refresh token. + /// - Returns: An OAuthTokenResponse. + /// - Throws: An error if token refresh fails. + func refreshAccessToken(clientID: String, refreshToken: String) async throws -> OAuthTokenResponse + + /// Logs out the user using the provided access token. + /// - Parameter accessToken: The access token. + /// - Throws: An error if logout fails. + func logout(accessToken: String) async throws + + /// Exchanges an access token for a new one. + /// - Parameters: + /// - accessTokenV1: The old access token. + /// - authSessionID: The authentication session ID. + /// - Returns: An OAuthRedirectionURI. + /// - Throws: An error if the exchange fails. + func exchangeToken(accessTokenV1: String, authSessionID: String) async throws -> AuthorisationCode + + /// Retrieves JWT signers using JWKs from the endpoint. + /// - Returns: A JWTSigners instance. + /// - Throws: An error if retrieval fails. + func getJWTSigners() async throws -> JWTSigners +} + +public struct DefaultOAuthService: OAuthService { + + private let baseURL: URL + private let apiService: any APIService + + /// Default initialiser + /// - Parameters: + /// - baseURL: The API protocol + host url, used for building all API requests' URL + public init(baseURL: URL, apiService: any APIService) { + self.baseURL = baseURL + self.apiService = apiService + } + + /// Extract an header from the HTTP response + /// - Parameters: + /// - header: The header key + /// - httpResponse: The HTTP URL Response + /// - Returns: The header value, throws an error if not present + internal func extract(header: String, from httpResponse: HTTPURLResponse) throws -> String { + let headers = httpResponse.allHeaderFields + guard let result = headers[header] as? String else { + throw OAuthServiceError.missingResponseValue(header) + } + return result + } + + /// Extract an API error from the HTTP response body. + /// The Auth API can answer with errors in the HTTP response body, format: `{ "error": "$error_code" }`, this function decodes the body in `AuthRequest.BodyError`and generates an AuthServiceError containing the error info + /// - Parameter responseBody: The HTTP response body Data + /// - Returns: and AuthServiceError.authAPIError containing the error code and description, nil if the body + internal func extractError(from response: APIResponseV2) -> OAuthServiceError? { + if let bodyError: OAuthRequest.BodyError = try? response.decodeBody() { + return OAuthServiceError.authAPIError(code: bodyError.error) + } + return nil + } + + internal func throwError(forResponse response: APIResponseV2) throws { + if let error = extractError(from: response) { + throw error + } else { + throw OAuthServiceError.missingResponseValue("Body error") + } + } + + internal func fetch(request: OAuthRequest) async throws -> T { + try Task.checkCancellation() + let response = try await apiService.fetch(request: request.apiRequest) + try Task.checkCancellation() + + let statusCode = response.httpResponse.httpStatus + if statusCode == request.httpSuccessCode { + return try response.decodeBody() + } else if request.httpErrorCodes.contains(statusCode) { + try throwError(forResponse: response) + } + throw OAuthServiceError.invalidResponseCode(statusCode) + } + + // MARK: - API requests + + // MARK: Authorize + + public func authorize(codeChallenge: String) async throws -> OAuthSessionID { + try Task.checkCancellation() + guard let request = OAuthRequest.authorize(baseURL: baseURL, codeChallenge: codeChallenge) else { + throw OAuthServiceError.invalidRequest + } + + let response = try await apiService.fetch(request: request.apiRequest) + try Task.checkCancellation() + + let statusCode = response.httpResponse.httpStatus + if statusCode == request.httpSuccessCode { + // let location = try extract(header: HTTPHeaderKey.location, from: response.httpResponse) + guard let cookieValue = response.httpResponse.getCookie(withName: "ddg_auth_session_id")?.value else { + throw OAuthServiceError.missingResponseValue("ddg_auth_session_id cookie") + } + return cookieValue + } else if request.httpErrorCodes.contains(statusCode) { + try throwError(forResponse: response) + } + throw OAuthServiceError.invalidResponseCode(statusCode) + } + + // MARK: Create Account + + public func createAccount(authSessionID: String) async throws -> AuthorisationCode { + try Task.checkCancellation() + guard let request = OAuthRequest.createAccount(baseURL: baseURL, authSessionID: authSessionID) else { + throw OAuthServiceError.invalidRequest + } + + let response = try await apiService.fetch(request: request.apiRequest) + try Task.checkCancellation() + + let statusCode = response.httpResponse.httpStatus + if statusCode == request.httpSuccessCode { + // The redirect URI from the original Authorization request indicated by the ddg_auth_session_id in the provided Cookie header, with the authorization code needed for the Access Token request appended as a query param. The intention is that the client will intercept this redirect and extract the authorization code to make the Access Token request in the background. + let redirectURI = try extract(header: HTTPHeaderKey.location, from: response.httpResponse) + + // Extract the code from the URL query params, example: com.duckduckgo:/authcb?code=NgNjnlLaqUomt9b5LDbzAtTyeW9cBNhCGtLB3vpcctluSZI51M9tb2ZDIZdijSPTYBr4w8dtVZl85zNSemxozv + guard let authCode = URLComponents(string: redirectURI)?.queryItems?.first(where: { queryItem in + queryItem.name == "code" + })?.value else { + throw OAuthServiceError.missingResponseValue("Authorization Code in redirect URI") + } + return authCode + } else if request.httpErrorCodes.contains(statusCode) { + try throwError(forResponse: response) + } + throw OAuthServiceError.invalidResponseCode(statusCode) + } + + /* MARK: Request OTP + + public func requestOTP(authSessionID: String, emailAddress: String) async throws { + try Task.checkCancellation() + guard let request = OAuthRequest.requestOTP(baseURL: baseURL, authSessionID: authSessionID, emailAddress: emailAddress) else { + throw OAuthServiceError.invalidRequest + } + + let response = try await apiService.fetch(request: request.apiRequest) + try Task.checkCancellation() + + let statusCode = response.httpResponse.httpStatus + if statusCode == request.httpSuccessCode { + } else if request.httpErrorCodes.contains(statusCode) { + try throwError(forResponse: response, request: request) + } + throw OAuthServiceError.invalidResponseCode(statusCode) + } + + // MARK: Login + + public func login(withOTP otp: String, authSessionID: String, email: String) async throws -> AuthorisationCode { + try Task.checkCancellation() + let method = OAuthLoginMethodOTP(email: email, otp: otp) + guard let request = OAuthRequest.login(baseURL: baseURL, authSessionID: authSessionID, method: method) else { + throw OAuthServiceError.invalidRequest + } + + let response = try await apiService.fetch(request: request.apiRequest) + try Task.checkCancellation() + + let statusCode = response.httpResponse.httpStatus + if statusCode == request.httpSuccessCode { + return try extract(header: HTTPHeaderKey.location, from: response.httpResponse) + } else if request.httpErrorCodes.contains(statusCode) { + try throwError(forResponse: response, request: request) + } + throw OAuthServiceError.invalidResponseCode(statusCode) + } + */ + + public func login(withSignature signature: String, authSessionID: String) async throws -> AuthorisationCode { + try Task.checkCancellation() + let method = OAuthLoginMethodSignature(signature: signature) + guard let request = OAuthRequest.login(baseURL: baseURL, authSessionID: authSessionID, method: method) else { + throw OAuthServiceError.invalidRequest + } + + let response = try await apiService.fetch(request: request.apiRequest) + try Task.checkCancellation() + + let statusCode = response.httpResponse.httpStatus + if statusCode == request.httpSuccessCode { + // "com.duckduckgo:/authcb?code=eud8rNxyq2lhN4VFwQ7CAcir80dFBRIE4YpPY0gqeunTw4j6SoWkN4AA2c0TNO1sohqe84zubUtERkLLl94Qam" + guard let locationHeaderValue = try? extract(header: HTTPHeaderKey.location, from: response.httpResponse), + let redirectURL = URL(string: locationHeaderValue), + let authCode = redirectURL.queryParameters()?["code"] else { + throw OAuthServiceError.missingResponseValue("Auth code") + } + return authCode + } else if request.httpErrorCodes.contains(statusCode) { + try throwError(forResponse: response) + } + throw OAuthServiceError.invalidResponseCode(statusCode) + } + + // MARK: Access token + + public func getAccessToken(clientID: String, codeVerifier: String, code: String, redirectURI: String) async throws -> OAuthTokenResponse { + guard let request = OAuthRequest.getAccessToken(baseURL: baseURL, clientID: clientID, codeVerifier: codeVerifier, code: code, redirectURI: redirectURI) else { + throw OAuthServiceError.invalidRequest + } + return try await fetch(request: request) + } + + public func refreshAccessToken(clientID: String, refreshToken: String) async throws -> OAuthTokenResponse { + guard let request = OAuthRequest.refreshAccessToken(baseURL: baseURL, clientID: clientID, refreshToken: refreshToken) else { + throw OAuthServiceError.invalidRequest + } + return try await fetch(request: request) + } + + /* MARK: Edit account + + /// Edit an account email address + /// - Parameters: + /// - email: The email address to change to. If omitted, the account email address will be removed. + /// - Returns: EditAccountResponse containing a status, always "confirmed" and an hash used in the `confirm edit account` API call + public func editAccount(clientID: String, accessToken: String, email: String?) async throws -> EditAccountResponse { + guard let request = OAuthRequest.editAccount(baseURL: baseURL, accessToken: accessToken, email: email) else { + throw OAuthServiceError.invalidRequest + } + return try await fetch(request: request) + } + + public func confirmEditAccount(accessToken: String, email: String, hash: String, otp: String) async throws -> ConfirmEditAccountResponse { + guard let request = OAuthRequest.confirmEditAccount(baseURL: baseURL, accessToken: accessToken, email: email, hash: hash, otp: otp) else { + throw OAuthServiceError.invalidRequest + } + return try await fetch(request: request) + } + */ + + // MARK: Logout + + public func logout(accessToken: String) async throws { + guard let request = OAuthRequest.logout(baseURL: baseURL, accessToken: accessToken) else { + throw OAuthServiceError.invalidRequest + } + let response: LogoutResponse = try await fetch(request: request) + guard response.status == "logged_out" else { + throw OAuthServiceError.missingResponseValue("LogoutResponse.status") + } + } + + // MARK: Access token exchange + + public func exchangeToken(accessTokenV1: String, authSessionID: String) async throws -> AuthorisationCode { + try Task.checkCancellation() + guard let request = OAuthRequest.exchangeToken(baseURL: baseURL, accessTokenV1: accessTokenV1, authSessionID: authSessionID) else { + throw OAuthServiceError.invalidRequest + } + let response = try await apiService.fetch(request: request.apiRequest) + try Task.checkCancellation() + + let statusCode = response.httpResponse.httpStatus + if statusCode == request.httpSuccessCode { + let redirectURI = try extract(header: HTTPHeaderKey.location, from: response.httpResponse) + // Extract the code from the URL query params, example: com.duckduckgo:/authcb?code=NgNj...ozv + guard let authCode = URLComponents(string: redirectURI)?.queryItems?.first(where: { queryItem in + queryItem.name == "code" + })?.value else { + throw OAuthServiceError.missingResponseValue("Authorization Code in redirect URI") + } + return authCode + } else if request.httpErrorCodes.contains(statusCode) { + try throwError(forResponse: response) + } + throw OAuthServiceError.invalidResponseCode(statusCode) + } + + // MARK: JWKs + + /// Create a JWTSigners with the JWKs provided by the endpoint + /// - Returns: A JWTSigners that can be used to verify JWTs + public func getJWTSigners() async throws -> JWTSigners { + try Task.checkCancellation() + guard let request = OAuthRequest.jwks(baseURL: baseURL) else { + throw OAuthServiceError.invalidRequest + } + try Task.checkCancellation() + let response: String = try await fetch(request: request) + let signers = JWTSigners() + try signers.use(jwksJSON: response) + return signers + } +} + +// MARK: - Requests' support models and types + +public typealias OAuthSessionID = String + +public protocol OAuthLoginMethod { + var name: String { get } +} + +public struct OAuthLoginMethodOTP: OAuthLoginMethod { + public let name = "otp" + public let email: String + public let otp: String +} + +public struct OAuthLoginMethodSignature: OAuthLoginMethod { + public let name = "signature" + public let signature: String + public let source = "apple_app_store" +} + +/// The redirect URI from the original Authorization request indicated by the ddg_auth_session_id in the provided Cookie header, with the authorization code needed for the Access Token request appended as a query param. The intention is that the client will intercept this redirect and extract the authorization code to make the Access Token request in the background. +public typealias AuthorisationCode = String + +/// https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2 +public struct OAuthTokenResponse: Decodable { + /// JWT with encoded account details and entitlements. Can be verified using tokens published on the /api/auth/v2/.well-known/jwks.json endpoint. Used to gain access to Privacy Pro BE service resources (VPN, PIR, ITR). Expires after 4 hours, but can be refreshed with a refresh token. + let accessToken: String + /// JWT which can be used to get a new access token after the access token expires. Expires after 30 days. Can only be used once. Re-using a refresh token will invalidate any access tokens already issued from that refresh token. + let refreshToken: String + /// **ignored** access token expiry date in seconds. The real expiry date will be decoded from the JWT token itself + let expiresIn: Double + /// Fix as `Bearer` https://www.rfc-editor.org/rfc/rfc6749#section-7.1 + let tokenType: String + + enum CodingKeys: CodingKey { + case accessToken + case refreshToken + case expiresIn + case tokenType + + var stringValue: String { + switch self { + case .accessToken: return "access_token" + case .refreshToken: return "refresh_token" + case .expiresIn: return "expires_in" + case .tokenType: return "token_type" + } + } + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.accessToken = try container.decode(String.self, forKey: .accessToken) + self.refreshToken = try container.decode(String.self, forKey: .refreshToken) + self.expiresIn = try container.decode(Double.self, forKey: .expiresIn) + self.tokenType = try container.decode(String.self, forKey: .tokenType) + } + + init(accessToken: String, refreshToken: String) { + self.accessToken = accessToken + self.refreshToken = refreshToken + self.expiresIn = 14400 + self.tokenType = "Bearer" + } +} + +public struct EditAccountResponse: Decodable { + let status: String // Always "confirm" + let hash: String // Edit hash for edit confirmation +} + +public struct ConfirmEditAccountResponse: Decodable { + let status: String // Always "confirmed" + let email: String // The new email address +} + +public struct LogoutResponse: Decodable { + let status: String // Always "logged_out" +} diff --git a/Sources/Networking/OAuth/OAuthServiceError.swift b/Sources/Networking/OAuth/OAuthServiceError.swift new file mode 100644 index 000000000..5d39db557 --- /dev/null +++ b/Sources/Networking/OAuth/OAuthServiceError.swift @@ -0,0 +1,59 @@ +// +// OAuthServiceError.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public enum OAuthServiceError: Error, LocalizedError, Equatable { + case authAPIError(code: OAuthRequest.BodyErrorCode) + case apiServiceError(Error) + case invalidRequest + case invalidResponseCode(HTTPStatusCode) + case missingResponseValue(String) + + public var errorDescription: String? { + switch self { + case .authAPIError(let code): + "Auth API responded with error \(code.rawValue) - \(code.description)" + case .apiServiceError(let error): + "API service error - \(error.localizedDescription)" + case .invalidRequest: + "Failed to generate the API request" + case .invalidResponseCode(let code): + "Invalid API request response code: \(code.rawValue) - \(code.description)" + case .missingResponseValue(let value): + "The API response is missing \(value)" + } + } + + public static func == (lhs: OAuthServiceError, rhs: OAuthServiceError) -> Bool { + switch (lhs, rhs) { + case (.authAPIError(let lhsCode), .authAPIError(let rhsCode)): + return lhsCode == rhsCode + case (.apiServiceError(let lhsError), .apiServiceError(let rhsError)): + return lhsError.localizedDescription == rhsError.localizedDescription + case (.invalidRequest, .invalidRequest): + return true + case (.invalidResponseCode(let lhsCode), .invalidResponseCode(let rhsCode)): + return lhsCode == rhsCode + case (.missingResponseValue(let lhsValue), .missingResponseValue(let rhsValue)): + return lhsValue == rhsValue + default: + return false + } + } +} diff --git a/Sources/Networking/OAuth/OAuthTokens.swift b/Sources/Networking/OAuth/OAuthTokens.swift new file mode 100644 index 000000000..b8219bf4d --- /dev/null +++ b/Sources/Networking/OAuth/OAuthTokens.swift @@ -0,0 +1,169 @@ +// +// OAuthTokens.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import JWTKit + +/// Container for both access and refresh tokens +/// +/// WARNING: Specialised for Privacy Pro Subscription, abstract for other use cases. +/// +/// This is the object that should be stored in the keychain and used to make authenticated requests +/// The decoded tokens are used to determine the user's entitlements +/// The access token is used to make authenticated requests +/// The refresh token is used to get a new access token when the current one expires +public struct TokenContainer: Codable { + public let accessToken: String + public let refreshToken: String + public let decodedAccessToken: JWTAccessToken + public let decodedRefreshToken: JWTRefreshToken +} + +extension TokenContainer: Equatable { + + public static func == (lhs: TokenContainer, rhs: TokenContainer) -> Bool { + lhs.accessToken == rhs.accessToken && lhs.refreshToken == rhs.refreshToken + } +} + +extension TokenContainer: CustomDebugStringConvertible { + + public var debugDescription: String { + """ + Access Token: \(decodedAccessToken) + Refresh Token: \(decodedRefreshToken) + """ + } +} + +extension TokenContainer { + + public var data: NSData? { + return try? JSONEncoder().encode(self) as NSData + } + + public init(with data: NSData) throws { + self = try JSONDecoder().decode(TokenContainer.self, from: data as Data) + } +} + +public enum TokenPayloadError: Error { + case invalidTokenScope +} + +public struct JWTAccessToken: JWTPayload, Equatable { + let exp: ExpirationClaim + let iat: IssuedAtClaim + let sub: SubjectClaim + let aud: AudienceClaim + let iss: IssuerClaim + let jti: IDClaim + let scope: String + let api: String // always v2 + public let email: String? + let entitlements: [EntitlementPayload] + + public func verify(using signer: JWTKit.JWTSigner) throws { + try self.exp.verifyNotExpired() + if self.scope != "privacypro" { + throw TokenPayloadError.invalidTokenScope + } + } + + public func isExpired() -> Bool { + do { + try self.exp.verifyNotExpired() + } catch { + return true + } + return false + } + + public var externalID: String { + sub.value + } + + public var expirationDate: Date { + exp.value + } +} + +public struct JWTRefreshToken: JWTPayload, Equatable { + let exp: ExpirationClaim + let iat: IssuedAtClaim + let sub: SubjectClaim + let aud: AudienceClaim + let iss: IssuerClaim + let jti: IDClaim + let scope: String + let api: String + + public func verify(using signer: JWTKit.JWTSigner) throws { + try self.exp.verifyNotExpired() + if self.scope != "refresh" { + throw TokenPayloadError.invalidTokenScope + } + } + + public func isExpired() -> Bool { + do { + try self.exp.verifyNotExpired() + } catch { + return true + } + return false + } + + public var expirationDate: Date { + exp.value + } +} + +public enum SubscriptionEntitlement: String, Codable, Equatable, CustomDebugStringConvertible { + case networkProtection = "Network Protection" + case dataBrokerProtection = "Data Broker Protection" + case identityTheftRestoration = "Identity Theft Restoration" + case identityTheftRestorationGlobal = "Global Identity Theft Restoration" + case unknown + + public init(from decoder: Decoder) throws { + self = try Self(rawValue: decoder.singleValueContainer().decode(RawValue.self)) ?? .unknown + } + + public var debugDescription: String { + return self.rawValue + } +} + +public struct EntitlementPayload: Codable, Equatable { + public let product: SubscriptionEntitlement // Can expand in future + public let name: String // always `subscriber` +} + +public extension JWTAccessToken { + + var subscriptionEntitlements: [SubscriptionEntitlement] { + return entitlements.map({ entPayload in + entPayload.product + }) + } + + func hasEntitlement(_ entitlement: SubscriptionEntitlement) -> Bool { + return subscriptionEntitlements.contains(entitlement) + } +} diff --git a/Sources/Networking/OAuth/SessionDelegate.swift b/Sources/Networking/OAuth/SessionDelegate.swift new file mode 100644 index 000000000..d5052d1e8 --- /dev/null +++ b/Sources/Networking/OAuth/SessionDelegate.swift @@ -0,0 +1,28 @@ +// +// SessionDelegate.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import os.log + +public final class SessionDelegate: NSObject, URLSessionTaskDelegate { + + /// Disable automatic redirection, in our specific OAuth implementation we manage the redirection, not the user + public func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest) async -> URLRequest? { + return nil + } +} diff --git a/Sources/Networking/v1/APIRequestConfiguration.swift b/Sources/Networking/v1/APIRequestConfiguration.swift index e158ddfc4..3dc29c323 100644 --- a/Sources/Networking/v1/APIRequestConfiguration.swift +++ b/Sources/Networking/v1/APIRequestConfiguration.swift @@ -17,7 +17,6 @@ // import Foundation -import Common extension APIRequest { diff --git a/Sources/Networking/v1/HTTPURLResponseExtension.swift b/Sources/Networking/v1/HTTPURLResponseExtension.swift index 5b00fe308..7b97c5ab1 100644 --- a/Sources/Networking/v1/HTTPURLResponseExtension.swift +++ b/Sources/Networking/v1/HTTPURLResponseExtension.swift @@ -17,7 +17,6 @@ // import Foundation -import Common public extension HTTPURLResponse { diff --git a/Sources/Networking/v2/APIRequestV2.swift b/Sources/Networking/v2/APIRequestV2.swift index a61604861..c8a84d714 100644 --- a/Sources/Networking/v2/APIRequestV2.swift +++ b/Sources/Networking/v2/APIRequestV2.swift @@ -16,14 +16,42 @@ // limitations under the License. // -import Common import Foundation -public struct APIRequestV2: CustomDebugStringConvertible { +public typealias QueryItems = [String: String] + +public class APIRequestV2: Hashable, CustomDebugStringConvertible { + + private(set) var urlRequest: URLRequest + + public struct RetryPolicy: Hashable, CustomDebugStringConvertible { + public let maxRetries: Int + public let delay: TimeInterval + + public init(maxRetries: Int, delay: TimeInterval = 0) { + self.maxRetries = maxRetries + self.delay = delay + } + + public var debugDescription: String { + "MaxRetries: \(maxRetries), delay: \(delay)" + } + + public static func == (lhs: APIRequestV2.RetryPolicy, rhs: APIRequestV2.RetryPolicy) -> Bool { + lhs.maxRetries == rhs.maxRetries && lhs.delay == rhs.delay + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(maxRetries) + hasher.combine(delay) + } + } let timeoutInterval: TimeInterval let responseConstraints: [APIResponseConstraints]? - public let urlRequest: URLRequest + let retryPolicy: RetryPolicy? + var authRefreshRetryCount: Int = 0 + var failureRetryCount: Int = 0 /// Designated initialiser /// - Parameters: @@ -36,26 +64,26 @@ public struct APIRequestV2: CustomDebugStringConvertible { /// - cachePolicy: The request cache policy, default is `.useProtocolCachePolicy` /// - responseRequirements: The response requirements /// - allowedQueryReservedCharacters: The characters in this character set will not be URL encoded in the query parameters - public init( - url: URL, - method: HTTPRequestMethod = .get, - queryItems: QueryParams?, - headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), - body: Data? = nil, - timeoutInterval: TimeInterval = 60.0, - cachePolicy: URLRequest.CachePolicy? = nil, - responseConstraints: [APIResponseConstraints]? = nil, - allowedQueryReservedCharacters: CharacterSet? = nil - ) where QueryParams.Element == (key: String, value: String) { + public init?(url: URL, + method: HTTPRequestMethod = .get, + queryItems: QueryItems? = nil, + headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), + body: Data? = nil, + timeoutInterval: TimeInterval = 60.0, + retryPolicy: RetryPolicy? = nil, + cachePolicy: URLRequest.CachePolicy? = nil, + responseConstraints: [APIResponseConstraints]? = nil, + allowedQueryReservedCharacters: CharacterSet? = nil) { self.timeoutInterval = timeoutInterval self.responseConstraints = responseConstraints - let finalURL = if let queryItems { - url.appendingParameters(queryItems, allowedReservedCharacters: allowedQueryReservedCharacters) - } else { - url + // Generate URL request + guard var urlComps = URLComponents(url: url, resolvingAgainstBaseURL: true) else { + return nil } + urlComps.queryItems = queryItems?.toURLQueryItems(allowedReservedCharacters: allowedQueryReservedCharacters) + guard let finalURL = urlComps.url else { return nil } var request = URLRequest(url: finalURL, timeoutInterval: timeoutInterval) request.allHTTPHeaderFields = headers?.httpHeaders request.httpMethod = method.rawValue @@ -64,19 +92,7 @@ public struct APIRequestV2: CustomDebugStringConvertible { request.cachePolicy = cachePolicy } self.urlRequest = request - } - - public init( - url: URL, - method: HTTPRequestMethod = .get, - headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), - body: Data? = nil, - timeoutInterval: TimeInterval = 60.0, - cachePolicy: URLRequest.CachePolicy? = nil, - responseConstraints: [APIResponseConstraints]? = nil, - allowedQueryReservedCharacters: CharacterSet? = nil - ) { - self.init(url: url, method: method, queryItems: [String: String]?.none, headers: headers, body: body, timeoutInterval: timeoutInterval, cachePolicy: cachePolicy, responseConstraints: responseConstraints, allowedQueryReservedCharacters: allowedQueryReservedCharacters) + self.retryPolicy = retryPolicy } public var debugDescription: String { @@ -89,6 +105,40 @@ public struct APIRequestV2: CustomDebugStringConvertible { Timeout Interval: \(timeoutInterval)s Cache Policy: \(urlRequest.cachePolicy) Response Constraints: \(responseConstraints?.map { $0.rawValue } ?? []) + Retry Policy: \(retryPolicy?.debugDescription ?? "None") + Retries counts: Refresh \(authRefreshRetryCount), Failure \(failureRetryCount) """ } + + public func updateAuthorizationHeader(_ token: String) { + self.urlRequest.allHTTPHeaderFields?[HTTPHeaderKey.authorization] = "Bearer \(token)" + } + + public var isAuthenticated: Bool { + return urlRequest.allHTTPHeaderFields?[HTTPHeaderKey.authorization] != nil + } + + // MARK: Hashable Conformance + + public static func == (lhs: APIRequestV2, rhs: APIRequestV2) -> Bool { + let urlLhs = lhs.urlRequest.url?.pathComponents.joined(separator: "/") + let urlRhs = rhs.urlRequest.url?.pathComponents.joined(separator: "/") + + return urlLhs == urlRhs && + lhs.timeoutInterval == rhs.timeoutInterval && + lhs.responseConstraints == rhs.responseConstraints && + lhs.retryPolicy == rhs.retryPolicy && + lhs.authRefreshRetryCount == rhs.authRefreshRetryCount && + lhs.failureRetryCount == rhs.failureRetryCount + } + + public func hash(into hasher: inout Hasher) { + let urlPath = urlRequest.url?.pathComponents.joined(separator: "/") + hasher.combine(urlPath) + hasher.combine(timeoutInterval) + hasher.combine(responseConstraints) + hasher.combine(retryPolicy) + hasher.combine(authRefreshRetryCount) + hasher.combine(failureRetryCount) + } } diff --git a/Sources/Networking/v2/APIRequestErrorV2.swift b/Sources/Networking/v2/APIRequestV2Error.swift similarity index 61% rename from Sources/Networking/v2/APIRequestErrorV2.swift rename to Sources/Networking/v2/APIRequestV2Error.swift index f371b4fb6..c2f1a3729 100644 --- a/Sources/Networking/v2/APIRequestErrorV2.swift +++ b/Sources/Networking/v2/APIRequestV2Error.swift @@ -1,5 +1,5 @@ // -// APIRequestErrorV2.swift +// APIRequestV2Error.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -20,7 +20,8 @@ import Foundation extension APIRequestV2 { - public enum Error: Swift.Error, LocalizedError { + public enum Error: Swift.Error, LocalizedError, Equatable { + case urlSession(Swift.Error) case invalidResponse case unsatisfiedRequirement(APIResponseConstraints) @@ -44,6 +45,26 @@ extension APIRequestV2 { return "The response body is nil" } } + + // MARK: - Equatable Conformance + public static func == (lhs: Error, rhs: Error) -> Bool { + switch (lhs, rhs) { + case (.urlSession(let lhsError), .urlSession(let rhsError)): + return lhsError.localizedDescription == rhsError.localizedDescription + case (.invalidResponse, .invalidResponse): + return true + case (.unsatisfiedRequirement(let lhsRequirement), .unsatisfiedRequirement(let rhsRequirement)): + return lhsRequirement == rhsRequirement + case (.invalidStatusCode(let lhsStatusCode), .invalidStatusCode(let rhsStatusCode)): + return lhsStatusCode == rhsStatusCode + case (.invalidDataType, .invalidDataType): + return true + case (.emptyResponseBody, .emptyResponseBody): + return true + default: + return false + } + } } } diff --git a/Sources/Networking/v2/APIResponseV2.swift b/Sources/Networking/v2/APIResponseV2.swift index 8987e377b..87abdb51c 100644 --- a/Sources/Networking/v2/APIResponseV2.swift +++ b/Sources/Networking/v2/APIResponseV2.swift @@ -33,13 +33,19 @@ public extension APIResponseV2 { /// Decode the APIResponseV2 into the inferred `Decodable` type /// - Parameter decoder: A custom JSONDecoder, if not provided the default JSONDecoder() is used - /// - Returns: An instance of a Decodable model of the type inferred + /// - Returns: An instance of a Decodable model of the type inferred, throws an error if the body is empty or the decoding fails func decodeBody(decoder: JSONDecoder = JSONDecoder()) throws -> T { + decoder.dateDecodingStrategy = .millisecondsSince1970 guard let data = self.data else { throw APIRequestV2.Error.emptyResponseBody } +#if DEBUG + let resultString = String(data: data, encoding: .utf8) + Logger.networking.debug("APIResponse body: \(resultString ?? "")") +#endif + Logger.networking.debug("Decoding APIResponse body as \(T.self)") switch T.self { case is String.Type: diff --git a/Sources/Networking/v2/APIService.swift b/Sources/Networking/v2/APIService.swift index 79eed52d5..979e6094c 100644 --- a/Sources/Networking/v2/APIService.swift +++ b/Sources/Networking/v2/APIService.swift @@ -20,15 +20,18 @@ import Foundation import os.log public protocol APIService { + typealias AuthorizationRefresherCallback = ((_: APIRequestV2) async throws -> String) + var authorizationRefresherCallback: AuthorizationRefresherCallback? { get set } func fetch(request: APIRequestV2) async throws -> APIResponseV2 } -public struct DefaultAPIService: APIService { +public class DefaultAPIService: APIService { private let urlSession: URLSession + public var authorizationRefresherCallback: AuthorizationRefresherCallback? - public init(urlSession: URLSession = .shared) { + public init(urlSession: URLSession = .shared, authorizationRefresherCallback: AuthorizationRefresherCallback? = nil) { self.urlSession = urlSession - + self.authorizationRefresherCallback = authorizationRefresherCallback } /// Fetch an API Request @@ -45,12 +48,41 @@ public struct DefaultAPIService: APIService { // Check response code let httpResponse = try response.asHTTPURLResponse() let responseHTTPStatus = httpResponse.httpStatus - if responseHTTPStatus.isFailure { - return APIResponseV2(data: data, httpResponse: httpResponse) + + // First time the request is executed and the response is `.unauthorized` we try to refresh the authentication token + if responseHTTPStatus == .unauthorized, + request.isAuthenticated == true, + request.authRefreshRetryCount == 0, + let authorizationRefresherCallback { + request.authRefreshRetryCount += 1 + + // Ask to refresh the token + let refreshedToken = try await authorizationRefresherCallback(request) + request.updateAuthorizationHeader(refreshedToken) + + // Try again + return try await fetch(request: request) } - try checkConstraints(in: httpResponse, for: request) + // It's a failure and the request must be retried + if let retryPolicy = request.retryPolicy, + responseHTTPStatus.isFailure, + responseHTTPStatus != .unauthorized, // No retries needed is unuathorised + request.failureRetryCount < retryPolicy.maxRetries { + request.failureRetryCount += 1 + if retryPolicy.delay > 0 { + try? await Task.sleep(interval: retryPolicy.delay) + } + + // Try again + return try await fetch(request: request) + } + + // It's not a failure, we check the constraints + if !responseHTTPStatus.isFailure { + try checkConstraints(in: httpResponse, for: request) + } return APIResponseV2(data: data, httpResponse: httpResponse) } diff --git a/Sources/SubscriptionTestingUtilities/Flows/AppStoreAccountManagementFlowMock.swift b/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift similarity index 51% rename from Sources/SubscriptionTestingUtilities/Flows/AppStoreAccountManagementFlowMock.swift rename to Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift index cff7d88e6..004daf7ff 100644 --- a/Sources/SubscriptionTestingUtilities/Flows/AppStoreAccountManagementFlowMock.swift +++ b/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift @@ -1,5 +1,5 @@ // -// AppStoreAccountManagementFlowMock.swift +// Dictionary+URLQueryItem.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -17,18 +17,19 @@ // import Foundation -import Subscription +import Common -public final class AppStoreAccountManagementFlowMock: AppStoreAccountManagementFlow { - public var refreshAuthTokenIfNeededResult: Result? - public var onRefreshAuthTokenIfNeeded: (() -> Void)? - public var refreshAuthTokenIfNeededCalled: Bool = false +extension Dictionary where Key == String, Value == String { - public init() { } - - public func refreshAuthTokenIfNeeded() async -> Result { - refreshAuthTokenIfNeededCalled = true - onRefreshAuthTokenIfNeeded?() - return refreshAuthTokenIfNeededResult! + public func toURLQueryItems(allowedReservedCharacters: CharacterSet? = nil) -> [URLQueryItem] { + return self.sorted(by: <).map { + if let allowedReservedCharacters { + URLQueryItem(percentEncodingName: $0.key, + value: $0.value, + withAllowedCharacters: allowedReservedCharacters) + } else { + URLQueryItem(name: $0.key, value: $0.value) + } + } } } diff --git a/Sources/Common/DecodableHelper.swift b/Sources/Networking/v2/Extensions/HTTPURLResponse+Cookie.swift similarity index 55% rename from Sources/Common/DecodableHelper.swift rename to Sources/Networking/v2/Extensions/HTTPURLResponse+Cookie.swift index 44491301c..aff26ee0f 100644 --- a/Sources/Common/DecodableHelper.swift +++ b/Sources/Networking/v2/Extensions/HTTPURLResponse+Cookie.swift @@ -1,7 +1,7 @@ // -// DecodableHelper.swift +// HTTPURLResponse+Cookie.swift // -// Copyright © 2021 DuckDuckGo. All rights reserved. +// Copyright © 2024 DuckDuckGo. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,16 +17,20 @@ // import Foundation -import os.log -public struct DecodableHelper { - public static func decode(from input: Input) -> Target? { - do { - let json = try JSONSerialization.data(withJSONObject: input) - return try JSONDecoder().decode(Target.self, from: json) - } catch { - Logger.general.error("Error decoding message body: \(error.localizedDescription, privacy: .public)") +public extension HTTPURLResponse { + + var cookies: [HTTPCookie]? { + guard let fields = allHeaderFields as? [String: String], let url else { return nil } + return HTTPCookie.cookies(withResponseHeaderFields: fields, for: url) + } + + func getCookie(withName name: String) -> HTTPCookie? { + if let cookie = cookies?.first(where: { $0.name == name }) { + return cookie + } + return nil } } diff --git a/Sources/Networking/v2/Extensions/HTTPURLResponse+Utilities.swift b/Sources/Networking/v2/Extensions/HTTPURLResponse+Etag.swift similarity index 82% rename from Sources/Networking/v2/Extensions/HTTPURLResponse+Utilities.swift rename to Sources/Networking/v2/Extensions/HTTPURLResponse+Etag.swift index 10e7b8028..b7889abf7 100644 --- a/Sources/Networking/v2/Extensions/HTTPURLResponse+Utilities.swift +++ b/Sources/Networking/v2/Extensions/HTTPURLResponse+Etag.swift @@ -1,7 +1,7 @@ // -// HTTPURLResponse+Utilities.swift +// HTTPURLResponse+Etag.swift // -// Copyright © 2023 DuckDuckGo. All rights reserved. +// Copyright © 2024 DuckDuckGo. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,15 +17,10 @@ // import Foundation -import Common public extension HTTPURLResponse { - var httpStatus: HTTPStatusCode { - HTTPStatusCode(rawValue: statusCode) ?? .unknown - } var etag: String? { etag(droppingWeakPrefix: true) } - private static let weakEtagPrefix = "W/" func etag(droppingWeakPrefix: Bool) -> String? { diff --git a/Sources/NetworkProtection/FeatureActivation/NetworkProtectionFeatureActivation.swift b/Sources/Networking/v2/Extensions/HTTPURLResponse+HTTPStatusCode.swift similarity index 58% rename from Sources/NetworkProtection/FeatureActivation/NetworkProtectionFeatureActivation.swift rename to Sources/Networking/v2/Extensions/HTTPURLResponse+HTTPStatusCode.swift index f8aafe67e..196d188c7 100644 --- a/Sources/NetworkProtection/FeatureActivation/NetworkProtectionFeatureActivation.swift +++ b/Sources/Networking/v2/Extensions/HTTPURLResponse+HTTPStatusCode.swift @@ -1,5 +1,5 @@ // -// NetworkProtectionFeatureActivation.swift +// HTTPURLResponse+HTTPStatusCode.swift // // Copyright © 2023 DuckDuckGo. All rights reserved. // @@ -18,19 +18,9 @@ import Foundation -public protocol NetworkProtectionFeatureActivation { +public extension HTTPURLResponse { - /// Has the invite code flow been completed and an oAuth token stored? - /// - var isFeatureActivated: Bool { get } -} - -extension NetworkProtectionKeychainTokenStore: NetworkProtectionFeatureActivation { - public var isFeatureActivated: Bool { - do { - return try fetchToken() != nil - } catch { - return false - } + var httpStatus: HTTPStatusCode { + HTTPStatusCode(rawValue: statusCode) ?? .unknown } } diff --git a/Sources/Subscription/API/Model/Entitlement.swift b/Sources/Networking/v2/Extensions/URL+QueryParamExtraction.swift similarity index 52% rename from Sources/Subscription/API/Model/Entitlement.swift rename to Sources/Networking/v2/Extensions/URL+QueryParamExtraction.swift index 1d8eb645a..989e25439 100644 --- a/Sources/Subscription/API/Model/Entitlement.swift +++ b/Sources/Networking/v2/Extensions/URL+QueryParamExtraction.swift @@ -1,5 +1,5 @@ // -// Entitlement.swift +// URL+QueryParamExtraction.swift // // Copyright © 2023 DuckDuckGo. All rights reserved. // @@ -18,18 +18,20 @@ import Foundation -public struct Entitlement: Codable, Equatable { - public let product: ProductName +public extension URL { - public enum ProductName: String, Codable { - case networkProtection = "Network Protection" - case dataBrokerProtection = "Data Broker Protection" - case identityTheftRestoration = "Identity Theft Restoration" - case identityTheftRestorationGlobal = "Global Identity Theft Restoration" - case unknown - - public init(from decoder: Decoder) throws { - self = try Self(rawValue: decoder.singleValueContainer().decode(RawValue.self)) ?? .unknown + /// Extract the query parameters from the URL + func queryParameters() -> [String: String]? { + guard let urlComponents = URLComponents(url: self, resolvingAgainstBaseURL: false), + let queryItems = urlComponents.queryItems else { + return nil + } + // Convert the query items into a dictionary + var parameters: [String: String] = [:] + for item in queryItems { + parameters[item.name] = item.value } + return parameters } + } diff --git a/Sources/Networking/v2/HTTP Components/HTTPStatusCode.swift b/Sources/Networking/v2/HTTP Components/HTTPStatusCode.swift index 633b92322..1ff8610b4 100644 --- a/Sources/Networking/v2/HTTP Components/HTTPStatusCode.swift +++ b/Sources/Networking/v2/HTTP Components/HTTPStatusCode.swift @@ -95,27 +95,27 @@ public enum HTTPStatusCode: Int, CustomDebugStringConvertible { case networkAuthenticationRequired = 511 // Utility functions - var isInformational: Bool { + public var isInformational: Bool { return (100...199).contains(self.rawValue) } - var isSuccess: Bool { + public var isSuccess: Bool { return (200...299).contains(self.rawValue) } - var isRedirection: Bool { + public var isRedirection: Bool { return (300...399).contains(self.rawValue) } - var isClientError: Bool { + public var isClientError: Bool { return (400...499).contains(self.rawValue) } - var isServerError: Bool { + public var isServerError: Bool { return (500...599).contains(self.rawValue) } - var isFailure: Bool { + public var isFailure: Bool { return isClientError || isServerError } @@ -123,7 +123,7 @@ public enum HTTPStatusCode: Int, CustomDebugStringConvertible { "\(self.rawValue) - \(description)" } - var description: String { + public var description: String { switch self { case .unknown: return "Unknown" diff --git a/Sources/Networking/v2/HeadersV2.swift b/Sources/Networking/v2/HeadersV2.swift index 8a1b91e20..74cbc9d49 100644 --- a/Sources/Networking/v2/HeadersV2.swift +++ b/Sources/Networking/v2/HeadersV2.swift @@ -20,6 +20,30 @@ import Foundation public extension APIRequestV2 { + /// All possible request content types + enum ContentType: String, Codable { + case json = "application/json" + case xml = "application/xml" + case formURLEncoded = "application/x-www-form-urlencoded" + case multipartFormData = "multipart/form-data" + case html = "text/html" + case plainText = "text/plain" + case css = "text/css" + case javascript = "application/javascript" + case octetStream = "application/octet-stream" + case png = "image/png" + case jpeg = "image/jpeg" + case gif = "image/gif" + case svg = "image/svg+xml" + case pdf = "application/pdf" + case zip = "application/zip" + case csv = "text/csv" + case rtf = "application/rtf" + case mp4 = "video/mp4" + case webm = "video/webm" + case ogg = "application/ogg" + } + struct HeadersV2 { private var userAgent: String? @@ -32,17 +56,26 @@ public extension APIRequestV2 { }.joined(separator: ", ") }() let etag: String? - let additionalHeaders: HTTPHeaders? + let cookies: [HTTPCookie]? + let authToken: String? + let additionalHeaders: [String: String]? + let contentType: ContentType? public init(userAgent: String? = nil, etag: String? = nil, - additionalHeaders: HTTPHeaders? = nil) { + cookies: [HTTPCookie]? = nil, + authToken: String? = nil, + contentType: ContentType? = nil, + additionalHeaders: [String: String]? = nil) { self.userAgent = userAgent self.etag = etag + self.cookies = cookies + self.authToken = authToken + self.contentType = contentType self.additionalHeaders = additionalHeaders } - public var httpHeaders: HTTPHeaders { + public var httpHeaders: [String: String] { var headers = [ HTTPHeaderKey.acceptEncoding: acceptEncoding, HTTPHeaderKey.acceptLanguage: acceptLanguage @@ -53,6 +86,19 @@ public extension APIRequestV2 { if let etag { headers[HTTPHeaderKey.ifNoneMatch] = etag } + if let cookies, cookies.isEmpty == false { + let cookieHeaders = HTTPCookie.requestHeaderFields(with: cookies) + headers.merge(cookieHeaders) { lx, _ in + assertionFailure("Duplicated values in HTTPHeaders") + return lx + } + } + if let authToken { + headers[HTTPHeaderKey.authorization] = "Bearer \(authToken)" + } + if let contentType { + headers[HTTPHeaderKey.contentType] = contentType.rawValue + } if let additionalHeaders { headers.merge(additionalHeaders) { old, _ in old } } diff --git a/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift b/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift index 3fc35c46d..0c619e58d 100644 --- a/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift +++ b/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift @@ -190,7 +190,7 @@ final class PrivacyDashboardUserScript: NSObject, StaticUserScript { } private func getProtectionState(from message: WKScriptMessage) -> ProtectionState? { - guard let protectionState: ProtectionState = DecodableHelper.decode(from: message.messageBody) else { + guard let protectionState: ProtectionState = CodableHelper.decode(from: message.messageBody) else { assertionFailure("privacyDashboardSetProtection: expected ProtectionState") return nil } @@ -315,7 +315,7 @@ final class PrivacyDashboardUserScript: NSObject, StaticUserScript { } private func handleTelemetrySpan(message: WKScriptMessage) { - guard let telemetrySpan: TelemetrySpan = DecodableHelper.decode(from: message.messageBody) else { + guard let telemetrySpan: TelemetrySpan = CodableHelper.decode(from: message.messageBody) else { assertionFailure("privacyDashboardTelemetrySpan: expected TelemetrySpan") return } diff --git a/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift b/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift index 37e53e352..4a96c9d38 100644 --- a/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift +++ b/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift @@ -30,12 +30,12 @@ public struct DefaultRemoteMessagingSurveyURLBuilder: RemoteMessagingSurveyActio private let statisticsStore: StatisticsStore private let vpnActivationDateStore: VPNActivationDateProviding - private let subscription: Subscription? + private let subscription: PrivacyProSubscription? private let localeIdentifier: String public init(statisticsStore: StatisticsStore, vpnActivationDateStore: VPNActivationDateProviding, - subscription: Subscription?, + subscription: PrivacyProSubscription?, localeIdentifier: String = Locale.current.identifier) { self.statisticsStore = statisticsStore self.vpnActivationDateStore = vpnActivationDateStore @@ -134,7 +134,7 @@ public struct DefaultRemoteMessagingSurveyURLBuilder: RemoteMessagingSurveyActio } -extension Subscription { +extension PrivacyProSubscription { var privacyProStatusSurveyParameter: String { switch status { case .autoRenewable: diff --git a/Sources/Subscription/API/APIService.swift b/Sources/Subscription/API/APIService.swift deleted file mode 100644 index 41c634706..000000000 --- a/Sources/Subscription/API/APIService.swift +++ /dev/null @@ -1,129 +0,0 @@ -// -// APIService.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common -import os.log - -public enum APIServiceError: Swift.Error { - case decodingError - case encodingError - case serverError(statusCode: Int, error: String?) - case unknownServerError - case connectionError -} - -struct ErrorResponse: Decodable { - let error: String -} - -public protocol APIService { - func executeAPICall(method: String, endpoint: String, headers: [String: String]?, body: Data?) async -> Result where T: Decodable - func makeAuthorizationHeader(for token: String) -> [String: String] -} - -public enum APICachePolicy { - case reloadIgnoringLocalCacheData - case returnCacheDataElseLoad - case returnCacheDataDontLoad -} - -public struct DefaultAPIService: APIService { - private let baseURL: URL - private let session: URLSession - - public init(baseURL: URL, session: URLSession) { - self.baseURL = baseURL - self.session = session - } - - public func executeAPICall(method: String, endpoint: String, headers: [String: String]? = nil, body: Data? = nil) async -> Result where T: Decodable { - let request = makeAPIRequest(method: method, endpoint: endpoint, headers: headers, body: body) - - do { - let (data, urlResponse) = try await session.data(for: request) - - printDebugInfo(method: method, endpoint: endpoint, data: data, response: urlResponse) - - guard let httpResponse = urlResponse as? HTTPURLResponse else { return .failure(.unknownServerError) } - - if (200..<300).contains(httpResponse.statusCode) { - if let decodedResponse = decode(T.self, from: data) { - return .success(decodedResponse) - } else { - Logger.subscription.error("Service error: APIServiceError.decodingError") - return .failure(.decodingError) - } - } else { - var errorString: String? - - if let decodedResponse = decode(ErrorResponse.self, from: data) { - errorString = decodedResponse.error - } - - let errorLogMessage = "/\(endpoint) \(httpResponse.statusCode): \(errorString ?? "")" - Logger.subscription.error("Service error: \(errorLogMessage, privacy: .public)") - return .failure(.serverError(statusCode: httpResponse.statusCode, error: errorString)) - } - } catch { - Logger.subscription.error("Service error: \(error.localizedDescription, privacy: .public)") - return .failure(.connectionError) - } - } - - private func makeAPIRequest(method: String, endpoint: String, headers: [String: String]?, body: Data?) -> URLRequest { - let url = baseURL.appendingPathComponent(endpoint) - var request = URLRequest(url: url) - request.httpMethod = method - if let headers = headers { - request.allHTTPHeaderFields = headers - } - if let body = body { - request.httpBody = body - } - - return request - } - - private func decode(_: T.Type, from data: Data) -> T? where T: Decodable { - let decoder = JSONDecoder() - decoder.keyDecodingStrategy = .convertFromSnakeCase - decoder.dateDecodingStrategy = .millisecondsSince1970 - - return try? decoder.decode(T.self, from: data) - } - - private func printDebugInfo(method: String, endpoint: String, data: Data, response: URLResponse) { - let statusCode = (response as? HTTPURLResponse)!.statusCode - let stringData = String(data: data, encoding: .utf8) ?? "" - - Logger.subscription.info("[API] \(statusCode) \(method, privacy: .public) \(endpoint, privacy: .public) :: \(stringData, privacy: .public)") - } - - public func makeAuthorizationHeader(for token: String) -> [String: String] { - ["Authorization": "Bearer " + token] - } -} - -fileprivate extension URLResponse { - - var httpStatusCodeAsString: String? { - guard let httpStatusCode = (self as? HTTPURLResponse)?.statusCode else { return nil } - return String(httpStatusCode) - } -} diff --git a/Sources/Subscription/API/AuthEndpointService.swift b/Sources/Subscription/API/AuthEndpointService.swift deleted file mode 100644 index 31972404a..000000000 --- a/Sources/Subscription/API/AuthEndpointService.swift +++ /dev/null @@ -1,110 +0,0 @@ -// -// AuthEndpointService.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common - -public struct AccessTokenResponse: Decodable { - public let accessToken: String -} - -public struct ValidateTokenResponse: Decodable { - public let account: Account - - public struct Account: Decodable { - public let email: String? - public let entitlements: [Entitlement] - public let externalID: String - - enum CodingKeys: String, CodingKey { - case email, entitlements, externalID = "externalId" // no underscores due to keyDecodingStrategy = .convertFromSnakeCase - } - } -} - -public struct CreateAccountResponse: Decodable { - public let authToken: String - public let externalID: String - public let status: String - - enum CodingKeys: String, CodingKey { - case authToken = "authToken", externalID = "externalId", status // no underscores due to keyDecodingStrategy = .convertFromSnakeCase - } -} - -public struct StoreLoginResponse: Decodable { - public let authToken: String - public let email: String - public let externalID: String - public let id: Int - public let status: String - - enum CodingKeys: String, CodingKey { - case authToken = "authToken", email, externalID = "externalId", id, status // no underscores due to keyDecodingStrategy = .convertFromSnakeCase - } -} - -public protocol AuthEndpointService { - func getAccessToken(token: String) async -> Result - func validateToken(accessToken: String) async -> Result - func createAccount(emailAccessToken: String?) async -> Result - func storeLogin(signature: String) async -> Result -} - -public struct DefaultAuthEndpointService: AuthEndpointService { - private let currentServiceEnvironment: SubscriptionEnvironment.ServiceEnvironment - private let apiService: APIService - - public init(currentServiceEnvironment: SubscriptionEnvironment.ServiceEnvironment, apiService: APIService) { - self.currentServiceEnvironment = currentServiceEnvironment - self.apiService = apiService - } - - public init(currentServiceEnvironment: SubscriptionEnvironment.ServiceEnvironment) { - self.currentServiceEnvironment = currentServiceEnvironment - let baseURL = currentServiceEnvironment == .production ? URL(string: "https://quack.duckduckgo.com/api/auth")! : URL(string: "https://quackdev.duckduckgo.com/api/auth")! - let session = URLSession(configuration: URLSessionConfiguration.ephemeral) - self.apiService = DefaultAPIService(baseURL: baseURL, session: session) - } - - public func getAccessToken(token: String) async -> Result { - await apiService.executeAPICall(method: "GET", endpoint: "access-token", headers: apiService.makeAuthorizationHeader(for: token), body: nil) - } - - public func validateToken(accessToken: String) async -> Result { - await apiService.executeAPICall(method: "GET", endpoint: "validate-token", headers: apiService.makeAuthorizationHeader(for: accessToken), body: nil) - } - - public func createAccount(emailAccessToken: String?) async -> Result { - var headers: [String: String]? - - if let emailAccessToken { - headers = apiService.makeAuthorizationHeader(for: emailAccessToken) - } - - return await apiService.executeAPICall(method: "POST", endpoint: "account/create", headers: headers, body: nil) - } - - public func storeLogin(signature: String) async -> Result { - let bodyDict = ["signature": signature, - "store": "apple_app_store"] - - guard let bodyData = try? JSONEncoder().encode(bodyDict) else { return .failure(.encodingError) } - return await apiService.executeAPICall(method: "POST", endpoint: "store-login", headers: nil, body: bodyData) - } -} diff --git a/Sources/Subscription/API/Model/Subscription.swift b/Sources/Subscription/API/Model/PrivacyProSubscription.swift similarity index 56% rename from Sources/Subscription/API/Model/Subscription.swift rename to Sources/Subscription/API/Model/PrivacyProSubscription.swift index 3dc9f8807..257e47969 100644 --- a/Sources/Subscription/API/Model/Subscription.swift +++ b/Sources/Subscription/API/Model/PrivacyProSubscription.swift @@ -1,5 +1,5 @@ // -// Subscription.swift +// PrivacyProSubscription.swift // // Copyright © 2023 DuckDuckGo. All rights reserved. // @@ -17,18 +17,20 @@ // import Foundation +import Networking -public typealias DDGSubscription = Subscription // to avoid conflicts when Combine is imported - -public struct Subscription: Codable, Equatable { +public struct PrivacyProSubscription: Codable, Equatable, CustomDebugStringConvertible { public let productId: String public let name: String - public let billingPeriod: Subscription.BillingPeriod + public let billingPeriod: BillingPeriod public let startedAt: Date public let expiresOrRenewsAt: Date - public let platform: Subscription.Platform + public let platform: Platform public let status: Status + /// Not parsed from + public var features: [SubscriptionEntitlement]? + public enum BillingPeriod: String, Codable { case monthly = "Monthly" case yearly = "Yearly" @@ -64,4 +66,37 @@ public struct Subscription: Codable, Equatable { public var isActive: Bool { status != .expired && status != .inactive } + + public var debugDescription: String { + return """ + Subscription: + - Product ID: \(productId) + - Name: \(name) + - Billing Period: \(billingPeriod.rawValue) + - Started At: \(formatDate(startedAt)) + - Expires/Renews At: \(formatDate(expiresOrRenewsAt)) + - Platform: \(platform.rawValue) + - Status: \(status.rawValue) + - Features: \(features?.map { $0.rawValue } ?? []) + """ + } + + private func formatDate(_ date: Date) -> String { + let dateFormatter = DateFormatter() + dateFormatter.dateStyle = .medium + dateFormatter.timeStyle = .short + dateFormatter.timeZone = TimeZone.current + return dateFormatter.string(from: date) + } + + public static func == (lhs: PrivacyProSubscription, rhs: PrivacyProSubscription) -> Bool { + return lhs.productId == rhs.productId && + lhs.name == rhs.name && + lhs.billingPeriod == rhs.billingPeriod && + lhs.startedAt == rhs.startedAt && + lhs.expiresOrRenewsAt == rhs.expiresOrRenewsAt && + lhs.platform == rhs.platform && + lhs.status == rhs.status + // Ignore the features + } } diff --git a/Sources/Subscription/API/SubscriptionEndpointService.swift b/Sources/Subscription/API/SubscriptionEndpointService.swift index 552c28d4a..549a5214b 100644 --- a/Sources/Subscription/API/SubscriptionEndpointService.swift +++ b/Sources/Subscription/API/SubscriptionEndpointService.swift @@ -18,8 +18,10 @@ import Common import Foundation +import Networking +import os.log -public struct GetProductsItem: Decodable { +public struct GetProductsItem: Codable, Equatable { public let productId: String public let productLabel: String public let billingPeriod: String @@ -27,140 +29,222 @@ public struct GetProductsItem: Decodable { public let currency: String } -public struct GetSubscriptionFeaturesResponse: Decodable { - public let features: [Entitlement.ProductName] -} - -public struct GetCustomerPortalURLResponse: Decodable { +public struct GetCustomerPortalURLResponse: Codable, Equatable { public let customerPortalUrl: String } -public struct ConfirmPurchaseResponse: Decodable { +public struct ConfirmPurchaseResponse: Codable, Equatable { public let email: String? - public let entitlements: [Entitlement] - public let subscription: Subscription + public let subscription: PrivacyProSubscription } -public enum SubscriptionServiceError: Error { - case noCachedData - case apiError(APIServiceError) +public struct GetSubscriptionFeaturesResponse: Decodable { + public let features: [SubscriptionEntitlement] +} + +public enum SubscriptionEndpointServiceError: Error { + case noData + case invalidRequest + case invalidResponseCode(HTTPStatusCode) +} + +public enum SubscriptionCachePolicy { + case reloadIgnoringLocalCacheData + case returnCacheDataElseLoad + case returnCacheDataDontLoad } public protocol SubscriptionEndpointService { - func updateCache(with subscription: Subscription) - func getSubscription(accessToken: String, cachePolicy: APICachePolicy) async -> Result - func signOut() - func getProducts() async -> Result<[GetProductsItem], APIServiceError> - func getSubscriptionFeatures(for subscriptionID: String) async -> Result - func getCustomerPortalURL(accessToken: String, externalID: String) async -> Result - func confirmPurchase(accessToken: String, signature: String) async -> Result + func ingestSubscription(_ subscription: PrivacyProSubscription) async throws + func getSubscription(accessToken: String, cachePolicy: SubscriptionCachePolicy) async throws -> PrivacyProSubscription + func clearSubscription() + func getProducts() async throws -> [GetProductsItem] + func getSubscriptionFeatures(for subscriptionID: String) async throws -> GetSubscriptionFeaturesResponse + func getCustomerPortalURL(accessToken: String, externalID: String) async throws -> GetCustomerPortalURLResponse + func confirmPurchase(accessToken: String, signature: String) async throws -> ConfirmPurchaseResponse } extension SubscriptionEndpointService { - public func getSubscription(accessToken: String) async -> Result { - await getSubscription(accessToken: accessToken, cachePolicy: .returnCacheDataElseLoad) + public func getSubscription(accessToken: String) async throws -> PrivacyProSubscription { + try await getSubscription(accessToken: accessToken, cachePolicy: SubscriptionCachePolicy.returnCacheDataElseLoad) } } /// Communicates with our backend public struct DefaultSubscriptionEndpointService: SubscriptionEndpointService { - private let currentServiceEnvironment: SubscriptionEnvironment.ServiceEnvironment + private let apiService: APIService - private let subscriptionCache = UserDefaultsCache(key: UserDefaultsCacheKey.subscription, - settings: UserDefaultsCacheSettings(defaultExpirationInterval: .minutes(20))) + private let baseURL: URL + private let subscriptionCache: UserDefaultsCache + private let cacheSerialQueue = DispatchQueue(label: "com.duckduckgo.subscriptionEndpointService.cache", qos: .background) - public init(currentServiceEnvironment: SubscriptionEnvironment.ServiceEnvironment, apiService: APIService) { - self.currentServiceEnvironment = currentServiceEnvironment + public init(apiService: APIService, + baseURL: URL, + subscriptionCache: UserDefaultsCache = UserDefaultsCache(key: UserDefaultsCacheKey.subscription, settings: UserDefaultsCacheSettings(defaultExpirationInterval: .minutes(20)))) { self.apiService = apiService - } - - public init(currentServiceEnvironment: SubscriptionEnvironment.ServiceEnvironment) { - self.currentServiceEnvironment = currentServiceEnvironment - let baseURL = currentServiceEnvironment == .production ? URL(string: "https://subscriptions.duckduckgo.com/api")! : URL(string: "https://subscriptions-dev.duckduckgo.com/api")! - let session = URLSession(configuration: URLSessionConfiguration.ephemeral) - self.apiService = DefaultAPIService(baseURL: baseURL, session: session) + self.baseURL = baseURL + self.subscriptionCache = subscriptionCache } // MARK: - Subscription fetching with caching - private func getRemoteSubscription(accessToken: String) async -> Result { + private func getRemoteSubscription(accessToken: String) async throws -> PrivacyProSubscription { - let result: Result = await apiService.executeAPICall(method: "GET", endpoint: "subscription", headers: apiService.makeAuthorizationHeader(for: accessToken), body: nil) - switch result { - case .success(let subscriptionResponse): - updateCache(with: subscriptionResponse) - return .success(subscriptionResponse) - case .failure(let error): - return .failure(.apiError(error)) + Logger.subscriptionEndpointService.log("Requesting subscription details") + guard let request = SubscriptionRequest.getSubscription(baseURL: baseURL, accessToken: accessToken) else { + throw SubscriptionEndpointServiceError.invalidRequest } - } + let response = try await apiService.fetch(request: request.apiRequest) + let statusCode = response.httpResponse.httpStatus + + if statusCode.isSuccess { + let subscription: PrivacyProSubscription = try response.decodeBody() + Logger.subscriptionEndpointService.log("Subscription details retrieved successfully: \(String(describing: subscription))") - public func updateCache(with subscription: Subscription) { + try await storeAndAddFeaturesIfNeededTo(subscription: subscription) + + return subscription + } else { + if statusCode == .badRequest { + Logger.subscriptionEndpointService.log("No subscription found") + clearSubscription() + throw SubscriptionEndpointServiceError.noData + } else { + let bodyString: String = try response.decodeBody() + Logger.subscriptionEndpointService.log("(\(statusCode.description) Failed to retrieve Subscription details: \(bodyString)") + throw SubscriptionEndpointServiceError.invalidResponseCode(statusCode) + } + } + } - let cachedSubscription: Subscription? = subscriptionCache.get() + private func storeAndAddFeaturesIfNeededTo(subscription: PrivacyProSubscription) async throws { + let cachedSubscription: PrivacyProSubscription? = subscriptionCache.get() if subscription != cachedSubscription { - let defaultExpiryDate = Date().addingTimeInterval(subscriptionCache.settings.defaultExpirationInterval) - let expiryDate = min(defaultExpiryDate, subscription.expiresOrRenewsAt) + var subscription = subscription + // fetch remote features + subscription.features = try await getSubscriptionFeatures(for: subscription.productId).features + + updateCache(with: subscription) + + Logger.subscriptionEndpointService.debug(""" +Subscription changed, updating cache and notifying observers. +Old: \(cachedSubscription?.debugDescription ?? "nil") +New: \(subscription.debugDescription) +""") + } else { + Logger.subscriptionEndpointService.debug("No subscription update required") + } + } - subscriptionCache.set(subscription, expires: expiryDate) + func updateCache(with subscription: PrivacyProSubscription) { + cacheSerialQueue.sync { + subscriptionCache.set(subscription) NotificationCenter.default.post(name: .subscriptionDidChange, object: self, userInfo: [UserDefaultsCacheKey.subscription: subscription]) } } - public func getSubscription(accessToken: String, cachePolicy: APICachePolicy = .returnCacheDataElseLoad) async -> Result { + public func ingestSubscription(_ subscription: PrivacyProSubscription) async throws { + try await storeAndAddFeaturesIfNeededTo(subscription: subscription) + } + public func getSubscription(accessToken: String, cachePolicy: SubscriptionCachePolicy = .returnCacheDataElseLoad) async throws -> PrivacyProSubscription { switch cachePolicy { case .reloadIgnoringLocalCacheData: - return await getRemoteSubscription(accessToken: accessToken) + return try await getRemoteSubscription(accessToken: accessToken) case .returnCacheDataElseLoad: - if let cachedSubscription = subscriptionCache.get() { - return .success(cachedSubscription) + if let cachedSubscription = getCachedSubscription() { + return cachedSubscription } else { - return await getRemoteSubscription(accessToken: accessToken) + return try await getRemoteSubscription(accessToken: accessToken) } case .returnCacheDataDontLoad: - if let cachedSubscription = subscriptionCache.get() { - return .success(cachedSubscription) + if let cachedSubscription = getCachedSubscription() { + return cachedSubscription } else { - return .failure(.noCachedData) + throw SubscriptionEndpointServiceError.noData } } } - public func signOut() { - subscriptionCache.reset() + private func getCachedSubscription() -> PrivacyProSubscription? { + var result: PrivacyProSubscription? + cacheSerialQueue.sync { + result = subscriptionCache.get() + } + return result } - // MARK: - - - public func getProducts() async -> Result<[GetProductsItem], APIServiceError> { - await apiService.executeAPICall(method: "GET", endpoint: "products", headers: nil, body: nil) + public func clearSubscription() { + cacheSerialQueue.sync { + subscriptionCache.reset() + } +// NotificationCenter.default.post(name: .subscriptionDidChange, object: self, userInfo: nil) } // MARK: - - public func getSubscriptionFeatures(for subscriptionID: String) async -> Result { - await apiService.executeAPICall(method: "GET", endpoint: "products/\(subscriptionID)/features", headers: nil, body: nil) + public func getProducts() async throws -> [GetProductsItem] { + guard let request = SubscriptionRequest.getProducts(baseURL: baseURL) else { + throw SubscriptionEndpointServiceError.invalidRequest + } + let response = try await apiService.fetch(request: request.apiRequest) + let statusCode = response.httpResponse.httpStatus + + if statusCode.isSuccess { + Logger.subscriptionEndpointService.log("\(#function) request completed") + return try response.decodeBody() + } else { + throw SubscriptionEndpointServiceError.invalidResponseCode(statusCode) + } } // MARK: - - public func getCustomerPortalURL(accessToken: String, externalID: String) async -> Result { - var headers = apiService.makeAuthorizationHeader(for: accessToken) - headers["externalAccountId"] = externalID - return await apiService.executeAPICall(method: "GET", endpoint: "checkout/portal", headers: headers, body: nil) + public func getCustomerPortalURL(accessToken: String, externalID: String) async throws -> GetCustomerPortalURLResponse { + guard let request = SubscriptionRequest.getCustomerPortalURL(baseURL: baseURL, accessToken: accessToken, externalID: externalID) else { + throw SubscriptionEndpointServiceError.invalidRequest + } + let response = try await apiService.fetch(request: request.apiRequest) + let statusCode = response.httpResponse.httpStatus + if statusCode.isSuccess { + Logger.subscriptionEndpointService.log("\(#function) request completed") + return try response.decodeBody() + } else { + throw SubscriptionEndpointServiceError.invalidResponseCode(statusCode) + } } // MARK: - - public func confirmPurchase(accessToken: String, signature: String) async -> Result { - let headers = apiService.makeAuthorizationHeader(for: accessToken) - let bodyDict = ["signedTransactionInfo": signature] + public func confirmPurchase(accessToken: String, signature: String) async throws -> ConfirmPurchaseResponse { + guard let request = SubscriptionRequest.confirmPurchase(baseURL: baseURL, accessToken: accessToken, signature: signature) else { + throw SubscriptionEndpointServiceError.invalidRequest + } + let response = try await apiService.fetch(request: request.apiRequest) + let statusCode = response.httpResponse.httpStatus + if statusCode.isSuccess { + Logger.subscriptionEndpointService.log("\(#function) request completed") + return try response.decodeBody() + } else { + throw SubscriptionEndpointServiceError.invalidResponseCode(statusCode) + } + } - guard let bodyData = try? JSONEncoder().encode(bodyDict) else { return .failure(.encodingError) } - return await apiService.executeAPICall(method: "POST", endpoint: "purchase/confirm/apple", headers: headers, body: bodyData) + public func getSubscriptionFeatures(for subscriptionID: String) async throws -> GetSubscriptionFeaturesResponse { + Logger.subscriptionEndpointService.log("Getting subscription features") + guard let request = SubscriptionRequest.subscriptionFeatures(baseURL: baseURL, subscriptionID: subscriptionID) else { + throw SubscriptionEndpointServiceError.invalidRequest + } + let response = try await apiService.fetch(request: request.apiRequest) + let statusCode = response.httpResponse.httpStatus + if statusCode.isSuccess { + Logger.subscriptionEndpointService.log("\(#function) request completed") + return try response.decodeBody() + } else { + throw SubscriptionEndpointServiceError.invalidResponseCode(statusCode) + } } } diff --git a/Sources/Subscription/API/SubscriptionRequest.swift b/Sources/Subscription/API/SubscriptionRequest.swift new file mode 100644 index 000000000..f34d6441f --- /dev/null +++ b/Sources/Subscription/API/SubscriptionRequest.swift @@ -0,0 +1,83 @@ +// +// SubscriptionRequest.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import Networking +import Common + +struct SubscriptionRequest { + let apiRequest: APIRequestV2 + + // MARK: Get subscription + + static func getSubscription(baseURL: URL, accessToken: String) -> SubscriptionRequest? { + let path = "/subscription" + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + headers: APIRequestV2.HeadersV2(authToken: accessToken), + timeoutInterval: 20) else { + return nil + } + return SubscriptionRequest(apiRequest: request) + } + + static func getProducts(baseURL: URL) -> SubscriptionRequest? { + let path = "/products" + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .get) else { + return nil + } + return SubscriptionRequest(apiRequest: request) + } + + static func getCustomerPortalURL(baseURL: URL, accessToken: String, externalID: String) -> SubscriptionRequest? { + let path = "/checkout/portal" + let headers = [ + "externalAccountId": externalID + ] + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .get, + headers: APIRequestV2.HeadersV2(authToken: accessToken, + additionalHeaders: headers)) else { + return nil + } + return SubscriptionRequest(apiRequest: request) + } + + static func confirmPurchase(baseURL: URL, accessToken: String, signature: String) -> SubscriptionRequest? { + let path = "/purchase/confirm/apple" + let bodyDict = ["signedTransactionInfo": signature] + guard let bodyData = CodableHelper.encode(bodyDict) else { return nil } + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + method: .post, + headers: APIRequestV2.HeadersV2(authToken: accessToken), + body: bodyData, + retryPolicy: APIRequestV2.RetryPolicy(maxRetries: 3, delay: 4.0)) else { + return nil + } + return SubscriptionRequest(apiRequest: request) + } + + static func subscriptionFeatures(baseURL: URL, subscriptionID: String) -> SubscriptionRequest? { + let path = "/products/\(subscriptionID)/features" + guard let request = APIRequestV2(url: baseURL.appendingPathComponent(path), + cachePolicy: .returnCacheDataElseLoad) else { // Cached on purpose, the response never changes + return nil + } + return SubscriptionRequest(apiRequest: request) + } +} diff --git a/Sources/Subscription/DefaultSubscriptionEndpointService+SubscriptionFeatureMappingCache.swift b/Sources/Subscription/DefaultSubscriptionEndpointService+SubscriptionFeatureMappingCache.swift new file mode 100644 index 000000000..fb27bc72f --- /dev/null +++ b/Sources/Subscription/DefaultSubscriptionEndpointService+SubscriptionFeatureMappingCache.swift @@ -0,0 +1,34 @@ +// +// DefaultSubscriptionEndpointService+SubscriptionFeatureMappingCache.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import Networking +import os.log + +extension DefaultSubscriptionEndpointService: SubscriptionFeatureMappingCache { + + public func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Networking.SubscriptionEntitlement] { + do { + let response = try await getSubscriptionFeatures(for: subscriptionIdentifier) + return response.features + } catch { + Logger.subscription.error("Failed to get subscription features: \(error)") + return [.networkProtection, .dataBrokerProtection, .identityTheftRestoration] + } + } +} diff --git a/Sources/Subscription/Flows/AppStore/AppStoreAccountManagementFlow.swift b/Sources/Subscription/Flows/AppStore/AppStoreAccountManagementFlow.swift deleted file mode 100644 index ff75ecf4b..000000000 --- a/Sources/Subscription/Flows/AppStore/AppStoreAccountManagementFlow.swift +++ /dev/null @@ -1,75 +0,0 @@ -// -// AppStoreAccountManagementFlow.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import StoreKit -import os.log - -public enum AppStoreAccountManagementFlowError: Swift.Error { - case noPastTransaction - case authenticatingWithTransactionFailed - case missingAuthTokenOnRefresh -} - -@available(macOS 12.0, iOS 15.0, *) -public protocol AppStoreAccountManagementFlow { - @discardableResult func refreshAuthTokenIfNeeded() async -> Result -} - -@available(macOS 12.0, iOS 15.0, *) -public final class DefaultAppStoreAccountManagementFlow: AppStoreAccountManagementFlow { - - private let authEndpointService: AuthEndpointService - private let storePurchaseManager: StorePurchaseManager - private let accountManager: AccountManager - - public init(authEndpointService: any AuthEndpointService, storePurchaseManager: any StorePurchaseManager, accountManager: any AccountManager) { - self.authEndpointService = authEndpointService - self.storePurchaseManager = storePurchaseManager - self.accountManager = accountManager - } - - @discardableResult - public func refreshAuthTokenIfNeeded() async -> Result { - Logger.subscription.info("[AppStoreAccountManagementFlow] refreshAuthTokenIfNeeded") - - guard let authToken = accountManager.authToken else { return .failure(.missingAuthTokenOnRefresh) } - - // Check if auth token if still valid - if case let .failure(validateTokenError) = await authEndpointService.validateToken(accessToken: authToken) { - Logger.subscription.error("[AppStoreAccountManagementFlow] validateToken error: \(String(reflecting: validateTokenError), privacy: .public)") - - // In case of invalid token attempt store based authentication to obtain a new one - guard let lastTransactionJWSRepresentation = await storePurchaseManager.mostRecentTransaction() else { return .failure(.noPastTransaction) } - - switch await authEndpointService.storeLogin(signature: lastTransactionJWSRepresentation) { - case .success(let response): - if response.externalID == accountManager.externalID { - let refreshedAuthToken = response.authToken - accountManager.storeAuthToken(token: refreshedAuthToken) - return .success(refreshedAuthToken) - } - case .failure(let storeLoginError): - Logger.subscription.error("[AppStoreAccountManagementFlow] storeLogin error: \(String(reflecting: storeLoginError), privacy: .public)") - return .failure(.authenticatingWithTransactionFailed) - } - } - - return .success(authToken) - } -} diff --git a/Sources/Subscription/Flows/AppStore/AppStorePurchaseFlow.swift b/Sources/Subscription/Flows/AppStore/AppStorePurchaseFlow.swift index 8e2e5c3f0..2c25e59f9 100644 --- a/Sources/Subscription/Flows/AppStore/AppStorePurchaseFlow.swift +++ b/Sources/Subscription/Flows/AppStore/AppStorePurchaseFlow.swift @@ -19,159 +19,209 @@ import Foundation import StoreKit import os.log +import Networking -public enum AppStorePurchaseFlowError: Swift.Error { +public enum AppStorePurchaseFlowError: Swift.Error, Equatable, LocalizedError { case noProductsFound case activeSubscriptionAlreadyPresent case authenticatingWithTransactionFailed - case accountCreationFailed - case purchaseFailed + case accountCreationFailed(Swift.Error) + case purchaseFailed(Swift.Error) case cancelledByUser case missingEntitlements case internalError + + public var errorDescription: String? { + switch self { + case .noProductsFound: + "No products found" + case .activeSubscriptionAlreadyPresent: + "An active subscription is already present" + case .authenticatingWithTransactionFailed: + "Authenticating with transaction failed" + case .accountCreationFailed(let subError): + "Account creation failed: \(subError.localizedDescription)" + case .purchaseFailed(let subError): + "Purchase failed: \(subError.localizedDescription)" + case .cancelledByUser: + "Purchase cancelled by user" + case .missingEntitlements: + "Missing entitlements" + case .internalError: + "Internal error" + } + } + + public static func == (lhs: AppStorePurchaseFlowError, rhs: AppStorePurchaseFlowError) -> Bool { + switch (lhs, rhs) { + case (.noProductsFound, .noProductsFound), + (.activeSubscriptionAlreadyPresent, .activeSubscriptionAlreadyPresent), + (.authenticatingWithTransactionFailed, .authenticatingWithTransactionFailed), + (.cancelledByUser, .cancelledByUser), + (.missingEntitlements, .missingEntitlements), + (.internalError, .internalError): + return true + case let (.accountCreationFailed(lhsError), .accountCreationFailed(rhsError)): + return lhsError.localizedDescription == rhsError.localizedDescription + case let (.purchaseFailed(lhsError), .purchaseFailed(rhsError)): + return lhsError.localizedDescription == rhsError.localizedDescription + default: + return false + } + } } @available(macOS 12.0, iOS 15.0, *) public protocol AppStorePurchaseFlow { typealias TransactionJWS = String - func purchaseSubscription(with subscriptionIdentifier: String, emailAccessToken: String?) async -> Result - @discardableResult - func completeSubscriptionPurchase(with transactionJWS: AppStorePurchaseFlow.TransactionJWS) async -> Result + func purchaseSubscription(with subscriptionIdentifier: String) async -> Result + @discardableResult func completeSubscriptionPurchase(with transactionJWS: TransactionJWS) async -> Result } @available(macOS 12.0, iOS 15.0, *) public final class DefaultAppStorePurchaseFlow: AppStorePurchaseFlow { - private let subscriptionEndpointService: SubscriptionEndpointService + private let subscriptionManager: any SubscriptionManager private let storePurchaseManager: StorePurchaseManager - private let accountManager: AccountManager private let appStoreRestoreFlow: AppStoreRestoreFlow - private let authEndpointService: AuthEndpointService - public init(subscriptionEndpointService: any SubscriptionEndpointService, + public init(subscriptionManager: any SubscriptionManager, storePurchaseManager: any StorePurchaseManager, - accountManager: any AccountManager, - appStoreRestoreFlow: any AppStoreRestoreFlow, - authEndpointService: any AuthEndpointService) { - self.subscriptionEndpointService = subscriptionEndpointService + appStoreRestoreFlow: any AppStoreRestoreFlow + ) { + self.subscriptionManager = subscriptionManager self.storePurchaseManager = storePurchaseManager - self.accountManager = accountManager self.appStoreRestoreFlow = appStoreRestoreFlow - self.authEndpointService = authEndpointService } - public func purchaseSubscription(with subscriptionIdentifier: String, emailAccessToken: String?) async -> Result { - Logger.subscription.info("[AppStorePurchaseFlow] purchaseSubscription") - let externalID: String + public func purchaseSubscription(with subscriptionIdentifier: String) async -> Result { + Logger.subscriptionAppStorePurchaseFlow.log("Purchasing Subscription") - // If the current account is a third party expired account, we want to purchase and attach subs to it + var externalID: String? if let existingExternalID = await getExpiredSubscriptionID() { + Logger.subscriptionAppStorePurchaseFlow.log("External ID retrieved from expired subscription") externalID = existingExternalID - - // Otherwise, try to retrieve an expired Apple subscription or create a new one } else { - // Check for past transactions most recent + Logger.subscriptionAppStorePurchaseFlow.log("Try to retrieve an expired Apple subscription or create a new one") + + // Try to restore an account from a past purchase switch await appStoreRestoreFlow.restoreAccountFromPastPurchase() { case .success: - Logger.subscription.info("[AppStorePurchaseFlow] purchaseSubscription -> restoreAccountFromPastPurchase: activeSubscriptionAlreadyPresent") + Logger.subscriptionAppStorePurchaseFlow.log("An active subscription is already present") return .failure(.activeSubscriptionAlreadyPresent) case .failure(let error): - Logger.subscription.info("[AppStorePurchaseFlow] purchaseSubscription -> restoreAccountFromPastPurchase: \(String(reflecting: error), privacy: .public)") - switch error { - case .subscriptionExpired(let expiredAccountDetails): - externalID = expiredAccountDetails.externalID - accountManager.storeAuthToken(token: expiredAccountDetails.authToken) - accountManager.storeAccount(token: expiredAccountDetails.accessToken, email: expiredAccountDetails.email, externalID: expiredAccountDetails.externalID) - default: - switch await authEndpointService.createAccount(emailAccessToken: emailAccessToken) { - case .success(let response): - externalID = response.externalID - - if case let .success(accessToken) = await accountManager.exchangeAuthTokenToAccessToken(response.authToken), - case let .success(accountDetails) = await accountManager.fetchAccountDetails(with: accessToken) { - accountManager.storeAuthToken(token: response.authToken) - accountManager.storeAccount(token: accessToken, email: accountDetails.email, externalID: accountDetails.externalID) - } - case .failure(let error): - Logger.subscription.error("[AppStorePurchaseFlow] createAccount error: \(String(reflecting: error), privacy: .public)") - return .failure(.accountCreationFailed) + Logger.subscriptionAppStorePurchaseFlow.log("Failed to restore an account from a past purchase: \(error.localizedDescription, privacy: .public)") + do { + externalID = try await subscriptionManager.getTokenContainer(policy: .createIfNeeded).decodedAccessToken.externalID + } catch OAuthClientError.deadToken { + do { + let transactionJWS = try await recoverSubscriptionFromDeadToken() + return .success(transactionJWS) + } catch { + return .failure(.purchaseFailed(OAuthClientError.deadToken)) } + } catch Networking.OAuthClientError.missingTokens { + Logger.subscriptionStripePurchaseFlow.error("Failed to create a new account: \(error.localizedDescription, privacy: .public)") + return .failure(.accountCreationFailed(error)) + } catch { + Logger.subscriptionStripePurchaseFlow.error("Failed to create a new account: \(error.localizedDescription, privacy: .public), the operation is unrecoverable") + return .failure(.internalError) } } } + guard let externalID else { + Logger.subscriptionAppStorePurchaseFlow.fault("Missing external ID, subscription purchase failed") + return .failure(.internalError) + } + // Make the purchase switch await storePurchaseManager.purchaseSubscription(with: subscriptionIdentifier, externalID: externalID) { case .success(let transactionJWS): return .success(transactionJWS) case .failure(let error): - Logger.subscription.error("[AppStorePurchaseFlow] purchaseSubscription error: \(String(reflecting: error), privacy: .public)") - accountManager.signOut(skipNotification: true) + Logger.subscriptionAppStorePurchaseFlow.error("purchaseSubscription error: \(String(reflecting: error), privacy: .public)") + + await subscriptionManager.signOut() + switch error { case .purchaseCancelledByUser: return .failure(.cancelledByUser) default: - return .failure(.purchaseFailed) + return .failure(.purchaseFailed(error)) } } } @discardableResult public func completeSubscriptionPurchase(with transactionJWS: TransactionJWS) async -> Result { - - // Clear subscription Cache - subscriptionEndpointService.signOut() - - Logger.subscription.info("[AppStorePurchaseFlow] completeSubscriptionPurchase") - - guard let accessToken = accountManager.accessToken else { return .failure(.missingEntitlements) } - - let result = await callWithRetries(retry: 5, wait: 2.0) { - switch await subscriptionEndpointService.confirmPurchase(accessToken: accessToken, signature: transactionJWS) { - case .success(let confirmation): - subscriptionEndpointService.updateCache(with: confirmation.subscription) - accountManager.updateCache(with: confirmation.entitlements) - return true - case .failure: - return false + Logger.subscriptionAppStorePurchaseFlow.log("Completing Subscription Purchase") + + subscriptionManager.clearSubscriptionCache() + + do { + let subscription = try await subscriptionManager.confirmPurchase(signature: transactionJWS) + if subscription.isActive { + let refreshedToken = try await subscriptionManager.getTokenContainer(policy: .localForceRefresh) + if refreshedToken.decodedAccessToken.subscriptionEntitlements.isEmpty { + Logger.subscriptionAppStorePurchaseFlow.error("Missing entitlements") + return .failure(.missingEntitlements) + } else { + return .success(.completed) + } + } else { + Logger.subscriptionAppStorePurchaseFlow.error("Subscription expired") + // Removing all traces of the subscription and the account + return .failure(.purchaseFailed(AppStoreRestoreFlowError.subscriptionExpired)) + } + } catch OAuthClientError.deadToken { + do { + try await recoverSubscriptionFromDeadToken() + return .success(.completed) + } catch { + return .failure(.purchaseFailed(OAuthClientError.deadToken)) } + } catch { + Logger.subscriptionAppStorePurchaseFlow.error("Purchase Failed: \(error)") + return .failure(.purchaseFailed(error)) } - - return result ? .success(PurchaseUpdate.completed) : .failure(.missingEntitlements) } - private func callWithRetries(retry retryCount: Int, wait waitTime: Double, conditionToCheck: () async -> Bool) async -> Bool { - var count = 0 - var successful = false - - repeat { - successful = await conditionToCheck() - - if successful { - break - } else { - count += 1 - try? await Task.sleep(seconds: waitTime) + private func getExpiredSubscriptionID() async -> String? { + do { + let subscription = try await subscriptionManager.getSubscription(cachePolicy: .reloadIgnoringLocalCacheData) + // Only return an externalID if the subscription is expired so to prevent creating multiple subscriptions in the same account + if !subscription.isActive, + subscription.platform != .apple { + return try await subscriptionManager.getTokenContainer(policy: .localValid).decodedAccessToken.externalID } - } while !successful && count < retryCount - - return successful + return nil + } catch OAuthClientError.deadToken { + do { + try await recoverSubscriptionFromDeadToken() + return try? await subscriptionManager.getTokenContainer(policy: .localValid).decodedAccessToken.externalID + } catch { + Logger.subscription.error("Failed to retrieve the current subscription: Missing transaction JWS") + return nil + } + } catch { + return nil + } } - private func getExpiredSubscriptionID() async -> String? { - guard accountManager.isUserAuthenticated, - let externalID = accountManager.externalID, - let token = accountManager.accessToken - else { return nil } - - let subscriptionInfo = await subscriptionEndpointService.getSubscription(accessToken: token, cachePolicy: .reloadIgnoringLocalCacheData) - - // Only return an externalID if the subscription is expired - // To prevent creating multiple subscriptions in the same account - if case .success(let subscription) = subscriptionInfo, - !subscription.isActive, - subscription.platform != .apple { - return externalID + @discardableResult + private func recoverSubscriptionFromDeadToken() async throws -> String { + Logger.subscriptionAppStorePurchaseFlow.log("Recovering Subscription From Dead Token") + + // Clear everything, the token is unrecoverable + await subscriptionManager.signOut() + + switch await appStoreRestoreFlow.restoreAccountFromPastPurchase() { + case .success(let transactionJWS): + Logger.subscriptionAppStorePurchaseFlow.log("Subscription recovered") + return transactionJWS + case .failure(let error): + Logger.subscriptionAppStorePurchaseFlow.log("Failed to recover Apple subscription: \(error.localizedDescription, privacy: .public)") + throw error } - return nil } } diff --git a/Sources/Subscription/Flows/AppStore/AppStoreRestoreFlow.swift b/Sources/Subscription/Flows/AppStore/AppStoreRestoreFlow.swift index 004b77f8f..37eb8ad24 100644 --- a/Sources/Subscription/Flows/AppStore/AppStoreRestoreFlow.swift +++ b/Sources/Subscription/Flows/AppStore/AppStoreRestoreFlow.swift @@ -19,108 +19,75 @@ import Foundation import StoreKit import os.log +import Networking -public enum AppStoreRestoreFlowError: Swift.Error, Equatable { +public enum AppStoreRestoreFlowError: LocalizedError, Equatable { case missingAccountOrTransactions case pastTransactionAuthenticationError case failedToObtainAccessToken case failedToFetchAccountDetails case failedToFetchSubscriptionDetails - case subscriptionExpired(accountDetails: RestoredAccountDetails) -} - -public struct RestoredAccountDetails: Equatable { - let authToken: String - let accessToken: String - let externalID: String - let email: String? + case subscriptionExpired + + public var errorDescription: String? { + switch self { + case .missingAccountOrTransactions: + return "Missing account or transactions." + case .pastTransactionAuthenticationError: + return "Past transaction authentication error." + case .failedToObtainAccessToken: + return "Failed to obtain access token." + case .failedToFetchAccountDetails: + return "Failed to fetch account details." + case .failedToFetchSubscriptionDetails: + return "Failed to fetch subscription details." + case .subscriptionExpired: + return "Subscription expired." + } + } } @available(macOS 12.0, iOS 15.0, *) public protocol AppStoreRestoreFlow { - @discardableResult func restoreAccountFromPastPurchase() async -> Result + @discardableResult func restoreAccountFromPastPurchase() async -> Result } @available(macOS 12.0, iOS 15.0, *) public final class DefaultAppStoreRestoreFlow: AppStoreRestoreFlow { - private let accountManager: AccountManager + private let subscriptionManager: SubscriptionManager private let storePurchaseManager: StorePurchaseManager - private let subscriptionEndpointService: SubscriptionEndpointService - private let authEndpointService: AuthEndpointService - public init(accountManager: any AccountManager, - storePurchaseManager: any StorePurchaseManager, - subscriptionEndpointService: any SubscriptionEndpointService, - authEndpointService: any AuthEndpointService) { - self.accountManager = accountManager + public init(subscriptionManager: SubscriptionManager, + storePurchaseManager: any StorePurchaseManager) { + self.subscriptionManager = subscriptionManager self.storePurchaseManager = storePurchaseManager - self.subscriptionEndpointService = subscriptionEndpointService - self.authEndpointService = authEndpointService } @discardableResult - public func restoreAccountFromPastPurchase() async -> Result { + public func restoreAccountFromPastPurchase() async -> Result { + Logger.subscriptionAppStoreRestoreFlow.log("Restoring account from past purchase") // Clear subscription Cache - subscriptionEndpointService.signOut() - - Logger.subscription.info("[AppStoreRestoreFlow] restoreAccountFromPastPurchase") + subscriptionManager.clearSubscriptionCache() guard let lastTransactionJWSRepresentation = await storePurchaseManager.mostRecentTransaction() else { - Logger.subscription.error("[AppStoreRestoreFlow] Error: missingAccountOrTransactions") + Logger.subscriptionAppStoreRestoreFlow.error("Missing last transaction") return .failure(.missingAccountOrTransactions) } - // Do the store login to get short-lived token - let authToken: String - - switch await authEndpointService.storeLogin(signature: lastTransactionJWSRepresentation) { - case .success(let response): - authToken = response.authToken - case .failure: - Logger.subscription.error("[AppStoreRestoreFlow] Error: pastTransactionAuthenticationError") + do { + if let subscription = try await subscriptionManager.getSubscriptionFrom(lastTransactionJWSRepresentation: lastTransactionJWSRepresentation), + subscription.isActive { + return .success(lastTransactionJWSRepresentation) + } else { + Logger.subscriptionAppStoreRestoreFlow.error("Subscription expired") + // Removing all traces of the subscription and the account + await subscriptionManager.signOut() + return .failure(.subscriptionExpired) + } + } catch { + Logger.subscriptionAppStoreRestoreFlow.error("Error activating past transaction: \(error, privacy: .public)") return .failure(.pastTransactionAuthenticationError) } - - let accessToken: String - let email: String? - let externalID: String - - switch await accountManager.exchangeAuthTokenToAccessToken(authToken) { - case .success(let exchangedAccessToken): - accessToken = exchangedAccessToken - case .failure: - Logger.subscription.error("[AppStoreRestoreFlow] Error: failedToObtainAccessToken") - return .failure(.failedToObtainAccessToken) - } - - switch await accountManager.fetchAccountDetails(with: accessToken) { - case .success(let accountDetails): - email = accountDetails.email - externalID = accountDetails.externalID - case .failure: - Logger.subscription.error("[AppStoreRestoreFlow] Error: failedToFetchAccountDetails") - return .failure(.failedToFetchAccountDetails) - } - - var isSubscriptionActive = false - - switch await subscriptionEndpointService.getSubscription(accessToken: accessToken, cachePolicy: .reloadIgnoringLocalCacheData) { - case .success(let subscription): - isSubscriptionActive = subscription.isActive - case .failure: - Logger.subscription.error("[AppStoreRestoreFlow] Error: failedToFetchSubscriptionDetails") - return .failure(.failedToFetchSubscriptionDetails) - } - - if isSubscriptionActive { - accountManager.storeAuthToken(token: authToken) - accountManager.storeAccount(token: accessToken, email: email, externalID: externalID) - return .success(()) - } else { - let details = RestoredAccountDetails(authToken: authToken, accessToken: accessToken, externalID: externalID, email: email) - Logger.subscription.error("[AppStoreRestoreFlow] Error: subscriptionExpired") - return .failure(.subscriptionExpired(accountDetails: details)) - } } } diff --git a/Sources/Subscription/Flows/Models/PurchaseUpdate.swift b/Sources/Subscription/Flows/Models/PurchaseUpdate.swift index 027fa5f7d..27f60fc80 100644 --- a/Sources/Subscription/Flows/Models/PurchaseUpdate.swift +++ b/Sources/Subscription/Flows/Models/PurchaseUpdate.swift @@ -18,7 +18,7 @@ import Foundation -public struct PurchaseUpdate: Codable { +public struct PurchaseUpdate: Codable, Equatable { let type: String let token: String? diff --git a/Sources/Subscription/Flows/Models/SubscriptionOptions.swift b/Sources/Subscription/Flows/Models/SubscriptionOptions.swift index 2c69e31f6..0a628b27e 100644 --- a/Sources/Subscription/Flows/Models/SubscriptionOptions.swift +++ b/Sources/Subscription/Flows/Models/SubscriptionOptions.swift @@ -17,27 +17,41 @@ // import Foundation +import Networking public struct SubscriptionOptions: Encodable, Equatable { + struct Feature: Encodable, Equatable { + let name: SubscriptionEntitlement + } + let platform: SubscriptionPlatformName let options: [SubscriptionOption] - let features: [SubscriptionFeature] + /// The available features in the subscription based on the country and feature flags. Not based on user entitlements + let features: [SubscriptionOptions.Feature] + + public init(platform: SubscriptionPlatformName, options: [SubscriptionOption], availableEntitlements: [SubscriptionEntitlement]) { + self.platform = platform + self.options = options + self.features = availableEntitlements.map({ entitlement in + Feature(name: entitlement) + }) + } public static var empty: SubscriptionOptions { - let features = [SubscriptionFeature(name: .networkProtection), - SubscriptionFeature(name: .dataBrokerProtection), - SubscriptionFeature(name: .identityTheftRestoration)] + let features: [SubscriptionEntitlement] = [.networkProtection, .dataBrokerProtection, .identityTheftRestoration] let platform: SubscriptionPlatformName #if os(iOS) platform = .ios #else platform = .macos #endif - return SubscriptionOptions(platform: platform, options: [], features: features) + return SubscriptionOptions(platform: platform, options: [], availableEntitlements: features) } public func withoutPurchaseOptions() -> Self { - SubscriptionOptions(platform: platform, options: [], features: features) + SubscriptionOptions(platform: platform, options: [], availableEntitlements: features.map({ feature in + feature.name + })) } } @@ -64,10 +78,6 @@ struct SubscriptionOptionCost: Encodable, Equatable { let recurrence: String } -public struct SubscriptionFeature: Encodable, Equatable { - let name: Entitlement.ProductName -} - /// A `SubscriptionOptionOffer` represents an offer (e.g Free Trials) associated with a Subscription public struct SubscriptionOptionOffer: Encodable, Equatable { diff --git a/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift b/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift index 43e0448e7..bdde33910 100644 --- a/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift +++ b/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift @@ -19,6 +19,7 @@ import Foundation import StoreKit import os.log +import Networking public enum StripePurchaseFlowError: Swift.Error { case noProductsFound @@ -32,23 +33,18 @@ public protocol StripePurchaseFlow { } public final class DefaultStripePurchaseFlow: StripePurchaseFlow { - private let subscriptionEndpointService: SubscriptionEndpointService - private let authEndpointService: AuthEndpointService - private let accountManager: AccountManager - - public init(subscriptionEndpointService: any SubscriptionEndpointService, - authEndpointService: any AuthEndpointService, - accountManager: any AccountManager) { - self.subscriptionEndpointService = subscriptionEndpointService - self.authEndpointService = authEndpointService - self.accountManager = accountManager + private let subscriptionManager: SubscriptionManager + + public init(subscriptionManager: SubscriptionManager) { + self.subscriptionManager = subscriptionManager } public func subscriptionOptions() async -> Result { - Logger.subscription.info("[StripePurchaseFlow] subscriptionOptions") + Logger.subscriptionStripePurchaseFlow.log("Getting subscription options for Stripe") - guard case let .success(products) = await subscriptionEndpointService.getProducts(), !products.isEmpty else { - Logger.subscription.error("[StripePurchaseFlow] Error: noProductsFound") + guard let products = try? await subscriptionManager.getProducts(), + !products.isEmpty else { + Logger.subscriptionStripePurchaseFlow.error("Failed to obtain products") return .failure(.noProductsFound) } @@ -64,69 +60,53 @@ public final class DefaultStripePurchaseFlow: StripePurchaseFlow { if let price = Float($0.price), let formattedPrice = formatter.string(from: price as NSNumber) { displayPrice = formattedPrice } - let cost = SubscriptionOptionCost(displayPrice: displayPrice, recurrence: $0.billingPeriod.lowercased()) - - return SubscriptionOption(id: $0.productId, - cost: cost) + return SubscriptionOption(id: $0.productId, cost: cost) } - let features = [SubscriptionFeature(name: .networkProtection), - SubscriptionFeature(name: .dataBrokerProtection), - SubscriptionFeature(name: .identityTheftRestoration)] - + let features: [SubscriptionEntitlement] = [.networkProtection, + .dataBrokerProtection, + .identityTheftRestoration] return .success(SubscriptionOptions(platform: SubscriptionPlatformName.stripe, options: options, - features: features)) + availableEntitlements: features)) } public func prepareSubscriptionPurchase(emailAccessToken: String?) async -> Result { - Logger.subscription.info("[StripePurchaseFlow] prepareSubscriptionPurchase") + Logger.subscription.log("Preparing subscription purchase") - // Clear subscription Cache - subscriptionEndpointService.signOut() - var token: String = "" + subscriptionManager.clearSubscriptionCache() - if let accessToken = accountManager.accessToken { - if await isSubscriptionExpired(accessToken: accessToken) { - token = accessToken + if subscriptionManager.isUserAuthenticated { + if let subscriptionExpired = await isSubscriptionExpired(), + subscriptionExpired == true, + let tokenContainer = try? await subscriptionManager.getTokenContainer(policy: .localValid) { + return .success(PurchaseUpdate.redirect(withToken: tokenContainer.accessToken)) + } else { + return .success(PurchaseUpdate.redirect(withToken: "")) } } else { - switch await authEndpointService.createAccount(emailAccessToken: emailAccessToken) { - case .success(let response): - token = response.authToken - accountManager.storeAuthToken(token: token) - case .failure: - Logger.subscription.error("[StripePurchaseFlow] Error: accountCreationFailed") + do { + // Create account + let tokenContainer = try await subscriptionManager.getTokenContainer(policy: .createIfNeeded) + return .success(PurchaseUpdate.redirect(withToken: tokenContainer.accessToken)) + } catch { + Logger.subscriptionStripePurchaseFlow.error("Account creation failed: \(error.localizedDescription, privacy: .public)") return .failure(.accountCreationFailed) } } - - return .success(PurchaseUpdate.redirect(withToken: token)) } - private func isSubscriptionExpired(accessToken: String) async -> Bool { - if case .success(let subscription) = await subscriptionEndpointService.getSubscription(accessToken: accessToken) { - return !subscription.isActive + private func isSubscriptionExpired() async -> Bool? { + guard let subscription = try? await subscriptionManager.getSubscription(cachePolicy: .reloadIgnoringLocalCacheData) else { + return nil } - - return false + return !subscription.isActive } public func completeSubscriptionPurchase() async { - // Clear subscription Cache - subscriptionEndpointService.signOut() - - Logger.subscription.info("[StripePurchaseFlow] completeSubscriptionPurchase") - if !accountManager.isUserAuthenticated, - let authToken = accountManager.authToken { - if case let .success(accessToken) = await accountManager.exchangeAuthTokenToAccessToken(authToken), - case let .success(accountDetails) = await accountManager.fetchAccountDetails(with: accessToken) { - accountManager.storeAuthToken(token: authToken) - accountManager.storeAccount(token: accessToken, email: accountDetails.email, externalID: accountDetails.externalID) - } - } - - await accountManager.checkForEntitlements(wait: 2.0, retry: 5) + Logger.subscriptionStripePurchaseFlow.log("Completing subscription purchase") + subscriptionManager.clearSubscriptionCache() + _ = try? await subscriptionManager.getTokenContainer(policy: .localForceRefresh) } } diff --git a/Sources/Subscription/Logger+Subscription.swift b/Sources/Subscription/Logger+Subscription.swift index a09bd370d..0242b2a30 100644 --- a/Sources/Subscription/Logger+Subscription.swift +++ b/Sources/Subscription/Logger+Subscription.swift @@ -20,5 +20,13 @@ import Foundation import os.log public extension Logger { - static var subscription = { Logger(subsystem: "Subscription", category: "") }() + private static var subscriptionSubsystem = "Subscription" + static var subscription = { Logger(subsystem: Self.subscriptionSubsystem, category: "") }() + static var subscriptionAppStorePurchaseFlow = { Logger(subsystem: Self.subscriptionSubsystem, category: "AppStorePurchaseFlow") }() + static var subscriptionAppStoreRestoreFlow = { Logger(subsystem: Self.subscriptionSubsystem, category: "AppStoreRestoreFlow") }() + static var subscriptionStripePurchaseFlow = { Logger(subsystem: Self.subscriptionSubsystem, category: "StripePurchaseFlow") }() + static var subscriptionEndpointService = { Logger(subsystem: Self.subscriptionSubsystem, category: "EndpointService") }() + static var subscriptionStorePurchaseManager = { Logger(subsystem: Self.subscriptionSubsystem, category: "StorePurchaseManager") }() + static var subscriptionKeychain = { Logger(subsystem: Self.subscriptionSubsystem, category: "KeyChain") }() + static var subscriptionCookieManager = { Logger(subsystem: Self.subscriptionSubsystem, category: "CookieManager") }() } diff --git a/Sources/Subscription/Managers/AccountManager.swift b/Sources/Subscription/Managers/AccountManager.swift deleted file mode 100644 index bc769bbc6..000000000 --- a/Sources/Subscription/Managers/AccountManager.swift +++ /dev/null @@ -1,342 +0,0 @@ -// -// AccountManager.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common -import os.log - -public protocol AccountManagerKeychainAccessDelegate: AnyObject { - func accountManagerKeychainAccessFailed(accessType: AccountKeychainAccessType, error: AccountKeychainAccessError) -} - -public protocol AccountManager { - - var delegate: AccountManagerKeychainAccessDelegate? { get set } - var accessToken: String? { get } - var authToken: String? { get } - var email: String? { get } - var externalID: String? { get } - - func storeAuthToken(token: String) - func storeAccount(token: String, email: String?, externalID: String?) - func signOut(skipNotification: Bool) - func signOut() - - // Entitlements - func hasEntitlement(forProductName productName: Entitlement.ProductName, cachePolicy: APICachePolicy) async -> Result - - func updateCache(with entitlements: [Entitlement]) - @discardableResult func fetchEntitlements(cachePolicy: APICachePolicy) async -> Result<[Entitlement], Error> - func exchangeAuthTokenToAccessToken(_ authToken: String) async -> Result - - typealias AccountDetails = (email: String?, externalID: String) - func fetchAccountDetails(with accessToken: String) async -> Result - - @discardableResult func checkForEntitlements(wait waitTime: Double, retry retryCount: Int) async -> Bool -} - -extension AccountManager { - - public func hasEntitlement(forProductName productName: Entitlement.ProductName) async -> Result { - await hasEntitlement(forProductName: productName, cachePolicy: .returnCacheDataElseLoad) - } - - public func fetchEntitlements() async -> Result<[Entitlement], Error> { - await fetchEntitlements(cachePolicy: .returnCacheDataElseLoad) - } - - public var isUserAuthenticated: Bool { accessToken != nil } -} - -public final class DefaultAccountManager: AccountManager { - - private let storage: AccountStoring - private let entitlementsCache: UserDefaultsCache<[Entitlement]> - private let accessTokenStorage: SubscriptionTokenStoring - private let subscriptionEndpointService: SubscriptionEndpointService - private let authEndpointService: AuthEndpointService - - public weak var delegate: AccountManagerKeychainAccessDelegate? - - // MARK: - Initialisers - - public init(storage: AccountStoring = AccountKeychainStorage(), - accessTokenStorage: SubscriptionTokenStoring, - entitlementsCache: UserDefaultsCache<[Entitlement]>, - subscriptionEndpointService: SubscriptionEndpointService, - authEndpointService: AuthEndpointService) { - self.storage = storage - self.entitlementsCache = entitlementsCache - self.accessTokenStorage = accessTokenStorage - self.subscriptionEndpointService = subscriptionEndpointService - self.authEndpointService = authEndpointService - } - - // MARK: - - - public var authToken: String? { - do { - return try storage.getAuthToken() - } catch { - if let error = error as? AccountKeychainAccessError { - delegate?.accountManagerKeychainAccessFailed(accessType: .getAuthToken, error: error) - } else { - assertionFailure("Expected AccountKeychainAccessError") - } - - return nil - } - } - - public var accessToken: String? { - do { - return try accessTokenStorage.getAccessToken() - } catch { - if let error = error as? AccountKeychainAccessError { - delegate?.accountManagerKeychainAccessFailed(accessType: .getAccessToken, error: error) - } else { - assertionFailure("Expected AccountKeychainAccessError") - } - - return nil - } - } - - public var email: String? { - do { - return try storage.getEmail() - } catch { - if let error = error as? AccountKeychainAccessError { - delegate?.accountManagerKeychainAccessFailed(accessType: .getEmail, error: error) - } else { - assertionFailure("Expected AccountKeychainAccessError") - } - - return nil - } - } - - public var externalID: String? { - do { - return try storage.getExternalID() - } catch { - if let error = error as? AccountKeychainAccessError { - delegate?.accountManagerKeychainAccessFailed(accessType: .getExternalID, error: error) - } else { - assertionFailure("Expected AccountKeychainAccessError") - } - - return nil - } - } - - public func storeAuthToken(token: String) { - Logger.subscription.info("[AccountManager] storeAuthToken") - - do { - try storage.store(authToken: token) - } catch { - if let error = error as? AccountKeychainAccessError { - delegate?.accountManagerKeychainAccessFailed(accessType: .storeAuthToken, error: error) - } else { - assertionFailure("Expected AccountKeychainAccessError") - } - } - } - - public func storeAccount(token: String, email: String?, externalID: String?) { - Logger.subscription.info("[AccountManager] storeAccount") - - do { - try accessTokenStorage.store(accessToken: token) - } catch { - if let error = error as? AccountKeychainAccessError { - delegate?.accountManagerKeychainAccessFailed(accessType: .storeAccessToken, error: error) - } else { - assertionFailure("Expected AccountKeychainAccessError") - } - } - - do { - try storage.store(email: email) - } catch { - if let error = error as? AccountKeychainAccessError { - delegate?.accountManagerKeychainAccessFailed(accessType: .storeEmail, error: error) - } else { - assertionFailure("Expected AccountKeychainAccessError") - } - } - - do { - try storage.store(externalID: externalID) - } catch { - if let error = error as? AccountKeychainAccessError { - delegate?.accountManagerKeychainAccessFailed(accessType: .storeExternalID, error: error) - } else { - assertionFailure("Expected AccountKeychainAccessError") - } - } - NotificationCenter.default.post(name: .accountDidSignIn, object: self, userInfo: nil) - } - - public func signOut() { - signOut(skipNotification: false) - } - - public func signOut(skipNotification: Bool = false) { - Logger.subscription.info("[AccountManager] signOut") - - do { - try storage.clearAuthenticationState() - try accessTokenStorage.removeAccessToken() - subscriptionEndpointService.signOut() - entitlementsCache.reset() - } catch { - if let error = error as? AccountKeychainAccessError { - delegate?.accountManagerKeychainAccessFailed(accessType: .clearAuthenticationData, error: error) - } else { - assertionFailure("Expected AccountKeychainAccessError") - } - } - - if !skipNotification { - NotificationCenter.default.post(name: .accountDidSignOut, object: self, userInfo: nil) - } - } - - // MARK: - - public func hasEntitlement(forProductName productName: Entitlement.ProductName, cachePolicy: APICachePolicy) async -> Result { - switch await fetchEntitlements(cachePolicy: cachePolicy) { - case .success(let entitlements): - return .success(entitlements.compactMap { $0.product }.contains(productName)) - case .failure(let error): - return .failure(error) - } - } - - private func fetchRemoteEntitlements() async -> Result<[Entitlement], Error> { - guard let accessToken else { - entitlementsCache.reset() - return .failure(EntitlementsError.noAccessToken) - } - - switch await authEndpointService.validateToken(accessToken: accessToken) { - case .success(let response): - let entitlements = response.account.entitlements - updateCache(with: entitlements) - return .success(entitlements) - - case .failure(let error): - Logger.subscription.error("[AccountManager] fetchEntitlements error: \(error.localizedDescription, privacy: .public)") - return .failure(error) - } - } - - public func updateCache(with entitlements: [Entitlement]) { - let cachedEntitlements: [Entitlement] = entitlementsCache.get() ?? [] - - if entitlements != cachedEntitlements { - if entitlements.isEmpty { - entitlementsCache.reset() - } else { - entitlementsCache.set(entitlements) - } - NotificationCenter.default.post(name: .entitlementsDidChange, object: self, userInfo: [UserDefaultsCacheKey.subscriptionEntitlements: entitlements]) - } - } - - public enum EntitlementsError: Error { - case noAccessToken - case noCachedData - } - - @discardableResult - public func fetchEntitlements(cachePolicy: APICachePolicy) async -> Result<[Entitlement], Error> { - - switch cachePolicy { - case .reloadIgnoringLocalCacheData: - return await fetchRemoteEntitlements() - - case .returnCacheDataElseLoad: - if let cachedEntitlements: [Entitlement] = entitlementsCache.get() { - return .success(cachedEntitlements) - } else { - return await fetchRemoteEntitlements() - } - - case .returnCacheDataDontLoad: - if let cachedEntitlements: [Entitlement] = entitlementsCache.get() { - return .success(cachedEntitlements) - } else { - return .failure(EntitlementsError.noCachedData) - } - } - - } - - public func exchangeAuthTokenToAccessToken(_ authToken: String) async -> Result { - switch await authEndpointService.getAccessToken(token: authToken) { - case .success(let response): - return .success(response.accessToken) - case .failure(let error): - Logger.subscription.error("[AccountManager] exchangeAuthTokenToAccessToken error: \(error.localizedDescription, privacy: .public)") - return .failure(error) - } - } - - public func fetchAccountDetails(with accessToken: String) async -> Result { - switch await authEndpointService.validateToken(accessToken: accessToken) { - case .success(let response): - return .success(AccountDetails(email: response.account.email, externalID: response.account.externalID)) - case .failure(let error): - Logger.subscription.error("[AccountManager] fetchAccountDetails error: \(error.localizedDescription, privacy: .public)") - return .failure(error) - } - } - - @discardableResult - public func checkForEntitlements(wait waitTime: Double, retry retryCount: Int) async -> Bool { - var count = 0 - var hasEntitlements = false - - repeat { - switch await fetchEntitlements() { - case .success(let entitlements): - hasEntitlements = !entitlements.isEmpty - case .failure: - hasEntitlements = false - } - - if hasEntitlements { - break - } else { - count += 1 - try? await Task.sleep(seconds: waitTime) - } - } while !hasEntitlements && count < retryCount - - return hasEntitlements - } -} - -extension Task where Success == Never, Failure == Never { - static func sleep(seconds: Double) async throws { - let duration = UInt64(seconds * 1_000_000_000) - try await Task.sleep(nanoseconds: duration) - } -} diff --git a/Sources/Subscription/Managers/StorePurchaseManager/StorePurchaseManager.swift b/Sources/Subscription/Managers/StorePurchaseManager/StorePurchaseManager.swift index b23f24548..62f15c205 100644 --- a/Sources/Subscription/Managers/StorePurchaseManager/StorePurchaseManager.swift +++ b/Sources/Subscription/Managers/StorePurchaseManager/StorePurchaseManager.swift @@ -19,6 +19,7 @@ import Foundation import StoreKit import os.log +import Networking public enum StoreError: Error { case failedVerification @@ -66,8 +67,8 @@ public protocol StorePurchaseManager { @available(macOS 12.0, iOS 15.0, *) public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseManager { - private let storeSubscriptionConfiguration: StoreSubscriptionConfiguration - private let subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache + private let storeSubscriptionConfiguration: any StoreSubscriptionConfiguration + private let subscriptionFeatureMappingCache: any SubscriptionFeatureMappingCache private let subscriptionFeatureFlagger: FeatureFlaggerMapping? @Published public private(set) var availableProducts: [any SubscriptionProduct] = [] @@ -78,10 +79,9 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM public private(set) var currentStorefrontRegion: SubscriptionRegion = .usa private var transactionUpdates: Task? private var storefrontChanges: Task? - private var productFetcher: ProductFetching - public init(subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache, + public init(subscriptionFeatureMappingCache: any SubscriptionFeatureMappingCache, subscriptionFeatureFlagger: FeatureFlaggerMapping? = nil, productFetcher: ProductFetching = DefaultProductFetcher()) { self.storeSubscriptionConfiguration = DefaultStoreSubscriptionConfiguration() @@ -102,16 +102,16 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM do { purchaseQueue.removeAll() - Logger.subscription.info("[StorePurchaseManager] Before AppStore.sync()") + Logger.subscriptionStorePurchaseManager.log("Before AppStore.sync()") try await AppStore.sync() - Logger.subscription.info("[StorePurchaseManager] After AppStore.sync()") + Logger.subscriptionStorePurchaseManager.log("After AppStore.sync()") await updatePurchasedProducts() await updateAvailableProducts() } catch { - Logger.subscription.error("[StorePurchaseManager] Error: \(String(reflecting: error), privacy: .public) (\(error.localizedDescription, privacy: .public))") + Logger.subscriptionStorePurchaseManager.error("[StorePurchaseManager] Error: \(String(reflecting: error), privacy: .public) (\(error.localizedDescription, privacy: .public))") throw error } } @@ -128,13 +128,14 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM @MainActor public func updateAvailableProducts() async { - Logger.subscription.info("[StorePurchaseManager] updateAvailableProducts") + Logger.subscriptionStorePurchaseManager.log("Update available products") do { let storefrontCountryCode: String? let storefrontRegion: SubscriptionRegion - if let featureFlagger = subscriptionFeatureFlagger, featureFlagger.isFeatureOn(.isLaunchedROW) || featureFlagger.isFeatureOn(.isLaunchedROWOverride) { + if let featureFlagger = subscriptionFeatureFlagger, + featureFlagger.isFeatureOn(.isLaunchedROW) || featureFlagger.isFeatureOn(.isLaunchedROWOverride) { if let subscriptionFeatureFlagger, subscriptionFeatureFlagger.isFeatureOn(.usePrivacyProUSARegionOverride) { storefrontCountryCode = "USA" } else if let subscriptionFeatureFlagger, subscriptionFeatureFlagger.isFeatureOn(.usePrivacyProROWRegionOverride) { @@ -163,13 +164,13 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM } } } catch { - Logger.subscription.error("[StorePurchaseManager] Error: \(String(reflecting: error), privacy: .public)") + Logger.subscriptionStorePurchaseManager.error("Failed to fetch available products: \(String(reflecting: error), privacy: .public)") } } @MainActor public func updatePurchasedProducts() async { - Logger.subscription.info("[StorePurchaseManager] updatePurchasedProducts") + Logger.subscriptionStorePurchaseManager.log("Update purchased products") var purchasedSubscriptions: [String] = [] @@ -185,10 +186,10 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM } } } catch { - Logger.subscription.error("[StorePurchaseManager] Error: \(String(reflecting: error), privacy: .public)") + Logger.subscriptionStorePurchaseManager.error("Failed to update purchased products: \(String(reflecting: error), privacy: .public)") } - Logger.subscription.info("[StorePurchaseManager] updatePurchasedProducts fetched \(purchasedSubscriptions.count) active subscriptions") + Logger.subscriptionStorePurchaseManager.log("UpdatePurchasedProducts fetched \(purchasedSubscriptions.count) active subscriptions") if self.purchasedProductIDs != purchasedSubscriptions { self.purchasedProductIDs = purchasedSubscriptions @@ -197,31 +198,24 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM @MainActor public func mostRecentTransaction() async -> String? { - Logger.subscription.info("[StorePurchaseManager] mostRecentTransaction") + Logger.subscriptionStorePurchaseManager.log("Retrieving most recent transaction") var transactions: [VerificationResult] = [] - for await result in Transaction.all { transactions.append(result) } - - Logger.subscription.info("[StorePurchaseManager] mostRecentTransaction fetched \(transactions.count) transactions") - + let lastTransaction = transactions.first + Logger.subscriptionStorePurchaseManager.log("Most recent transaction fetched: \(lastTransaction?.debugDescription ?? "?") (tot: \(transactions.count) transactions)") return transactions.first?.jwsRepresentation } @MainActor public func hasActiveSubscription() async -> Bool { - Logger.subscription.info("[StorePurchaseManager] hasActiveSubscription") - var transactions: [VerificationResult] = [] - for await result in Transaction.currentEntitlements { transactions.append(result) } - - Logger.subscription.info("[StorePurchaseManager] hasActiveSubscription fetched \(transactions.count) transactions") - + Logger.subscriptionStorePurchaseManager.log("hasActiveSubscription fetched \(transactions.count) transactions") return !transactions.isEmpty } @@ -230,7 +224,7 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM guard let product = availableProducts.first(where: { $0.id == identifier }) else { return .failure(StorePurchaseManagerError.productNotFound) } - Logger.subscription.info("[StorePurchaseManager] purchaseSubscription \(product.displayName, privacy: .public) (\(externalID, privacy: .public))") + Logger.subscriptionStorePurchaseManager.log("Purchasing Subscription: \(product.displayName, privacy: .public) (\(externalID, privacy: .public))") purchaseQueue.append(product.id) @@ -239,7 +233,7 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM if let token = UUID(uuidString: externalID) { options.insert(.appAccountToken(token)) } else { - Logger.subscription.error("[StorePurchaseManager] Error: Failed to create UUID") + Logger.subscriptionStorePurchaseManager.error("Failed to create UUID from \(externalID, privacy: .public)") return .failure(StorePurchaseManagerError.externalIDisNotAValidUUID) } @@ -247,11 +241,11 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM do { purchaseResult = try await product.purchase(options: options) } catch { - Logger.subscription.error("[StorePurchaseManager] Error: \(String(reflecting: error), privacy: .public)") + Logger.subscriptionStorePurchaseManager.error("Error: \(String(reflecting: error), privacy: .public)") return .failure(StorePurchaseManagerError.purchaseFailed) } - Logger.subscription.info("[StorePurchaseManager] purchaseSubscription complete") + Logger.subscriptionStorePurchaseManager.log("PurchaseSubscription complete") purchaseQueue.removeAll() @@ -259,27 +253,27 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM case let .success(verificationResult): switch verificationResult { case let .verified(transaction): - Logger.subscription.info("[StorePurchaseManager] purchaseSubscription result: success") + Logger.subscriptionStorePurchaseManager.log("PurchaseSubscription result: success") // Successful purchase await transaction.finish() await self.updatePurchasedProducts() return .success(verificationResult.jwsRepresentation) case let .unverified(_, error): - Logger.subscription.info("[StorePurchaseManager] purchaseSubscription result: success /unverified/ - \(String(reflecting: error), privacy: .public)") + Logger.subscriptionStorePurchaseManager.log("purchaseSubscription result: success /unverified/ - \(String(reflecting: error), privacy: .public)") // Successful purchase but transaction/receipt can't be verified // Could be a jailbroken phone return .failure(StorePurchaseManagerError.transactionCannotBeVerified) } case .pending: - Logger.subscription.info("[StorePurchaseManager] purchaseSubscription result: pending") + Logger.subscriptionStorePurchaseManager.log("purchaseSubscription result: pending") // Transaction waiting on SCA (Strong Customer Authentication) or // approval from Ask to Buy return .failure(StorePurchaseManagerError.transactionPendingAuthentication) case .userCancelled: - Logger.subscription.info("[StorePurchaseManager] purchaseSubscription result: user cancelled") + Logger.subscriptionStorePurchaseManager.log("purchaseSubscription result: user cancelled") return .failure(StorePurchaseManagerError.purchaseCancelledByUser) @unknown default: - Logger.subscription.info("[StorePurchaseManager] purchaseSubscription result: unknown") + Logger.subscriptionStorePurchaseManager.log("purchaseSubscription result: unknown") return .failure(StorePurchaseManagerError.unknownError) } } @@ -302,20 +296,17 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM }() let options: [SubscriptionOption] = await [.init(from: monthly, withRecurrence: "monthly"), - .init(from: yearly, withRecurrence: "yearly")] - - let features: [SubscriptionFeature] - - if let featureFlagger = subscriptionFeatureFlagger, featureFlagger.isFeatureOn(.isLaunchedROW) || featureFlagger.isFeatureOn(.isLaunchedROWOverride) { - features = await subscriptionFeatureMappingCache.subscriptionFeatures(for: monthly.id).compactMap { SubscriptionFeature(name: $0) } + .init(from: yearly, withRecurrence: "yearly")] + let features: [SubscriptionEntitlement] + if let featureFlagger = subscriptionFeatureFlagger, + featureFlagger.isFeatureOn(.isLaunchedROW) || featureFlagger.isFeatureOn(.isLaunchedROWOverride) { + features = await subscriptionFeatureMappingCache.subscriptionFeatures(for: monthly.id) } else { - let allFeatures: [Entitlement.ProductName] = [.networkProtection, .dataBrokerProtection, .identityTheftRestoration] - features = allFeatures.compactMap { SubscriptionFeature(name: $0) } + features = [.networkProtection, .dataBrokerProtection, .identityTheftRestoration] } - return SubscriptionOptions(platform: platform, options: options, - features: features) + availableEntitlements: features) } private func checkVerified(_ result: VerificationResult) throws -> T { @@ -334,7 +325,7 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM Task.detached { [weak self] in for await result in Transaction.updates { - Logger.subscription.info("[StorePurchaseManager] observeTransactionUpdates") + Logger.subscriptionStorePurchaseManager.log("observeTransactionUpdates") if case .verified(let transaction) = result { await transaction.finish() @@ -349,7 +340,7 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM Task.detached { [weak self] in for await result in Storefront.updates { - Logger.subscription.info("[StorePurchaseManager] observeStorefrontChanges: \(result.countryCode)") + Logger.subscriptionStorePurchaseManager.log("observeStorefrontChanges: \(result.countryCode)") await self?.updatePurchasedProducts() await self?.updateAvailableProducts() } diff --git a/Sources/Subscription/Managers/SubscriptionManager.swift b/Sources/Subscription/Managers/SubscriptionManager.swift index cac861106..23fb726ba 100644 --- a/Sources/Subscription/Managers/SubscriptionManager.swift +++ b/Sources/Subscription/Managers/SubscriptionManager.swift @@ -18,53 +18,148 @@ import Foundation import Common +import os.log +import Networking -public protocol SubscriptionManager { - // Dependencies - var accountManager: AccountManager { get } - var subscriptionEndpointService: SubscriptionEndpointService { get } - var authEndpointService: AuthEndpointService { get } - var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache { get } +public enum SubscriptionManagerError: Error, Equatable { + case tokenUnavailable(error: Error?) + case confirmationHasInvalidSubscription + case noProductsFound + + public static func == (lhs: SubscriptionManagerError, rhs: SubscriptionManagerError) -> Bool { + switch (lhs, rhs) { + case (.tokenUnavailable(let lhsError), .tokenUnavailable(let rhsError)): + return lhsError?.localizedDescription == rhsError?.localizedDescription + case (.confirmationHasInvalidSubscription, .confirmationHasInvalidSubscription), + (.noProductsFound, .noProductsFound): + return true + default: + return false + } + } +} + +public enum SubscriptionPixelType { + case deadToken +} + +/// A `SubscriptionFeature` is **available** if the specific feature is `on` for the specific subscription. Feature availability if decided based on the country and the local and remote feature flags. +/// A `SubscriptionFeature` is **enabled** if the logged in user has the required entitlements. +public struct SubscriptionFeature: Equatable, CustomDebugStringConvertible { + public var entitlement: SubscriptionEntitlement + public var enabled: Bool + + public var debugDescription: String { + "\(entitlement.rawValue) is \(enabled ? "enabled" : "disabled")" + } +} + +/// The sole entity responsible of obtaining, storing and refreshing an OAuth Token +public protocol SubscriptionTokenProvider { + + /// Get a token container accordingly to the policy + /// - Parameter policy: The policy that will be used to get the token, it effects the tokens source and validity + /// - Returns: The TokenContainer + /// - Throws: OAuthClientError.deadToken if the token is unrecoverable. SubscriptionEndpointServiceError.noData if the token is not available. + @discardableResult + func getTokenContainer(policy: TokensCachePolicy) async throws -> TokenContainer + + /// Get a token container synchronously accordingly to the policy + /// - Parameter policy: The policy that will be used to get the token, it effects the tokens source and validity + /// - Returns: The TokenContainer, nil in case of error + func getTokenContainerSynchronously(policy: TokensCachePolicy) -> TokenContainer? + + /// Exchange access token v1 for a access token v2 + /// - Parameter tokenV1: The Auth v1 access token + /// - Returns: An auth v2 TokenContainer + func exchange(tokenV1: String) async throws -> TokenContainer + + /// Used only from the Mac Packet Tunnel Provider when a token is received during configuration + func adopt(tokenContainer: TokenContainer) async throws + + /// Remove the stored token container + func removeTokenContainer() +} + +public protocol SubscriptionManager: SubscriptionTokenProvider { // Environment static func loadEnvironmentFrom(userDefaults: UserDefaults) -> SubscriptionEnvironment? static func save(subscriptionEnvironment: SubscriptionEnvironment, userDefaults: UserDefaults) var currentEnvironment: SubscriptionEnvironment { get } + /// Tries to get an authentication token and request the subscription + func loadInitialData() + + // Subscription + func refreshCachedSubscription(completion: @escaping (_ isSubscriptionActive: Bool) -> Void) + @discardableResult func getSubscription(cachePolicy: SubscriptionCachePolicy) async throws -> PrivacyProSubscription + + /// Tries to activate a subscription using a platform signature + /// - Parameter lastTransactionJWSRepresentation: A platform signature coming from the AppStore + /// - Returns: A subscription if found + /// - Throws: An error if the access token is not available or something goes wrong in the api requests + func getSubscriptionFrom(lastTransactionJWSRepresentation: String) async throws -> PrivacyProSubscription? + var canPurchase: Bool { get } + func getProducts() async throws -> [GetProductsItem] + @available(macOS 12.0, iOS 15.0, *) func storePurchaseManager() -> StorePurchaseManager - func loadInitialData() - func refreshCachedSubscriptionAndEntitlements(completion: @escaping (_ isSubscriptionActive: Bool) -> Void) func url(for type: SubscriptionURL) -> URL - func currentSubscriptionFeatures() async -> [Entitlement.ProductName] + + func getCustomerPortalURL() async throws -> URL + + // User + var isUserAuthenticated: Bool { get } + var userEmail: String? { get } + + /// Sign out the user and clear all the tokens and subscription cache + func signOut() async + + func clearSubscriptionCache() + + /// Confirm a purchase with a platform signature + func confirmPurchase(signature: String) async throws -> PrivacyProSubscription + + /// Pixels handler + typealias PixelHandler = (SubscriptionPixelType) -> Void + + // MARK: - Features + + /// Get the current subscription features + /// A feature is based on an entitlement and can be enabled or disabled + /// A user cant have an entitlement without the feature, if a user is missing an entitlement the feature is disabled + func currentSubscriptionFeatures(forceRefresh: Bool) async -> [SubscriptionFeature] + + /// True if the feature can be used, false otherwise + func isFeatureActive(_ entitlement: SubscriptionEntitlement) async -> Bool } /// Single entry point for everything related to Subscription. This manager is disposable, every time something related to the environment changes this need to be recreated. public final class DefaultSubscriptionManager: SubscriptionManager { + + var oAuthClient: any OAuthClient private let _storePurchaseManager: StorePurchaseManager? - public let accountManager: AccountManager - public let subscriptionEndpointService: SubscriptionEndpointService - public let authEndpointService: AuthEndpointService - public let subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache + private let subscriptionEndpointService: SubscriptionEndpointService + private let pixelHandler: PixelHandler public let currentEnvironment: SubscriptionEnvironment - private let subscriptionFeatureFlagger: FeatureFlaggerMapping + private let subscriptionFeatureFlagger: FeatureFlaggerMapping? public init(storePurchaseManager: StorePurchaseManager? = nil, - accountManager: AccountManager, + oAuthClient: any OAuthClient, subscriptionEndpointService: SubscriptionEndpointService, - authEndpointService: AuthEndpointService, - subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache, subscriptionEnvironment: SubscriptionEnvironment, - subscriptionFeatureFlagger: FeatureFlaggerMapping) { + subscriptionFeatureFlagger: FeatureFlaggerMapping?, + pixelHandler: @escaping PixelHandler) { self._storePurchaseManager = storePurchaseManager - self.accountManager = accountManager + self.oAuthClient = oAuthClient self.subscriptionEndpointService = subscriptionEndpointService - self.authEndpointService = authEndpointService - self.subscriptionFeatureMappingCache = subscriptionFeatureMappingCache self.currentEnvironment = subscriptionEnvironment + self.pixelHandler = pixelHandler self.subscriptionFeatureFlagger = subscriptionFeatureFlagger +#if !NETP_SYSTEM_EXTENSION switch currentEnvironment.purchasePlatform { case .appStore: if #available(macOS 12.0, iOS 15.0, *) { @@ -75,6 +170,7 @@ public final class DefaultSubscriptionManager: SubscriptionManager { case .stripe: break } +#endif } public var canPurchase: Bool { @@ -84,7 +180,8 @@ public final class DefaultSubscriptionManager: SubscriptionManager { case .usa: return storePurchaseManager.areProductsAvailable case .restOfWorld: - if subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROW) || subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROWOverride) { + if let subscriptionFeatureFlagger, + subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROW) || subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROWOverride) { return storePurchaseManager.areProductsAvailable } else { return false @@ -117,7 +214,7 @@ public final class DefaultSubscriptionManager: SubscriptionManager { } } - // MARK: - Environment, ex SubscriptionPurchaseEnvironment + // MARK: - Environment @available(macOS 12.0, iOS 15.0, *) private func setupForAppStore() { Task { @@ -125,44 +222,58 @@ public final class DefaultSubscriptionManager: SubscriptionManager { } } - // MARK: - + // MARK: - Subscription public func loadInitialData() { + refreshCachedSubscription { isSubscriptionActive in + Logger.subscription.log("Subscription is \(isSubscriptionActive ? "active" : "not active")") + } + } + + public func refreshCachedSubscription(completion: @escaping (_ isSubscriptionActive: Bool) -> Void) { Task { - if let token = accountManager.accessToken { - _ = await subscriptionEndpointService.getSubscription(accessToken: token, cachePolicy: .reloadIgnoringLocalCacheData) - _ = await accountManager.fetchEntitlements(cachePolicy: .reloadIgnoringLocalCacheData) + guard let tokenContainer = try? await getTokenContainer(policy: .localForceRefresh) else { + completion(false) + return } + // Refetch and cache subscription + let subscription = try? await subscriptionEndpointService.getSubscription(accessToken: tokenContainer.accessToken, cachePolicy: .reloadIgnoringLocalCacheData) + completion(subscription?.isActive ?? false) } } - public func refreshCachedSubscriptionAndEntitlements(completion: @escaping (_ isSubscriptionActive: Bool) -> Void) { - Task { - guard let token = accountManager.accessToken else { return } + @discardableResult + public func getSubscription(cachePolicy: SubscriptionCachePolicy) async throws -> PrivacyProSubscription { + if !isUserAuthenticated { + throw SubscriptionEndpointServiceError.noData + } - var isSubscriptionActive = false + do { + let tokenContainer = try await getTokenContainer(policy: .localValid) + return try await subscriptionEndpointService.getSubscription(accessToken: tokenContainer.accessToken, cachePolicy: cachePolicy) + } catch SubscriptionEndpointServiceError.noData { +// await signOut() + throw SubscriptionEndpointServiceError.noData + } + } - defer { - completion(isSubscriptionActive) - } + public func getSubscriptionFrom(lastTransactionJWSRepresentation: String) async throws -> PrivacyProSubscription? { + do { + let tokenContainer = try await oAuthClient.activate(withPlatformSignature: lastTransactionJWSRepresentation) + return try await subscriptionEndpointService.getSubscription(accessToken: tokenContainer.accessToken, cachePolicy: .reloadIgnoringLocalCacheData) + } catch SubscriptionEndpointServiceError.noData { + return nil + } catch { + throw error + } + } - // Refetch and cache subscription - switch await subscriptionEndpointService.getSubscription(accessToken: token, cachePolicy: .reloadIgnoringLocalCacheData) { - case .success(let subscription): - isSubscriptionActive = subscription.isActive - case .failure(let error): - if case let .apiError(serviceError) = error, case let .serverError(statusCode, _) = serviceError { - if statusCode == 401 { - // Token is no longer valid - accountManager.signOut() - return - } - } - } + public func getProducts() async throws -> [GetProductsItem] { + try await subscriptionEndpointService.getProducts() + } - // Refetch and cache entitlements - _ = await accountManager.fetchEntitlements(cachePolicy: .reloadIgnoringLocalCacheData) - } + public func clearSubscriptionCache() { + subscriptionEndpointService.clearSubscription() } // MARK: - URLs @@ -171,20 +282,166 @@ public final class DefaultSubscriptionManager: SubscriptionManager { type.subscriptionURL(environment: currentEnvironment.serviceEnvironment) } - // MARK: - Current subscription's features + public func getCustomerPortalURL() async throws -> URL { + let tokenContainer = try await getTokenContainer(policy: .localValid) + // Get Stripe Customer Portal URL and update the model + let serviceResponse = try await subscriptionEndpointService.getCustomerPortalURL(accessToken: tokenContainer.accessToken, externalID: tokenContainer.decodedAccessToken.externalID) + guard let url = URL(string: serviceResponse.customerPortalUrl) else { + throw SubscriptionEndpointServiceError.noData + } + return url + } + + // MARK: - User + public var isUserAuthenticated: Bool { + oAuthClient.isUserAuthenticated + } + + public var userEmail: String? { + return oAuthClient.currentTokenContainer?.decodedAccessToken.email + } + + // MARK: - + + private func refreshAccount() async { + do { + try await getTokenContainer(policy: .localForceRefresh) + } catch { + Logger.subscription.error("Failed to refresh account: \(error.localizedDescription, privacy: .public)") + } + } + + @discardableResult public func getTokenContainer(policy: TokensCachePolicy) async throws -> TokenContainer { + do { + Logger.subscription.debug("Get tokens \(policy.description, privacy: .public)") + + let referenceCachedTokenContainer = try? await oAuthClient.getTokens(policy: .local) + let referenceCachedEntitlements = referenceCachedTokenContainer?.decodedAccessToken.subscriptionEntitlements + let resultTokenContainer = try await oAuthClient.getTokens(policy: policy) + let newEntitlements = resultTokenContainer.decodedAccessToken.subscriptionEntitlements + + // Send notification when entitlements change + if referenceCachedEntitlements != newEntitlements { + Logger.subscription.debug("Entitlements changed - New \(newEntitlements) Old \(String(describing: referenceCachedEntitlements))") + NotificationCenter.default.post(name: .entitlementsDidChange, object: self, userInfo: [UserDefaultsCacheKey.subscriptionEntitlements: newEntitlements]) + } - public func currentSubscriptionFeatures() async -> [Entitlement.ProductName] { - guard let token = accountManager.accessToken else { return [] } + if referenceCachedTokenContainer == nil { // new login + Logger.subscription.debug("New login detected") + NotificationCenter.default.post(name: .accountDidSignIn, object: self, userInfo: nil) + } + return resultTokenContainer + } catch OAuthClientError.deadToken { + return try await throwAppropriateDeadTokenError() + } catch { + throw SubscriptionManagerError.tokenUnavailable(error: error) + } + } - if subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROW) || subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROWOverride) { - switch await subscriptionEndpointService.getSubscription(accessToken: token, cachePolicy: .returnCacheDataElseLoad) { - case .success(let subscription): - return await subscriptionFeatureMappingCache.subscriptionFeatures(for: subscription.productId) - case .failure: + /// If the client succeeds in making a refresh request but does not get the response, then the second refresh request will fail with `invalidTokenRequest` and the stored token will become unusable and un-refreshable. + private func throwAppropriateDeadTokenError() async throws -> TokenContainer { + Logger.subscription.warning("Dead token detected") + do { + let subscription = try await subscriptionEndpointService.getSubscription(accessToken: "", // Token is unused + cachePolicy: .returnCacheDataDontLoad) + switch subscription.platform { + case .apple: + pixelHandler(.deadToken) + throw OAuthClientError.deadToken + default: + throw SubscriptionManagerError.tokenUnavailable(error: nil) + } + } catch { + throw SubscriptionManagerError.tokenUnavailable(error: error) + } + } + + public func getTokenContainerSynchronously(policy: TokensCachePolicy) -> TokenContainer? { + Logger.subscription.debug("Fetching tokens synchronously") + let semaphore = DispatchSemaphore(value: 0) + + Task(priority: .high) { + defer { semaphore.signal() } + return try? await getTokenContainer(policy: policy) + } + + semaphore.wait() + return nil + } + + public func exchange(tokenV1: String) async throws -> TokenContainer { + let tokenContainer = try await oAuthClient.exchange(accessTokenV1: tokenV1) + NotificationCenter.default.post(name: .accountDidSignIn, object: self, userInfo: nil) + return tokenContainer + } + + public func adopt(tokenContainer: TokenContainer) async throws { + oAuthClient.currentTokenContainer = tokenContainer + } + + public func removeTokenContainer() { + oAuthClient.removeLocalAccount() + } + + public func signOut() async { + Logger.subscription.log("Removing all traces of the subscription and auth tokens") + try? await oAuthClient.logout() + subscriptionEndpointService.clearSubscription() + NotificationCenter.default.post(name: .accountDidSignOut, object: self, userInfo: nil) + } + + public func confirmPurchase(signature: String) async throws -> PrivacyProSubscription { + Logger.subscription.log("Confirming Purchase...") + let accessToken = try await getTokenContainer(policy: .localValid).accessToken + let confirmation = try await subscriptionEndpointService.confirmPurchase(accessToken: accessToken, signature: signature) + try await subscriptionEndpointService.ingestSubscription(confirmation.subscription) + Logger.subscription.log("Purchase confirmed!") + return confirmation.subscription + } + + // MARK: - Features + + /// Returns the features available for the current subscription, a feature is enabled only if the user has the corresponding entitlement + /// - Parameter forceRefresh: ignore subscription and token cache and re-download everything + /// - Returns: An Array of SubscriptionFeature where each feature is enabled or disabled based on the user entitlements + public func currentSubscriptionFeatures(forceRefresh: Bool) async -> [SubscriptionFeature] { + guard isUserAuthenticated else { return [] } + + if let subscriptionFeatureFlagger, + subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROW) || subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROWOverride) { + do { + let currentSubscription = try await getSubscription(cachePolicy: .returnCacheDataDontLoad) + let tokenContainer = try await getTokenContainer(policy: forceRefresh ? .localForceRefresh : .local) + let userEntitlements = tokenContainer.decodedAccessToken.subscriptionEntitlements + let availableFeatures = currentSubscription.features ?? [] // await subscriptionFeatureMappingCache.subscriptionFeatures(for: subscription.productId) + + // Filter out the features that are not available because the user doesn't have the right entitlements + let result = availableFeatures.map({ featureEntitlement in + let enabled = userEntitlements.contains(featureEntitlement) + return SubscriptionFeature(entitlement: featureEntitlement, enabled: enabled) + }) + Logger.subscription.log(""" +User entitlements: \(userEntitlements) +Available Features: \(availableFeatures) +Subscription features: \(result) +""") + return result + } catch { return [] } } else { - return [.networkProtection, .dataBrokerProtection, .identityTheftRestoration] + let result = [SubscriptionFeature(entitlement: .networkProtection, enabled: true), + SubscriptionFeature(entitlement: .dataBrokerProtection, enabled: true), + SubscriptionFeature(entitlement: .identityTheftRestoration, enabled: true)] + Logger.subscription.debug("Default Subscription features: \(result)") + return result + } + } + + public func isFeatureActive(_ entitlement: SubscriptionEntitlement) async -> Bool { + let currentFeatures = await currentSubscriptionFeatures(forceRefresh: false) + return currentFeatures.contains { feature in + feature.entitlement == entitlement && feature.enabled } } } diff --git a/Sources/Subscription/AccountStorage/AccountKeychainStorage.swift b/Sources/Subscription/Storage/V1/AccountKeychainStorage.swift similarity index 99% rename from Sources/Subscription/AccountStorage/AccountKeychainStorage.swift rename to Sources/Subscription/Storage/V1/AccountKeychainStorage.swift index 13e83d3ea..31205ae8e 100644 --- a/Sources/Subscription/AccountStorage/AccountKeychainStorage.swift +++ b/Sources/Subscription/Storage/V1/AccountKeychainStorage.swift @@ -100,7 +100,7 @@ public final class AccountKeychainStorage: AccountStoring { } } -private extension AccountKeychainStorage { +extension AccountKeychainStorage { /* Uses just kSecAttrService as the primary key, since we don't want to store diff --git a/Sources/Subscription/AccountStorage/AccountStoring.swift b/Sources/Subscription/Storage/V1/AccountStoring.swift similarity index 100% rename from Sources/Subscription/AccountStorage/AccountStoring.swift rename to Sources/Subscription/Storage/V1/AccountStoring.swift diff --git a/Sources/Subscription/AccountStorage/SubscriptionTokenKeychainStorage.swift b/Sources/Subscription/Storage/V1/SubscriptionTokenKeychainStorage.swift similarity index 85% rename from Sources/Subscription/AccountStorage/SubscriptionTokenKeychainStorage.swift rename to Sources/Subscription/Storage/V1/SubscriptionTokenKeychainStorage.swift index 3e5772a44..89221f988 100644 --- a/Sources/Subscription/AccountStorage/SubscriptionTokenKeychainStorage.swift +++ b/Sources/Subscription/Storage/V1/SubscriptionTokenKeychainStorage.swift @@ -17,6 +17,7 @@ // import Foundation +import Common public final class SubscriptionTokenKeychainStorage: SubscriptionTokenStoring { @@ -145,40 +146,7 @@ private extension SubscriptionTokenKeychainStorage { kSecClass: kSecClassGenericPassword, kSecAttrSynchronizable: false ] - attributes.merge(keychainType.queryAttributes()) { $1 } - return attributes } } - -public enum KeychainType { - case dataProtection(_ accessGroup: AccessGroup) - /// Uses the system keychain. - case system - case fileBased - - public enum AccessGroup { - case unspecified - case named(_ name: String) - } - - func queryAttributes() -> [CFString: Any] { - switch self { - case .dataProtection(let accessGroup): - switch accessGroup { - case .unspecified: - return [kSecUseDataProtectionKeychain: true] - case .named(let accessGroup): - return [ - kSecUseDataProtectionKeychain: true, - kSecAttrAccessGroup: accessGroup - ] - } - case .system: - return [kSecUseDataProtectionKeychain: false] - case .fileBased: - return [kSecUseDataProtectionKeychain: false] - } - } -} diff --git a/Sources/Subscription/AccountStorage/SubscriptionTokenStoring.swift b/Sources/Subscription/Storage/V1/SubscriptionTokenStoring.swift similarity index 100% rename from Sources/Subscription/AccountStorage/SubscriptionTokenStoring.swift rename to Sources/Subscription/Storage/V1/SubscriptionTokenStoring.swift diff --git a/Sources/Subscription/Storage/V2/SubscriptionTokenKeychainStorage+LegacyTokenStoring.swift b/Sources/Subscription/Storage/V2/SubscriptionTokenKeychainStorage+LegacyTokenStoring.swift new file mode 100644 index 000000000..7445496e0 --- /dev/null +++ b/Sources/Subscription/Storage/V2/SubscriptionTokenKeychainStorage+LegacyTokenStoring.swift @@ -0,0 +1,45 @@ +// +// SubscriptionTokenKeychainStorage+LegacyTokenStoring.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import Networking + +extension SubscriptionTokenKeychainStorage: LegacyTokenStoring { + + public var token: String? { + get { + do { + return try getAccessToken() + } catch { + assertionFailure("Failed to retrieve auth token: \(error)") + } + return nil + } + set(newValue) { + do { + guard let newValue else { + try removeAccessToken() + return + } + try store(accessToken: newValue) + } catch { + assertionFailure("Failed set token: \(error)") + } + } + } +} diff --git a/Sources/Subscription/Storage/V2/SubscriptionTokenKeychainStorageV2.swift b/Sources/Subscription/Storage/V2/SubscriptionTokenKeychainStorageV2.swift new file mode 100644 index 000000000..764068c80 --- /dev/null +++ b/Sources/Subscription/Storage/V2/SubscriptionTokenKeychainStorageV2.swift @@ -0,0 +1,170 @@ +// +// SubscriptionTokenKeychainStorageV2.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import os.log +import Networking +import Common + +public final class SubscriptionTokenKeychainStorageV2: TokenStoring { + + private let keychainType: KeychainType + + public init(keychainType: KeychainType = .dataProtection(.unspecified)) { + self.keychainType = keychainType + } + + public var tokenContainer: TokenContainer? { + get { + guard let data = try? retrieveData(forField: .tokens) else { + Logger.subscriptionKeychain.debug("TokenContainer not found") + return nil + } + return CodableHelper.decode(jsonData: data) + } + set { + do { + guard let newValue else { + Logger.subscriptionKeychain.debug("remove TokenContainer") + try self.deleteItem(forField: .tokens) + return + } + + if let data = CodableHelper.encode(newValue) { + try self.store(data: data, forField: .tokens) + } else { + Logger.subscriptionKeychain.fault("Failed to encode TokenContainer") + assertionFailure("Failed to encode TokenContainer") + } + } catch { + Logger.subscriptionKeychain.fault("Failed to set TokenContainer: \(error, privacy: .public)") + assertionFailure("Failed to set TokenContainer") + } + } + } +} + +extension SubscriptionTokenKeychainStorageV2 { + + /* + Uses just kSecAttrService as the primary key, since we don't want to store + multiple accounts/tokens at the same time + */ + enum SubscriptionKeychainField: String, CaseIterable { + case tokens = "subscription.v2.tokens" + + var keyValue: String { + "com.duckduckgo" + "." + rawValue + } + } + + func getString(forField field: SubscriptionKeychainField) throws -> String? { + guard let data = try retrieveData(forField: field) else { + return nil + } + + if let decodedString = String(data: data, encoding: String.Encoding.utf8) { + return decodedString + } else { + throw AccountKeychainAccessError.failedToDecodeKeychainDataAsString + } + } + + func retrieveData(forField field: SubscriptionKeychainField) throws -> Data? { + var query = defaultAttributes() + query[kSecAttrService] = field.keyValue + query[kSecMatchLimit] = kSecMatchLimitOne + query[kSecReturnData] = true + + var item: CFTypeRef? + let status = SecItemCopyMatching(query as CFDictionary, &item) + + if status == errSecSuccess { + if let existingItem = item as? Data { + return existingItem + } else { + throw AccountKeychainAccessError.failedToDecodeKeychainValueAsData + } + } else if status == errSecItemNotFound { + return nil + } else { + throw AccountKeychainAccessError.keychainLookupFailure(status) + } + } + + func set(string: String, forField field: SubscriptionKeychainField) throws { + guard let stringData = string.data(using: .utf8) else { + return + } + + try store(data: stringData, forField: field) + } + + func store(data: Data, forField field: SubscriptionKeychainField) throws { + var query = defaultAttributes() + query[kSecAttrService] = field.keyValue + query[kSecAttrAccessible] = kSecAttrAccessibleAfterFirstUnlock + query[kSecValueData] = data + + let status = SecItemAdd(query as CFDictionary, nil) + + switch status { + case errSecSuccess: + return + case errSecDuplicateItem: + let updateStatus = updateData(data, forField: field) + + if updateStatus != errSecSuccess { + throw AccountKeychainAccessError.keychainSaveFailure(status) + } + default: + throw AccountKeychainAccessError.keychainSaveFailure(status) + } + } + + private func updateData(_ data: Data, forField field: SubscriptionKeychainField) -> OSStatus { + var query = defaultAttributes() + query[kSecAttrService] = field.keyValue + + let newAttributes = [ + kSecValueData: data, + kSecAttrAccessible: kSecAttrAccessibleAfterFirstUnlock + ] as [CFString: Any] + + return SecItemUpdate(query as CFDictionary, newAttributes as CFDictionary) + } + + func deleteItem(forField field: SubscriptionKeychainField, useDataProtectionKeychain: Bool = true) throws { + let query = defaultAttributes() + + let status = SecItemDelete(query as CFDictionary) + + if status != errSecSuccess && status != errSecItemNotFound { + throw AccountKeychainAccessError.keychainDeleteFailure(status) + } + } + + private func defaultAttributes() -> [CFString: Any] { + var attributes: [CFString: Any] = [ + kSecClass: kSecClassGenericPassword, + kSecAttrSynchronizable: false + ] + attributes.merge(keychainType.queryAttributes()) { $1 } + return attributes + } +} diff --git a/Sources/Subscription/SubscriptionCookie/SubscriptionCookieManager.swift b/Sources/Subscription/SubscriptionCookie/SubscriptionCookieManager.swift index eff9f2e69..5fe425eff 100644 --- a/Sources/Subscription/SubscriptionCookie/SubscriptionCookieManager.swift +++ b/Sources/Subscription/SubscriptionCookie/SubscriptionCookieManager.swift @@ -21,7 +21,6 @@ import Common import os.log public protocol SubscriptionCookieManaging { - init(subscriptionManager: SubscriptionManager, currentCookieStore: @MainActor @escaping () -> HTTPCookieStore?, eventMapping: EventMapping) func enableSettingSubscriptionCookie() func disableSettingSubscriptionCookie() async @@ -86,18 +85,17 @@ public final class SubscriptionCookieManager: SubscriptionCookieManaging { let cookieStore = await currentCookieStore() else { return } - guard let accessToken = subscriptionManager.accountManager.accessToken else { - Logger.subscription.error("[SubscriptionCookieManager] Handle .accountDidSignIn - can't set the cookie, token is missing") - eventMapping.fire(.errorHandlingAccountDidSignInTokenIsMissing) - return - } - Logger.subscription.info("[SubscriptionCookieManager] Handle .accountDidSignIn - setting cookie") - do { + let accessToken = try await subscriptionManager.getTokenContainer(policy: .local).accessToken + Logger.subscriptionCookieManager.info("Handle .accountDidSignIn - setting cookie") try await cookieStore.setSubscriptionCookie(for: accessToken) updateLastRefreshDateToNow() - } catch { + } catch SubscriptionCookieManagerError.failedToCreateSubscriptionCookie { eventMapping.fire(.failedToSetSubscriptionCookie) + } catch { + Logger.subscriptionCookieManager.error("Handle .accountDidSignIn - can't set the cookie, token is missing") + eventMapping.fire(.errorHandlingAccountDidSignInTokenIsMissing) + return } } } @@ -107,7 +105,7 @@ public final class SubscriptionCookieManager: SubscriptionCookieManaging { guard isSettingSubscriptionCookieEnabled, let cookieStore = await currentCookieStore() else { return } - Logger.subscription.info("[SubscriptionCookieManager] Handle .accountDidSignOut - deleting cookie") + Logger.subscriptionCookieManager.info("Handle .accountDidSignOut - deleting cookie") do { try await cookieStore.setEmptySubscriptionCookie() @@ -123,17 +121,17 @@ public final class SubscriptionCookieManager: SubscriptionCookieManaging { shouldRefreshSubscriptionCookie(), let cookieStore = await currentCookieStore() else { return } - Logger.subscription.info("[SubscriptionCookieManager] Refresh subscription cookie") + Logger.subscriptionCookieManager.info("Refresh subscription cookie") updateLastRefreshDateToNow() - let accessToken: String? = subscriptionManager.accountManager.accessToken + let accessToken: String? = try? await subscriptionManager.getTokenContainer(policy: .local).accessToken let subscriptionCookie = await cookieStore.fetchCurrentSubscriptionCookie() let noCookieOrWithUnexpectedValue = (accessToken ?? "") != subscriptionCookie?.value do { if noCookieOrWithUnexpectedValue { - Logger.subscription.info("[SubscriptionCookieManager] Refresh: No cookie or one with unexpected value") + Logger.subscriptionCookieManager.info("Refresh: No cookie or one with unexpected value") if let accessToken { try await cookieStore.setSubscriptionCookie(for: accessToken) @@ -190,12 +188,12 @@ private extension HTTPCookieStore { .secure: true, .init(rawValue: "HttpOnly"): true ]) else { - Logger.subscription.error("[HTTPCookieStore] Subscription cookie could not be created") + Logger.subscriptionCookieManager.error("Subscription cookie could not be created") assertionFailure("Subscription cookie could not be created") throw SubscriptionCookieManagerError.failedToCreateSubscriptionCookie } - Logger.subscription.info("[HTTPCookieStore] Setting subscription cookie") + Logger.subscriptionCookieManager.info("Setting subscription cookie") await setCookie(cookie) } } diff --git a/Sources/Subscription/SubscriptionEnvironment.swift b/Sources/Subscription/SubscriptionEnvironment.swift index 3f5ed3bf2..c84b1264e 100644 --- a/Sources/Subscription/SubscriptionEnvironment.swift +++ b/Sources/Subscription/SubscriptionEnvironment.swift @@ -20,13 +20,15 @@ import Foundation public struct SubscriptionEnvironment: Codable { - public enum ServiceEnvironment: Codable { + public enum ServiceEnvironment: String, Codable { case production, staging - public var description: String { + public var url: URL { switch self { - case .production: return "Production" - case .staging: return "Staging" + case .production: + URL(string: "https://subscriptions.duckduckgo.com/api")! + case .staging: + URL(string: "https://subscriptions-dev.duckduckgo.com/api")! } } } @@ -42,4 +44,8 @@ public struct SubscriptionEnvironment: Codable { self.serviceEnvironment = serviceEnvironment self.purchasePlatform = purchasePlatform } + + public var description: String { + "ServiceEnvironment: \(serviceEnvironment.rawValue), PurchasePlatform: \(purchasePlatform.rawValue)" + } } diff --git a/Sources/Subscription/SubscriptionFeatureMappingCache.swift b/Sources/Subscription/SubscriptionFeatureMappingCache.swift index 9bfa3cee1..1fe4cf924 100644 --- a/Sources/Subscription/SubscriptionFeatureMappingCache.swift +++ b/Sources/Subscription/SubscriptionFeatureMappingCache.swift @@ -18,117 +18,8 @@ import Foundation import os.log +import Networking public protocol SubscriptionFeatureMappingCache { - func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName] + func subscriptionFeatures(for subscriptionIdentifier: String) async -> [SubscriptionEntitlement] } - -public actor DefaultSubscriptionFeatureMappingCache: SubscriptionFeatureMappingCache { - - private let subscriptionEndpointService: SubscriptionEndpointService - private let userDefaults: UserDefaults - - public init(subscriptionEndpointService: SubscriptionEndpointService, userDefaults: UserDefaults) { - self.subscriptionEndpointService = subscriptionEndpointService - self.userDefaults = userDefaults - } - - public func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName] { - Logger.subscription.debug("[SubscriptionFeatureMappingCache] \(#function) \(subscriptionIdentifier)") - let features: [Entitlement.ProductName] - - if let subscriptionFeatures = currentSubscriptionFeatureMapping[subscriptionIdentifier] { - Logger.subscription.debug("[SubscriptionFeatureMappingCache] - got cached features") - features = subscriptionFeatures - } else if let subscriptionFeatures = await fetchRemoteFeatures(for: subscriptionIdentifier) { - Logger.subscription.debug("[SubscriptionFeatureMappingCache] - fetching features from BE API") - features = subscriptionFeatures - updateCachedFeatureMapping(with: subscriptionFeatures, for: subscriptionIdentifier) - } else { - Logger.subscription.debug("[SubscriptionFeatureMappingCache] - Error: using fallback") - features = fallbackFeatures - } - - return features - } - - // MARK: - Current feature mapping - - private var currentSubscriptionFeatureMapping: SubscriptionFeatureMapping { - Logger.subscription.debug("[SubscriptionFeatureMappingCache] - \(#function)") - let featureMapping: SubscriptionFeatureMapping - - if let cachedFeatureMapping { - Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- got cachedFeatureMapping") - featureMapping = cachedFeatureMapping - } else if let storedFeatureMapping { - Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- have to fetchStoredFeatureMapping") - featureMapping = storedFeatureMapping - updateCachedFeatureMapping(to: featureMapping) - } else { - Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- so creating a new one!") - featureMapping = SubscriptionFeatureMapping() - updateCachedFeatureMapping(to: featureMapping) - } - - return featureMapping - } - - // MARK: - Cached subscription feature mapping - - private var cachedFeatureMapping: SubscriptionFeatureMapping? - - private func updateCachedFeatureMapping(to featureMapping: SubscriptionFeatureMapping) { - cachedFeatureMapping = featureMapping - } - - private func updateCachedFeatureMapping(with features: [Entitlement.ProductName], for subscriptionIdentifier: String) { - var updatedFeatureMapping = cachedFeatureMapping ?? SubscriptionFeatureMapping() - updatedFeatureMapping[subscriptionIdentifier] = features - - self.cachedFeatureMapping = updatedFeatureMapping - self.storedFeatureMapping = updatedFeatureMapping - } - - // MARK: - Stored subscription feature mapping - - static private let subscriptionFeatureMappingKey = "com.duckduckgo.subscription.featuremapping" - - var storedFeatureMapping: SubscriptionFeatureMapping? { - get { - guard let data = userDefaults.data(forKey: Self.subscriptionFeatureMappingKey) else { return nil } - do { - return try JSONDecoder().decode(SubscriptionFeatureMapping?.self, from: data) - } catch { - assertionFailure("Errored while decoding feature mapping") - return nil - } - } - - set { - do { - let data = try JSONEncoder().encode(newValue) - userDefaults.set(data, forKey: Self.subscriptionFeatureMappingKey) - } catch { - assertionFailure("Errored while encoding feature mapping") - } - } - } - - // MARK: - Remote subscription feature mapping - - private func fetchRemoteFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName]? { - if case let .success(response) = await subscriptionEndpointService.getSubscriptionFeatures(for: subscriptionIdentifier) { - Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- Fetched features for `\(subscriptionIdentifier)`: \(response.features)") - return response.features - } - - return nil - } - - // MARK: - Fallback subscription feature mapping - - private let fallbackFeatures: [Entitlement.ProductName] = [.networkProtection, .dataBrokerProtection, .identityTheftRestoration] -} - -typealias SubscriptionFeatureMapping = [String: [Entitlement.ProductName]] diff --git a/Sources/SubscriptionTestingUtilities/APIs/APIServiceMock.swift b/Sources/SubscriptionTestingUtilities/APIs/APIServiceMock.swift deleted file mode 100644 index 97c923279..000000000 --- a/Sources/SubscriptionTestingUtilities/APIs/APIServiceMock.swift +++ /dev/null @@ -1,63 +0,0 @@ -// -// APIServiceMock.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Subscription - -public final class APIServiceMock: APIService { - public var mockAuthHeaders: [String: String] = [String: String]() - - public var mockResponseJSONData: Data? - public var mockAPICallSuccessResult: Any? - public var mockAPICallError: APIServiceError? - - public var onExecuteAPICall: ((ExecuteAPICallParameters) -> Void)? - - public typealias ExecuteAPICallParameters = (method: String, endpoint: String, headers: [String: String]?, body: Data?) - - public init() { } - - // swiftlint:disable force_cast - public func executeAPICall(method: String, endpoint: String, headers: [String: String]?, body: Data?) async -> Result where T: Decodable { - - onExecuteAPICall?(ExecuteAPICallParameters(method, endpoint, headers, body)) - - if let data = mockResponseJSONData { - let decoder = JSONDecoder() - decoder.keyDecodingStrategy = .convertFromSnakeCase - decoder.dateDecodingStrategy = .millisecondsSince1970 - - if let decodedResponse = try? decoder.decode(T.self, from: data) { - return .success(decodedResponse) - } else { - return .failure(.decodingError) - } - } else if let success = mockAPICallSuccessResult { - return .success(success as! T) - } else if let error = mockAPICallError { - return .failure(error) - } - - return .failure(.unknownServerError) - } - // swiftlint:enable force_cast - - public func makeAuthorizationHeader(for token: String) -> [String: String] { - return mockAuthHeaders - } -} diff --git a/Sources/SubscriptionTestingUtilities/APIs/AuthEndpointServiceMock.swift b/Sources/SubscriptionTestingUtilities/APIs/AuthEndpointServiceMock.swift deleted file mode 100644 index e36f32fee..000000000 --- a/Sources/SubscriptionTestingUtilities/APIs/AuthEndpointServiceMock.swift +++ /dev/null @@ -1,57 +0,0 @@ -// -// AuthEndpointServiceMock.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Subscription - -public final class AuthEndpointServiceMock: AuthEndpointService { - public var getAccessTokenResult: Result? - public var validateTokenResult: Result? - public var createAccountResult: Result? - public var storeLoginResult: Result? - - public var onValidateToken: ((String) -> Void)? - - public var getAccessTokenCalled: Bool = false - public var validateTokenCalled: Bool = false - public var createAccountCalled: Bool = false - public var storeLoginCalled: Bool = false - - public init() { } - - public func getAccessToken(token: String) async -> Result { - getAccessTokenCalled = true - return getAccessTokenResult! - } - - public func validateToken(accessToken: String) async -> Result { - validateTokenCalled = true - onValidateToken?(accessToken) - return validateTokenResult! - } - - public func createAccount(emailAccessToken: String?) async -> Result { - createAccountCalled = true - return createAccountResult! - } - - public func storeLogin(signature: String) async -> Result { - storeLoginCalled = true - return storeLoginResult! - } -} diff --git a/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift b/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift index b17d585d0..6cc072158 100644 --- a/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift +++ b/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift @@ -18,53 +18,68 @@ import Foundation import Subscription +import Networking public final class SubscriptionEndpointServiceMock: SubscriptionEndpointService { - public var getSubscriptionResult: Result? - public var getProductsResult: Result<[GetProductsItem], APIServiceError>? - public var getSubscriptionFeaturesResult: Result? - public var getCustomerPortalURLResult: Result? - public var confirmPurchaseResult: Result? - - public var onUpdateCache: ((Subscription) -> Void)? - public var onGetSubscription: ((String, APICachePolicy) -> Void)? public var onSignOut: (() -> Void)? - - public var updateCacheWithSubscriptionCalled: Bool = false - public var getSubscriptionCalled: Bool = false public var signOutCalled: Bool = false public init() { } - public func updateCache(with subscription: Subscription) { + public var updateCacheWithSubscriptionCalled: Bool = false + public var onUpdateCache: ((PrivacyProSubscription) -> Void)? + public func updateCache(with subscription: Subscription.PrivacyProSubscription) { onUpdateCache?(subscription) updateCacheWithSubscriptionCalled = true } - public func getSubscription(accessToken: String, cachePolicy: APICachePolicy) async -> Result { - getSubscriptionCalled = true - onGetSubscription?(accessToken, cachePolicy) - return getSubscriptionResult! + public func clearSubscription() {} + + public var getProductsResult: Result<[GetProductsItem], APIRequestV2.Error>? + public func getProducts() async throws -> [Subscription.GetProductsItem] { + switch getProductsResult! { + case .success(let result): return result + case .failure(let error): throw error + } } - public func signOut() { - signOutCalled = true - onSignOut?() + public var getSubscriptionCalled: Bool = false + public var onGetSubscription: ((String, SubscriptionCachePolicy) -> Void)? + public var getSubscriptionResult: Result? + public func getSubscription(accessToken: String, cachePolicy: Subscription.SubscriptionCachePolicy) async throws -> Subscription.PrivacyProSubscription { + getSubscriptionCalled = true + onGetSubscription?(accessToken, cachePolicy) + switch getSubscriptionResult! { + case .success(let subscription): return subscription + case .failure(let error): throw error + } } - public func getProducts() async -> Result<[GetProductsItem], APIServiceError> { - getProductsResult! + public var getCustomerPortalURLResult: Result? + public func getCustomerPortalURL(accessToken: String, externalID: String) async throws -> Subscription.GetCustomerPortalURLResponse { + switch getCustomerPortalURLResult! { + case .success(let result): return result + case .failure(let error): throw error + } } - public func getSubscriptionFeatures(for subscriptionID: String) async -> Result { - getSubscriptionFeaturesResult! + public var confirmPurchaseResult: Result? + public func confirmPurchase(accessToken: String, signature: String) async throws -> Subscription.ConfirmPurchaseResponse { + switch confirmPurchaseResult! { + case .success(let result): return result + case .failure(let error): throw error + } } - public func getCustomerPortalURL(accessToken: String, externalID: String) async -> Result { - getCustomerPortalURLResult! + public var getSubscriptionFeaturesResult: Result? + public func getSubscriptionFeatures(for subscriptionID: String) async throws -> Subscription.GetSubscriptionFeaturesResponse { + switch getSubscriptionFeaturesResult! { + case .success(let result): return result + case .failure(let error): throw error + } } - public func confirmPurchase(accessToken: String, signature: String) async -> Result { - confirmPurchaseResult! + public func ingestSubscription(_ subscription: Subscription.PrivacyProSubscription) async throws { + getSubscriptionResult = .success(subscription) } } diff --git a/Sources/SubscriptionTestingUtilities/AccountStorage/AccountManagerKeychainAccessDelegateMock.swift b/Sources/SubscriptionTestingUtilities/AccountStorage/AccountManagerKeychainAccessDelegateMock.swift deleted file mode 100644 index 0c15398b3..000000000 --- a/Sources/SubscriptionTestingUtilities/AccountStorage/AccountManagerKeychainAccessDelegateMock.swift +++ /dev/null @@ -1,33 +0,0 @@ -// -// AccountManagerKeychainAccessDelegateMock.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Subscription - -public final class AccountManagerKeychainAccessDelegateMock: AccountManagerKeychainAccessDelegate { - - public var onAccountManagerKeychainAccessFailed: ((AccountKeychainAccessType, AccountKeychainAccessError) -> Void)? - - public init(onAccountManagerKeychainAccessFailed: ( (AccountKeychainAccessType, AccountKeychainAccessError) -> Void)? = nil) { - self.onAccountManagerKeychainAccessFailed = onAccountManagerKeychainAccessFailed - } - - public func accountManagerKeychainAccessFailed(accessType: AccountKeychainAccessType, error: AccountKeychainAccessError) { - onAccountManagerKeychainAccessFailed?(accessType, error) - } -} diff --git a/Sources/SubscriptionTestingUtilities/Flows/AppStorePurchaseFlowMock.swift b/Sources/SubscriptionTestingUtilities/Flows/AppStorePurchaseFlowMock.swift index 493ec562f..91587e2dd 100644 --- a/Sources/SubscriptionTestingUtilities/Flows/AppStorePurchaseFlowMock.swift +++ b/Sources/SubscriptionTestingUtilities/Flows/AppStorePurchaseFlowMock.swift @@ -25,10 +25,11 @@ public final class AppStorePurchaseFlowMock: AppStorePurchaseFlow { public init() { } - public func purchaseSubscription(with subscriptionIdentifier: String, emailAccessToken: String?) async -> Result { + public func purchaseSubscription(with subscriptionIdentifier: String) async -> Result { purchaseSubscriptionResult! } + @discardableResult public func completeSubscriptionPurchase(with transactionJWS: TransactionJWS) async -> Result { completeSubscriptionPurchaseResult! } diff --git a/Sources/SubscriptionTestingUtilities/Flows/AppStoreRestoreFlowMock.swift b/Sources/SubscriptionTestingUtilities/Flows/AppStoreRestoreFlowMock.swift index 6daea9c44..99402c8be 100644 --- a/Sources/SubscriptionTestingUtilities/Flows/AppStoreRestoreFlowMock.swift +++ b/Sources/SubscriptionTestingUtilities/Flows/AppStoreRestoreFlowMock.swift @@ -20,12 +20,12 @@ import Foundation import Subscription public final class AppStoreRestoreFlowMock: AppStoreRestoreFlow { - public var restoreAccountFromPastPurchaseResult: Result? + public var restoreAccountFromPastPurchaseResult: Result? public var restoreAccountFromPastPurchaseCalled: Bool = false public init() { } - public func restoreAccountFromPastPurchase() async -> Result { + @discardableResult public func restoreAccountFromPastPurchase() async -> Result { restoreAccountFromPastPurchaseCalled = true return restoreAccountFromPastPurchaseResult! } diff --git a/Sources/SubscriptionTestingUtilities/Managers/AccountManagerMock.swift b/Sources/SubscriptionTestingUtilities/Managers/AccountManagerMock.swift deleted file mode 100644 index 2111700c8..000000000 --- a/Sources/SubscriptionTestingUtilities/Managers/AccountManagerMock.swift +++ /dev/null @@ -1,110 +0,0 @@ -// -// AccountManagerMock.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Subscription - -public final class AccountManagerMock: AccountManager { - public var delegate: AccountManagerKeychainAccessDelegate? - public var accessToken: String? - public var authToken: String? - public var email: String? - public var externalID: String? - - public var exchangeAuthTokenToAccessTokenResult: Result? - public var fetchAccountDetailsResult: Result? - - public var onStoreAuthToken: ((String) -> Void)? - public var onStoreAccount: ((String, String?, String?) -> Void)? - public var onFetchEntitlements: ((APICachePolicy) -> Void)? - public var onExchangeAuthTokenToAccessToken: ((String) -> Void)? - public var onFetchAccountDetails: ((String) -> Void)? - public var onCheckForEntitlements: ((Double, Int) -> Bool)? - - public var storeAuthTokenCalled: Bool = false - public var storeAccountCalled: Bool = false - public var signOutCalled: Bool = false - public var updateCacheWithEntitlementsCalled: Bool = false - public var fetchEntitlementsCalled: Bool = false - public var exchangeAuthTokenToAccessTokenCalled: Bool = false - public var fetchAccountDetailsCalled: Bool = false - public var checkForEntitlementsCalled: Bool = false - - public init() { } - - public func storeAuthToken(token: String) { - storeAuthTokenCalled = true - onStoreAuthToken?(token) - self.authToken = token - } - - public func storeAccount(token: String, email: String?, externalID: String?) { - storeAccountCalled = true - onStoreAccount?(token, email, externalID) - self.accessToken = token - self.email = email - self.externalID = externalID - } - - public func signOut(skipNotification: Bool) { - signOutCalled = true - self.authToken = nil - self.accessToken = nil - self.email = nil - self.externalID = nil - } - - public func signOut() { - signOutCalled = true - self.authToken = nil - self.accessToken = nil - self.email = nil - self.externalID = nil - } - - public func hasEntitlement(forProductName productName: Entitlement.ProductName, cachePolicy: APICachePolicy) async -> Result { - return .success(true) - } - - public func updateCache(with entitlements: [Entitlement]) { - updateCacheWithEntitlementsCalled = true - } - - public func fetchEntitlements(cachePolicy: APICachePolicy) async -> Result<[Entitlement], Error> { - fetchEntitlementsCalled = true - onFetchEntitlements?(cachePolicy) - return .success([]) - } - - public func exchangeAuthTokenToAccessToken(_ authToken: String) async -> Result { - exchangeAuthTokenToAccessTokenCalled = true - onExchangeAuthTokenToAccessToken?(authToken) - return exchangeAuthTokenToAccessTokenResult! - } - - public func fetchAccountDetails(with accessToken: String) async -> Result { - fetchAccountDetailsCalled = true - onFetchAccountDetails?(accessToken) - return fetchAccountDetailsResult! - } - - public func checkForEntitlements(wait waitTime: Double, retry retryCount: Int) async -> Bool { - checkForEntitlementsCalled = true - return onCheckForEntitlements!(waitTime, retryCount) - } -} diff --git a/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift b/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift index 3217eedbb..da52ba1dc 100644 --- a/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift +++ b/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift @@ -17,63 +17,157 @@ // import Foundation +@testable import Networking @testable import Subscription public final class SubscriptionManagerMock: SubscriptionManager { - public var accountManager: AccountManager - public var subscriptionEndpointService: SubscriptionEndpointService - public var authEndpointService: AuthEndpointService - public var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache + public init() {} - public static var storedEnvironment: SubscriptionEnvironment? - public static func loadEnvironmentFrom(userDefaults: UserDefaults) -> SubscriptionEnvironment? { - return storedEnvironment + public static var environment: Subscription.SubscriptionEnvironment? + public static func loadEnvironmentFrom(userDefaults: UserDefaults) -> Subscription.SubscriptionEnvironment? { + return environment } - public static func save(subscriptionEnvironment: SubscriptionEnvironment, userDefaults: UserDefaults) { - storedEnvironment = subscriptionEnvironment + public static func save(subscriptionEnvironment: Subscription.SubscriptionEnvironment, userDefaults: UserDefaults) { + environment = subscriptionEnvironment } - public var currentEnvironment: SubscriptionEnvironment - public var canPurchase: Bool + public var currentEnvironment: Subscription.SubscriptionEnvironment = .init(serviceEnvironment: .staging, purchasePlatform: .appStore) - public func storePurchaseManager() -> StorePurchaseManager { - internalStorePurchaseManager + public func loadInitialData() {} + + public func refreshCachedSubscription(completion: @escaping (Bool) -> Void) {} + + public var resultSubscription: Subscription.PrivacyProSubscription? + + public func getSubscriptionFrom(lastTransactionJWSRepresentation: String) async throws -> Subscription.PrivacyProSubscription? { + guard let resultSubscription else { + throw OAuthClientError.missingTokens + } + return resultSubscription + } + + public var canPurchase: Bool = true + + public var resultStorePurchaseManager: (any Subscription.StorePurchaseManager)? + public func storePurchaseManager() -> any Subscription.StorePurchaseManager { + return resultStorePurchaseManager! + } + + public var resultURL: URL! + public func url(for type: Subscription.SubscriptionURL) -> URL { + return resultURL + } + + public var customerPortalURL: URL? + public func getCustomerPortalURL() async throws -> URL { + guard let customerPortalURL else { + throw SubscriptionEndpointServiceError.noData + } + return customerPortalURL + } + + public var isUserAuthenticated: Bool { + resultTokenContainer != nil + } + + public var userEmail: String? { + resultTokenContainer?.decodedAccessToken.email + } + + public var resultTokenContainer: Networking.TokenContainer? + public var resultCreateAccountTokenContainer: Networking.TokenContainer? + public func getTokenContainer(policy: Networking.TokensCachePolicy) async throws -> Networking.TokenContainer { + switch policy { + case .local, .localValid, .localForceRefresh: + guard let resultTokenContainer else { + throw OAuthClientError.missingTokens + } + return resultTokenContainer + case .createIfNeeded: + guard let resultCreateAccountTokenContainer else { + throw OAuthClientError.missingTokens + } + resultTokenContainer = resultCreateAccountTokenContainer + return resultCreateAccountTokenContainer + } } - public func loadInitialData() { + public func getTokenContainerSynchronously(policy: Networking.TokensCachePolicy) -> Networking.TokenContainer? { + return resultTokenContainer + } + + public var resultExchangeTokenContainer: Networking.TokenContainer? + public func exchange(tokenV1: String) async throws -> Networking.TokenContainer { + guard let resultExchangeTokenContainer else { + throw OAuthClientError.missingTokens + } + resultTokenContainer = resultExchangeTokenContainer + return resultExchangeTokenContainer + } + + public func signOut(skipNotification: Bool) { + + } + + public func signOut() async { + resultTokenContainer = nil + } + public func removeTokenContainer() { + resultTokenContainer = nil } - public func refreshCachedSubscriptionAndEntitlements(completion: @escaping (Bool) -> Void) { - completion(true) + public func clearSubscriptionCache() { + + } + + public var confirmPurchaseResponse: Result? + public func confirmPurchase(signature: String) async throws -> Subscription.PrivacyProSubscription { + switch confirmPurchaseResponse! { + case .success(let result): + return result + case .failure(let error): + throw error + } + } + + public func refreshAccount() async {} + + public var confirmPurchaseError: Error? + public func confirmPurchase(signature: String) async throws { + if let confirmPurchaseError { + throw confirmPurchaseError + } } - public func url(for type: SubscriptionURL) -> URL { - type.subscriptionURL(environment: currentEnvironment.serviceEnvironment) + public func getSubscription(cachePolicy: Subscription.SubscriptionCachePolicy) async throws -> Subscription.PrivacyProSubscription { + guard let resultSubscription else { + throw SubscriptionEndpointServiceError.noData + } + return resultSubscription } - public func currentSubscriptionFeatures() async -> [Entitlement.ProductName] { - return [] + public var productsResponse: Result<[Subscription.GetProductsItem], Error>? + public func getProducts() async throws -> [Subscription.GetProductsItem] { + switch productsResponse! { + case .success(let result): + return result + case .failure(let error): + throw error + } } - public init(accountManager: AccountManager, - subscriptionEndpointService: SubscriptionEndpointService, - authEndpointService: AuthEndpointService, - storePurchaseManager: StorePurchaseManager, - currentEnvironment: SubscriptionEnvironment, - canPurchase: Bool, - subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache) { - self.accountManager = accountManager - self.subscriptionEndpointService = subscriptionEndpointService - self.authEndpointService = authEndpointService - self.internalStorePurchaseManager = storePurchaseManager - self.currentEnvironment = currentEnvironment - self.canPurchase = canPurchase - self.subscriptionFeatureMappingCache = subscriptionFeatureMappingCache + public func adopt(tokenContainer: Networking.TokenContainer) async throws { + self.resultTokenContainer = tokenContainer } - // MARK: - + public var resultFeatures: [Subscription.SubscriptionFeature] = [] + public func currentSubscriptionFeatures(forceRefresh: Bool) async -> [Subscription.SubscriptionFeature] { + resultFeatures + } - let internalStorePurchaseManager: StorePurchaseManager + public func isFeatureActive(_ entitlement: Networking.SubscriptionEntitlement) async -> Bool { + resultFeatures.contains { $0.entitlement == entitlement } + } } diff --git a/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift b/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift index b2a5b8133..8887b69aa 100644 --- a/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift +++ b/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift @@ -18,38 +18,13 @@ import Foundation import Common -import Subscription +@testable import Subscription +import TestUtils public final class SubscriptionCookieManagerMock: SubscriptionCookieManaging { public var lastRefreshDate: Date? - - public convenience init() { - let accountManager = AccountManagerMock() - let subscriptionService = DefaultSubscriptionEndpointService(currentServiceEnvironment: .production) - let authService = DefaultAuthEndpointService(currentServiceEnvironment: .production) - let storePurchaseManager = StorePurchaseManagerMock() - let subscriptionFeatureMappingCache = SubscriptionFeatureMappingCacheMock() - let subscriptionManager = SubscriptionManagerMock(accountManager: accountManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService, - storePurchaseManager: storePurchaseManager, - currentEnvironment: SubscriptionEnvironment(serviceEnvironment: .production, - purchasePlatform: .appStore), - canPurchase: true, - subscriptionFeatureMappingCache: subscriptionFeatureMappingCache) - - self.init(subscriptionManager: subscriptionManager, - currentCookieStore: { return nil }, - eventMapping: MockSubscriptionCookieManagerEventPixelMapping()) - } - - public init(subscriptionManager: SubscriptionManager, - currentCookieStore: @MainActor @escaping () -> HTTPCookieStore?, - eventMapping: EventMapping) { - - } - + public init() {} public func enableSettingSubscriptionCookie() { } public func disableSettingSubscriptionCookie() async { } public func refreshSubscriptionCookie() async { } diff --git a/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift b/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift index ef39c4d04..474c0a9c7 100644 --- a/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift +++ b/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift @@ -18,17 +18,18 @@ import Foundation import Subscription +import Networking public final class SubscriptionFeatureMappingCacheMock: SubscriptionFeatureMappingCache { public var didCallSubscriptionFeatures = false public var lastCalledSubscriptionId: String? - public var mapping: [String: [Entitlement.ProductName]] = [:] + public var mapping: [String: [SubscriptionEntitlement]] = [:] public init() { } - public func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName] { + public func subscriptionFeatures(for subscriptionIdentifier: String) async -> [SubscriptionEntitlement] { didCallSubscriptionFeatures = true lastCalledSubscriptionId = subscriptionIdentifier return mapping[subscriptionIdentifier] ?? [] diff --git a/Sources/SubscriptionTestingUtilities/SubscriptionMockFactory.swift b/Sources/SubscriptionTestingUtilities/SubscriptionMockFactory.swift index 99776fa1e..f508ee0f4 100644 --- a/Sources/SubscriptionTestingUtilities/SubscriptionMockFactory.swift +++ b/Sources/SubscriptionTestingUtilities/SubscriptionMockFactory.swift @@ -22,14 +22,14 @@ import Foundation /// Provides all mocks needed for testing subscription initialised with positive outcomes and basic configurations. All mocks can be partially reconfigured with failures or incorrect data public struct SubscriptionMockFactory { - public static let subscription = Subscription(productId: UUID().uuidString, + public static let subscription = PrivacyProSubscription(productId: UUID().uuidString, name: "Subscription test #1", billingPeriod: .monthly, startedAt: Date(), expiresOrRenewsAt: Date().addingTimeInterval(TimeInterval.days(+30)), platform: .apple, status: .autoRenewable) - public static let expiredSubscription = Subscription(productId: UUID().uuidString, + public static let expiredSubscription = PrivacyProSubscription(productId: UUID().uuidString, name: "Subscription test #2", billingPeriod: .monthly, startedAt: Date().addingTimeInterval(TimeInterval.days(-31)), @@ -37,7 +37,7 @@ public struct SubscriptionMockFactory { platform: .apple, status: .expired) - public static let expiredStripeSubscription = Subscription(productId: UUID().uuidString, + public static let expiredStripeSubscription = PrivacyProSubscription(productId: UUID().uuidString, name: "Subscription test #2", billingPeriod: .monthly, startedAt: Date().addingTimeInterval(TimeInterval.days(-31)), diff --git a/Sources/TestUtils/API/APIMockResponseFactory.swift b/Sources/TestUtils/API/APIMockResponseFactory.swift new file mode 100644 index 000000000..b7713420c --- /dev/null +++ b/Sources/TestUtils/API/APIMockResponseFactory.swift @@ -0,0 +1,152 @@ +// +// APIMockResponseFactory.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +@testable import Subscription +@testable import Networking +import Common + +public struct APIMockResponseFactory { + + static let authCookieHeaders = [ HTTPHeaderKey.setCookie: "ddg_auth_session_id=kADeCPMmCIHIV5uD6AFoB7Fk7pRiXFzlmQE4gW9r7FRKV8OGC1rRnZcTXoa7iIa8qgjiQCqZYq6Caww6k5HJl3; domain=duckduckgo.com; path=/api/auth/v2/; max-age=600; SameSite=Strict; secure; HttpOnly"] + + public static func mockAuthoriseResponse(destinationMockAPIService apiService: MockAPIService, success: Bool) { + let request = OAuthRequest.authorize(baseURL: OAuthEnvironment.staging.url, codeChallenge: "codeChallenge")! + if success { + let httpResponse = HTTPURLResponse(url: request.apiRequest.urlRequest.url!, + statusCode: request.httpSuccessCode.rawValue, + httpVersion: nil, + headerFields: authCookieHeaders)! + let response = APIResponseV2(data: nil, httpResponse: httpResponse) + apiService.set(response: response, forRequest: request.apiRequest) + } else { + let httpResponse = HTTPURLResponse(url: request.apiRequest.urlRequest.url!, + statusCode: request.httpErrorCodes.first!.rawValue, + httpVersion: nil, + headerFields: [:])! + let response = APIResponseV2(data: nil, httpResponse: httpResponse) + apiService.set(response: response, forRequest: request.apiRequest) + } + } + + public static func mockCreateAccountResponse(destinationMockAPIService apiService: MockAPIService, success: Bool) { + let request = OAuthRequest.createAccount(baseURL: OAuthEnvironment.staging.url, authSessionID: "someAuthSessionID")! + if success { + let httpResponse = HTTPURLResponse(url: request.apiRequest.urlRequest.url!, + statusCode: request.httpSuccessCode.rawValue, + httpVersion: nil, + headerFields: [HTTPHeaderKey.location: "com.duckduckgo:/authcb?code=NgNjnlLaqUomt9b5LDbzAtTyeW9cBNhCGtLB3vpcctluSZI51M9tb2ZDIZdijSPTYBr4w8dtVZl85zNSemxozv"])! + let response = APIResponseV2(data: nil, httpResponse: httpResponse) + apiService.set(response: response, forRequest: request.apiRequest) + } else { + assertionFailure("TODO: implement") + } + } + + public static func mockGetAccessTokenResponse(destinationMockAPIService apiService: MockAPIService, success: Bool) { + let request = OAuthRequest.getAccessToken(baseURL: OAuthEnvironment.staging.url, + clientID: "clientID", + codeVerifier: "codeVerifier", + code: "code", + redirectURI: "redirectURI")! + if success { + let jsonString = """ +{"access_token":"eyJraWQiOiIzODJiNzQ5Yy1hNTc3LTRkOTMtOTU0My04NTI5MWZiYTM3MmEiLCJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiJxWHk2TlRjeEI2UkQ0UUtSU05RYkNSM3ZxYU1SQU1RM1Q1UzVtTWdOWWtCOVZTVnR5SHdlb1R4bzcxVG1DYkJKZG1GWmlhUDVWbFVRQnd5V1dYMGNGUjo3ZjM4MTljZi0xNTBmLTRjYjEtOGNjNy1iNDkyMThiMDA2ZTgiLCJzY29wZSI6InByaXZhY3lwcm8iLCJhdWQiOiJQcml2YWN5UHJvIiwic3ViIjoiZTM3NmQ4YzQtY2FhOS00ZmNkLThlODYtMTlhNmQ2M2VlMzcxIiwiZXhwIjoxNzMwMzAxNTcyLCJlbWFpbCI6bnVsbCwiaWF0IjoxNzMwMjg3MTcyLCJpc3MiOiJodHRwczovL3F1YWNrZGV2LmR1Y2tkdWNrZ28uY29tIiwiZW50aXRsZW1lbnRzIjpbXSwiYXBpIjoidjIifQ.wOYgz02TXPJjDcEsp-889Xe1zh6qJG0P1UNHUnFBBELmiWGa91VQpqdl41EOOW3aE89KGvrD8YphRoZKiA3nHg", + "refresh_token":"eyJraWQiOiIzODJiNzQ5Yy1hNTc3LTRkOTMtOTU0My04NTI5MWZiYTM3MmEiLCJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJhcGkiOiJ2MiIsImlzcyI6Imh0dHBzOi8vcXVhY2tkZXYuZHVja2R1Y2tnby5jb20iLCJleHAiOjE3MzI4NzkxNzIsInN1YiI6ImUzNzZkOGM0LWNhYTktNGZjZC04ZTg2LTE5YTZkNjNlZTM3MSIsImF1ZCI6IkF1dGgiLCJpYXQiOjE3MzAyODcxNzIsInNjb3BlIjoicmVmcmVzaCIsImp0aSI6InFYeTZOVGN4QjZSRDRRS1JTTlFiQ1IzdnFhTVJBTVEzVDVTNW1NZ05Za0I5VlNWdHlId2VvVHhvNzFUbUNiQkpkbUZaaWFQNVZsVVFCd3lXV1gwY0ZSOmU2ODkwMDE5LWJmMDUtNGQxZC04OGFhLThlM2UyMDdjOGNkOSJ9.OQaGCmDBbDMM5XIpyY-WCmCLkZxt5Obp4YAmtFP8CerBSRexbUUp6SNwGDjlvCF0-an2REBsrX92ZmQe5ewqyQ","expires_in": 14400,"token_type": "Bearer"} +""" + let httpResponse = HTTPURLResponse(url: request.apiRequest.urlRequest.url!, + statusCode: request.httpSuccessCode.rawValue, + httpVersion: nil, + headerFields: [:])! + let response = APIResponseV2(data: jsonString.data(using: .utf8), httpResponse: httpResponse) + apiService.set(response: response, forRequest: request.apiRequest) + } else { + assertionFailure("TODO: implement") + } + } + + public static func mockGetJWKS(destinationMockAPIService apiService: MockAPIService, success: Bool) { + let request = OAuthRequest.jwks(baseURL: OAuthEnvironment.staging.url)! + if success { + let jsonString = """ +{"keys":[{"alg":"ES256","crv":"P-256","kid":"382b749c-a577-4d93-9543-85291fba372a","kty":"EC","ts":1727109704,"x":"e-WcWXtyf0mzVuc8lzAErb0EYq0kiOj7u8Ia4qsB4z4","y":"2WYzD5-POgIx2_3B_J6u84giGwSwgrYMTj83djMSWxM"},{"crv":"P-256","kid":"aa4c0019-9da9-4143-9866-3f7b54224a46","kty":"EC","ts":1722282670,"x":"kN2BXRyRbylNSaw3CrZKiKdATXjF1RIp2FpOxYMeuWg","y":"wovX-ifQuoKKAi-ZPYFcZ9YBhCxN_Fng3qKSW2wKpdg"}]} +""" + let httpResponse = HTTPURLResponse(url: request.apiRequest.urlRequest.url!, + statusCode: request.httpSuccessCode.rawValue, + httpVersion: nil, + headerFields: [:])! + let response = APIResponseV2(data: jsonString.data(using: .utf8), httpResponse: httpResponse) + apiService.set(response: response, forRequest: request.apiRequest) + } else { + assertionFailure("TODO: implement") + } + } + + public static func mockConfirmPurchase(destinationMockAPIService apiService: MockAPIService, success: Bool) { + let request = SubscriptionRequest.confirmPurchase(baseURL: SubscriptionEnvironment.ServiceEnvironment.staging.url, + accessToken: "somAccessToken", + signature: "someSignature")! + if success { + let jsonString = """ +{"email":"","entitlements":[{"product":"Data Broker Protection","name":"subscriber"},{"product":"Identity Theft Restoration","name":"subscriber"},{"product":"Network Protection","name":"subscriber"}],"subscription":{"productId":"ios.subscription.1month","name":"Monthly Subscription","billingPeriod":"Monthly","startedAt":1730991734000,"expiresOrRenewsAt":1730992034000,"platform":"apple","status":"Auto-Renewable"}} +""" + let httpResponse = HTTPURLResponse(url: request.apiRequest.urlRequest.url!, + statusCode: HTTPStatusCode.ok.rawValue, + httpVersion: nil, + headerFields: [:])! + let response = APIResponseV2(data: jsonString.data(using: .utf8), httpResponse: httpResponse) + apiService.set(response: response, forRequest: request.apiRequest) + } else { + assertionFailure("TODO: implement") + } + } + + public static func mockGetProducts(destinationMockAPIService apiService: MockAPIService, success: Bool) { + let request = SubscriptionRequest.getProducts(baseURL: SubscriptionEnvironment.ServiceEnvironment.staging.url)! + if success { + let jsonString = """ +[{"productId":"ddg-privacy-pro-sandbox-monthly-renews-us","productLabel":"Monthly Subscription","billingPeriod":"Monthly","price":"9.99","currency":"USD"},{"productId":"ddg-privacy-pro-sandbox-yearly-renews-us","productLabel":"Yearly Subscription","billingPeriod":"Yearly","price":"99.99","currency":"USD"}] +""" + let httpResponse = HTTPURLResponse(url: request.apiRequest.urlRequest.url!, + statusCode: HTTPStatusCode.ok.rawValue, + httpVersion: nil, + headerFields: [:])! + let response = APIResponseV2(data: jsonString.data(using: .utf8), httpResponse: httpResponse) + apiService.set(response: response, forRequest: request.apiRequest) + } else { + assertionFailure("TODO: implement") + } + } + + public static func mockGetFeatures(destinationMockAPIService apiService: MockAPIService, success: Bool, subscriptionID: String) { + let request = SubscriptionRequest.subscriptionFeatures(baseURL: SubscriptionEnvironment.ServiceEnvironment.staging.url, subscriptionID: subscriptionID)! + if success { + let jsonString = """ +{"features":["Data Broker Protection","Identity Theft Restoration","Network Protection"]} +""" + let httpResponse = HTTPURLResponse(url: request.apiRequest.urlRequest.url!, + statusCode: HTTPStatusCode.ok.rawValue, + httpVersion: nil, + headerFields: [:])! + let response = APIResponseV2(data: jsonString.data(using: .utf8), httpResponse: httpResponse) + apiService.set(response: response, forRequest: request.apiRequest) + } else { + assertionFailure("TODO: implement") + } + } +} diff --git a/Sources/TestUtils/API/MockAPIService.swift b/Sources/TestUtils/API/MockAPIService.swift new file mode 100644 index 000000000..dfd7a5316 --- /dev/null +++ b/Sources/TestUtils/API/MockAPIService.swift @@ -0,0 +1,66 @@ +// +// MockAPIService.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +@testable import Networking + +public class MockAPIService: APIService { + + public var authorizationRefresherCallback: AuthorizationRefresherCallback? + + /// Dictionary to store mocked responses for specific requests + private var mockResponses: [APIRequestV2: APIResponseV2] = [:] + /// Dictionary to store mocked responses for specific requests by URL + private var mockResponsesByURL: [URL: APIResponseV2] = [:] + /// Request handler + public var requestHandler: ((APIRequestV2) -> Result)? + + public init(requestHandler: ((APIRequestV2) -> Result)? = nil) { + self.requestHandler = requestHandler + } + + public func set(response: APIResponseV2, forRequest request: APIRequestV2) { + mockResponses[request] = response + } + + public func set(response: APIResponseV2, forRequestURL url: URL) { + mockResponsesByURL[url] = response + } + + // Function to fetch response for a given request + public func fetch(request: APIRequestV2) async throws -> APIResponseV2 { + if let requestHandler { + switch requestHandler(request) { + case .success(let result): + return result + case .failure(let error): + throw error + } + } else if let response = mockResponses[request] { + return response + } else { + return mockResponsesByURL[request.urlRequest.url!]! // Intentionally crash if the mock is not available + } + } +} + +public extension APIRequestV2 { + var host: String { + return urlRequest.url!.host! + } +} diff --git a/Sources/TestUtils/MockKeyValueStore.swift b/Sources/TestUtils/MockKeyValueStore.swift index b13963eba..ea4664422 100644 --- a/Sources/TestUtils/MockKeyValueStore.swift +++ b/Sources/TestUtils/MockKeyValueStore.swift @@ -40,7 +40,6 @@ public class MockKeyValueStore: KeyValueStoring { public func clearAll() { store.removeAll() } - } extension MockKeyValueStore: DictionaryRepresentable { diff --git a/Tests/NetworkProtectionTests/Mocks/NetworkProtectionTokenStoreMocks.swift b/Sources/TestUtils/MockLegacyTokenStorage.swift similarity index 57% rename from Tests/NetworkProtectionTests/Mocks/NetworkProtectionTokenStoreMocks.swift rename to Sources/TestUtils/MockLegacyTokenStorage.swift index 4e1228682..e5bb9ae64 100644 --- a/Tests/NetworkProtectionTests/Mocks/NetworkProtectionTokenStoreMocks.swift +++ b/Sources/TestUtils/MockLegacyTokenStorage.swift @@ -1,7 +1,7 @@ // -// NetworkProtectionTokenStoreMocks.swift +// MockLegacyTokenStorage.swift // -// Copyright © 2021 DuckDuckGo. All rights reserved. +// Copyright © 2024 DuckDuckGo. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,25 +17,13 @@ // import Foundation -@testable import NetworkProtection +import Networking -final class NetworkProtectionTokenStoreMock: NetworkProtectionTokenStore { +public class MockLegacyTokenStorage: LegacyTokenStoring { - var token: String? - - func store(_ token: String) { + public init(token: String? = nil) { self.token = token } - func fetchToken() -> String? { - token - } - - func deleteToken() { - self.token = nil - } - - func fetchSubscriptionToken() throws -> String? { - "ddg:accessToken" - } + public var token: String? } diff --git a/Sources/TestUtils/MockOAuthClient.swift b/Sources/TestUtils/MockOAuthClient.swift new file mode 100644 index 000000000..363b1a542 --- /dev/null +++ b/Sources/TestUtils/MockOAuthClient.swift @@ -0,0 +1,139 @@ +// +// MockOAuthClient.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import Networking + +public class MockOAuthClient: OAuthClient { + + public init() {} + public var isUserAuthenticated: Bool = false + public var currentTokenContainer: Networking.TokenContainer? + + func missingResponseError(request: String) -> Error { + return Networking.OAuthClientError.internalError("Missing mocked response for \(request)") + } + + public var getTokensResponse: Result! + public func getTokens(policy: Networking.TokensCachePolicy) async throws -> Networking.TokenContainer { + switch getTokensResponse { + case .success(let success): + return success + case .failure(let failure): + throw failure + case .none: + throw missingResponseError(request: #function) + } + } + + public var createAccountResponse: Result! + public func createAccount() async throws -> Networking.TokenContainer { + switch createAccountResponse { + case .success(let success): + return success + case .failure(let failure): + throw failure + case .none: + throw missingResponseError(request: #function) + } + } + + public var requestOTPResponse: Result<(authSessionID: String, codeVerifier: String), Error>! + public func requestOTP(email: String) async throws -> (authSessionID: String, codeVerifier: String) { + switch requestOTPResponse { + case .success(let success): + return success + case .failure(let failure): + throw failure + case .none: + throw missingResponseError(request: #function) + } + } + + public var activateWithOTPError: Error? + public func activate(withOTP otp: String, email: String, codeVerifier: String, authSessionID: String) async throws { + if let activateWithOTPError { + throw activateWithOTPError + } + } + + public var activateWithPlatformSignatureResponse: Result! + public func activate(withPlatformSignature signature: String) async throws -> Networking.TokenContainer { + switch activateWithPlatformSignatureResponse { + case .success(let success): + return success + case .failure(let failure): + throw failure + case .none: + throw missingResponseError(request: #function) + } + } + + public var refreshTokensResponse: Result! + public func refreshTokens() async throws -> Networking.TokenContainer { + switch refreshTokensResponse { + case .success(let success): + return success + case .failure(let failure): + throw failure + case .none: + throw missingResponseError(request: #function) + } + } + + public var exchangeAccessTokenV1Response: Result! + public func exchange(accessTokenV1: String) async throws -> Networking.TokenContainer { + switch exchangeAccessTokenV1Response { + case .success(let success): + return success + case .failure(let failure): + throw failure + case .none: + throw missingResponseError(request: #function) + } + } + + public var logoutError: Error? + public func logout() async throws { + if let logoutError { + throw logoutError + } + } + + public func removeLocalAccount() {} + + public var changeAccountEmailResponse: Result! + public func changeAccount(email: String?) async throws -> String { + switch changeAccountEmailResponse { + case .success(let success): + return success + case .failure(let failure): + throw failure + case .none: + throw missingResponseError(request: #function) + } + } + + public var confirmChangeAccountEmailError: Error? + public func confirmChangeAccount(email: String, otp: String, hash: String) async throws { + if let confirmChangeAccountEmailError { + throw confirmChangeAccountEmailError + } + } + +} diff --git a/Sources/TestUtils/MockOAuthService.swift b/Sources/TestUtils/MockOAuthService.swift new file mode 100644 index 000000000..a14298960 --- /dev/null +++ b/Sources/TestUtils/MockOAuthService.swift @@ -0,0 +1,103 @@ +// +// MockOAuthService.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import Networking +import JWTKit + +public final class MockOAuthService: OAuthService { + + public init() {} + + public var authorizeResponse: Result? + public func authorize(codeChallenge: String) async throws -> Networking.OAuthSessionID { + switch authorizeResponse! { + case .success(let result): + return result + case .failure(let error): + throw error + } + } + + public var createAccountResponse: Result? + public func createAccount(authSessionID: String) async throws -> Networking.AuthorisationCode { + switch createAccountResponse! { + case .success(let result): + return result + case .failure(let error): + throw error + } + } + + public var loginWithSignatureResponse: Result? + public func login(withSignature signature: String, authSessionID: String) async throws -> Networking.AuthorisationCode { + switch loginWithSignatureResponse! { + case .success(let result): + return result + case .failure(let error): + throw error + } + } + + public var getAccessTokenResponse: Result? + public func getAccessToken(clientID: String, codeVerifier: String, code: String, redirectURI: String) async throws -> Networking.OAuthTokenResponse { + switch getAccessTokenResponse! { + case .success(let result): + return result + case .failure(let error): + throw error + } + } + + public var refreshAccessTokenResponse: Result? + public func refreshAccessToken(clientID: String, refreshToken: String) async throws -> Networking.OAuthTokenResponse { + switch refreshAccessTokenResponse! { + case .success(let result): + return result + case .failure(let error): + throw error + } + } + + public var logoutError: Error? + public func logout(accessToken: String) async throws { + if let logoutError { + throw logoutError + } + } + + public var exchangeTokenResponse: Result? + public func exchangeToken(accessTokenV1: String, authSessionID: String) async throws -> Networking.AuthorisationCode { + switch exchangeTokenResponse! { + case .success(let result): + return result + case .failure(let error): + throw error + } + } + + public var getJWTSignersResponse: Result? + public func getJWTSigners() async throws -> JWTKit.JWTSigners { + switch getJWTSignersResponse! { + case .success(let result): + return result + case .failure(let error): + throw error + } + } +} diff --git a/Sources/TestUtils/MockAPIService.swift b/Sources/TestUtils/MockTokenStorage.swift similarity index 55% rename from Sources/TestUtils/MockAPIService.swift rename to Sources/TestUtils/MockTokenStorage.swift index f4d35b4b6..58efde776 100644 --- a/Sources/TestUtils/MockAPIService.swift +++ b/Sources/TestUtils/MockTokenStorage.swift @@ -1,5 +1,5 @@ // -// MockAPIService.swift +// MockTokenStorage.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -19,20 +19,11 @@ import Foundation import Networking -public class MockAPIService: APIService { +public class MockTokenStorage: TokenStoring { - public var requestHandler: ((APIRequestV2) -> Result)! - - public init(requestHandler: ((APIRequestV2) -> Result)? = nil) { - self.requestHandler = requestHandler + public init(tokenContainer: Networking.TokenContainer? = nil) { + self.tokenContainer = tokenContainer } - public func fetch(request: APIRequestV2) async throws -> APIResponseV2 { - switch requestHandler!(request) { - case .success(let result): - return result - case .failure(let error): - throw error - } - } + public var tokenContainer: Networking.TokenContainer? } diff --git a/Sources/TestUtils/OAuthTokensFactory.swift b/Sources/TestUtils/OAuthTokensFactory.swift new file mode 100644 index 000000000..8fcf4dc40 --- /dev/null +++ b/Sources/TestUtils/OAuthTokensFactory.swift @@ -0,0 +1,134 @@ +// +// OAuthTokensFactory.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +@testable import Networking +@testable import JWTKit + +public struct OAuthTokensFactory { + + // Helper function to create an expired JWTAccessToken + public static func makeExpiredAccessToken() -> JWTAccessToken { + return JWTAccessToken( + exp: ExpirationClaim(value: Date().addingTimeInterval(-3600)), // Expired 1 hour ago + iat: IssuedAtClaim(value: Date().addingTimeInterval(-7200)), + sub: SubjectClaim(value: "test-subject"), + aud: AudienceClaim(value: ["test-audience"]), + iss: IssuerClaim(value: "test-issuer"), + jti: IDClaim(value: "test-id"), + scope: "privacypro", + api: "v2", + email: "test@example.com", + entitlements: [] + ) + } + + // Helper function to create a valid JWTAccessToken with customizable scope + public static func makeAccessToken(scope: String, email: String = "test@example.com") -> JWTAccessToken { + return JWTAccessToken( + exp: ExpirationClaim(value: Date().addingTimeInterval(3600)), // 1 hour from now + iat: IssuedAtClaim(value: Date()), + sub: SubjectClaim(value: "test-subject"), + aud: AudienceClaim(value: ["test-audience"]), + iss: IssuerClaim(value: "test-issuer"), + jti: IDClaim(value: "test-id"), + scope: scope, + api: "v2", + email: email, + entitlements: [] + ) + } + + // Helper function to create a valid JWTRefreshToken with customizable scope + public static func makeRefreshToken(scope: String) -> JWTRefreshToken { + return JWTRefreshToken( + exp: ExpirationClaim(value: Date().addingTimeInterval(3600)), + iat: IssuedAtClaim(value: Date()), + sub: SubjectClaim(value: "test-subject"), + aud: AudienceClaim(value: ["test-audience"]), + iss: IssuerClaim(value: "test-issuer"), + jti: IDClaim(value: "test-id"), + scope: scope, + api: "v2" + ) + } + + public static func makeValidTokenContainer() -> TokenContainer { + return TokenContainer(accessToken: "accessToken", + refreshToken: "refreshToken", + decodedAccessToken: OAuthTokensFactory.makeAccessToken(scope: "privacypro"), + decodedRefreshToken: OAuthTokensFactory.makeRefreshToken(scope: "refresh")) + } + + public static func makeValidTokenContainerWithEntitlements() -> TokenContainer { + return TokenContainer(accessToken: "accessToken", + refreshToken: "refreshToken", + decodedAccessToken: JWTAccessToken.mock, + decodedRefreshToken: JWTRefreshToken.mock) + } + + public static func makeExpiredTokenContainer() -> TokenContainer { + return TokenContainer(accessToken: "accessToken", + refreshToken: "refreshToken", + decodedAccessToken: OAuthTokensFactory.makeExpiredAccessToken(), + decodedRefreshToken: OAuthTokensFactory.makeRefreshToken(scope: "refresh")) + } + + public static func makeExpiredOAuthTokenResponse() -> OAuthTokenResponse { + return OAuthTokenResponse(accessToken: "eyJraWQiOiIzODJiNzQ5Yy1hNTc3LTRkOTMtOTU0My04NTI5MWZiYTM3MmEiLCJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiJxWHk2TlRjeEI2UkQ0UUtSU05RYkNSM3ZxYU1SQU1RM1Q1UzVtTWdOWWtCOVZTVnR5SHdlb1R4bzcxVG1DYkJKZG1GWmlhUDVWbFVRQnd5V1dYMGNGUjo3ZjM4MTljZi0xNTBmLTRjYjEtOGNjNy1iNDkyMThiMDA2ZTgiLCJzY29wZSI6InByaXZhY3lwcm8iLCJhdWQiOiJQcml2YWN5UHJvIiwic3ViIjoiZTM3NmQ4YzQtY2FhOS00ZmNkLThlODYtMTlhNmQ2M2VlMzcxIiwiZXhwIjoxNzMwMzAxNTcyLCJlbWFpbCI6bnVsbCwiaWF0IjoxNzMwMjg3MTcyLCJpc3MiOiJodHRwczovL3F1YWNrZGV2LmR1Y2tkdWNrZ28uY29tIiwiZW50aXRsZW1lbnRzIjpbXSwiYXBpIjoidjIifQ.wOYgz02TXPJjDcEsp-889Xe1zh6qJG0P1UNHUnFBBELmiWGa91VQpqdl41EOOW3aE89KGvrD8YphRoZKiA3nHg", + refreshToken: "eyJraWQiOiIzODJiNzQ5Yy1hNTc3LTRkOTMtOTU0My04NTI5MWZiYTM3MmEiLCJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJhcGkiOiJ2MiIsImlzcyI6Imh0dHBzOi8vcXVhY2tkZXYuZHVja2R1Y2tnby5jb20iLCJleHAiOjE3MzI4NzkxNzIsInN1YiI6ImUzNzZkOGM0LWNhYTktNGZjZC04ZTg2LTE5YTZkNjNlZTM3MSIsImF1ZCI6IkF1dGgiLCJpYXQiOjE3MzAyODcxNzIsInNjb3BlIjoicmVmcmVzaCIsImp0aSI6InFYeTZOVGN4QjZSRDRRS1JTTlFiQ1IzdnFhTVJBTVEzVDVTNW1NZ05Za0I5VlNWdHlId2VvVHhvNzFUbUNiQkpkbUZaaWFQNVZsVVFCd3lXV1gwY0ZSOmU2ODkwMDE5LWJmMDUtNGQxZC04OGFhLThlM2UyMDdjOGNkOSJ9.OQaGCmDBbDMM5XIpyY-WCmCLkZxt5Obp4YAmtFP8CerBSRexbUUp6SNwGDjlvCF0-an2REBsrX92ZmQe5ewqyQ") + } + + public static func makeValidOAuthTokenResponse() -> OAuthTokenResponse { + return OAuthTokenResponse(accessToken: "**validaccesstoken**", refreshToken: "**validrefreshtoken**") + } +} + +public extension JWTAccessToken { + + static var mock: Self { + let now = Date() + return JWTAccessToken(exp: ExpirationClaim(value: now.addingTimeInterval(3600)), + iat: IssuedAtClaim(value: now), + sub: SubjectClaim(value: "test-subject"), + aud: AudienceClaim(value: ["PrivacyPro"]), + iss: IssuerClaim(value: "test-issuer"), + jti: IDClaim(value: "test-id"), + scope: "privacypro", + api: "v2", + email: nil, + entitlements: [EntitlementPayload(product: .networkProtection, name: "subscriber"), + EntitlementPayload(product: .dataBrokerProtection, name: "subscriber"), + EntitlementPayload(product: .identityTheftRestoration, name: "subscriber")]) + } +} + +public extension JWTRefreshToken { + + static var mock: Self { + let now = Date() + return JWTRefreshToken(exp: ExpirationClaim(value: now.addingTimeInterval(3600)), + iat: IssuedAtClaim(value: now), + sub: SubjectClaim(value: "test-subject"), + aud: AudienceClaim(value: ["PrivacyPro"]), + iss: IssuerClaim(value: "test-issuer"), + jti: IDClaim(value: "test-id"), + scope: "privacypro", + api: "v2") + } +} diff --git a/Sources/UserScript/UserScriptMessaging.swift b/Sources/UserScript/UserScriptMessaging.swift index 4eb7dcade..2a08bf164 100644 --- a/Sources/UserScript/UserScriptMessaging.swift +++ b/Sources/UserScript/UserScriptMessaging.swift @@ -218,7 +218,7 @@ public final class UserScriptMessageBroker: NSObject { /// As far as the client is concerned, a `notification` is fire-and-forget case .notify(let handler, let notification): do { - _=try await handler(notification.params, original) + _ = try await handler(notification.params, original) } catch { Logger.general.error("UserScriptMessaging: unhandled exception \(error.localizedDescription, privacy: .public)") } diff --git a/Tests/BrowserServicesKit-Package.xctestplan b/Tests/BrowserServicesKit-Package.xctestplan new file mode 100644 index 000000000..14179517b --- /dev/null +++ b/Tests/BrowserServicesKit-Package.xctestplan @@ -0,0 +1,213 @@ +{ + "configurations" : [ + { + "id" : "CEDD46E5-DAEC-407E-B790-8A23D5B18D80", + "name" : "Configuration 1", + "options" : { + + } + } + ], + "defaultOptions" : { + + }, + "testTargets" : [ + { + "target" : { + "containerPath" : "container:", + "identifier" : "ConfigurationTests", + "name" : "ConfigurationTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PageRefreshMonitorTests", + "name" : "PageRefreshMonitorTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "BookmarksTests", + "name" : "BookmarksTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "BrokenSitePromptTests", + "name" : "BrokenSitePromptTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "NetworkProtectionTests", + "name" : "NetworkProtectionTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "DDGSyncCryptoTests", + "name" : "DDGSyncCryptoTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "NavigationTests", + "name" : "NavigationTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "RemoteMessagingTests", + "name" : "RemoteMessagingTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "CrashesTests", + "name" : "CrashesTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "DDGSyncTests", + "name" : "DDGSyncTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "CommonTests", + "name" : "CommonTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SecureStorageTests", + "name" : "SecureStorageTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SyncDataProvidersTests", + "name" : "SyncDataProvidersTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "DuckPlayerTests", + "name" : "DuckPlayerTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "MaliciousSiteProtectionTests", + "name" : "MaliciousSiteProtectionTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SubscriptionTests", + "name" : "SubscriptionTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "NetworkingTests", + "name" : "NetworkingTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PrivacyStatsTests", + "name" : "PrivacyStatsTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PixelExperimentKitTests", + "name" : "PixelExperimentKitTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PersistenceTests", + "name" : "PersistenceTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "HistoryTests", + "name" : "HistoryTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "UserScriptTests", + "name" : "UserScriptTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PrivacyDashboardTests", + "name" : "PrivacyDashboardTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PixelKitTests", + "name" : "PixelKitTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "BrowserServicesKitTests", + "name" : "BrowserServicesKitTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "OnboardingTests", + "name" : "OnboardingTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SuggestionsTests", + "name" : "SuggestionsTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SpecialErrorPagesTests", + "name" : "SpecialErrorPagesTests" + } + } + ], + "version" : 1 +} diff --git a/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift b/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift index 9388b14e9..ae03c6bf9 100644 --- a/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift +++ b/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift @@ -434,31 +434,38 @@ class AppPrivacyConfigurationTests: XCTestCase { // When valid number of installed days (less than or equal to 21): // 0 days - let installDate0DaysAgo = Date().addingTimeInterval(-60 * 60 * 24 * 0) + let installDate0DaysAgo = Date() config = createPrivacyConfigWithInstallDate(mockEmbeddedData, mockProtectionStore, installDate: installDate0DaysAgo) XCTAssertTrue(config.isEnabled(featureKey: .incontextSignup, versionProvider: appVersion)) // 1 day - let installDate1DayAgo = Date().addingTimeInterval(-60 * 60 * 24 * 1) + let installDate1DayAgo = Date().addingTimeInterval(TimeInterval.days(-1)) config = createPrivacyConfigWithInstallDate(mockEmbeddedData, mockProtectionStore, installDate: installDate1DayAgo) XCTAssertTrue(config.isEnabled(featureKey: .incontextSignup, versionProvider: appVersion)) // 20 days (1 day less than config) - let installDate20DaysAgo = Date().addingTimeInterval(-60 * 60 * 24 * 20) + let installDate20DaysAgo = Date().addingTimeInterval(TimeInterval.days(-20)) config = createPrivacyConfigWithInstallDate(mockEmbeddedData, mockProtectionStore, installDate: installDate20DaysAgo) XCTAssertTrue(config.isEnabled(featureKey: .incontextSignup, versionProvider: appVersion)) // 21 days (same as config) - let installDate21DaysAgo = Date().addingTimeInterval(-60 * 60 * 24 * 21) + let installDate21DaysAgo = Date().addingTimeInterval(TimeInterval.days(-21)) config = createPrivacyConfigWithInstallDate(mockEmbeddedData, mockProtectionStore, installDate: installDate21DaysAgo) XCTAssertTrue(config.isEnabled(featureKey: .incontextSignup, versionProvider: appVersion)) // When invalid number of installed days (> 21 days): - // 22 days (1 day more than config) - let installDate22DaysAgo = Date().addingTimeInterval(-60 * 60 * 24 * 22) - config = createPrivacyConfigWithInstallDate(mockEmbeddedData, mockProtectionStore, installDate: installDate22DaysAgo) +// // 22 days (1 day more than config) ! not working in different timezones + may have some issues with daytime saving +// let installDate22DaysAgo = Date().addingTimeInterval(TimeInterval.days(-22)) +// config = createPrivacyConfigWithInstallDate(mockEmbeddedData, mockProtectionStore, installDate: installDate22DaysAgo) +// XCTAssertFalse(config.isEnabled(featureKey: .incontextSignup, versionProvider: appVersion)) +// XCTAssertEqual(config.stateFor(featureKey: .incontextSignup, versionProvider: appVersion), .disabled(.tooOldInstallation), "22 days ago should be too old") + + // 23 days (1 day more than config) + let installDate23DaysAgo = Date().addingTimeInterval(TimeInterval.days(-23)) + config = createPrivacyConfigWithInstallDate(mockEmbeddedData, mockProtectionStore, installDate: installDate23DaysAgo) XCTAssertFalse(config.isEnabled(featureKey: .incontextSignup, versionProvider: appVersion)) - XCTAssertEqual(config.stateFor(featureKey: .incontextSignup, versionProvider: appVersion), .disabled(.tooOldInstallation)) + XCTAssertEqual(config.stateFor(featureKey: .incontextSignup, versionProvider: appVersion), .disabled(.tooOldInstallation), "23 days ago should be too old") + // 444 days (many days more than config) - let installDate444DaysAgo = Date().addingTimeInterval(-60 * 60 * 24 * 444) + let installDate444DaysAgo = Date().addingTimeInterval(TimeInterval.days(-444)) config = createPrivacyConfigWithInstallDate(mockEmbeddedData, mockProtectionStore, installDate: installDate444DaysAgo) XCTAssertFalse(config.isEnabled(featureKey: .incontextSignup, versionProvider: appVersion)) XCTAssertEqual(config.stateFor(featureKey: .incontextSignup, versionProvider: appVersion), .disabled(.tooOldInstallation)) diff --git a/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests b/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests index 6133e7d9d..a603ff9af 160000 --- a/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests +++ b/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests @@ -1 +1 @@ -Subproject commit 6133e7d9d9cd5f1b925cab1971b4d785dc639df7 +Subproject commit a603ff9af22ca3ff7ce2e7ffbfe18c447d9f23e8 diff --git a/Tests/CommonTests/DecodableHelperTests.swift b/Tests/CommonTests/DecodableHelperTests.swift index b17f5a21d..0b6fdb637 100644 --- a/Tests/CommonTests/DecodableHelperTests.swift +++ b/Tests/CommonTests/DecodableHelperTests.swift @@ -26,19 +26,19 @@ final class DecodableHelperTests: XCTestCase { func testWhenDecodingDictionary_ThenValueIsReturned() { let dictionary = ["name": "dax"] - let person: Person? = DecodableHelper.decode(from: dictionary) + let person: Person? = CodableHelper.decode(from: dictionary) XCTAssertEqual("dax", person?.name) } func testWhenDecodingAny_ThenValueIsReturned() { let data = ["name": "dax"] as Any - let person: Person? = DecodableHelper.decode(from: data) + let person: Person? = CodableHelper.decode(from: data) XCTAssertEqual("dax", person?.name) } func testWhenDecodingFails_ThenNilIsReturned() { let data = ["oops_name": "dax"] as Any - let person: Person? = DecodableHelper.decode(from: data) + let person: Person? = CodableHelper.decode(from: data) XCTAssertNil(person) } } diff --git a/Tests/CommonTests/Extensions/DateExtensionTest.swift b/Tests/CommonTests/Extensions/DateExtensionTest.swift new file mode 100644 index 000000000..5c6b00601 --- /dev/null +++ b/Tests/CommonTests/Extensions/DateExtensionTest.swift @@ -0,0 +1,213 @@ +// +// DateExtensionTest.swift +// +// Copyright © 2022 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +@testable import Common + +final class DateExtensionTests: XCTestCase { + + func testComponents() { + let date = Date() + let components = date.components + + XCTAssertNotNil(components.day) + XCTAssertNotNil(components.month) + XCTAssertNotNil(components.year) + } + + func testWeekAgo() { + let weekAgo = Date.weekAgo + let expectedDate = Calendar.current.date(byAdding: .weekOfMonth, value: -1, to: Date())! + + XCTAssertEqual(weekAgo.startOfDay, expectedDate.startOfDay) + } + + func testMonthAgo() { + let monthAgo = Date.monthAgo + let expectedDate = Calendar.current.date(byAdding: .month, value: -1, to: Date())! + + XCTAssertEqual(monthAgo.startOfDay, expectedDate.startOfDay) + } + + func testYearAgo() { + let yearAgo = Date.yearAgo + let expectedDate = Calendar.current.date(byAdding: .year, value: -1, to: Date())! + + XCTAssertEqual(yearAgo.startOfDay, expectedDate.startOfDay) + } + + func testAYearFromNow() { + let aYearFromNow = Date.aYearFromNow + let expectedDate = Calendar.current.date(byAdding: .year, value: 1, to: Date())! + + XCTAssertEqual(aYearFromNow.startOfDay, expectedDate.startOfDay) + } + + func testDaysAgo() { + let daysAgo = Date.daysAgo(5) + let expectedDate = Calendar.current.date(byAdding: .day, value: -5, to: Date())! + + XCTAssertEqual(daysAgo.startOfDay, expectedDate.startOfDay) + } + + func testIsSameDay() { + let today = Date() + let sameDay = today + let differentDay = Calendar.current.date(byAdding: .day, value: -1, to: today)! + + XCTAssertTrue(Date.isSameDay(today, sameDay)) + XCTAssertFalse(Date.isSameDay(today, differentDay)) + XCTAssertFalse(Date.isSameDay(today, nil)) + } + + func testStartOfDayTomorrow() { + let startOfDayTomorrow = Date.startOfDayTomorrow + let tomorrow = Calendar.current.date(byAdding: .day, value: 1, to: Date())! + + XCTAssertEqual(startOfDayTomorrow, Calendar.current.startOfDay(for: tomorrow)) + } + + func testStartOfDayToday() { + let startOfDayToday = Date.startOfDayToday + XCTAssertEqual(startOfDayToday, Calendar.current.startOfDay(for: Date())) + } + + func testStartOfDay() { + let date = Date() + let startOfDay = date.startOfDay + + XCTAssertEqual(startOfDay, Calendar.current.startOfDay(for: date)) + } + + func testDaysAgoInstanceMethod() { + let date = Date() + let daysAgo = date.daysAgo(3) + let expectedDate = Calendar.current.date(byAdding: .day, value: -3, to: date)! + + XCTAssertEqual(daysAgo.startOfDay, expectedDate.startOfDay) + } + + func testStartOfMinuteNow() { + let startOfMinuteNow = Date.startOfMinuteNow + let now = Calendar.current.date(bySetting: .second, value: 0, of: Date())! + let expectedStart = Calendar.current.date(byAdding: .minute, value: -1, to: now)! + + XCTAssertEqual(startOfMinuteNow, expectedStart) + } + + func testMonthsWithIndex() { + let monthsWithIndex = Date.monthsWithIndex + let monthSymbols = Calendar.current.monthSymbols + + XCTAssertEqual(monthsWithIndex.count, 12) + XCTAssertEqual(monthsWithIndex.first?.name, monthSymbols.first) + XCTAssertEqual(monthsWithIndex.first?.index, 1) + } + + func testDaysInMonth() { + XCTAssertEqual(Date.daysInMonth, Array(1...31)) + } + + func testNextTenYears() { + let nextTenYears = Date.nextTenYears + let currentYear = Calendar.current.component(.year, from: Date()) + + XCTAssertEqual(nextTenYears.count, 11) + XCTAssertEqual(nextTenYears.first, currentYear) + XCTAssertEqual(nextTenYears.last, currentYear + 10) + } + + func testLastHundredYears() { + let lastHundredYears = Date.lastHundredYears + let currentYear = Calendar.current.component(.year, from: Date()) + + XCTAssertEqual(lastHundredYears.count, 101) + XCTAssertEqual(lastHundredYears.first, currentYear) + XCTAssertEqual(lastHundredYears.last, currentYear - 100) + } + + func testDaySinceReferenceDate() { + let date = Date() + let daysSinceReference = Int(date.timeIntervalSinceReferenceDate / TimeInterval.day) + + XCTAssertEqual(date.daySinceReferenceDate, daysSinceReference) + } + + func testAdding() { + let date = Date() + let addedDate = date.adding(60) + + XCTAssertEqual(addedDate.timeIntervalSince(date), 60) + } + + func testIsSameDayInstanceMethod() { + let today = Date() + let sameDay = today + let differentDay = Calendar.current.date(byAdding: .day, value: -1, to: today)! + + XCTAssertTrue(today.isSameDay(sameDay)) + XCTAssertFalse(today.isSameDay(differentDay)) + XCTAssertFalse(today.isSameDay(nil)) + } + + func testIsLessThanDaysAgo() { + let recentDate = Calendar.current.date(byAdding: .day, value: -2, to: Date())! + let olderDate = Calendar.current.date(byAdding: .day, value: -5, to: Date())! + + XCTAssertTrue(recentDate.isLessThan(daysAgo: 3)) + XCTAssertFalse(olderDate.isLessThan(daysAgo: 3)) + } + + func testIsLessThanMinutesAgo() { + let recentDate = Calendar.current.date(byAdding: .minute, value: -10, to: Date())! + let olderDate = Calendar.current.date(byAdding: .minute, value: -30, to: Date())! + + XCTAssertTrue(recentDate.isLessThan(minutesAgo: 15)) + XCTAssertFalse(olderDate.isLessThan(minutesAgo: 15)) + } + + func testSecondsSinceNow() { + let date = Calendar.current.date(byAdding: .second, value: -30, to: Date())! + XCTAssertEqual(date.secondsSinceNow(), 30) + } + + func testMinutesSinceNow() { + let date = Calendar.current.date(byAdding: .minute, value: -10, to: Date())! + XCTAssertEqual(date.minutesSinceNow(), 10) + } + + func testHoursSinceNow() { + let date = Calendar.current.date(byAdding: .hour, value: -5, to: Date())! + XCTAssertEqual(date.hoursSinceNow(), 5) + } + + func testDaysSinceNow() { + let date = Calendar.current.date(byAdding: .day, value: -7, to: Date())! + XCTAssertEqual(date.daysSinceNow(), 7) + } + + func testMonthsSinceNow() { + let date = Calendar.current.date(byAdding: .month, value: -3, to: Date())! + XCTAssertEqual(date.monthsSinceNow(), 3) + } + + func testYearsSinceNow() { + let date = Calendar.current.date(byAdding: .year, value: -2, to: Date())! + XCTAssertEqual(date.yearsSinceNow(), 2) + } +} diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift index fcea80939..17756b0e9 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift @@ -16,7 +16,7 @@ // limitations under the License. // import Foundation -import Networking +@testable import Networking import TestUtils import XCTest @@ -30,7 +30,7 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { override func setUp() { super.setUp() mockService = MockAPIService() - client = .init(environment: MaliciousSiteDetector.APIEnvironment.staging, service: mockService) + client = MaliciousSiteProtection.APIClient(environment: MaliciousSiteDetector.APIEnvironment.staging, service: mockService) } override func tearDown() { @@ -45,7 +45,10 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { let deleteFilter = Filter(hash: "6a929cd0b3ba4677eaedf1b2bdaf3ff89281cca94f688c83103bc9a676aea46d", regex: "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?") let expectedResponse = APIClient.Response.FiltersChangeSet(insert: [insertFilter], delete: [deleteFilter], revision: 666, replace: false) mockService.requestHandler = { [unowned self] in - XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .filterSet(.init(threatKind: .phishing, revision: 666)))) + let resultURL = $0.urlRequest.url! + let expectedQueryItems = client.environment.queryItems(for: .filterSet(APIRequestType.FilterSet(threatKind: .phishing, revision: 666))) + let expectedURL = client.environment.url(for: .filterSet(APIRequestType.FilterSet(threatKind: .phishing, revision: 666))).appending(queryItems: expectedQueryItems.toURLQueryItems()) + XCTAssertEqual(resultURL, expectedURL) let data = try? JSONEncoder().encode(expectedResponse) let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! return .success(.init(data: data, httpResponse: response)) @@ -62,7 +65,10 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { // Given let expectedResponse = APIClient.Response.HashPrefixesChangeSet(insert: ["abc"], delete: ["def"], revision: 1, replace: false) mockService.requestHandler = { [unowned self] in - XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .hashPrefixSet(.init(threatKind: .phishing, revision: 1)))) + let resultURL = $0.urlRequest.url + let expectedQueryItems = client.environment.queryItems(for: .hashPrefixSet(.init(threatKind: .phishing, revision: 1))) + let expectedURL = client.environment.url(for: .hashPrefixSet(.init(threatKind: .phishing, revision: 1))).appending(queryItems: expectedQueryItems.toURLQueryItems()) + XCTAssertEqual(resultURL, expectedURL) let data = try? JSONEncoder().encode(expectedResponse) let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! return .success(.init(data: data, httpResponse: response)) @@ -79,7 +85,10 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { // Given let expectedResponse = APIClient.Response.Matches(matches: [Match(hostname: "example.com", url: "https://example.com/test", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil)]) mockService.requestHandler = { [unowned self] in - XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .matches(.init(hashPrefix: "abc")))) + let resultURL = $0.urlRequest.url + let expectedQueryItems = client.environment.queryItems(for: .matches(.init(hashPrefix: "abc"))) + let expectedURL = client.environment.url(for: .matches(.init(hashPrefix: "abc"))).appending(queryItems: expectedQueryItems.toURLQueryItems()) + XCTAssertEqual(resultURL, expectedURL) let data = try? JSONEncoder().encode(expectedResponse) let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! return .success(.init(data: data, httpResponse: response)) diff --git a/Tests/NetworkProtectionTests/Mocks/MockSubscriptionTokenProvider.swift b/Tests/NetworkProtectionTests/Mocks/MockSubscriptionTokenProvider.swift new file mode 100644 index 000000000..63ee1afec --- /dev/null +++ b/Tests/NetworkProtectionTests/Mocks/MockSubscriptionTokenProvider.swift @@ -0,0 +1,77 @@ +// +// MockSubscriptionTokenProvider.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import Networking +import Subscription + +public class MockSubscriptionTokenProvider: SubscriptionTokenProvider { + public var tokenResult: Result? + + public func getTokenContainer(policy: Networking.TokensCachePolicy) async throws -> Networking.TokenContainer { + guard let tokenResult = tokenResult else { + throw OAuthClientError.missingTokens + } + switch tokenResult { + case .success(let result): + return result + case .failure(let error): + throw error + } + } + + public func getTokenContainerSynchronously(policy: Networking.TokensCachePolicy) -> Networking.TokenContainer? { + guard let tokenResult = tokenResult else { + return nil + } + switch tokenResult { + case .success(let result): + return result + case .failure: + return nil + } + } + + public func exchange(tokenV1: String) async throws -> Networking.TokenContainer { + guard let tokenResult = tokenResult else { + throw OAuthClientError.missingTokens + } + switch tokenResult { + case .success(let result): + return result + case .failure(let error): + throw error + } + } + + public func adopt(tokenContainer: Networking.TokenContainer) async throws { + guard let tokenResult = tokenResult else { + throw OAuthClientError.missingTokens + } + switch tokenResult { + case .success: + return + case .failure(let error): + throw error + } + } + + public func removeTokenContainer() { + tokenResult = nil + } +} diff --git a/Tests/NetworkProtectionTests/NetworkProtectionDeviceManagerTests.swift b/Tests/NetworkProtectionTests/NetworkProtectionDeviceManagerTests.swift index 21b28e346..745f714db 100644 --- a/Tests/NetworkProtectionTests/NetworkProtectionDeviceManagerTests.swift +++ b/Tests/NetworkProtectionTests/NetworkProtectionDeviceManagerTests.swift @@ -20,9 +20,12 @@ import Foundation import XCTest @testable import NetworkProtection @testable import NetworkProtectionTestUtils +@testable import Networking +@testable import Subscription +import TestUtils final class NetworkProtectionDeviceManagerTests: XCTestCase { - var tokenStore: NetworkProtectionTokenStoreMock! + var tokenProvider: MockSubscriptionTokenProvider! var keyStore: NetworkProtectionKeyStoreMock! var networkClient: MockNetworkProtectionClient! var temporaryURL: URL! @@ -30,22 +33,22 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase { override func setUp() { super.setUp() - tokenStore = NetworkProtectionTokenStoreMock() - tokenStore.token = "initialtoken" + tokenProvider = MockSubscriptionTokenProvider() + tokenProvider.tokenResult = .success(OAuthTokensFactory.makeValidTokenContainer()) keyStore = NetworkProtectionKeyStoreMock() networkClient = MockNetworkProtectionClient() temporaryURL = temporaryFileURL() manager = NetworkProtectionDeviceManager( networkClient: networkClient, - tokenStore: tokenStore, + tokenProvider: tokenProvider, keyStore: keyStore, errorEvents: nil ) } override func tearDown() { - tokenStore = nil + tokenProvider = nil keyStore = nil temporaryURL = nil manager = nil @@ -108,25 +111,16 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase { XCTAssertEqual(networkClient.spyRegister?.requestBody.server, server.serverName) } - func testWhenGeneratingTunnelConfig_storedAuthTokenIsInvalidOnGettingServers_deletesToken() async { + func testWhenGeneratingTunnelConfig_storedAuthTokenIsInvalidOnGettingServers_deletesToken() async throws { _ = NetworkProtectionServer.mockRegisteredServer networkClient.stubRegister = .failure(.invalidAuthToken) - XCTAssertNotNil(tokenStore.token) + tokenProvider.tokenResult = .success(OAuthTokensFactory.makeValidTokenContainerWithEntitlements()) _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) - XCTAssertNil(tokenStore.token) - } - - func testWhenGeneratingTunnelConfig_storedAuthTokenIsInvalidOnRegisteringServer_deletesToken() async { - networkClient.stubRegister = .failure(.invalidAuthToken) - - XCTAssertNotNil(tokenStore.token) - - _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) - - XCTAssertNil(tokenStore.token) + let tokens = try? await tokenProvider.getTokenContainer(policy: .local) + XCTAssertNil(tokens) } func testDecodingServers() throws { @@ -210,12 +204,10 @@ extension NetworkProtectionDeviceManager { func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod, regenerateKey: Bool) async throws -> NetworkProtectionDeviceManager.GenerateTunnelConfigurationResult { - try await generateTunnelConfiguration( - resolvedSelectionMethod: selectionMethod, - excludeLocalNetworks: false, - dnsSettings: .default, - regenerateKey: regenerateKey - ) + try await generateTunnelConfiguration(resolvedSelectionMethod: selectionMethod, + excludeLocalNetworks: false, + dnsSettings: .default, + regenerateKey: regenerateKey) } } diff --git a/Tests/NetworkProtectionTests/Repositories/NetworkProtectionLocationListCompositeRepositoryTests.swift b/Tests/NetworkProtectionTests/Repositories/NetworkProtectionLocationListCompositeRepositoryTests.swift index 57f992da4..80c1a1f62 100644 --- a/Tests/NetworkProtectionTests/Repositories/NetworkProtectionLocationListCompositeRepositoryTests.swift +++ b/Tests/NetworkProtectionTests/Repositories/NetworkProtectionLocationListCompositeRepositoryTests.swift @@ -21,20 +21,23 @@ import XCTest @testable import NetworkProtection @testable import NetworkProtectionTestUtils import Common +@testable import Subscription +@testable import Networking +import TestUtils class NetworkProtectionLocationListCompositeRepositoryTests: XCTestCase { var repository: NetworkProtectionLocationListCompositeRepository! var client: MockNetworkProtectionClient! - var tokenStore: MockNetworkProtectionTokenStorage! + var tokenProvider: MockSubscriptionTokenProvider! var verifyErrorEvent: ((NetworkProtectionError) -> Void)? override func setUp() { super.setUp() client = MockNetworkProtectionClient() - tokenStore = MockNetworkProtectionTokenStorage() + tokenProvider = MockSubscriptionTokenProvider() repository = NetworkProtectionLocationListCompositeRepository( client: client, - tokenStore: tokenStore, + tokenProvider: tokenProvider, errorEvents: .init { [weak self] event, _, _, _ in self?.verifyErrorEvent?(event) }) @@ -44,13 +47,12 @@ class NetworkProtectionLocationListCompositeRepositoryTests: XCTestCase { override func tearDown() { NetworkProtectionLocationListCompositeRepository.clearCache() client = nil - tokenStore = nil + tokenProvider = nil repository = nil super.tearDown() } func testFetchLocationList_firstCall_fetchesAndReturnsList() async throws { - let expectedToken = "aToken" let expectedList: [NetworkProtectionLocation] = [ .testData(country: "US", cities: [ .testData(name: "New York"), @@ -58,21 +60,22 @@ class NetworkProtectionLocationListCompositeRepositoryTests: XCTestCase { ]) ] client.stubGetLocations = .success(expectedList) - tokenStore.stubFetchToken = expectedToken + let tokenContainer = OAuthTokensFactory.makeValidTokenContainerWithEntitlements() + tokenProvider.tokenResult = .success(tokenContainer) let locations = try await repository.fetchLocationList() - XCTAssertEqual(expectedToken, client.spyGetLocationsAuthToken) + XCTAssertEqual("ddg:\(tokenContainer.accessToken)", client.spyGetLocationsAuthToken) XCTAssertEqual(expectedList, locations) } func testFetchLocationList_secondCall_returnsCachedList() async throws { - let expectedToken = "aToken" let expectedList: [NetworkProtectionLocation] = [ .testData(country: "DE", cities: [ .testData(name: "Berlin") ]) ] client.stubGetLocations = .success(expectedList) - tokenStore.stubFetchToken = expectedToken + let tokenContainer = OAuthTokensFactory.makeValidTokenContainer() + tokenProvider.tokenResult = .success(tokenContainer) _ = try await repository.fetchLocationList() client.spyGetLocationsAuthToken = nil let locations = try await repository.fetchLocationList() @@ -83,7 +86,7 @@ class NetworkProtectionLocationListCompositeRepositoryTests: XCTestCase { func testFetchLocationList_noAuthToken_throwsError() async throws { client.stubGetLocations = .success([.testData()]) - tokenStore.stubFetchToken = nil + tokenProvider.tokenResult = .failure(OAuthClientError.missingTokens) var errorResult: NetworkProtectionError? do { _ = try await repository.fetchLocationList() @@ -101,7 +104,7 @@ class NetworkProtectionLocationListCompositeRepositoryTests: XCTestCase { func testFetchLocationList_noAuthToken_sendsErrorEvent() async { client.stubGetLocations = .success([.testData()]) - tokenStore.stubFetchToken = nil + tokenProvider.tokenResult = .failure(OAuthClientError.missingTokens) var didReceiveError: Bool = false verifyErrorEvent = { error in didReceiveError = true diff --git a/Tests/NetworkProtectionTests/StartupOptionTests.swift b/Tests/NetworkProtectionTests/StartupOptionTests.swift index 5211305dd..909872d2c 100644 --- a/Tests/NetworkProtectionTests/StartupOptionTests.swift +++ b/Tests/NetworkProtectionTests/StartupOptionTests.swift @@ -32,7 +32,7 @@ final class StartupOptionsTests: XCTestCase { let rawOptions = [String: Any]() let options = StartupOptions(options: rawOptions) - XCTAssertEqual(options.authToken, .useExisting) + XCTAssertEqual(options.tokenContainer, .useExisting) XCTAssertEqual(options.enableTester, .useExisting) XCTAssertEqual(options.keyValidity, .useExisting) XCTAssertFalse(options.simulateCrash) @@ -54,7 +54,7 @@ final class StartupOptionsTests: XCTestCase { ] let options = StartupOptions(options: rawOptions) - XCTAssertEqual(options.authToken, .reset) + XCTAssertEqual(options.tokenContainer, .reset) XCTAssertEqual(options.enableTester, .reset) XCTAssertEqual(options.keyValidity, .reset) XCTAssertFalse(options.simulateCrash) @@ -75,7 +75,7 @@ final class StartupOptionsTests: XCTestCase { ] let options = StartupOptions(options: rawOptions) - XCTAssertEqual(options.authToken, .useExisting) + XCTAssertEqual(options.tokenContainer, .useExisting) XCTAssertEqual(options.enableTester, .useExisting) XCTAssertEqual(options.keyValidity, .useExisting) XCTAssertFalse(options.simulateCrash) diff --git a/Tests/NetworkingTests/OAuth/.swift b/Tests/NetworkingTests/OAuth/.swift new file mode 100644 index 000000000..72cf8d685 --- /dev/null +++ b/Tests/NetworkingTests/OAuth/.swift @@ -0,0 +1,18 @@ +// +// Untitled.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + diff --git a/Tests/NetworkingTests/OAuth/OAuthClientTests.swift b/Tests/NetworkingTests/OAuth/OAuthClientTests.swift new file mode 100644 index 000000000..01a50cc2b --- /dev/null +++ b/Tests/NetworkingTests/OAuth/OAuthClientTests.swift @@ -0,0 +1,251 @@ +// +// OAuthClientTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +import TestUtils +@testable import Networking +import JWTKit + +final class OAuthClientTests: XCTestCase { + + var oAuthClient: DefaultOAuthClient! + var mockOAuthService: MockOAuthService! + var tokenStorage: MockTokenStorage! + var legacyTokenStorage: MockLegacyTokenStorage! + + override func setUp() async throws { + mockOAuthService = MockOAuthService() + tokenStorage = MockTokenStorage() + legacyTokenStorage = MockLegacyTokenStorage() + oAuthClient = DefaultOAuthClient(tokensStorage: tokenStorage, + legacyTokenStorage: legacyTokenStorage, + authService: mockOAuthService) + } + + override func tearDown() async throws { + mockOAuthService = nil + oAuthClient = nil + tokenStorage = nil + legacyTokenStorage = nil + } + + // MARK: - + + func testUserNotAuthenticated() async throws { + XCTAssertFalse(oAuthClient.isUserAuthenticated) + } + + func testUserAuthenticated() async throws { + tokenStorage.tokenContainer = OAuthTokensFactory.makeValidTokenContainer() + XCTAssertTrue(oAuthClient.isUserAuthenticated) + } + + func testCurrentTokenContainer() async throws { + XCTAssertNil(oAuthClient.currentTokenContainer) + tokenStorage.tokenContainer = OAuthTokensFactory.makeValidTokenContainer() + XCTAssertNotNil(oAuthClient.currentTokenContainer) + } + + // MARK: - Get tokens + + // MARK: Local + + func testGetToken_Local_Fail() async throws { + let localContainer = try? await oAuthClient.getTokens(policy: .local) + XCTAssertNil(localContainer) + } + + func testGetToken_Local_Success() async throws { + tokenStorage.tokenContainer = OAuthTokensFactory.makeValidTokenContainer() + + let localContainer = try? await oAuthClient.getTokens(policy: .local) + XCTAssertNotNil(localContainer) + XCTAssertFalse(localContainer!.decodedAccessToken.isExpired()) + } + + func testGetToken_Local_SuccessExpired() async throws { + tokenStorage.tokenContainer = OAuthTokensFactory.makeExpiredTokenContainer() + + let localContainer = try? await oAuthClient.getTokens(policy: .local) + XCTAssertNotNil(localContainer) + XCTAssertTrue(localContainer!.decodedAccessToken.isExpired()) + } + + // MARK: Local Valid + + /// A valid local token exists + func testGetToken_localValid_local() async throws { + + tokenStorage.tokenContainer = OAuthTokensFactory.makeValidTokenContainer() + + let localContainer = try await oAuthClient.getTokens(policy: .localValid) + XCTAssertNotNil(localContainer.accessToken) + XCTAssertNotNil(localContainer.refreshToken) + XCTAssertNotNil(localContainer.decodedAccessToken) + XCTAssertNotNil(localContainer.decodedRefreshToken) + XCTAssertFalse(localContainer.decodedAccessToken.isExpired()) + } + + /// An expired local token exists and is refreshed successfully + func testGetToken_localValid_refreshSuccess() async throws { + + mockOAuthService.getJWTSignersResponse = .success(JWTSigners()) + mockOAuthService.refreshAccessTokenResponse = .success( OAuthTokensFactory.makeValidOAuthTokenResponse()) + tokenStorage.tokenContainer = OAuthTokensFactory.makeExpiredTokenContainer() + + oAuthClient.testingDecodedTokenContainer = TokenContainer(accessToken: "accessToken", + refreshToken: "refreshToken", + decodedAccessToken: JWTAccessToken.mock, + decodedRefreshToken: JWTRefreshToken.mock) + + let localContainer = try await oAuthClient.getTokens(policy: .localValid) + XCTAssertNotNil(localContainer.accessToken) + XCTAssertNotNil(localContainer.refreshToken) + XCTAssertNotNil(localContainer.decodedAccessToken) + XCTAssertNotNil(localContainer.decodedRefreshToken) + XCTAssertFalse(localContainer.decodedAccessToken.isExpired()) + } + + /// An expired local token exists but refresh fails + func testGetToken_localValid_refreshFail() async throws { + + mockOAuthService.getJWTSignersResponse = .success(JWTSigners()) + mockOAuthService.refreshAccessTokenResponse = .failure(OAuthServiceError.invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + tokenStorage.tokenContainer = OAuthTokensFactory.makeExpiredTokenContainer() + + do { + _ = try await oAuthClient.getTokens(policy: .localValid) + XCTFail("Error expected") + } catch { + XCTAssertEqual(error as? OAuthServiceError, .invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + } + } + + // MARK: Force Refresh + + /// Local token is missing, refresh fails + func testGetToken_localForceRefresh_missingLocal() async throws { + do { + _ = try await oAuthClient.getTokens(policy: .localForceRefresh) + XCTFail("Error expected") + } catch { + XCTAssertEqual(error as? Networking.OAuthClientError, .missingRefreshToken) + } + } + + /// An expired local token exists and is refreshed successfully + func testGetToken_localForceRefresh_success() async throws { + + mockOAuthService.getJWTSignersResponse = .success(JWTSigners()) + mockOAuthService.refreshAccessTokenResponse = .success( OAuthTokensFactory.makeValidOAuthTokenResponse()) + tokenStorage.tokenContainer = OAuthTokensFactory.makeExpiredTokenContainer() + + oAuthClient.testingDecodedTokenContainer = TokenContainer(accessToken: "accessToken", + refreshToken: "refreshToken", + decodedAccessToken: JWTAccessToken.mock, + decodedRefreshToken: JWTRefreshToken.mock) + + let localContainer = try await oAuthClient.getTokens(policy: .localForceRefresh) + XCTAssertNotNil(localContainer.accessToken) + XCTAssertNotNil(localContainer.refreshToken) + XCTAssertNotNil(localContainer.decodedAccessToken) + XCTAssertNotNil(localContainer.decodedRefreshToken) + XCTAssertFalse(localContainer.decodedAccessToken.isExpired()) + } + + func testGetToken_localForceRefresh_refreshFail() async throws { + + mockOAuthService.getJWTSignersResponse = .success(JWTSigners()) + mockOAuthService.refreshAccessTokenResponse = .failure(OAuthServiceError.invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + tokenStorage.tokenContainer = OAuthTokensFactory.makeExpiredTokenContainer() + + do { + _ = try await oAuthClient.getTokens(policy: .localForceRefresh) + XCTFail("Error expected") + } catch { + XCTAssertEqual(error as? OAuthServiceError, .invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + } + } + + // MARK: Create if needed + + func testGetToken_createIfNeeded_foundLocal() async throws { + tokenStorage.tokenContainer = OAuthTokensFactory.makeValidTokenContainer() + + let tokenContainer = try await oAuthClient.getTokens(policy: .createIfNeeded) + XCTAssertNotNil(tokenContainer.accessToken) + XCTAssertNotNil(tokenContainer.refreshToken) + XCTAssertNotNil(tokenContainer.decodedAccessToken) + XCTAssertNotNil(tokenContainer.decodedRefreshToken) + XCTAssertFalse(tokenContainer.decodedAccessToken.isExpired()) + } + + func testGetToken_createIfNeeded_missingLocal_createSuccess() async throws { + mockOAuthService.authorizeResponse = .success("auth_session_id") + mockOAuthService.createAccountResponse = .success("auth_code") + mockOAuthService.getAccessTokenResponse = .success(OAuthTokensFactory.makeValidOAuthTokenResponse()) + + oAuthClient.testingDecodedTokenContainer = TokenContainer(accessToken: "accessToken", + refreshToken: "refreshToken", + decodedAccessToken: JWTAccessToken.mock, + decodedRefreshToken: JWTRefreshToken.mock) + + let tokenContainer = try await oAuthClient.getTokens(policy: .createIfNeeded) + XCTAssertNotNil(tokenContainer.accessToken) + XCTAssertNotNil(tokenContainer.refreshToken) + XCTAssertNotNil(tokenContainer.decodedAccessToken) + XCTAssertNotNil(tokenContainer.decodedRefreshToken) + XCTAssertFalse(tokenContainer.decodedAccessToken.isExpired()) + } + + func testGetToken_createIfNeeded_missingLocal_createFail() async throws { + mockOAuthService.authorizeResponse = .failure(OAuthServiceError.invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + + do { + _ = try await oAuthClient.getTokens(policy: .createIfNeeded) + XCTFail("Error expected") + } catch { + XCTAssertEqual(error as? OAuthServiceError, .invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + } + } + + func testGetToken_createIfNeeded_missingLocal_createFail2() async throws { + mockOAuthService.authorizeResponse = .success("auth_session_id") + mockOAuthService.createAccountResponse = .failure(OAuthServiceError.invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + + do { + _ = try await oAuthClient.getTokens(policy: .createIfNeeded) + XCTFail("Error expected") + } catch { + XCTAssertEqual(error as? OAuthServiceError, .invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + } + } + + func testGetToken_createIfNeeded_missingLocal_createFail3() async throws { + mockOAuthService.authorizeResponse = .success("auth_session_id") + mockOAuthService.createAccountResponse = .success("auth_code") + mockOAuthService.getAccessTokenResponse = .failure(OAuthServiceError.invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + + do { + _ = try await oAuthClient.getTokens(policy: .createIfNeeded) + XCTFail("Error expected") + } catch { + XCTAssertEqual(error as? OAuthServiceError, .invalidResponseCode(HTTPStatusCode.gatewayTimeout)) + } + } +} diff --git a/Tests/NetworkingTests/OAuth/OAuthServiceTests.swift b/Tests/NetworkingTests/OAuth/OAuthServiceTests.swift new file mode 100644 index 000000000..eab540c4a --- /dev/null +++ b/Tests/NetworkingTests/OAuth/OAuthServiceTests.swift @@ -0,0 +1,82 @@ +// +// OAuthServiceTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +import TestUtils +@testable import Networking + +final class AuthServiceTests: XCTestCase { + + let baseURL = OAuthEnvironment.staging.url + + override func setUpWithError() throws { + /* + var mockedApiService = MockAPIService(decodableResponse: <#T##Result#>, + apiResponse: <#T##Result<(data: Data?, httpResponse: HTTPURLResponse), any Error>#>) + */ + } + + override func tearDownWithError() throws { + // Put teardown code here. This method is called after the invocation of each test method in the class. + } + + var realAPISService: APIService { + let configuration = URLSessionConfiguration.default + configuration.httpCookieStorage = nil + configuration.requestCachePolicy = .reloadIgnoringLocalCacheData + let urlSession = URLSession(configuration: configuration, + delegate: SessionDelegate(), + delegateQueue: nil) + return DefaultAPIService(urlSession: urlSession) + } + + // MARK: - REAL tests, useful for development and debugging but disabled for normal testing + + func disabled_test_real_AuthoriseSuccess() async throws { + let authService = DefaultOAuthService(baseURL: baseURL, apiService: realAPISService) + let codeChallenge = OAuthCodesGenerator.codeChallenge(codeVerifier: OAuthCodesGenerator.codeVerifier)! + let result = try await authService.authorize(codeChallenge: codeChallenge) + XCTAssertNotNil(result) + } + + func disabled_test_real_AuthoriseFailure() async throws { + let authService = DefaultOAuthService(baseURL: baseURL, apiService: realAPISService) + do { + _ = try await authService.authorize(codeChallenge: "") + } catch { + switch error { + case OAuthServiceError.authAPIError(let code): + XCTAssertEqual(code.rawValue, "invalid_authorization_request") + XCTAssertEqual(code.description, "One or more of the required parameters are missing or any provided parameters have invalid values") + default: + XCTFail("Wrong error") + } + } + } + + func disabled_test_real_GetJWTSigner() async throws { + let authService = DefaultOAuthService(baseURL: baseURL, apiService: realAPISService) + let signer = try await authService.getJWTSigners() + do { + let _: JWTAccessToken = try signer.verify("sdfgdsdzfgsdf") + XCTFail("Should have thrown an error") + } catch { + XCTAssertNotNil(error) + } + } +} diff --git a/Tests/NetworkingTests/OAuth/TokenContainerTests.swift b/Tests/NetworkingTests/OAuth/TokenContainerTests.swift new file mode 100644 index 000000000..a3375a72e --- /dev/null +++ b/Tests/NetworkingTests/OAuth/TokenContainerTests.swift @@ -0,0 +1,139 @@ +// +// TokenContainerTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +import JWTKit +@testable import Networking +import TestUtils + +final class TokenContainerTests: XCTestCase { + + // Test expired access token + func testExpiredAccessToken() { + let token = OAuthTokensFactory.makeExpiredAccessToken() + XCTAssertTrue(token.isExpired(), "Expected token to be expired.") + } + + // Test invalid scope in access token + func testAccessTokenInvalidScope() { + let token = OAuthTokensFactory.makeAccessToken(scope: "invalid-scope") + XCTAssertThrowsError(try token.verify(using: .hs256(key: "secret"))) { error in + XCTAssertEqual(error as? TokenPayloadError, .invalidTokenScope, "Expected invalidTokenScope error.") + } + } + + // Test invalid scope in refresh token + func testRefreshTokenInvalidScope() { + let token = OAuthTokensFactory.makeRefreshToken(scope: "invalid-scope") + XCTAssertThrowsError(try token.verify(using: .hs256(key: "secret"))) { error in + XCTAssertEqual(error as? TokenPayloadError, .invalidTokenScope, "Expected invalidTokenScope error.") + } + } + + // Test valid scope in access token + func testAccessTokenValidScope() { + let token = OAuthTokensFactory.makeAccessToken(scope: "privacypro") + XCTAssertNoThrow(try token.verify(using: .hs256(key: "secret")), "Expected no error for valid scope.") + } + + // Test valid scope in refresh token + func testRefreshTokenValidScope() { + let token = OAuthTokensFactory.makeRefreshToken(scope: "refresh") + XCTAssertNoThrow(try token.verify(using: .hs256(key: "secret")), "Expected no error for valid scope.") + } + + // Test entitlements with multiple types, including unsupported + func testSubscriptionEntitlements() { + let entitlements = [ + EntitlementPayload(product: .networkProtection, name: "subscriber"), + EntitlementPayload(product: .unknown, name: "subscriber") + ] + let token = JWTAccessToken( + exp: ExpirationClaim(value: Date().addingTimeInterval(3600)), + iat: IssuedAtClaim(value: Date()), + sub: SubjectClaim(value: "test-subject"), + aud: AudienceClaim(value: ["test-audience"]), + iss: IssuerClaim(value: "test-issuer"), + jti: IDClaim(value: "test-id"), + scope: "privacypro", + api: "v2", + email: "test@example.com", + entitlements: entitlements + ) + + XCTAssertEqual(token.subscriptionEntitlements, [.networkProtection, .unknown], "Expected mixed entitlements including unknown.") + XCTAssertTrue(token.hasEntitlement(.networkProtection), "Expected entitlement for networkProtection.") + XCTAssertFalse(token.hasEntitlement(.identityTheftRestoration), "Expected no entitlement for identityTheftRestoration.") + } + + // Test equatability of TokenContainer with same tokens but different fields + func testTokenContainerEquatabilitySameTokens() { + let accessToken = "same-access-token" + let refreshToken = "same-refresh-token" + + let container1 = TokenContainer( + accessToken: accessToken, + refreshToken: refreshToken, + decodedAccessToken: OAuthTokensFactory.makeAccessToken(scope: "privacypro"), + decodedRefreshToken: OAuthTokensFactory.makeRefreshToken(scope: "refresh") + ) + + let container2 = TokenContainer( + accessToken: accessToken, + refreshToken: refreshToken, + decodedAccessToken: OAuthTokensFactory.makeAccessToken(scope: "privacypro"), + decodedRefreshToken: OAuthTokensFactory.makeRefreshToken(scope: "refresh") + ) + + XCTAssertEqual(container1, container2, "Expected containers with identical tokens to be equal.") + } + + // Test equatability of TokenContainer with same token values but different decoded content + func testTokenContainerEquatabilityDifferentContent() { + let accessToken = "same-access-token" + let refreshToken = "same-refresh-token" + + let container1 = TokenContainer( + accessToken: accessToken, + refreshToken: refreshToken, + decodedAccessToken: OAuthTokensFactory.makeAccessToken(scope: "privacypro"), + decodedRefreshToken: OAuthTokensFactory.makeRefreshToken(scope: "refresh") + ) + + let modifiedAccessToken = OAuthTokensFactory.makeAccessToken(scope: "privacypro", email: "modified@example.com") // Changing a field in decoded token + + let container2 = TokenContainer( + accessToken: accessToken, + refreshToken: refreshToken, + decodedAccessToken: modifiedAccessToken, + decodedRefreshToken: OAuthTokensFactory.makeRefreshToken(scope: "refresh") + ) + + XCTAssertEqual(container1, container2, "Expected containers with identical tokens but different decoded content to be equal.") + } + + func testEncodeDecodeData() throws { + let container = OAuthTokensFactory.makeValidTokenContainer() + let tokenContainer = try TokenContainer(with: container.data!) + XCTAssertEqual(container, tokenContainer, "Expected decoded token container to be equal to original.") + XCTAssertEqual(container.accessToken, tokenContainer.accessToken) + XCTAssertEqual(container.refreshToken, tokenContainer.refreshToken) + XCTAssertEqual(container.decodedAccessToken, tokenContainer.decodedAccessToken) + XCTAssertEqual(container.decodedRefreshToken, tokenContainer.decodedRefreshToken) + } +} diff --git a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift index 4ec1b8b59..59eeadebb 100644 --- a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift +++ b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift @@ -41,18 +41,21 @@ final class APIRequestV2Tests: XCTestCase { cachePolicy: cachePolicy, responseConstraints: constraints) - let urlRequest = apiRequest.urlRequest + guard let urlRequest = apiRequest?.urlRequest else { + XCTFail("Nil URLRequest") + return + } XCTAssertEqual(urlRequest.url?.host(), url.host()) XCTAssertEqual(urlRequest.httpMethod, method.rawValue) let urlComponents = URLComponents(string: urlRequest.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value"))) + XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems())) XCTAssertEqual(urlRequest.allHTTPHeaderFields, headers.httpHeaders) XCTAssertEqual(urlRequest.httpBody, body) - XCTAssertEqual(apiRequest.timeoutInterval, timeoutInterval) + XCTAssertEqual(apiRequest?.timeoutInterval, timeoutInterval) XCTAssertEqual(urlRequest.cachePolicy, cachePolicy) - XCTAssertEqual(apiRequest.responseConstraints, constraints) + XCTAssertEqual(apiRequest?.responseConstraints, constraints) } func testURLRequestGeneration() { @@ -72,16 +75,16 @@ final class APIRequestV2Tests: XCTestCase { timeoutInterval: timeoutInterval, cachePolicy: cachePolicy) - let urlComponents = URLComponents(string: apiRequest.urlRequest.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value"))) + let urlComponents = URLComponents(string: apiRequest!.urlRequest.url!.absoluteString)! + XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems())) XCTAssertNotNil(apiRequest) - XCTAssertEqual(apiRequest.urlRequest.url?.absoluteString, "https://www.example.com?key=value") - XCTAssertEqual(apiRequest.urlRequest.httpMethod, method.rawValue) - XCTAssertEqual(apiRequest.urlRequest.allHTTPHeaderFields, headers.httpHeaders) - XCTAssertEqual(apiRequest.urlRequest.httpBody, body) - XCTAssertEqual(apiRequest.urlRequest.timeoutInterval, timeoutInterval) - XCTAssertEqual(apiRequest.urlRequest.cachePolicy, cachePolicy) + XCTAssertEqual(apiRequest?.urlRequest.url?.absoluteString, "https://www.example.com?key=value") + XCTAssertEqual(apiRequest?.urlRequest.httpMethod, method.rawValue) + XCTAssertEqual(apiRequest?.urlRequest.allHTTPHeaderFields, headers.httpHeaders) + XCTAssertEqual(apiRequest?.urlRequest.httpBody, body) + XCTAssertEqual(apiRequest?.urlRequest.timeoutInterval, timeoutInterval) + XCTAssertEqual(apiRequest?.urlRequest.cachePolicy, cachePolicy) } func testDefaultValues() { @@ -89,13 +92,16 @@ final class APIRequestV2Tests: XCTestCase { let apiRequest = APIRequestV2(url: url) let headers = APIRequestV2.HeadersV2() - let urlRequest = apiRequest.urlRequest + guard let urlRequest = apiRequest?.urlRequest else { + XCTFail("Nil URLRequest") + return + } XCTAssertEqual(urlRequest.httpMethod, HTTPRequestMethod.get.rawValue) XCTAssertEqual(urlRequest.timeoutInterval, 60.0) XCTAssertEqual(headers.httpHeaders, urlRequest.allHTTPHeaderFields) XCTAssertNil(urlRequest.httpBody) XCTAssertEqual(urlRequest.cachePolicy.rawValue, 0) - XCTAssertNil(apiRequest.responseConstraints) + XCTAssertNil(apiRequest?.responseConstraints) } func testAllowedQueryReservedCharacters() { @@ -106,10 +112,9 @@ final class APIRequestV2Tests: XCTestCase { queryItems: queryItems, allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) - let urlString = apiRequest.urlRequest.url!.absoluteString - XCTAssertEqual(urlString, "https://www.example.com?k%23e,y=val%23ue") - + let urlString = apiRequest!.urlRequest.url!.absoluteString + XCTAssertTrue(urlString == "https://www.example.com?k%2523e,y=val%2523ue") let urlComponents = URLComponents(string: urlString)! - XCTAssertEqual(urlComponents.queryItems?.count, 1) + XCTAssertTrue(urlComponents.queryItems?.count == 1) } } diff --git a/Tests/NetworkingTests/v2/APIServiceTests.swift b/Tests/NetworkingTests/v2/APIServiceTests.swift index 730d6afbb..394ec2949 100644 --- a/Tests/NetworkingTests/v2/APIServiceTests.swift +++ b/Tests/NetworkingTests/v2/APIServiceTests.swift @@ -31,7 +31,6 @@ final class APIServiceTests: XCTestCase { // MARK: - Real API calls, do not enable func disabled_testRealFull() async throws { -// func testRealFull() async throws { let request = APIRequestV2(url: HTTPURLResponse.testUrl, method: .post, queryItems: ["Query,Item1%Name": "Query,Item1%Value"], @@ -41,7 +40,7 @@ final class APIServiceTests: XCTestCase { cachePolicy: .reloadIgnoringLocalAndRemoteCacheData, responseConstraints: [APIResponseConstraints.allowHTTPNotModified, APIResponseConstraints.requireETagHeader], - allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) + allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))! let apiService = DefaultAPIService() let response = try await apiService.fetch(request: request) let responseHTML: String = try response.decodeBody() @@ -50,7 +49,7 @@ final class APIServiceTests: XCTestCase { func disabled_testRealCallJSON() async throws { // func testRealCallJSON() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl) + let request = APIRequestV2(url: HTTPURLResponse.testUrl)! let apiService = DefaultAPIService() let result = try await apiService.fetch(request: request) @@ -63,28 +62,30 @@ final class APIServiceTests: XCTestCase { func disabled_testRealCallString() async throws { // func testRealCallString() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl) + let request = APIRequestV2(url: HTTPURLResponse.testUrl)! let apiService = DefaultAPIService() let result = try await apiService.fetch(request: request) XCTAssertNotNil(result) } + // MARK: - + func testQueryItems() async throws { let qItems = ["qName1": "qValue1", "qName2": "qValue2"] MockURLProtocol.requestHandler = { request in let urlComponents = URLComponents(string: request.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(qItems.map { URLQueryItem(name: $0.key, value: $0.value) })) + XCTAssertTrue(urlComponents.queryItems!.contains(qItems.toURLQueryItems())) return (HTTPURLResponse.ok, nil) } - let request = APIRequestV2(url: HTTPURLResponse.testUrl, queryItems: qItems) + let request = APIRequestV2(url: HTTPURLResponse.testUrl, queryItems: qItems)! let apiService = DefaultAPIService(urlSession: mockURLSession) _ = try await apiService.fetch(request: request) } func testURLRequestError() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl) + let request = APIRequestV2(url: HTTPURLResponse.testUrl)! enum TestError: Error { case anError @@ -110,7 +111,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementAllowHTTPNotModifiedSuccess() async throws { let requirements = [APIResponseConstraints.allowHTTPNotModified ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) } @@ -121,7 +122,7 @@ final class APIServiceTests: XCTestCase { } func testResponseRequirementAllowHTTPNotModifiedFailure() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl) + let request = APIRequestV2(url: HTTPURLResponse.testUrl)! MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) } @@ -146,7 +147,7 @@ final class APIServiceTests: XCTestCase { let requirements: [APIResponseConstraints] = [ APIResponseConstraints.requireETagHeader ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) } // HTTPURLResponse.ok contains etag let apiService = DefaultAPIService(urlSession: mockURLSession) @@ -157,7 +158,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireETagHeaderFailure() async throws { let requirements = [ APIResponseConstraints.requireETagHeader ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.okNoEtag, nil) } @@ -180,7 +181,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireUserAgentSuccess() async throws { let requirements = [ APIResponseConstraints.requireUserAgent ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.okUserAgent, nil) @@ -193,7 +194,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireUserAgentFailure() async throws { let requirements = [ APIResponseConstraints.requireUserAgent ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) } @@ -212,4 +213,39 @@ final class APIServiceTests: XCTestCase { } } + // MARK: - Retry + + func testRetry() async throws { + let request = APIRequestV2(url: HTTPURLResponse.testUrl, retryPolicy: APIRequestV2.RetryPolicy(maxRetries: 3))! + let requestCountExpectation = expectation(description: "Request performed count") + requestCountExpectation.expectedFulfillmentCount = 4 + + MockURLProtocol.requestHandler = { request in + requestCountExpectation.fulfill() + return ( HTTPURLResponse.internalServerError, nil) + } + + let apiService = DefaultAPIService(urlSession: mockURLSession) + _ = try? await apiService.fetch(request: request) + + await fulfillment(of: [requestCountExpectation], timeout: 1.0) + } + + func testNoRetry() async throws { + let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let requestCountExpectation = expectation(description: "Request performed count") + requestCountExpectation.expectedFulfillmentCount = 1 + + MockURLProtocol.requestHandler = { request in + requestCountExpectation.fulfill() + return ( HTTPURLResponse.internalServerError, nil) + } + + let apiService = DefaultAPIService(urlSession: mockURLSession) + do { + _ = try await apiService.fetch(request: request) + } + + await fulfillment(of: [requestCountExpectation], timeout: 1.0) + } } diff --git a/Tests/NetworkingTests/v2/Extensions/DictionaryURLQueryItemsTests.swift b/Tests/NetworkingTests/v2/Extensions/DictionaryURLQueryItemsTests.swift new file mode 100644 index 000000000..4d99babce --- /dev/null +++ b/Tests/NetworkingTests/v2/Extensions/DictionaryURLQueryItemsTests.swift @@ -0,0 +1,113 @@ +// +// DictionaryURLQueryItemsTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +@testable import Networking + +final class DictionaryURLQueryItemsTests: XCTestCase { + + func queryParam(withName name: String, from queryItems: [URLQueryItem]) -> URLQueryItem { + return queryItems.compactMap({ queryItem in + if queryItem.name == name { + return queryItem + } else { + return nil + } + }).last! + } + + func testBasicKeyValuePairsConversion() { + let dict: [String: String] = ["key1": "value1", + "key2": "value2"] + let queryItems = dict.toURLQueryItems() + + XCTAssertEqual(queryItems.count, 2) + let q0 = queryParam(withName: "key1", from: queryItems) + XCTAssertEqual(q0.name, "key1") + XCTAssertEqual(q0.value, "value1") + + let q1 = queryParam(withName: "key2", from: queryItems) + XCTAssertEqual(q1.name, "key2") + XCTAssertEqual(q1.value, "value2") + } + + func testReservedCharactersAreEncoded() { + let dict: [String: String] = ["query": "value with spaces", + "special": "value/with/slash"] + let queryItems = dict.toURLQueryItems() + + XCTAssertEqual(queryItems.count, 2) + let q1 = queryParam(withName: "query", from: queryItems) + XCTAssertEqual(q1.name, "query") + XCTAssertEqual(q1.value, "value with spaces") + + let q2 = queryParam(withName: "special", from: queryItems) + XCTAssertEqual(q2.name, "special") + XCTAssertEqual(q2.value, "value/with/slash") + } + + func testReservedCharactersNotEncodedWhenAllowedCharacterSetProvided() { + let dict: [String: String] = ["specialKey": "value/with/slash"] + let allowedCharacters = CharacterSet.urlPathAllowed + let queryItems = dict.toURLQueryItems(allowedReservedCharacters: allowedCharacters) + + XCTAssertEqual(queryItems.count, 1) + XCTAssertEqual(queryItems[0].name, "specialKey") + XCTAssertEqual(queryItems[0].value, "value/with/slash") // '/' should be preserved + } + + func testEmptyDictionaryReturnsEmptyQueryItems() { + let dict: [String: String] = [:] + let queryItems = dict.toURLQueryItems() + + XCTAssertEqual(queryItems.count, 0) + } + + func testPercentEncodingWithCustomCharacterSet() { + let dict: [String: String] = ["key": "value with spaces & symbols!"] + let allowedCharacters = CharacterSet.punctuationCharacters.union(.whitespaces) + let queryItems = dict.toURLQueryItems(allowedReservedCharacters: allowedCharacters) + + XCTAssertEqual(queryItems.count, 1) + XCTAssertEqual(queryItems[0].name, "key") + XCTAssertEqual(queryItems[0].value, "value with spaces & symbols!") + } + + func testMultipleItemsWithReservedCharacters() { + let dict: [String: String] = [ + "path": "part/with/slashes", + "query": "value with spaces", + "fragment": "with#fragment" + ] + let allowedCharacters = CharacterSet.urlPathAllowed.union(.whitespaces).union(.punctuationCharacters) + let queryItems = dict.toURLQueryItems(allowedReservedCharacters: allowedCharacters) + + XCTAssertEqual(queryItems.count, 3) + let q0 = queryParam(withName: "path", from: queryItems) + XCTAssertEqual(q0.name, "path") + XCTAssertEqual(q0.value, "part/with/slashes") + + let q1 = queryParam(withName: "query", from: queryItems) + XCTAssertEqual(q1.name, "query") + XCTAssertEqual(q1.value, "value with spaces") + + let q2 = queryParam(withName: "fragment", from: queryItems) + XCTAssertEqual(q2.name, "fragment") + XCTAssertEqual(q2.value, "with#fragment") + } +} diff --git a/Tests/NetworkingTests/v2/Extensions/HTTPURLResponseCookiesTests.swift b/Tests/NetworkingTests/v2/Extensions/HTTPURLResponseCookiesTests.swift new file mode 100644 index 000000000..c3b3be9a2 --- /dev/null +++ b/Tests/NetworkingTests/v2/Extensions/HTTPURLResponseCookiesTests.swift @@ -0,0 +1,84 @@ +// +// HTTPURLResponseCookiesTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest + +final class HTTPURLResponseCookiesTests: XCTestCase { + + func getCookie(withName name: String, from cookies: [HTTPCookie]?) -> HTTPCookie? { + return cookies?.compactMap({ cookie in + if cookie.name == name { + return cookie + } else { + return nil + } + }).last + } + + func testCookiesRetrievesAllCookies() { + let url = URL(string: "https://example.com")! + let cookieHeader = "Set-Cookie" + let cookieValue1 = "name1=value1; Path=/; HttpOnly" + let cookieValue2 = "name2=value2; Path=/; Secure" + let headers = [cookieHeader: "\(cookieValue1), \(cookieValue2)"] + let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: headers) + + let cookies = response?.cookies + XCTAssertEqual(cookies?.count, 2) + + let c0 = getCookie(withName: "name1", from: cookies) + XCTAssertEqual(c0?.name, "name1") + XCTAssertEqual(c0?.value, "value1") + + let c1 = getCookie(withName: "name2", from: cookies) + XCTAssertEqual(c1?.name, "name2") + XCTAssertEqual(c1?.value, "value2") + } + + func testGetCookieWithNameReturnsCorrectCookie() { + let url = URL(string: "https://example.com")! + let cookieHeader = "Set-Cookie" + let cookieValue1 = "name1=value1; Path=/; HttpOnly" + let cookieValue2 = "name2=value2; Path=/; Secure" + let headers = [cookieHeader: "\(cookieValue1), \(cookieValue2)"] + let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: headers) + + let cookie = response?.getCookie(withName: "name2") + XCTAssertNotNil(cookie) + XCTAssertEqual(cookie?.name, "name2") + XCTAssertEqual(cookie?.value, "value2") + } + + func testGetCookieWithNameReturnsNilForNonExistentCookie() { + let url = URL(string: "https://example.com")! + let cookieHeader = "Set-Cookie" + let cookieValue1 = "name1=value1; Path=/; HttpOnly" + let headers = [cookieHeader: cookieValue1] + let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: headers) + + let cookie = response?.getCookie(withName: "nonexistent") + XCTAssertNil(cookie) + } + + func testCookiesReturnsNilWhenNoCookieHeaderFields() { + let url = URL(string: "https://example.com")! + let headers: [String: String] = [:] + let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: headers) + XCTAssertTrue(response!.cookies!.isEmpty) + } +} diff --git a/Tests/NetworkingTests/v2/Extensions/HTTPURLResponseETagTests.swift b/Tests/NetworkingTests/v2/Extensions/HTTPURLResponseETagTests.swift new file mode 100644 index 000000000..80f2a0483 --- /dev/null +++ b/Tests/NetworkingTests/v2/Extensions/HTTPURLResponseETagTests.swift @@ -0,0 +1,67 @@ +// +// HTTPURLResponseETagTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest + +final class HTTPURLResponseETagTests: XCTestCase { + + func testEtagReturnsStrongEtag() { + let url = URL(string: "https://example.com")! + let headers = ["Etag": "\"12345\""] + let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: headers) + + let etag = response?.etag + XCTAssertEqual(etag, "\"12345\"") + } + + func testEtagReturnsWeakEtagWithoutPrefix() { + let url = URL(string: "https://example.com")! + let headers = ["Etag": "W/\"12345\""] + let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: headers) + + let etag = response?.etag + XCTAssertEqual(etag, "\"12345\"") // Weak prefix "W/" should be dropped + } + + func testEtagRetainsWeakPrefixWhenDroppingWeakPrefixIsFalse() { + let url = URL(string: "https://example.com")! + let headers = ["Etag": "W/\"12345\""] + let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: headers) + + let etag = response?.etag(droppingWeakPrefix: false) + XCTAssertEqual(etag, "W/\"12345\"") // Weak prefix "W/" should be retained + } + + func testEtagReturnsNilWhenNoEtagHeaderPresent() { + let url = URL(string: "https://example.com")! + let headers: [String: String] = [:] + let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: headers) + + let etag = response?.etag + XCTAssertNil(etag) + } + + func testEtagReturnsEmptyStringForEmptyEtagHeader() { + let url = URL(string: "https://example.com")! + let headers = ["Etag": ""] + let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: headers) + + let etag = response?.etag + XCTAssertEqual(etag, "") + } +} diff --git a/Tests/NetworkingTests/v2/Extensions/URL+QueryParametersTests.swift b/Tests/NetworkingTests/v2/Extensions/URL+QueryParametersTests.swift new file mode 100644 index 000000000..96901ffb7 --- /dev/null +++ b/Tests/NetworkingTests/v2/Extensions/URL+QueryParametersTests.swift @@ -0,0 +1,95 @@ +// +// URL+QueryParametersTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest + +class URLExtensionTests: XCTestCase { + + func testQueryParametersWithValidURL() { + // Given + let url = URL(string: "https://example.com?param1=value1¶m2=value2")! + + // When + let parameters = url.queryParameters() + + // Then + XCTAssertNotNil(parameters) + XCTAssertEqual(parameters?["param1"], "value1") + XCTAssertEqual(parameters?["param2"], "value2") + } + + func testQueryParametersWithEmptyQuery() { + // Given + let url = URL(string: "https://example.com")! + + // When + let parameters = url.queryParameters() + + // Then + XCTAssertNil(parameters) + } + + func testQueryParametersWithNoValue() { + // Given + let url = URL(string: "https://example.com?param1=¶m2=value2")! + + // When + let parameters = url.queryParameters() + + // Then + XCTAssertNotNil(parameters) + XCTAssertEqual(parameters?["param1"], "") + XCTAssertEqual(parameters?["param2"], "value2") + } + + func testQueryParametersWithSpecialCharacters() { + // Given + let url = URL(string: "https://example.com?param1=value%201¶m2=value%202")! + + // When + let parameters = url.queryParameters() + + // Then + XCTAssertNotNil(parameters) + XCTAssertEqual(parameters?["param1"], "value 1") + XCTAssertEqual(parameters?["param2"], "value 2") + } + + func testQueryParametersWithMultipleSameKeys() { + // Given + let url = URL(string: "https://example.com?param=value1¶m=value2")! + + // When + let parameters = url.queryParameters() + + // Then + XCTAssertNotNil(parameters) + XCTAssertEqual(parameters?["param"], "value2") // Last value should overwrite the first + } + + func testQueryParametersWithInvalidURL() { + // Given + let url = URL(string: "invalid-url")! + + // When + let parameters = url.queryParameters() + + // Then + XCTAssertNil(parameters) + } +} diff --git a/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift b/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift index c6a490eae..0431f9c30 100644 --- a/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift +++ b/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift @@ -89,14 +89,13 @@ class DefaultRemoteMessagingSurveyURLBuilderTests: XCTestCase { daysSinceLastActive: vpnDaysSinceLastActive ) - let subscription = DDGSubscription(productId: "product-id", - name: "product-name", - billingPeriod: .monthly, - startedAt: Date(timeIntervalSince1970: 1000), - expiresOrRenewsAt: Date(timeIntervalSince1970: 2000), - platform: .apple, - status: .autoRenewable) - + let subscription = PrivacyProSubscription(productId: "product-id", + name: "product-name", + billingPeriod: .monthly, + startedAt: Date(timeIntervalSince1970: 1000), + expiresOrRenewsAt: Date(timeIntervalSince1970: 2000), + platform: .apple, + status: .autoRenewable) return DefaultRemoteMessagingSurveyURLBuilder( statisticsStore: mockStatisticsStore, vpnActivationDateStore: vpnActivationDateStore, diff --git a/Tests/SubscriptionTests/API/AuthEndpointServiceTests.swift b/Tests/SubscriptionTests/API/AuthEndpointServiceTests.swift deleted file mode 100644 index 3fdae38a4..000000000 --- a/Tests/SubscriptionTests/API/AuthEndpointServiceTests.swift +++ /dev/null @@ -1,319 +0,0 @@ -// -// AuthEndpointServiceTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import XCTest -@testable import Subscription -import SubscriptionTestingUtilities - -final class AuthEndpointServiceTests: XCTestCase { - - private struct Constants { - static let authToken = UUID().uuidString - static let accessToken = UUID().uuidString - static let externalID = UUID().uuidString - static let email = "dax@duck.com" - - static let mostRecentTransactionJWS = "dGhpcyBpcyBub3QgYSByZWFsIEFw(...)cCBTdG9yZSB0cmFuc2FjdGlvbiBKV1M=" - - static let authorizationHeader = ["Authorization": "Bearer TOKEN"] - - static let invalidTokenError = APIServiceError.serverError(statusCode: 401, error: "invalid_token") - } - - var apiService: APIServiceMock! - var authService: AuthEndpointService! - - override func setUpWithError() throws { - apiService = APIServiceMock() - authService = DefaultAuthEndpointService(currentServiceEnvironment: .staging, apiService: apiService) - } - - override func tearDownWithError() throws { - apiService = nil - authService = nil - } - - // MARK: - Tests for getAccessToken - - func testGetAccessTokenCall() async throws { - // Given - let apiServiceCalledExpectation = expectation(description: "apiService") - - apiService.mockAuthHeaders = Constants.authorizationHeader - apiService.onExecuteAPICall = { parameters in - let (method, endpoint, headers, _) = parameters - - apiServiceCalledExpectation.fulfill() - XCTAssertEqual(method, "GET") - XCTAssertEqual(endpoint, "access-token") - XCTAssertEqual(headers, Constants.authorizationHeader) - } - - // When - _ = await authService.getAccessToken(token: Constants.authToken) - - // Then - await fulfillment(of: [apiServiceCalledExpectation], timeout: 0.1) - } - - func testGetAccessTokenSuccess() async throws { - // Given - apiService.mockAuthHeaders = Constants.authorizationHeader - apiService.mockResponseJSONData = """ - { - "accessToken": "\(Constants.accessToken)", - } - """.data(using: .utf8)! - - // When - let result = await authService.getAccessToken(token: Constants.authToken) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success.accessToken, Constants.accessToken) - case .failure: - XCTFail("Unexpected failure") - } - } - - func testGetAccessTokenError() async throws { - // Given - apiService.mockAuthHeaders = Constants.authorizationHeader - apiService.mockAPICallError = Constants.invalidTokenError - - // When - let result = await authService.getAccessToken(token: Constants.authToken) - - // Then - switch result { - case .success: - XCTFail("Unexpected success") - case .failure: - break - } - } - - // MARK: - Tests for validateToken - - func testValidateTokenCall() async throws { - // Given - let apiServiceCalledExpectation = expectation(description: "apiService") - - apiService.mockAuthHeaders = Constants.authorizationHeader - apiService.onExecuteAPICall = { parameters in - let (method, endpoint, headers, _) = parameters - - apiServiceCalledExpectation.fulfill() - XCTAssertEqual(method, "GET") - XCTAssertEqual(endpoint, "validate-token") - XCTAssertEqual(headers, Constants.authorizationHeader) - } - - // When - _ = await authService.validateToken(accessToken: Constants.accessToken) - - // Then - await fulfillment(of: [apiServiceCalledExpectation], timeout: 0.1) - } - - func testValidateTokenSuccess() async throws { - // Given - apiService.mockAuthHeaders = Constants.authorizationHeader - apiService.mockResponseJSONData = """ - { - "account": { - "id": 149718, - "external_id": "\(Constants.externalID)", - "email": "\(Constants.email)", - "entitlements": [ - {"id":24, "name":"subscriber", "product":"Network Protection"}, - {"id":25, "name":"subscriber", "product":"Data Broker Protection"}, - {"id":26, "name":"subscriber", "product":"Identity Theft Restoration"} - ] - } - } - """.data(using: .utf8)! - - // When - let result = await authService.validateToken(accessToken: Constants.accessToken) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success.account.externalID, Constants.externalID) - XCTAssertEqual(success.account.email, Constants.email) - XCTAssertEqual(success.account.entitlements.count, 3) - case .failure: - XCTFail("Unexpected failure") - } - } - - func testValidateTokenError() async throws { - // Given - apiService.mockAuthHeaders = Constants.authorizationHeader - apiService.mockAPICallError = Constants.invalidTokenError - - // When - let result = await authService.validateToken(accessToken: Constants.accessToken) - - // Then - switch result { - case .success: - XCTFail("Unexpected success") - case .failure: - break - } - } - - // MARK: - Tests for createAccount - - func testCreateAccountCall() async throws { - // Given - let apiServiceCalledExpectation = expectation(description: "apiService") - - apiService.onExecuteAPICall = { parameters in - let (method, endpoint, headers, _) = parameters - - apiServiceCalledExpectation.fulfill() - XCTAssertEqual(method, "POST") - XCTAssertEqual(endpoint, "account/create") - XCTAssertNil(headers) - } - - // When - _ = await authService.createAccount(emailAccessToken: nil) - - // Then - await fulfillment(of: [apiServiceCalledExpectation], timeout: 0.1) - } - - func testCreateAccountSuccess() async throws { - // Given - apiService.mockResponseJSONData = """ - { - "auth_token": "\(Constants.authToken)", - "external_id": "\(Constants.externalID)", - "status": "created" - } - """.data(using: .utf8)! - - // When - let result = await authService.createAccount(emailAccessToken: nil) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success.authToken, Constants.authToken) - XCTAssertEqual(success.externalID, Constants.externalID) - XCTAssertEqual(success.status, "created") - case .failure: - XCTFail("Unexpected failure") - } - } - - func testCreateAccountError() async throws { - // Given - apiService.mockAuthHeaders = Constants.authorizationHeader - apiService.mockAPICallError = Constants.invalidTokenError - - // When - let result = await authService.createAccount(emailAccessToken: nil) - - // Then - switch result { - case .success: - XCTFail("Unexpected success") - case .failure: - break - } - } - - // MARK: - Tests for storeLogin - - func testStoreLoginCall() async throws { - // Given - let apiServiceCalledExpectation = expectation(description: "apiService") - - apiService.onExecuteAPICall = { parameters in - let (method, endpoint, headers, body) = parameters - - apiServiceCalledExpectation.fulfill() - XCTAssertEqual(method, "POST") - XCTAssertEqual(endpoint, "store-login") - XCTAssertNil(headers) - - if let bodyDict = try? JSONDecoder().decode([String: String].self, from: body!) { - XCTAssertEqual(bodyDict["signature"], Constants.mostRecentTransactionJWS) - XCTAssertEqual(bodyDict["store"], "apple_app_store") - } else { - XCTFail("Failed to decode body") - } - } - - // When - _ = await authService.storeLogin(signature: Constants.mostRecentTransactionJWS) - - // Then - await fulfillment(of: [apiServiceCalledExpectation], timeout: 0.1) - } - - func testStoreLoginSuccess() async throws { - // Given - apiService.mockResponseJSONData = """ - { - "auth_token": "\(Constants.authToken)", - "email": "\(Constants.email)", - "external_id": "\(Constants.externalID)", - "id": 1, - "status": "ok" - } - """.data(using: .utf8)! - - // When - let result = await authService.storeLogin(signature: Constants.mostRecentTransactionJWS) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success.authToken, Constants.authToken) - XCTAssertEqual(success.email, Constants.email) - XCTAssertEqual(success.externalID, Constants.externalID) - XCTAssertEqual(success.id, 1) - XCTAssertEqual(success.status, "ok") - case .failure: - XCTFail("Unexpected failure") - } - } - - func testStoreLoginError() async throws { - // Given - apiService.mockAPICallError = Constants.invalidTokenError - - // When - let result = await authService.storeLogin(signature: Constants.mostRecentTransactionJWS) - - // Then - switch result { - case .success: - XCTFail("Unexpected success") - case .failure: - break - } - } -} diff --git a/Tests/SubscriptionTests/API/Models/EntitlementTests.swift b/Tests/SubscriptionTests/API/Models/EntitlementTests.swift deleted file mode 100644 index 25409abce..000000000 --- a/Tests/SubscriptionTests/API/Models/EntitlementTests.swift +++ /dev/null @@ -1,47 +0,0 @@ -// -// EntitlementTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import XCTest -@testable import Subscription -import SubscriptionTestingUtilities - -final class EntitlementTests: XCTestCase { - - func testEquality() throws { - XCTAssertEqual(Entitlement(product: .dataBrokerProtection), Entitlement(product: .dataBrokerProtection)) - XCTAssertNotEqual(Entitlement(product: .dataBrokerProtection), Entitlement(product: .networkProtection)) - } - - func testDecoding() throws { - let rawNetPEntitlement = "{\"id\":24,\"name\":\"subscriber\",\"product\":\"Network Protection\"}" - let netPEntitlement = try JSONDecoder().decode(Entitlement.self, from: Data(rawNetPEntitlement.utf8)) - XCTAssertEqual(netPEntitlement, Entitlement(product: .networkProtection)) - - let rawDBPEntitlement = "{\"id\":25,\"name\":\"subscriber\",\"product\":\"Data Broker Protection\"}" - let dbpEntitlement = try JSONDecoder().decode(Entitlement.self, from: Data(rawDBPEntitlement.utf8)) - XCTAssertEqual(dbpEntitlement, Entitlement(product: .dataBrokerProtection)) - - let rawITREntitlement = "{\"id\":26,\"name\":\"subscriber\",\"product\":\"Identity Theft Restoration\"}" - let itrEntitlement = try JSONDecoder().decode(Entitlement.self, from: Data(rawITREntitlement.utf8)) - XCTAssertEqual(itrEntitlement, Entitlement(product: .identityTheftRestoration)) - - let rawUnexpectedEntitlement = "{\"id\":27,\"name\":\"subscriber\",\"product\":\"something unexpected\"}" - let unexpectedEntitlement = try JSONDecoder().decode(Entitlement.self, from: Data(rawUnexpectedEntitlement.utf8)) - XCTAssertEqual(unexpectedEntitlement, Entitlement(product: .unknown)) - } -} diff --git a/Tests/SubscriptionTests/API/Models/SubscriptionEntitlementTests.swift b/Tests/SubscriptionTests/API/Models/SubscriptionEntitlementTests.swift new file mode 100644 index 000000000..0903c84a3 --- /dev/null +++ b/Tests/SubscriptionTests/API/Models/SubscriptionEntitlementTests.swift @@ -0,0 +1,48 @@ +// +// SubscriptionEntitlementTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +@testable import Subscription +@testable import Networking +import SubscriptionTestingUtilities + +final class SubscriptionEntitlementTests: XCTestCase { + + func testEquality() throws { + XCTAssertEqual(SubscriptionEntitlement.dataBrokerProtection, SubscriptionEntitlement.dataBrokerProtection) + XCTAssertNotEqual(SubscriptionEntitlement.dataBrokerProtection, SubscriptionEntitlement.networkProtection) + } + + func testDecoding() throws { + let rawNetPEntitlement = "Network Protection" + let netPEntitlement = SubscriptionEntitlement(rawValue: rawNetPEntitlement) + XCTAssertEqual(netPEntitlement, SubscriptionEntitlement.networkProtection) + + let rawDBPEntitlement = "Data Broker Protection" + let dbpEntitlement = SubscriptionEntitlement(rawValue: rawDBPEntitlement) + XCTAssertEqual(dbpEntitlement, SubscriptionEntitlement.dataBrokerProtection) + + let rawITREntitlement = "Identity Theft Restoration" + let itrEntitlement = SubscriptionEntitlement(rawValue: rawITREntitlement) + XCTAssertEqual(itrEntitlement, SubscriptionEntitlement.identityTheftRestoration) + + let rawUnexpectedEntitlement = "something unexpected" + let unexpectedEntitlement = SubscriptionEntitlement(rawValue: rawUnexpectedEntitlement) + XCTAssertNil(unexpectedEntitlement) + } +} diff --git a/Tests/SubscriptionTests/API/Models/SubscriptionTests.swift b/Tests/SubscriptionTests/API/Models/SubscriptionTests.swift index 59106b277..fc2d2e874 100644 --- a/Tests/SubscriptionTests/API/Models/SubscriptionTests.swift +++ b/Tests/SubscriptionTests/API/Models/SubscriptionTests.swift @@ -23,48 +23,48 @@ import SubscriptionTestingUtilities final class SubscriptionTests: XCTestCase { func testEquality() throws { - let a = DDGSubscription(productId: "1", - name: "a", - billingPeriod: .monthly, - startedAt: Date(timeIntervalSince1970: 1000), - expiresOrRenewsAt: Date(timeIntervalSince1970: 2000), - platform: .apple, - status: .autoRenewable) - let b = DDGSubscription(productId: "1", - name: "a", - billingPeriod: .monthly, - startedAt: Date(timeIntervalSince1970: 1000), - expiresOrRenewsAt: Date(timeIntervalSince1970: 2000), - platform: .apple, - status: .autoRenewable) - let c = DDGSubscription(productId: "2", - name: "a", - billingPeriod: .monthly, - startedAt: Date(timeIntervalSince1970: 1000), - expiresOrRenewsAt: Date(timeIntervalSince1970: 2000), - platform: .apple, - status: .autoRenewable) + let a = PrivacyProSubscription(productId: "1", + name: "a", + billingPeriod: .monthly, + startedAt: Date(timeIntervalSince1970: 1000), + expiresOrRenewsAt: Date(timeIntervalSince1970: 2000), + platform: .apple, + status: .autoRenewable) + let b = PrivacyProSubscription(productId: "1", + name: "a", + billingPeriod: .monthly, + startedAt: Date(timeIntervalSince1970: 1000), + expiresOrRenewsAt: Date(timeIntervalSince1970: 2000), + platform: .apple, + status: .autoRenewable) + let c = PrivacyProSubscription(productId: "2", + name: "a", + billingPeriod: .monthly, + startedAt: Date(timeIntervalSince1970: 1000), + expiresOrRenewsAt: Date(timeIntervalSince1970: 2000), + platform: .apple, + status: .autoRenewable) XCTAssertEqual(a, b) XCTAssertNotEqual(a, c) } func testIfSubscriptionWithGivenStatusIsActive() throws { - let autoRenewableSubscription = Subscription.make(withStatus: .autoRenewable) + let autoRenewableSubscription = PrivacyProSubscription.make(withStatus: .autoRenewable) XCTAssertTrue(autoRenewableSubscription.isActive) - let notAutoRenewableSubscription = Subscription.make(withStatus: .notAutoRenewable) + let notAutoRenewableSubscription = PrivacyProSubscription.make(withStatus: .notAutoRenewable) XCTAssertTrue(notAutoRenewableSubscription.isActive) - let gracePeriodSubscription = Subscription.make(withStatus: .gracePeriod) + let gracePeriodSubscription = PrivacyProSubscription.make(withStatus: .gracePeriod) XCTAssertTrue(gracePeriodSubscription.isActive) - let inactiveSubscription = Subscription.make(withStatus: .inactive) + let inactiveSubscription = PrivacyProSubscription.make(withStatus: .inactive) XCTAssertFalse(inactiveSubscription.isActive) - let expiredSubscription = Subscription.make(withStatus: .expired) + let expiredSubscription = PrivacyProSubscription.make(withStatus: .expired) XCTAssertFalse(expiredSubscription.isActive) - let unknownSubscription = Subscription.make(withStatus: .unknown) + let unknownSubscription = PrivacyProSubscription.make(withStatus: .unknown) XCTAssertTrue(unknownSubscription.isActive) } @@ -74,7 +74,7 @@ final class SubscriptionTests: XCTestCase { let decoder = JSONDecoder() decoder.keyDecodingStrategy = .convertFromSnakeCase decoder.dateDecodingStrategy = .millisecondsSince1970 - let subscription = try decoder.decode(Subscription.self, from: Data(rawSubscription.utf8)) + let subscription = try decoder.decode(PrivacyProSubscription.self, from: Data(rawSubscription.utf8)) XCTAssertEqual(subscription.productId, "ddg-privacy-pro-sandbox-monthly-renews-us") XCTAssertEqual(subscription.name, "Monthly Subscription") @@ -85,60 +85,60 @@ final class SubscriptionTests: XCTestCase { } func testBillingPeriodDecoding() throws { - let monthly = try JSONDecoder().decode(Subscription.BillingPeriod.self, from: Data("\"Monthly\"".utf8)) - XCTAssertEqual(monthly, Subscription.BillingPeriod.monthly) + let monthly = try JSONDecoder().decode(PrivacyProSubscription.BillingPeriod.self, from: Data("\"Monthly\"".utf8)) + XCTAssertEqual(monthly, PrivacyProSubscription.BillingPeriod.monthly) - let yearly = try JSONDecoder().decode(Subscription.BillingPeriod.self, from: Data("\"Yearly\"".utf8)) - XCTAssertEqual(yearly, Subscription.BillingPeriod.yearly) + let yearly = try JSONDecoder().decode(PrivacyProSubscription.BillingPeriod.self, from: Data("\"Yearly\"".utf8)) + XCTAssertEqual(yearly, PrivacyProSubscription.BillingPeriod.yearly) - let unknown = try JSONDecoder().decode(Subscription.BillingPeriod.self, from: Data("\"something unexpected\"".utf8)) - XCTAssertEqual(unknown, Subscription.BillingPeriod.unknown) + let unknown = try JSONDecoder().decode(PrivacyProSubscription.BillingPeriod.self, from: Data("\"something unexpected\"".utf8)) + XCTAssertEqual(unknown, PrivacyProSubscription.BillingPeriod.unknown) } func testPlatformDecoding() throws { - let apple = try JSONDecoder().decode(Subscription.Platform.self, from: Data("\"apple\"".utf8)) - XCTAssertEqual(apple, Subscription.Platform.apple) + let apple = try JSONDecoder().decode(PrivacyProSubscription.Platform.self, from: Data("\"apple\"".utf8)) + XCTAssertEqual(apple, PrivacyProSubscription.Platform.apple) - let google = try JSONDecoder().decode(Subscription.Platform.self, from: Data("\"google\"".utf8)) - XCTAssertEqual(google, Subscription.Platform.google) + let google = try JSONDecoder().decode(PrivacyProSubscription.Platform.self, from: Data("\"google\"".utf8)) + XCTAssertEqual(google, PrivacyProSubscription.Platform.google) - let stripe = try JSONDecoder().decode(Subscription.Platform.self, from: Data("\"stripe\"".utf8)) - XCTAssertEqual(stripe, Subscription.Platform.stripe) + let stripe = try JSONDecoder().decode(PrivacyProSubscription.Platform.self, from: Data("\"stripe\"".utf8)) + XCTAssertEqual(stripe, PrivacyProSubscription.Platform.stripe) - let unknown = try JSONDecoder().decode(Subscription.Platform.self, from: Data("\"something unexpected\"".utf8)) - XCTAssertEqual(unknown, Subscription.Platform.unknown) + let unknown = try JSONDecoder().decode(PrivacyProSubscription.Platform.self, from: Data("\"something unexpected\"".utf8)) + XCTAssertEqual(unknown, PrivacyProSubscription.Platform.unknown) } func testStatusDecoding() throws { - let autoRenewable = try JSONDecoder().decode(Subscription.Status.self, from: Data("\"Auto-Renewable\"".utf8)) - XCTAssertEqual(autoRenewable, Subscription.Status.autoRenewable) + let autoRenewable = try JSONDecoder().decode(PrivacyProSubscription.Status.self, from: Data("\"Auto-Renewable\"".utf8)) + XCTAssertEqual(autoRenewable, PrivacyProSubscription.Status.autoRenewable) - let notAutoRenewable = try JSONDecoder().decode(Subscription.Status.self, from: Data("\"Not Auto-Renewable\"".utf8)) - XCTAssertEqual(notAutoRenewable, Subscription.Status.notAutoRenewable) + let notAutoRenewable = try JSONDecoder().decode(PrivacyProSubscription.Status.self, from: Data("\"Not Auto-Renewable\"".utf8)) + XCTAssertEqual(notAutoRenewable, PrivacyProSubscription.Status.notAutoRenewable) - let gracePeriod = try JSONDecoder().decode(Subscription.Status.self, from: Data("\"Grace Period\"".utf8)) - XCTAssertEqual(gracePeriod, Subscription.Status.gracePeriod) + let gracePeriod = try JSONDecoder().decode(PrivacyProSubscription.Status.self, from: Data("\"Grace Period\"".utf8)) + XCTAssertEqual(gracePeriod, PrivacyProSubscription.Status.gracePeriod) - let inactive = try JSONDecoder().decode(Subscription.Status.self, from: Data("\"Inactive\"".utf8)) - XCTAssertEqual(inactive, Subscription.Status.inactive) + let inactive = try JSONDecoder().decode(PrivacyProSubscription.Status.self, from: Data("\"Inactive\"".utf8)) + XCTAssertEqual(inactive, PrivacyProSubscription.Status.inactive) - let expired = try JSONDecoder().decode(Subscription.Status.self, from: Data("\"Expired\"".utf8)) - XCTAssertEqual(expired, Subscription.Status.expired) + let expired = try JSONDecoder().decode(PrivacyProSubscription.Status.self, from: Data("\"Expired\"".utf8)) + XCTAssertEqual(expired, PrivacyProSubscription.Status.expired) - let unknown = try JSONDecoder().decode(Subscription.Status.self, from: Data("\"something unexpected\"".utf8)) - XCTAssertEqual(unknown, Subscription.Status.unknown) + let unknown = try JSONDecoder().decode(PrivacyProSubscription.Status.self, from: Data("\"something unexpected\"".utf8)) + XCTAssertEqual(unknown, PrivacyProSubscription.Status.unknown) } } -extension Subscription { +extension PrivacyProSubscription { - static func make(withStatus status: Subscription.Status) -> Subscription { - Subscription(productId: UUID().uuidString, - name: "Subscription test #1", - billingPeriod: .monthly, - startedAt: Date(), - expiresOrRenewsAt: Date().addingTimeInterval(TimeInterval.days(+30)), - platform: .apple, - status: status) + static func make(withStatus status: PrivacyProSubscription.Status) -> PrivacyProSubscription { + PrivacyProSubscription(productId: UUID().uuidString, + name: "Subscription test #1", + billingPeriod: .monthly, + startedAt: Date(), + expiresOrRenewsAt: Date().addingTimeInterval(TimeInterval.days(+30)), + platform: .apple, + status: status) } } diff --git a/Tests/SubscriptionTests/API/SubscriptionEndpointServiceTests.swift b/Tests/SubscriptionTests/API/SubscriptionEndpointServiceTests.swift index 1d459e06d..0decc313e 100644 --- a/Tests/SubscriptionTests/API/SubscriptionEndpointServiceTests.swift +++ b/Tests/SubscriptionTests/API/SubscriptionEndpointServiceTests.swift @@ -18,15 +18,249 @@ import XCTest @testable import Subscription +@testable import Networking import SubscriptionTestingUtilities +import TestUtils +import Common +final class SubscriptionEndpointServiceTests: XCTestCase { + private var apiService: MockAPIService! + private var endpointService: DefaultSubscriptionEndpointService! + private let baseURL = SubscriptionEnvironment.ServiceEnvironment.staging.url + private let disposableCache = UserDefaultsCache(key: UserDefaultsCacheKeyKest.subscriptionTest, + settings: UserDefaultsCacheSettings(defaultExpirationInterval: .minutes(20))) + private enum UserDefaultsCacheKeyKest: String, UserDefaultsCacheKeyStore { + case subscriptionTest = "com.duckduckgo.bsk.subscription.info.testing" + } + private var encoder: JSONEncoder! + + override func setUp() { + super.setUp() + encoder = JSONEncoder() + encoder.dateEncodingStrategy = .millisecondsSince1970 + apiService = MockAPIService() + endpointService = DefaultSubscriptionEndpointService(apiService: apiService, + baseURL: baseURL, + subscriptionCache: disposableCache) + } + + override func tearDown() { + disposableCache.reset() + apiService = nil + endpointService = nil + super.tearDown() + } + + // MARK: - Helpers + + private func createSubscriptionResponseData() -> Data { + let date = Date(timeIntervalSince1970: 123456789) + let subscription = PrivacyProSubscription( + productId: "prod123", + name: "Pro Plan", + billingPeriod: .yearly, + startedAt: date, + expiresOrRenewsAt: date.addingTimeInterval(30 * 24 * 60 * 60), + platform: .apple, + status: .autoRenewable + ) + return try! encoder.encode(subscription) + } + + private func createAPIResponse(statusCode: Int, data: Data?) -> APIResponseV2 { + let response = HTTPURLResponse( + url: baseURL, + statusCode: statusCode, + httpVersion: nil, + headerFields: nil + )! + return APIResponseV2(data: data, httpResponse: response) + } + + // MARK: - getSubscription Tests + + func testGetSubscriptionReturnsCachedSubscription() async throws { + let date = Date(timeIntervalSince1970: 123456789) + let cachedSubscription = PrivacyProSubscription( + productId: "prod123", + name: "Pro Plan", + billingPeriod: .monthly, + startedAt: date, + expiresOrRenewsAt: date.addingTimeInterval(30 * 24 * 60 * 60), + platform: .google, + status: .autoRenewable + ) + endpointService.updateCache(with: cachedSubscription) + + let subscription = try await endpointService.getSubscription(accessToken: "token", cachePolicy: .returnCacheDataDontLoad) + XCTAssertEqual(subscription, cachedSubscription) + } + + func testGetSubscriptionFetchesRemoteSubscriptionWhenNoCache() async throws { + // mock subscription response + let subscriptionData = createSubscriptionResponseData() + let apiResponse = createAPIResponse(statusCode: 200, data: subscriptionData) + let request = SubscriptionRequest.getSubscription(baseURL: baseURL, accessToken: "token")!.apiRequest + + // mock features + APIMockResponseFactory.mockGetFeatures(destinationMockAPIService: apiService, success: true, subscriptionID: "prod123") + + apiService.set(response: apiResponse, forRequest: request) + + let subscription = try await endpointService.getSubscription(accessToken: "token", cachePolicy: .returnCacheDataElseLoad) + XCTAssertEqual(subscription.productId, "prod123") + XCTAssertEqual(subscription.name, "Pro Plan") + XCTAssertEqual(subscription.billingPeriod, .yearly) + XCTAssertEqual(subscription.platform, .apple) + XCTAssertEqual(subscription.status, .autoRenewable) + } + + func testGetSubscriptionThrowsNoDataWhenNoCacheAndFetchFails() async { + do { + _ = try await endpointService.getSubscription(accessToken: "token", cachePolicy: .returnCacheDataDontLoad) + XCTFail("Expected noData error") + } catch SubscriptionEndpointServiceError.noData { + // Success + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + // MARK: - getProducts Tests + + func testGetProductsReturnsListOfProducts() async throws { + let productItems = [ + GetProductsItem( + productId: "prod1", + productLabel: "Product 1", + billingPeriod: "Monthly", + price: "9.99", + currency: "USD" + ), + GetProductsItem( + productId: "prod2", + productLabel: "Product 2", + billingPeriod: "Yearly", + price: "99.99", + currency: "USD" + ) + ] + let productData = try encoder.encode(productItems) + let apiResponse = createAPIResponse(statusCode: 200, data: productData) + let request = SubscriptionRequest.getProducts(baseURL: baseURL)!.apiRequest + + apiService.set(response: apiResponse, forRequest: request) + + let products = try await endpointService.getProducts() + XCTAssertEqual(products, productItems) + } + + func testGetProductsThrowsInvalidResponse() async { + let request = SubscriptionRequest.getProducts(baseURL: baseURL)!.apiRequest + let apiResponse = createAPIResponse(statusCode: 200, data: nil) + apiService.set(response: apiResponse, forRequest: request) + do { + _ = try await endpointService.getProducts() + XCTFail("Expected invalidResponse error") + } catch Networking.APIRequestV2.Error.emptyResponseBody { + // Success + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + // MARK: - getCustomerPortalURL Tests + + func testGetCustomerPortalURLReturnsCorrectURL() async throws { + let portalResponse = GetCustomerPortalURLResponse(customerPortalUrl: "https://portal.example.com") + let portalData = try encoder.encode(portalResponse) + let apiResponse = createAPIResponse(statusCode: 200, data: portalData) + let request = SubscriptionRequest.getCustomerPortalURL(baseURL: baseURL, accessToken: "token", externalID: "id")!.apiRequest + + apiService.set(response: apiResponse, forRequest: request) + + let customerPortalURL = try await endpointService.getCustomerPortalURL(accessToken: "token", externalID: "id") + XCTAssertEqual(customerPortalURL, portalResponse) + } + + // MARK: - confirmPurchase Tests + + func testConfirmPurchaseReturnsCorrectResponse() async throws { + let date = Date(timeIntervalSince1970: 123456789) + let confirmResponse = ConfirmPurchaseResponse( + email: "user@example.com", + subscription: PrivacyProSubscription( + productId: "prod123", + name: "Pro Plan", + billingPeriod: .monthly, + startedAt: date, + expiresOrRenewsAt: date.addingTimeInterval(30 * 24 * 60 * 60), + platform: .stripe, + status: .gracePeriod + ) + ) + let confirmData = try encoder.encode(confirmResponse) + let apiResponse = createAPIResponse(statusCode: 200, data: confirmData) + let request = SubscriptionRequest.confirmPurchase(baseURL: baseURL, accessToken: "token", signature: "signature")!.apiRequest + + apiService.set(response: apiResponse, forRequest: request) + + let purchaseResponse = try await endpointService.confirmPurchase(accessToken: "token", signature: "signature") + XCTAssertEqual(purchaseResponse.email, confirmResponse.email) + XCTAssertEqual(purchaseResponse.subscription, confirmResponse.subscription) + } + + // MARK: - Cache Tests + + func testUpdateCacheStoresSubscription() async throws { + let date = Date(timeIntervalSince1970: 123456789) + let subscription = PrivacyProSubscription( + productId: "prod123", + name: "Pro Plan", + billingPeriod: .monthly, + startedAt: date, + expiresOrRenewsAt: date.addingTimeInterval(30 * 24 * 60 * 60), + platform: .google, + status: .autoRenewable + ) + endpointService.updateCache(with: subscription) + + let cachedSubscription = try await endpointService.getSubscription(accessToken: "token", cachePolicy: .returnCacheDataDontLoad) + XCTAssertEqual(cachedSubscription, subscription) + } + + func testClearSubscriptionRemovesCachedSubscription() async throws { + let date = Date(timeIntervalSince1970: 123456789) + let subscription = PrivacyProSubscription( + productId: "prod123", + name: "Pro Plan", + billingPeriod: .monthly, + startedAt: date, + expiresOrRenewsAt: date.addingTimeInterval(30 * 24 * 60 * 60), + platform: .apple, + status: .autoRenewable + ) + endpointService.updateCache(with: subscription) + + endpointService.clearSubscription() + do { + _ = try await endpointService.getSubscription(accessToken: "token", cachePolicy: .returnCacheDataDontLoad) + } catch SubscriptionEndpointServiceError.noData { + // Success + } catch { + XCTFail("Wrong error: \(error)") + } + } +} + +/* final class SubscriptionEndpointServiceTests: XCTestCase { private struct Constants { - static let authToken = UUID().uuidString - static let accessToken = UUID().uuidString - static let externalID = UUID().uuidString - static let email = "dax@duck.com" +// static let tokenContainer = OAuthTokensFactory.makeValidTokenContainer() +// static let accessToken = UUID().uuidString +// static let externalID = UUID().uuidString +// static let email = "dax@duck.com" static let mostRecentTransactionJWS = "dGhpcyBpcyBub3QgYSByZWFsIEFw(...)cCBTdG9yZSB0cmFuc2FjdGlvbiBKV1M=" @@ -36,15 +270,15 @@ final class SubscriptionEndpointServiceTests: XCTestCase { static let authorizationHeader = ["Authorization": "Bearer TOKEN"] - static let unknownServerError = APIServiceError.serverError(statusCode: 401, error: "unknown_error") +// static let unknownServerError = APIServiceError.serverError(statusCode: 401, error: "unknown_error") } - var apiService: APIServiceMock! + var apiService: MockAPIService! var subscriptionService: SubscriptionEndpointService! override func setUpWithError() throws { - apiService = APIServiceMock() - subscriptionService = DefaultSubscriptionEndpointService(currentServiceEnvironment: .staging, apiService: apiService) + apiService = MockAPIService() + subscriptionService = DefaultSubscriptionEndpointService(apiService: apiService, baseURL: URL(string: "https://something_tests.com")!) } override func tearDownWithError() throws { @@ -362,3 +596,4 @@ final class SubscriptionEndpointServiceTests: XCTestCase { } } } +*/ diff --git a/Tests/SubscriptionTests/Flows/AppStoreAccountManagementFlowTests.swift b/Tests/SubscriptionTests/Flows/AppStoreAccountManagementFlowTests.swift deleted file mode 100644 index e2c7f95c7..000000000 --- a/Tests/SubscriptionTests/Flows/AppStoreAccountManagementFlowTests.swift +++ /dev/null @@ -1,184 +0,0 @@ -// -// AppStoreAccountManagementFlowTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import XCTest -@testable import Subscription -import SubscriptionTestingUtilities - -final class AppStoreAccountManagementFlowTests: XCTestCase { - - private struct Constants { - static let oldAuthToken = UUID().uuidString - static let newAuthToken = UUID().uuidString - - static let externalID = UUID ().uuidString - static let otherExternalID = UUID().uuidString - - static let email = "dax@duck.com" - - static let mostRecentTransactionJWS = "dGhpcyBpcyBub3QgYSByZWFsIEFw(...)cCBTdG9yZSB0cmFuc2FjdGlvbiBKV1M=" - - static let invalidTokenError = APIServiceError.serverError(statusCode: 401, error: "invalid_token") - - static let entitlements = [Entitlement(product: .dataBrokerProtection), - Entitlement(product: .identityTheftRestoration), - Entitlement(product: .networkProtection)] - } - - var accountManager: AccountManagerMock! - var authEndpointService: AuthEndpointServiceMock! - var storePurchaseManager: StorePurchaseManagerMock! - - var appStoreAccountManagementFlow: AppStoreAccountManagementFlow! - - override func setUpWithError() throws { - accountManager = AccountManagerMock() - authEndpointService = AuthEndpointServiceMock() - storePurchaseManager = StorePurchaseManagerMock() - - appStoreAccountManagementFlow = DefaultAppStoreAccountManagementFlow(authEndpointService: authEndpointService, - storePurchaseManager: storePurchaseManager, - accountManager: accountManager) - } - - override func tearDownWithError() throws { - accountManager = nil - authEndpointService = nil - storePurchaseManager = nil - - appStoreAccountManagementFlow = nil - } - - // MARK: - Tests for refreshAuthTokenIfNeeded - - func testRefreshAuthTokenIfNeededSuccess() async throws { - // Given - accountManager.authToken = Constants.oldAuthToken - accountManager.externalID = Constants.externalID - - authEndpointService.validateTokenResult = .failure(Constants.invalidTokenError) - - storePurchaseManager.mostRecentTransactionResult = Constants.mostRecentTransactionJWS - - authEndpointService.storeLoginResult = .success(StoreLoginResponse(authToken: Constants.newAuthToken, - email: "", - externalID: Constants.externalID, - id: 1, - status: "authenticated")) - - // When - switch await appStoreAccountManagementFlow.refreshAuthTokenIfNeeded() { - case .success(let success): - // Then - XCTAssertTrue(storePurchaseManager.mostRecentTransactionCalled) - XCTAssertEqual(success, Constants.newAuthToken) - XCTAssertEqual(accountManager.authToken, Constants.newAuthToken) - case .failure(let error): - XCTFail("Unexpected failure: \(String(reflecting: error))") - } - } - - func testRefreshAuthTokenIfNeededSuccessButNotRefreshedIfStillValid() async throws { - // Given - accountManager.authToken = Constants.oldAuthToken - - authEndpointService.validateTokenResult = .success(ValidateTokenResponse(account: .init(email: Constants.email, - entitlements: Constants.entitlements, - externalID: Constants.externalID))) - - // When - switch await appStoreAccountManagementFlow.refreshAuthTokenIfNeeded() { - case .success(let success): - // Then - XCTAssertEqual(success, Constants.oldAuthToken) - XCTAssertEqual(accountManager.authToken, Constants.oldAuthToken) - case .failure(let error): - XCTFail("Unexpected failure: \(String(reflecting: error))") - } - } - - func testRefreshAuthTokenIfNeededSuccessButNotRefreshedIfStoreLoginRetrievedDifferentAccount() async throws { - // Given - accountManager.authToken = Constants.oldAuthToken - accountManager.externalID = Constants.externalID - accountManager.email = Constants.email - - authEndpointService.validateTokenResult = .failure(Constants.invalidTokenError) - - storePurchaseManager.mostRecentTransactionResult = Constants.mostRecentTransactionJWS - - authEndpointService.storeLoginResult = .success(StoreLoginResponse(authToken: Constants.newAuthToken, - email: "", - externalID: Constants.otherExternalID, - id: 1, - status: "authenticated")) - - // When - switch await appStoreAccountManagementFlow.refreshAuthTokenIfNeeded() { - case .success(let success): - // Then - XCTAssertTrue(storePurchaseManager.mostRecentTransactionCalled) - XCTAssertEqual(success, Constants.oldAuthToken) - XCTAssertEqual(accountManager.authToken, Constants.oldAuthToken) - XCTAssertEqual(accountManager.externalID, Constants.externalID) - XCTAssertEqual(accountManager.email, Constants.email) - case .failure(let error): - XCTFail("Unexpected failure: \(String(reflecting: error))") - } - } - - func testRefreshAuthTokenIfNeededErrorDueToNoPastTransactions() async throws { - // Given - accountManager.authToken = Constants.oldAuthToken - - authEndpointService.validateTokenResult = .failure(Constants.invalidTokenError) - - storePurchaseManager.mostRecentTransactionResult = nil - - // When - switch await appStoreAccountManagementFlow.refreshAuthTokenIfNeeded() { - case .success: - XCTFail("Unexpected success") - case .failure(let error): - // Then - XCTAssertTrue(storePurchaseManager.mostRecentTransactionCalled) - XCTAssertEqual(error, .noPastTransaction) - } - } - - func testRefreshAuthTokenIfNeededErrorDueToStoreLoginFailure() async throws { - // Given - accountManager.authToken = Constants.oldAuthToken - - authEndpointService.validateTokenResult = .failure(Constants.invalidTokenError) - - storePurchaseManager.mostRecentTransactionResult = Constants.mostRecentTransactionJWS - - authEndpointService.storeLoginResult = .failure(.unknownServerError) - - // When - switch await appStoreAccountManagementFlow.refreshAuthTokenIfNeeded() { - case .success: - XCTFail("Unexpected success") - case .failure(let error): - // Then - XCTAssertTrue(storePurchaseManager.mostRecentTransactionCalled) - XCTAssertEqual(error, .authenticatingWithTransactionFailed) - } - } -} diff --git a/Tests/SubscriptionTests/Flows/AppStorePurchaseFlowTests.swift b/Tests/SubscriptionTests/Flows/AppStorePurchaseFlowTests.swift index 94c1fe500..a0a0df21f 100644 --- a/Tests/SubscriptionTests/Flows/AppStorePurchaseFlowTests.swift +++ b/Tests/SubscriptionTests/Flows/AppStorePurchaseFlowTests.swift @@ -18,51 +18,184 @@ import XCTest @testable import Subscription +@testable import Networking import SubscriptionTestingUtilities +import TestUtils + +@available(macOS 12.0, iOS 15.0, *) +final class DefaultAppStorePurchaseFlowTests: XCTestCase { + + private var sut: DefaultAppStorePurchaseFlow! + private var subscriptionManagerMock: SubscriptionManagerMock! + private var storePurchaseManagerMock: StorePurchaseManagerMock! + private var appStoreRestoreFlowMock: AppStoreRestoreFlowMock! + + override func setUp() { + super.setUp() + subscriptionManagerMock = SubscriptionManagerMock() + storePurchaseManagerMock = StorePurchaseManagerMock() + appStoreRestoreFlowMock = AppStoreRestoreFlowMock() + sut = DefaultAppStorePurchaseFlow( + subscriptionManager: subscriptionManagerMock, + storePurchaseManager: storePurchaseManagerMock, + appStoreRestoreFlow: appStoreRestoreFlowMock + ) + } + + override func tearDown() { + sut = nil + subscriptionManagerMock = nil + storePurchaseManagerMock = nil + appStoreRestoreFlowMock = nil + super.tearDown() + } + + // MARK: - purchaseSubscription Tests + + func test_purchaseSubscription_withActiveSubscriptionAlreadyPresent_returnsError() async { + appStoreRestoreFlowMock.restoreAccountFromPastPurchaseResult = .success("someTransactionJWS") + + let result = await sut.purchaseSubscription(with: "testSubscriptionID") + + XCTAssertTrue(appStoreRestoreFlowMock.restoreAccountFromPastPurchaseCalled) + XCTAssertEqual(result, .failure(.activeSubscriptionAlreadyPresent)) + } + + func test_purchaseSubscription_withNoProductsFound_returnsError() async { + appStoreRestoreFlowMock.restoreAccountFromPastPurchaseResult = .failure(AppStoreRestoreFlowError.missingAccountOrTransactions) + + let result = await sut.purchaseSubscription(with: "testSubscriptionID") + + XCTAssertTrue(appStoreRestoreFlowMock.restoreAccountFromPastPurchaseCalled) + switch result { + case .success: + XCTFail("Unexpected success") + case .failure(let error): + switch error { + case AppStorePurchaseFlowError.accountCreationFailed: + break + default: + XCTFail("Unexpected error: \(error)") + } + } + } + + func test_purchaseSubscription_successfulPurchase_returnsTransactionJWS() async { + appStoreRestoreFlowMock.restoreAccountFromPastPurchaseResult = .failure(AppStoreRestoreFlowError.missingAccountOrTransactions) + subscriptionManagerMock.resultCreateAccountTokenContainer = OAuthTokensFactory.makeValidTokenContainerWithEntitlements() + storePurchaseManagerMock.purchaseSubscriptionResult = .success("transactionJWS") + + let result = await sut.purchaseSubscription(with: "testSubscriptionID") + + XCTAssertTrue(storePurchaseManagerMock.purchaseSubscriptionCalled) + XCTAssertEqual(result, .success("transactionJWS")) + } + + func test_purchaseSubscription_purchaseCancelledByUser_returnsCancelledError() async { + appStoreRestoreFlowMock.restoreAccountFromPastPurchaseResult = .failure(AppStoreRestoreFlowError.missingAccountOrTransactions) + storePurchaseManagerMock.purchaseSubscriptionResult = .failure(StorePurchaseManagerError.purchaseCancelledByUser) + subscriptionManagerMock.resultCreateAccountTokenContainer = OAuthTokensFactory.makeValidTokenContainerWithEntitlements() + subscriptionManagerMock.resultSubscription = SubscriptionMockFactory.subscription + + let result = await sut.purchaseSubscription(with: "testSubscriptionID") + + XCTAssertEqual(result, .failure(.cancelledByUser)) + } + + func test_purchaseSubscription_purchaseFailed_returnsPurchaseFailedError() async { + appStoreRestoreFlowMock.restoreAccountFromPastPurchaseResult = .failure(AppStoreRestoreFlowError.missingAccountOrTransactions) + storePurchaseManagerMock.purchaseSubscriptionResult = .failure(StorePurchaseManagerError.purchaseFailed) + subscriptionManagerMock.resultCreateAccountTokenContainer = OAuthTokensFactory.makeValidTokenContainerWithEntitlements() + subscriptionManagerMock.resultSubscription = SubscriptionMockFactory.subscription + + let result = await sut.purchaseSubscription(with: "testSubscriptionID") + + XCTAssertEqual(result, .failure(.purchaseFailed(StorePurchaseManagerError.purchaseFailed))) + } + + // MARK: - completeSubscriptionPurchase Tests + + func test_completeSubscriptionPurchase_withActiveSubscription_returnsSuccess() async { + subscriptionManagerMock.resultTokenContainer = OAuthTokensFactory.makeValidTokenContainerWithEntitlements() + subscriptionManagerMock.resultSubscription = SubscriptionMockFactory.subscription + subscriptionManagerMock.confirmPurchaseResponse = .success(subscriptionManagerMock.resultSubscription!) + + let result = await sut.completeSubscriptionPurchase(with: "transactionJWS") + + XCTAssertEqual(result, .success(.completed)) + } + func test_completeSubscriptionPurchase_withMissingEntitlements_returnsMissingEntitlementsError() async { + subscriptionManagerMock.resultTokenContainer = OAuthTokensFactory.makeValidTokenContainer() + subscriptionManagerMock.resultSubscription = SubscriptionMockFactory.subscription + subscriptionManagerMock.confirmPurchaseResponse = .success(subscriptionManagerMock.resultSubscription!) + + let result = await sut.completeSubscriptionPurchase(with: "transactionJWS") + + XCTAssertEqual(result, .failure(.missingEntitlements)) + } + + func test_completeSubscriptionPurchase_withExpiredSubscription_returnsPurchaseFailedError() async { + subscriptionManagerMock.resultTokenContainer = OAuthTokensFactory.makeValidTokenContainer() + subscriptionManagerMock.resultSubscription = SubscriptionMockFactory.expiredSubscription + subscriptionManagerMock.confirmPurchaseResponse = .success(subscriptionManagerMock.resultSubscription!) + + let result = await sut.completeSubscriptionPurchase(with: "transactionJWS") + + XCTAssertEqual(result, .failure(.purchaseFailed(AppStoreRestoreFlowError.subscriptionExpired))) + } + + func test_completeSubscriptionPurchase_withConfirmPurchaseError_returnsPurchaseFailedError() async { + subscriptionManagerMock.resultSubscription = SubscriptionMockFactory.subscription + subscriptionManagerMock.resultTokenContainer = OAuthTokensFactory.makeValidTokenContainerWithEntitlements() + subscriptionManagerMock.confirmPurchaseResponse = .failure(OAuthServiceError.invalidResponseCode(HTTPStatusCode.badRequest)) + + let result = await sut.completeSubscriptionPurchase(with: "transactionJWS") + switch result { + case .success: + XCTFail("Unexpected success") + case .failure(let error): + switch error { + case .purchaseFailed: + break + default: + XCTFail("Unexpected error: \(error)") + } + } + } +} + +/* final class AppStorePurchaseFlowTests: XCTestCase { private struct Constants { - static let authToken = UUID().uuidString - static let accessToken = UUID().uuidString static let externalID = UUID().uuidString static let email = "dax@duck.com" static let productID = UUID().uuidString static let transactionJWS = "dGhpcyBpcyBub3QgYSByZWFsIEFw(...)cCBTdG9yZSB0cmFuc2FjdGlvbiBKV1M=" - - static let unknownServerError = APIServiceError.serverError(statusCode: 401, error: "unknown_error") } - var accountManager: AccountManagerMock! - var subscriptionService: SubscriptionEndpointServiceMock! - var authService: AuthEndpointServiceMock! - var storePurchaseManager: StorePurchaseManagerMock! - var appStoreRestoreFlow: AppStoreRestoreFlowMock! + var mockSubscriptionManager: SubscriptionManagerMock! + var mockStorePurchaseManager: StorePurchaseManagerMock! + var mockAppStoreRestoreFlow: AppStoreRestoreFlowMock! var appStorePurchaseFlow: AppStorePurchaseFlow! override func setUpWithError() throws { - subscriptionService = SubscriptionEndpointServiceMock() - storePurchaseManager = StorePurchaseManagerMock() - accountManager = AccountManagerMock() - appStoreRestoreFlow = AppStoreRestoreFlowMock() - authService = AuthEndpointServiceMock() - - appStorePurchaseFlow = DefaultAppStorePurchaseFlow(subscriptionEndpointService: subscriptionService, - storePurchaseManager: storePurchaseManager, - accountManager: accountManager, - appStoreRestoreFlow: appStoreRestoreFlow, - authEndpointService: authService) + mockSubscriptionManager = SubscriptionManagerMock() + mockStorePurchaseManager = StorePurchaseManagerMock() + mockAppStoreRestoreFlow = AppStoreRestoreFlowMock() + + appStorePurchaseFlow = DefaultAppStorePurchaseFlow(subscriptionManager: mockSubscriptionManager, + storePurchaseManager: mockStorePurchaseManager, + appStoreRestoreFlow: mockAppStoreRestoreFlow) } override func tearDownWithError() throws { - subscriptionService = nil - storePurchaseManager = nil - accountManager = nil - appStoreRestoreFlow = nil - authService = nil - + mockSubscriptionManager = nil + mockStorePurchaseManager = nil + mockAppStoreRestoreFlow = nil appStorePurchaseFlow = nil } @@ -70,27 +203,27 @@ final class AppStorePurchaseFlowTests: XCTestCase { func testPurchaseSubscriptionSuccess() async throws { // Given - XCTAssertFalse(accountManager.isUserAuthenticated) - appStoreRestoreFlow.restoreAccountFromPastPurchaseResult = .failure(.missingAccountOrTransactions) - authService.createAccountResult = .success(CreateAccountResponse(authToken: Constants.authToken, - externalID: Constants.externalID, - status: "created")) - accountManager.exchangeAuthTokenToAccessTokenResult = .success(Constants.accessToken) - accountManager.fetchAccountDetailsResult = .success((email: "", externalID: Constants.externalID)) - storePurchaseManager.purchaseSubscriptionResult = .success(Constants.transactionJWS) + mockAppStoreRestoreFlow.restoreAccountFromPastPurchaseResult = .failure(.missingAccountOrTransactions) +// authService.createAccountResult = .success(CreateAccountResponse(authToken: Constants.authToken, +// externalID: Constants.externalID, +// status: "created")) +// accountManager.exchangeAuthTokenToAccessTokenResult = .success(Constants.accessToken) +// accountManager.fetchAccountDetailsResult = .success((email: "", externalID: Constants.externalID)) +// storePurchaseManager.purchaseSubscriptionResult = .success(Constants.transactionJWS) // When - switch await appStorePurchaseFlow.purchaseSubscription(with: Constants.productID, emailAccessToken: nil) { + switch await appStorePurchaseFlow.purchaseSubscription(with: Constants.productID) { case .success(let success): - // Then - XCTAssertTrue(appStoreRestoreFlow.restoreAccountFromPastPurchaseCalled) - XCTAssertTrue(authService.createAccountCalled) - XCTAssertTrue(accountManager.exchangeAuthTokenToAccessTokenCalled) - XCTAssertTrue(accountManager.storeAuthTokenCalled) - XCTAssertTrue(accountManager.storeAccountCalled) - XCTAssertTrue(storePurchaseManager.purchaseSubscriptionCalled) - XCTAssertEqual(success, Constants.transactionJWS) +// // Then +// XCTAssertTrue(appStoreRestoreFlow.restoreAccountFromPastPurchaseCalled) +// XCTAssertTrue(authService.createAccountCalled) +// XCTAssertTrue(accountManager.exchangeAuthTokenToAccessTokenCalled) +// XCTAssertTrue(accountManager.storeAuthTokenCalled) +// XCTAssertTrue(accountManager.storeAccountCalled) +// XCTAssertTrue(storePurchaseManager.purchaseSubscriptionCalled) +// XCTAssertEqual(success, Constants.transactionJWS) + break case .failure(let error): XCTFail("Unexpected failure: \(String(reflecting: error))") } @@ -289,4 +422,5 @@ final class AppStorePurchaseFlowTests: XCTestCase { XCTAssertEqual(error, .missingEntitlements) } } -} + } +*/ diff --git a/Tests/SubscriptionTests/Flows/AppStoreRestoreFlowTests.swift b/Tests/SubscriptionTests/Flows/AppStoreRestoreFlowTests.swift index 3d065d1d7..39c2a7f80 100644 --- a/Tests/SubscriptionTests/Flows/AppStoreRestoreFlowTests.swift +++ b/Tests/SubscriptionTests/Flows/AppStoreRestoreFlowTests.swift @@ -18,315 +18,92 @@ import XCTest @testable import Subscription +@testable import Networking import SubscriptionTestingUtilities - -final class AppStoreRestoreFlowTests: XCTestCase { - - private struct Constants { - static let authToken = UUID().uuidString - static let accessToken = UUID().uuidString - static let externalID = UUID().uuidString - static let email = "dax@duck.com" - - static let mostRecentTransactionJWS = "dGhpcyBpcyBub3QgYSByZWFsIEFw(...)cCBTdG9yZSB0cmFuc2FjdGlvbiBKV1M=" - static let storeLoginResponse = StoreLoginResponse(authToken: Constants.authToken, - email: Constants.email, - externalID: Constants.externalID, - id: 1, - status: "authenticated") - - static let unknownServerError = APIServiceError.serverError(statusCode: 401, error: "unknown_error") - } - - var accountManager: AccountManagerMock! - var storePurchaseManager: StorePurchaseManagerMock! - var subscriptionService: SubscriptionEndpointServiceMock! - var authService: AuthEndpointServiceMock! - - var appStoreRestoreFlow: AppStoreRestoreFlow! - - override func setUpWithError() throws { - accountManager = AccountManagerMock() - storePurchaseManager = StorePurchaseManagerMock() - subscriptionService = SubscriptionEndpointServiceMock() - authService = AuthEndpointServiceMock() - - appStoreRestoreFlow = DefaultAppStoreRestoreFlow(accountManager: accountManager, - storePurchaseManager: storePurchaseManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService) +import TestUtils + +@available(macOS 12.0, iOS 15.0, *) +final class DefaultAppStoreRestoreFlowTests: XCTestCase { + + private var sut: DefaultAppStoreRestoreFlow! + private var subscriptionManagerMock: SubscriptionManagerMock! + private var storePurchaseManagerMock: StorePurchaseManagerMock! + + override func setUp() { + super.setUp() + subscriptionManagerMock = SubscriptionManagerMock() + storePurchaseManagerMock = StorePurchaseManagerMock() + sut = DefaultAppStoreRestoreFlow( + subscriptionManager: subscriptionManagerMock, + storePurchaseManager: storePurchaseManagerMock + ) } - override func tearDownWithError() throws { - accountManager = nil - subscriptionService = nil - authService = nil - storePurchaseManager = nil - - appStoreRestoreFlow = nil + override func tearDown() { + sut = nil + subscriptionManagerMock = nil + storePurchaseManagerMock = nil + super.tearDown() } - // MARK: - Tests for restoreAccountFromPastPurchase - - func testRestoreAccountFromPastPurchaseSuccess() async throws { - // Given - XCTAssertFalse(accountManager.isUserAuthenticated) - - storePurchaseManager.mostRecentTransactionResult = Constants.mostRecentTransactionJWS - - authService.storeLoginResult = .success(Constants.storeLoginResponse) - - accountManager.exchangeAuthTokenToAccessTokenResult = .success(Constants.accessToken) - - accountManager.fetchAccountDetailsResult = .success(AccountManager.AccountDetails(email: Constants.email, - externalID: Constants.externalID)) - accountManager.onFetchAccountDetails = { accessToken in - XCTAssertEqual(accessToken, Constants.accessToken) - } - - let subscription = SubscriptionMockFactory.subscription - subscriptionService.getSubscriptionResult = .success(subscription) - - XCTAssertTrue(subscription.isActive) - - accountManager.onStoreAuthToken = { authToken in - XCTAssertEqual(authToken, Constants.authToken) - } + // MARK: - restoreAccountFromPastPurchase Tests - accountManager.onStoreAccount = { accessToken, email, externalID in - XCTAssertEqual(accessToken, Constants.accessToken) - XCTAssertEqual(externalID, Constants.externalID) - } + func test_restoreAccountFromPastPurchase_withNoTransaction_returnsMissingAccountOrTransactionsError() async { + storePurchaseManagerMock.mostRecentTransactionResult = nil - // When - switch await appStoreRestoreFlow.restoreAccountFromPastPurchase() { - case .success: - // Then - XCTAssertTrue(accountManager.exchangeAuthTokenToAccessTokenCalled) - XCTAssertTrue(accountManager.fetchAccountDetailsCalled) - XCTAssertTrue(accountManager.storeAuthTokenCalled) - XCTAssertTrue(accountManager.storeAccountCalled) + let result = await sut.restoreAccountFromPastPurchase() - XCTAssertTrue(accountManager.isUserAuthenticated) - XCTAssertEqual(accountManager.authToken, Constants.authToken) - XCTAssertEqual(accountManager.accessToken, Constants.accessToken) - XCTAssertEqual(accountManager.externalID, Constants.externalID) - XCTAssertEqual(accountManager.email, Constants.email) + XCTAssertTrue(storePurchaseManagerMock.mostRecentTransactionCalled) + switch result { case .failure(let error): - XCTFail("Unexpected failure: \(error)") - } - } - - func testRestoreAccountFromPastPurchaseErrorDueToSubscriptionBeingExpired() async throws { - // Given - XCTAssertFalse(accountManager.isUserAuthenticated) - - storePurchaseManager.mostRecentTransactionResult = Constants.mostRecentTransactionJWS - - authService.storeLoginResult = .success(Constants.storeLoginResponse) - - accountManager.exchangeAuthTokenToAccessTokenResult = .success(Constants.accessToken) - - accountManager.fetchAccountDetailsResult = .success(AccountManager.AccountDetails(email: nil, externalID: Constants.externalID)) - accountManager.onFetchAccountDetails = { accessToken in - XCTAssertEqual(accessToken, Constants.accessToken) - } - - let subscription = SubscriptionMockFactory.expiredSubscription - subscriptionService.getSubscriptionResult = .success(subscription) - - XCTAssertFalse(subscription.isActive) - - accountManager.onStoreAuthToken = { authToken in - XCTAssertEqual(authToken, Constants.authToken) - } - - accountManager.onStoreAccount = { accessToken, email, externalID in - XCTAssertEqual(accessToken, Constants.accessToken) - XCTAssertEqual(externalID, Constants.externalID) - } - - let appStoreRestoreFlow = DefaultAppStoreRestoreFlow(accountManager: accountManager, - storePurchaseManager: storePurchaseManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService) - // When - switch await appStoreRestoreFlow.restoreAccountFromPastPurchase() { + XCTAssertEqual(error, .missingAccountOrTransactions) case .success: XCTFail("Unexpected success") - case .failure(let error): - // Then - XCTAssertTrue(accountManager.exchangeAuthTokenToAccessTokenCalled) - XCTAssertTrue(accountManager.fetchAccountDetailsCalled) - XCTAssertFalse(accountManager.storeAuthTokenCalled) - XCTAssertFalse(accountManager.storeAccountCalled) - - guard case .subscriptionExpired(let accountDetails) = error else { - XCTFail("Expected .subscriptionExpired error") - return - } - - XCTAssertEqual(accountDetails.authToken, Constants.authToken) - XCTAssertEqual(accountDetails.accessToken, Constants.accessToken) - XCTAssertEqual(accountDetails.externalID, Constants.externalID) - - XCTAssertFalse(accountManager.isUserAuthenticated) } } - func testRestoreAccountFromPastPurchaseErrorWhenNoRecentTransaction() async throws { - // Given - XCTAssertFalse(accountManager.isUserAuthenticated) + func test_restoreAccountFromPastPurchase_withExpiredSubscription_returnsSubscriptionExpiredError() async { + storePurchaseManagerMock.mostRecentTransactionResult = "lastTransactionJWS" + subscriptionManagerMock.resultSubscription = SubscriptionMockFactory.expiredSubscription - storePurchaseManager.mostRecentTransactionResult = nil + let result = await sut.restoreAccountFromPastPurchase() - let appStoreRestoreFlow = DefaultAppStoreRestoreFlow(accountManager: accountManager, - storePurchaseManager: storePurchaseManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService) - // When - switch await appStoreRestoreFlow.restoreAccountFromPastPurchase() { + XCTAssertTrue(storePurchaseManagerMock.mostRecentTransactionCalled) + switch result { + case .failure(let error): + XCTAssertEqual(error, .subscriptionExpired) case .success: XCTFail("Unexpected success") - case .failure(let error): - // Then - XCTAssertFalse(accountManager.exchangeAuthTokenToAccessTokenCalled) - XCTAssertFalse(accountManager.fetchAccountDetailsCalled) - XCTAssertFalse(accountManager.storeAuthTokenCalled) - XCTAssertFalse(accountManager.storeAccountCalled) - XCTAssertEqual(error, .missingAccountOrTransactions) - - XCTAssertFalse(accountManager.isUserAuthenticated) } } - func testRestoreAccountFromPastPurchaseErrorDueToStoreLoginFailure() async throws { - // Given - XCTAssertFalse(accountManager.isUserAuthenticated) - - storePurchaseManager.mostRecentTransactionResult = Constants.mostRecentTransactionJWS + func test_restoreAccountFromPastPurchase_withPastTransactionAuthenticationError_returnsAuthenticationError() async { + storePurchaseManagerMock.mostRecentTransactionResult = "lastTransactionJWS" + subscriptionManagerMock.resultSubscription = nil // Triggers an error when calling getSubscriptionFrom() - authService.storeLoginResult = .failure(Constants.unknownServerError) + let result = await sut.restoreAccountFromPastPurchase() - let appStoreRestoreFlow = DefaultAppStoreRestoreFlow(accountManager: accountManager, - storePurchaseManager: storePurchaseManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService) - // When - switch await appStoreRestoreFlow.restoreAccountFromPastPurchase() { - case .success: - XCTFail("Unexpected success") + XCTAssertTrue(storePurchaseManagerMock.mostRecentTransactionCalled) + switch result { case .failure(let error): - // Then - XCTAssertFalse(accountManager.exchangeAuthTokenToAccessTokenCalled) - XCTAssertFalse(accountManager.fetchAccountDetailsCalled) - XCTAssertFalse(accountManager.storeAuthTokenCalled) - XCTAssertFalse(accountManager.storeAccountCalled) XCTAssertEqual(error, .pastTransactionAuthenticationError) - - XCTAssertFalse(accountManager.isUserAuthenticated) - } - } - - func testRestoreAccountFromPastPurchaseErrorDueToStoreAuthTokenExchangeFailure() async throws { - // Given - XCTAssertFalse(accountManager.isUserAuthenticated) - - storePurchaseManager.mostRecentTransactionResult = Constants.mostRecentTransactionJWS - - authService.storeLoginResult = .success(Constants.storeLoginResponse) - - accountManager.exchangeAuthTokenToAccessTokenResult = .failure(Constants.unknownServerError) - - let appStoreRestoreFlow = DefaultAppStoreRestoreFlow(accountManager: accountManager, - storePurchaseManager: storePurchaseManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService) - // When - switch await appStoreRestoreFlow.restoreAccountFromPastPurchase() { case .success: XCTFail("Unexpected success") - case .failure(let error): - // Then - XCTAssertTrue(accountManager.exchangeAuthTokenToAccessTokenCalled) - XCTAssertFalse(accountManager.fetchAccountDetailsCalled) - XCTAssertFalse(accountManager.storeAuthTokenCalled) - XCTAssertFalse(accountManager.storeAccountCalled) - XCTAssertEqual(error, .failedToObtainAccessToken) - - XCTAssertFalse(accountManager.isUserAuthenticated) } } - func testRestoreAccountFromPastPurchaseErrorDueToAccountDetailsFetchFailure() async throws { - // Given - XCTAssertFalse(accountManager.isUserAuthenticated) - - storePurchaseManager.mostRecentTransactionResult = Constants.mostRecentTransactionJWS + func test_restoreAccountFromPastPurchase_withActiveSubscription_returnsSuccess() async { + storePurchaseManagerMock.mostRecentTransactionResult = "lastTransactionJWS" + subscriptionManagerMock.resultSubscription = SubscriptionMockFactory.subscription - authService.storeLoginResult = .success(Constants.storeLoginResponse) + let result = await sut.restoreAccountFromPastPurchase() - accountManager.exchangeAuthTokenToAccessTokenResult = .success(Constants.accessToken) - - accountManager.fetchAccountDetailsResult = .failure(Constants.unknownServerError) - accountManager.onFetchAccountDetails = { accessToken in - XCTAssertEqual(accessToken, Constants.accessToken) - } - - let appStoreRestoreFlow = DefaultAppStoreRestoreFlow(accountManager: accountManager, - storePurchaseManager: storePurchaseManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService) - // When - switch await appStoreRestoreFlow.restoreAccountFromPastPurchase() { - case .success: - XCTFail("Unexpected success") + XCTAssertTrue(storePurchaseManagerMock.mostRecentTransactionCalled) + switch result { case .failure(let error): - // Then - XCTAssertTrue(accountManager.exchangeAuthTokenToAccessTokenCalled) - XCTAssertTrue(accountManager.fetchAccountDetailsCalled) - XCTAssertFalse(accountManager.storeAuthTokenCalled) - XCTAssertFalse(accountManager.storeAccountCalled) - XCTAssertEqual(error, .failedToFetchAccountDetails) - - XCTAssertFalse(accountManager.isUserAuthenticated) - } - } - - func testRestoreAccountFromPastPurchaseErrorDueToSubscriptionFetchFailure() async throws { - // Given - XCTAssertFalse(accountManager.isUserAuthenticated) - - storePurchaseManager.mostRecentTransactionResult = Constants.mostRecentTransactionJWS - - authService.storeLoginResult = .success(Constants.storeLoginResponse) - - accountManager.exchangeAuthTokenToAccessTokenResult = .success(Constants.accessToken) - - accountManager.fetchAccountDetailsResult = .success(AccountManager.AccountDetails(email: nil, externalID: Constants.externalID)) - accountManager.onFetchAccountDetails = { accessToken in - XCTAssertEqual(accessToken, Constants.accessToken) - } - - subscriptionService.getSubscriptionResult = .failure(.apiError(Constants.unknownServerError)) - - let appStoreRestoreFlow = DefaultAppStoreRestoreFlow(accountManager: accountManager, - storePurchaseManager: storePurchaseManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService) - // When - switch await appStoreRestoreFlow.restoreAccountFromPastPurchase() { + XCTFail("Unexpected error: \(error)") case .success: - XCTFail("Unexpected success") - case .failure(let error): - // Then - XCTAssertTrue(accountManager.exchangeAuthTokenToAccessTokenCalled) - XCTAssertTrue(accountManager.fetchAccountDetailsCalled) - XCTAssertFalse(accountManager.storeAuthTokenCalled) - XCTAssertFalse(accountManager.storeAccountCalled) - XCTAssertEqual(error, .failedToFetchSubscriptionDetails) - - XCTAssertFalse(accountManager.isUserAuthenticated) + break } } } diff --git a/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift b/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift index 8ea41fa66..6daa31025 100644 --- a/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift +++ b/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift @@ -19,6 +19,7 @@ import XCTest @testable import Subscription import SubscriptionTestingUtilities +import Networking final class SubscriptionOptionsTests: XCTestCase { @@ -32,10 +33,10 @@ final class SubscriptionOptionsTests: XCTestCase { SubscriptionOption(id: "2", cost: SubscriptionOptionCost(displayPrice: "99 USD", recurrence: "yearly"), offer: yearlySubscriptionOffer) ], - features: [ - SubscriptionFeature(name: .networkProtection), - SubscriptionFeature(name: .dataBrokerProtection), - SubscriptionFeature(name: .identityTheftRestoration) + availableEntitlements: [ + .networkProtection, + .dataBrokerProtection, + .identityTheftRestoration ]) let jsonEncoder = JSONEncoder() @@ -103,12 +104,12 @@ final class SubscriptionOptionsTests: XCTestCase { } func testSubscriptionFeatureEncoding() throws { - let subscriptionFeature = SubscriptionFeature(name: .identityTheftRestoration) + let subscriptionFeature: SubscriptionEntitlement = .identityTheftRestoration let data = try? JSONEncoder().encode(subscriptionFeature) let subscriptionFeatureString = String(data: data!, encoding: .utf8)! - XCTAssertEqual(subscriptionFeatureString, "{\"name\":\"Identity Theft Restoration\"}") + XCTAssertEqual(subscriptionFeatureString, "\"Identity Theft Restoration\"") } func testEmptySubscriptionOptions() throws { diff --git a/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift b/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift index e397805db..0c9202210 100644 --- a/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift +++ b/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift @@ -15,12 +15,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // +/* + import XCTest + @testable import Subscription + import SubscriptionTestingUtilities -import XCTest -@testable import Subscription -import SubscriptionTestingUtilities - -final class StripePurchaseFlowTests: XCTestCase { + final class StripePurchaseFlowTests: XCTestCase { private struct Constants { static let authToken = UUID().uuidString @@ -252,4 +252,5 @@ final class StripePurchaseFlowTests: XCTestCase { XCTAssertEqual(accountManager.accessToken, Constants.accessToken) XCTAssertEqual(accountManager.externalID, Constants.externalID) } -} + } +*/ diff --git a/Tests/SubscriptionTests/Managers/AccountManagerTests.swift b/Tests/SubscriptionTests/Managers/AccountManagerTests.swift deleted file mode 100644 index 0a04a4cde..000000000 --- a/Tests/SubscriptionTests/Managers/AccountManagerTests.swift +++ /dev/null @@ -1,508 +0,0 @@ -// -// AccountManagerTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import XCTest -@testable import Subscription -import SubscriptionTestingUtilities -import Common - -final class AccountManagerTests: XCTestCase { - - private struct Constants { - static let userDefaultsSuiteName = "AccountManagerTests" - - static let authToken = UUID().uuidString - static let accessToken = UUID().uuidString - static let externalID = UUID().uuidString - - static let email = "dax@duck.com" - - static let entitlements = [Entitlement(product: .dataBrokerProtection), - Entitlement(product: .identityTheftRestoration), - Entitlement(product: .networkProtection)] - - static let keychainError = AccountKeychainAccessError.keychainSaveFailure(1) - static let invalidTokenError = APIServiceError.serverError(statusCode: 401, error: "invalid_token") - static let unknownServerError = APIServiceError.serverError(statusCode: 401, error: "unknown_error") - } - - var userDefaults: UserDefaults! - var accountStorage: AccountKeychainStorageMock! - var accessTokenStorage: SubscriptionTokenKeychainStorageMock! - var entitlementsCache: UserDefaultsCache<[Entitlement]>! - var subscriptionService: SubscriptionEndpointServiceMock! - var authService: AuthEndpointServiceMock! - - var accountManager: AccountManager! - - override func setUpWithError() throws { - userDefaults = UserDefaults(suiteName: Constants.userDefaultsSuiteName)! - userDefaults.removePersistentDomain(forName: Constants.userDefaultsSuiteName) - - accountStorage = AccountKeychainStorageMock() - accessTokenStorage = SubscriptionTokenKeychainStorageMock() - entitlementsCache = UserDefaultsCache<[Entitlement]>(userDefaults: userDefaults, - key: UserDefaultsCacheKey.subscriptionEntitlements, - settings: UserDefaultsCacheSettings(defaultExpirationInterval: .minutes(20))) - subscriptionService = SubscriptionEndpointServiceMock() - authService = AuthEndpointServiceMock() - - accountManager = DefaultAccountManager(storage: accountStorage, - accessTokenStorage: accessTokenStorage, - entitlementsCache: entitlementsCache, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService) - } - - override func tearDownWithError() throws { - accountStorage = nil - accessTokenStorage = nil - entitlementsCache = nil - subscriptionService = nil - authService = nil - - accountManager = nil - } - - // MARK: - Tests for storeAuthToken - - func testStoreAuthToken() throws { - // When - accountManager.storeAuthToken(token: Constants.authToken) - - XCTAssertEqual(accountManager.authToken, Constants.authToken) - XCTAssertEqual(accountStorage.authToken, Constants.authToken) - } - - func testStoreAuthTokenFailure() async throws { - // Given - let delegateCalled = expectation(description: "AccountManagerKeychainAccessDelegate called") - let keychainAccessDelegateMock = AccountManagerKeychainAccessDelegateMock { type, error in - delegateCalled.fulfill() - XCTAssertEqual(type, .storeAuthToken) - XCTAssertEqual(error, Constants.keychainError) - } - - accountStorage.mockedAccessError = Constants.keychainError - accountManager.delegate = keychainAccessDelegateMock - - // When - accountManager.storeAuthToken(token: Constants.authToken) - - // Then - await fulfillment(of: [delegateCalled], timeout: 0.5) - } - - // MARK: - Tests for storeAccount - - func testStoreAccount() async throws { - // Given - - let notificationExpectation = expectation(forNotification: .accountDidSignIn, object: accountManager, handler: nil) - - // When - accountManager.storeAccount(token: Constants.accessToken, email: Constants.email, externalID: Constants.externalID) - - // Then - XCTAssertEqual(accountManager.accessToken, Constants.accessToken) - XCTAssertEqual(accountManager.email, Constants.email) - XCTAssertEqual(accountManager.externalID, Constants.externalID) - - XCTAssertEqual(accessTokenStorage.accessToken, Constants.accessToken) - XCTAssertEqual(accountStorage.email, Constants.email) - XCTAssertEqual(accountStorage.externalID, Constants.externalID) - - await fulfillment(of: [notificationExpectation], timeout: 0.5) - } - - func testStoreAccountUpdatingEmailToNil() throws { - // When - accountManager.storeAccount(token: Constants.accessToken, email: Constants.email, externalID: Constants.externalID) - accountManager.storeAccount(token: Constants.accessToken, email: nil, externalID: Constants.externalID) - - // Then - XCTAssertEqual(accountManager.accessToken, Constants.accessToken) - XCTAssertEqual(accountManager.email, nil) - XCTAssertEqual(accountManager.externalID, Constants.externalID) - - XCTAssertEqual(accessTokenStorage.accessToken, Constants.accessToken) - XCTAssertEqual(accountStorage.email, nil) - XCTAssertEqual(accountStorage.externalID, Constants.externalID) - } - - // MARK: - Tests for signOut - - func testSignOut() async throws { - // Given - accountManager.storeAuthToken(token: Constants.authToken) - accountManager.storeAccount(token: Constants.accessToken, email: Constants.email, externalID: Constants.externalID) - - XCTAssertTrue(accountManager.isUserAuthenticated) - - let notificationExpectation = expectation(forNotification: .accountDidSignOut, object: accountManager, handler: nil) - - // When - accountManager.signOut() - - // Then - XCTAssertFalse(accountManager.isUserAuthenticated) - - XCTAssertTrue(accountStorage.clearAuthenticationStateCalled) - XCTAssertTrue(accessTokenStorage.removeAccessTokenCalled) - XCTAssertTrue(subscriptionService.signOutCalled) - XCTAssertNil(entitlementsCache.get()) - - await fulfillment(of: [notificationExpectation], timeout: 0.5) - } - - func testSignOutWithoutSendingNotification() async throws { - // Given - accountManager.storeAuthToken(token: Constants.authToken) - accountManager.storeAccount(token: Constants.accessToken, email: Constants.email, externalID: Constants.externalID) - - XCTAssertTrue(accountManager.isUserAuthenticated) - - let notificationExpectation = expectation(forNotification: .accountDidSignOut, object: accountManager, handler: nil) - notificationExpectation.isInverted = true - - // When - accountManager.signOut(skipNotification: true) - - // Then - XCTAssertFalse(accountManager.isUserAuthenticated) - await fulfillment(of: [notificationExpectation], timeout: 0.5) - } - - // MARK: - Tests for hasEntitlement - - func testHasEntitlementIgnoringLocalCacheData() async throws { - // Given - let productName = Entitlement.ProductName.networkProtection - - accessTokenStorage.accessToken = Constants.accessToken - entitlementsCache.set([]) - authService.validateTokenResult = .success(ValidateTokenResponse(account: ValidateTokenResponse.Account(email: Constants.email, - entitlements: Constants.entitlements, - externalID: Constants.externalID))) - XCTAssertTrue(Constants.entitlements.compactMap { $0.product }.contains(productName)) - - // When - let result = await accountManager.hasEntitlement(forProductName: productName, cachePolicy: .reloadIgnoringLocalCacheData) - - // Then - switch result { - case .success(let success): - XCTAssertTrue(success) - XCTAssertTrue(authService.validateTokenCalled) - XCTAssertEqual(entitlementsCache.get(), Constants.entitlements) - case .failure: - XCTFail("Unexpected failure") - } - } - - func testHasEntitlementWithoutParameterUseCacheData() async throws { - // Given - let productName = Entitlement.ProductName.networkProtection - - accessTokenStorage.accessToken = Constants.accessToken - entitlementsCache.set(Constants.entitlements) - - XCTAssertTrue(Constants.entitlements.compactMap { $0.product }.contains(productName)) - - // When - let result = await accountManager.hasEntitlement(forProductName: productName) - - // Then - switch result { - case .success(let success): - XCTAssertTrue(success) - XCTAssertFalse(authService.validateTokenCalled) - case .failure: - XCTFail("Unexpected failure") - } - } - - // MARK: - Tests for updateCache - - func testUpdateEntitlementsCache() async throws { - // Given - let updatedEntitlements = [Entitlement(product: .networkProtection)] - XCTAssertNotEqual(Constants.entitlements, updatedEntitlements) - - entitlementsCache.set(Constants.entitlements) - - let notificationExpectation = expectation(forNotification: .entitlementsDidChange, object: accountManager, handler: nil) - - // When - accountManager.updateCache(with: updatedEntitlements) - - // Then - XCTAssertEqual(entitlementsCache.get(), updatedEntitlements) - await fulfillment(of: [notificationExpectation], timeout: 0.5) - } - - func testUpdateEntitlementsCacheWithEmptyArray() async throws { - // Given - entitlementsCache.set(Constants.entitlements) - - let notificationExpectation = expectation(forNotification: .entitlementsDidChange, object: accountManager, handler: nil) - - // When - accountManager.updateCache(with: []) - - // Then - XCTAssertNil(entitlementsCache.get()) - await fulfillment(of: [notificationExpectation], timeout: 0.5) - } - - func testUpdateEntitlementsCacheWithSameEntitlements() async throws { - // Given - entitlementsCache.set(Constants.entitlements) - - let notificationNotFiredExpectation = expectation(forNotification: .entitlementsDidChange, object: accountManager, handler: nil) - notificationNotFiredExpectation.isInverted = true - - // When - accountManager.updateCache(with: Constants.entitlements) - - // Then - XCTAssertEqual(entitlementsCache.get(), Constants.entitlements) - await fulfillment(of: [notificationNotFiredExpectation], timeout: 0.5) - } - - // MARK: - Tests for fetchEntitlements - - func testFetchEntitlementsIgnoringLocalCacheData() async throws { - // Given - accessTokenStorage.accessToken = Constants.accessToken - entitlementsCache.set([]) - authService.validateTokenResult = .success(ValidateTokenResponse(account: ValidateTokenResponse.Account(email: Constants.email, - entitlements: Constants.entitlements, - externalID: Constants.externalID))) - - // When - let result = await accountManager.fetchEntitlements(cachePolicy: .reloadIgnoringLocalCacheData) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success, Constants.entitlements) - XCTAssertTrue(authService.validateTokenCalled) - XCTAssertEqual(entitlementsCache.get(), Constants.entitlements) - case .failure: - XCTFail("Unexpected failure") - } - } - - func testFetchEntitlementsReturnCachedData() async throws { - // Given - accessTokenStorage.accessToken = Constants.accessToken - entitlementsCache.set(Constants.entitlements) - - // When - let result = await accountManager.fetchEntitlements(cachePolicy: .returnCacheDataElseLoad) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success, Constants.entitlements) - XCTAssertFalse(authService.validateTokenCalled) - XCTAssertEqual(entitlementsCache.get(), Constants.entitlements) - case .failure: - XCTFail("Unexpected failure") - } - } - - func testFetchEntitlementsReturnCachedDataWhenCacheIsExpired() async throws { - // Given - let updatedEntitlements = [Entitlement(product: .networkProtection)] - - accessTokenStorage.accessToken = Constants.accessToken - entitlementsCache.set(Constants.entitlements, expires: Date.distantPast) - authService.validateTokenResult = .success(ValidateTokenResponse(account: ValidateTokenResponse.Account(email: Constants.email, - entitlements: updatedEntitlements, - externalID: Constants.externalID))) - - XCTAssertNotEqual(Constants.entitlements, updatedEntitlements) - - // When - let result = await accountManager.fetchEntitlements(cachePolicy: .returnCacheDataElseLoad) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success, updatedEntitlements) - XCTAssertTrue(authService.validateTokenCalled) - XCTAssertEqual(entitlementsCache.get(), updatedEntitlements) - case .failure: - XCTFail("Unexpected failure") - } - } - - func testFetchEntitlementsReturnCacheDataDontLoad() async throws { - // Given - accessTokenStorage.accessToken = Constants.accessToken - entitlementsCache.set(Constants.entitlements) - - // When - let result = await accountManager.fetchEntitlements(cachePolicy: .returnCacheDataDontLoad) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success, Constants.entitlements) - XCTAssertFalse(authService.validateTokenCalled) - XCTAssertEqual(entitlementsCache.get(), Constants.entitlements) - case .failure: - XCTFail("Unexpected failure") - } - } - - func testFetchEntitlementsReturnCacheDataDontLoadWhenCacheIsExpired() async throws { - // Given - accessTokenStorage.accessToken = Constants.accessToken - entitlementsCache.set(Constants.entitlements, expires: Date.distantPast) - - // When - let result = await accountManager.fetchEntitlements(cachePolicy: .returnCacheDataDontLoad) - - // Then - switch result { - case .success: - XCTFail("Unexpected success") - case .failure(let error): - guard let entitlementsError = error as? DefaultAccountManager.EntitlementsError else { - XCTFail("Incorrect error type") - return - } - - XCTAssertEqual(entitlementsError, .noCachedData) - } - } - - // MARK: - Tests for exchangeAuthTokenToAccessToken - - func testExchangeAuthTokenToAccessToken() async throws { - // Given - authService.getAccessTokenResult = .success(.init(accessToken: Constants.accessToken)) - - // When - let result = await accountManager.exchangeAuthTokenToAccessToken(Constants.authToken) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success, Constants.accessToken) - XCTAssertTrue(authService.getAccessTokenCalled) - case .failure: - XCTFail("Unexpected failure") - } - } - - // MARK: - Tests for fetchAccountDetails - - func testFetchAccountDetails() async throws { - // Given - authService.validateTokenResult = .success(ValidateTokenResponse(account: ValidateTokenResponse.Account(email: Constants.email, - entitlements: Constants.entitlements, - externalID: Constants.externalID))) - - // When - let result = await accountManager.fetchAccountDetails(with: Constants.accessToken) - - // Then - switch result { - case .success(let success): - XCTAssertEqual(success.email, Constants.email) - XCTAssertEqual(success.externalID, Constants.externalID) - XCTAssertTrue(authService.validateTokenCalled) - case .failure: - XCTFail("Unexpected failure") - } - } - - // MARK: - Tests for checkForEntitlements - - func testCheckForEntitlementsSuccess() async throws { - // Given - var callCount = 0 - - accessTokenStorage.accessToken = Constants.accessToken - - authService.validateTokenResult = .success(ValidateTokenResponse(account: ValidateTokenResponse.Account(email: Constants.email, - entitlements: Constants.entitlements, - externalID: Constants.externalID))) - authService.onValidateToken = { _ in - callCount += 1 - } - - // When - let result = await accountManager.checkForEntitlements(wait: 0.1, retry: 5) - - // Then - XCTAssertTrue(result) - XCTAssertTrue(authService.validateTokenCalled) - XCTAssertEqual(callCount, 1) - } - - func testCheckForEntitlementsFailure() async throws { - // Given - var callCount = 0 - - accessTokenStorage.accessToken = Constants.accessToken - - authService.validateTokenResult = .failure(Constants.unknownServerError) - authService.onValidateToken = { _ in - callCount += 1 - } - - // When - let result = await accountManager.checkForEntitlements(wait: 0.1, retry: 5) - - // Then - XCTAssertFalse(result) - XCTAssertTrue(authService.validateTokenCalled) - XCTAssertEqual(callCount, 5) - } - - func testCheckForEntitlementsSuccessAfterRetries() async throws { - // Given - var callCount = 0 - - accessTokenStorage.accessToken = Constants.accessToken - - authService.validateTokenResult = .failure(Constants.unknownServerError) - authService.onValidateToken = { _ in - callCount += 1 - - if callCount == 3 { - self.authService.validateTokenResult = .success(ValidateTokenResponse(account: ValidateTokenResponse.Account(email: Constants.email, - entitlements: Constants.entitlements, - externalID: Constants.externalID))) - } - } - - // When - let result = await accountManager.checkForEntitlements(wait: 0.1, retry: 5) - - // Then - XCTAssertTrue(result) - XCTAssertTrue(authService.validateTokenCalled) - XCTAssertEqual(callCount, 3) - } -} diff --git a/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift b/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift index 26ce6d89c..56dba0324 100644 --- a/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift +++ b/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift @@ -18,184 +18,198 @@ import XCTest @testable import Subscription +@testable import Networking import SubscriptionTestingUtilities +import TestUtils -final class SubscriptionManagerTests: XCTestCase { +class SubscriptionManagerTests: XCTestCase { - private struct Constants { - static let userDefaultsSuiteName = "SubscriptionManagerTests" - - static let accessToken = UUID().uuidString - - static let invalidTokenError = APIServiceError.serverError(statusCode: 401, error: "invalid_token") - } - - var storePurchaseManager: StorePurchaseManagerMock! - var accountManager: AccountManagerMock! - var subscriptionService: SubscriptionEndpointServiceMock! - var authService: AuthEndpointServiceMock! - var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCacheMock! - var subscriptionEnvironment: SubscriptionEnvironment! + var subscriptionManager: DefaultSubscriptionManager! + var mockOAuthClient: MockOAuthClient! + var mockSubscriptionEndpointService: SubscriptionEndpointServiceMock! + var mockStorePurchaseManager: StorePurchaseManagerMock! var subscriptionFeatureFlagger: FeatureFlaggerMapping! - var subscriptionManager: SubscriptionManager! + override func setUp() { + super.setUp() - override func setUpWithError() throws { - storePurchaseManager = StorePurchaseManagerMock() - accountManager = AccountManagerMock() - subscriptionService = SubscriptionEndpointServiceMock() - authService = AuthEndpointServiceMock() - subscriptionFeatureMappingCache = SubscriptionFeatureMappingCacheMock() - subscriptionEnvironment = SubscriptionEnvironment(serviceEnvironment: .production, - purchasePlatform: .appStore) + mockOAuthClient = MockOAuthClient() + mockSubscriptionEndpointService = SubscriptionEndpointServiceMock() + mockStorePurchaseManager = StorePurchaseManagerMock() subscriptionFeatureFlagger = FeatureFlaggerMapping(mapping: { $0.defaultState }) - subscriptionManager = DefaultSubscriptionManager(storePurchaseManager: storePurchaseManager, - accountManager: accountManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService, - subscriptionFeatureMappingCache: subscriptionFeatureMappingCache, - subscriptionEnvironment: subscriptionEnvironment, - subscriptionFeatureFlagger: subscriptionFeatureFlagger) - + subscriptionManager = DefaultSubscriptionManager( + storePurchaseManager: mockStorePurchaseManager, + oAuthClient: mockOAuthClient, + subscriptionEndpointService: mockSubscriptionEndpointService, + subscriptionEnvironment: SubscriptionEnvironment(serviceEnvironment: .staging, purchasePlatform: .stripe), + subscriptionFeatureFlagger: subscriptionFeatureFlagger, + pixelHandler: { _ in } + ) } - override func tearDownWithError() throws { - storePurchaseManager = nil - accountManager = nil - subscriptionService = nil - authService = nil - subscriptionEnvironment = nil - + override func tearDown() { subscriptionManager = nil + mockOAuthClient = nil + mockSubscriptionEndpointService = nil + mockStorePurchaseManager = nil + super.tearDown() } - // MARK: - Tests for save and loadEnvironmentFrom - - func testLoadEnvironmentFromUserDefaults() async throws { - // Given - let userDefaults = UserDefaults(suiteName: Constants.userDefaultsSuiteName)! - userDefaults.removePersistentDomain(forName: Constants.userDefaultsSuiteName) + // MARK: - Token Retrieval Tests - var loadedEnvironment = DefaultSubscriptionManager.loadEnvironmentFrom(userDefaults: userDefaults) - XCTAssertNil(loadedEnvironment) + func testGetTokenContainer_Success() async throws { + let expectedTokenContainer = OAuthTokensFactory.makeValidTokenContainer() + mockOAuthClient.getTokensResponse = .success(expectedTokenContainer) - // When - DefaultSubscriptionManager.save(subscriptionEnvironment: subscriptionEnvironment, - userDefaults: userDefaults) - loadedEnvironment = DefaultSubscriptionManager.loadEnvironmentFrom(userDefaults: userDefaults) - - // Then - XCTAssertEqual(loadedEnvironment?.serviceEnvironment, subscriptionEnvironment.serviceEnvironment) - XCTAssertEqual(loadedEnvironment?.purchasePlatform, subscriptionEnvironment.purchasePlatform) + let result = try await subscriptionManager.getTokenContainer(policy: .localValid) + XCTAssertEqual(result, expectedTokenContainer) } - // MARK: - Tests for setup for App Store - - func testSetupForAppStore() async throws { - // Given - storePurchaseManager.onUpdateAvailableProducts = { - self.storePurchaseManager.areProductsAvailable = true + func testGetTokenContainer_ErrorHandlingDeadToken() async throws { + // Set up dead token error to trigger recovery attempt + mockOAuthClient.getTokensResponse = .failure(OAuthClientError.deadToken) + let date = Date() + let expiredSubscription = PrivacyProSubscription( + productId: "testProduct", + name: "Test Subscription", + billingPeriod: .monthly, + startedAt: date.addingTimeInterval(-30 * 24 * 60 * 60), // 30 days ago + expiresOrRenewsAt: date.addingTimeInterval(-1), // expired + platform: .apple, + status: .expired + ) + mockSubscriptionEndpointService.getSubscriptionResult = .success(expiredSubscription) + let expectation = self.expectation(description: "Dead token pixel called") + subscriptionManager = DefaultSubscriptionManager( + storePurchaseManager: mockStorePurchaseManager, + oAuthClient: mockOAuthClient, + subscriptionEndpointService: mockSubscriptionEndpointService, + subscriptionEnvironment: SubscriptionEnvironment(serviceEnvironment: .staging, purchasePlatform: .stripe), + subscriptionFeatureFlagger: subscriptionFeatureFlagger, + pixelHandler: { type in + XCTAssertEqual(type, .deadToken) + expectation.fulfill() + } + ) + + do { + _ = try await subscriptionManager.getTokenContainer(policy: .localValid) + XCTFail("Error expected") + } catch SubscriptionManagerError.tokenUnavailable { + // Expected error + } catch { + XCTFail("Unexpected error: \(error)") } - // When - // triggered on DefaultSubscriptionManager's init - try await Task.sleep(seconds: 0.5) - - // Then - XCTAssertTrue(storePurchaseManager.updateAvailableProductsCalled) - XCTAssertTrue(subscriptionManager.canPurchase) + await fulfillment(of: [expectation], timeout: 0.1) } - // MARK: - Tests for loadInitialData - - func testLoadInitialData() async throws { - // Given - accountManager.accessToken = Constants.accessToken - - subscriptionService.onGetSubscription = { _, cachePolicy in - XCTAssertEqual(cachePolicy, .reloadIgnoringLocalCacheData) + // MARK: - Subscription Status Tests + + func testRefreshCachedSubscription_ActiveSubscription() { + let expectation = self.expectation(description: "Active subscription callback") + let activeSubscription = PrivacyProSubscription( + productId: "testProduct", + name: "Test Subscription", + billingPeriod: .monthly, + startedAt: Date(), + expiresOrRenewsAt: Date().addingTimeInterval(30 * 24 * 60 * 60), // 30 days from now + platform: .stripe, + status: .autoRenewable + ) + mockSubscriptionEndpointService.getSubscriptionResult = .success(activeSubscription) + mockOAuthClient.getTokensResponse = .success(OAuthTokensFactory.makeValidTokenContainer()) + subscriptionManager.refreshCachedSubscription { isActive in + XCTAssertTrue(isActive) + expectation.fulfill() } - subscriptionService.getSubscriptionResult = .success(SubscriptionMockFactory.subscription) + wait(for: [expectation], timeout: 0.1) + } - accountManager.onFetchEntitlements = { cachePolicy in - XCTAssertEqual(cachePolicy, .reloadIgnoringLocalCacheData) + func testRefreshCachedSubscription_ExpiredSubscription() { + let expectation = self.expectation(description: "Expired subscription callback") + let expiredSubscription = PrivacyProSubscription( + productId: "testProduct", + name: "Test Subscription", + billingPeriod: .monthly, + startedAt: Date().addingTimeInterval(-30 * 24 * 60 * 60), // 30 days ago + expiresOrRenewsAt: Date().addingTimeInterval(-1), // expired + platform: .apple, + status: .expired + ) + mockSubscriptionEndpointService.getSubscriptionResult = .success(expiredSubscription) + + subscriptionManager.refreshCachedSubscription { isActive in + XCTAssertFalse(isActive) + expectation.fulfill() } - - // When - subscriptionManager.loadInitialData() - - try await Task.sleep(seconds: 0.5) - - // Then - XCTAssertTrue(subscriptionService.getSubscriptionCalled) - XCTAssertTrue(accountManager.fetchEntitlementsCalled) + wait(for: [expectation], timeout: 0.1) } - func testLoadInitialDataNotCalledWhenUnauthenticated() async throws { - // Given - XCTAssertNil(accountManager.accessToken) - XCTAssertFalse(accountManager.isUserAuthenticated) + // MARK: - URL Generation Tests - // When - subscriptionManager.loadInitialData() + func testURLGeneration_ForCustomerPortal() async throws { + mockOAuthClient.getTokensResponse = .success(OAuthTokensFactory.makeValidTokenContainer()) + let customerPortalURLString = "https://example.com/customer-portal" + mockSubscriptionEndpointService.getCustomerPortalURLResult = .success(GetCustomerPortalURLResponse(customerPortalUrl: customerPortalURLString)) - // Then - XCTAssertFalse(subscriptionService.getSubscriptionCalled) - XCTAssertFalse(accountManager.fetchEntitlementsCalled) + let url = try await subscriptionManager.getCustomerPortalURL() + XCTAssertEqual(url.absoluteString, customerPortalURLString) } - // MARK: - Tests for refreshCachedSubscriptionAndEntitlements - - func testForRefreshCachedSubscriptionAndEntitlements() async throws { - // Given - let subscription = SubscriptionMockFactory.subscription - - accountManager.accessToken = Constants.accessToken - - subscriptionService.onGetSubscription = { _, cachePolicy in - XCTAssertEqual(cachePolicy, .reloadIgnoringLocalCacheData) - } - subscriptionService.getSubscriptionResult = .success(subscription) + func testURLGeneration_ForSubscriptionTypes() { + let environment = SubscriptionEnvironment(serviceEnvironment: .production, purchasePlatform: .appStore) + subscriptionManager = DefaultSubscriptionManager( + storePurchaseManager: mockStorePurchaseManager, + oAuthClient: mockOAuthClient, + subscriptionEndpointService: mockSubscriptionEndpointService, + subscriptionEnvironment: environment, + subscriptionFeatureFlagger: subscriptionFeatureFlagger, + pixelHandler: { _ in } + ) + + let helpURL = subscriptionManager.url(for: .purchase) + XCTAssertEqual(helpURL.absoluteString, "https://duckduckgo.com/subscriptions/welcome") + } - accountManager.onFetchEntitlements = { cachePolicy in - XCTAssertEqual(cachePolicy, .reloadIgnoringLocalCacheData) + // MARK: - Purchase Confirmation Tests + + func testConfirmPurchase_ErrorHandling() async throws { + let testSignature = "invalidSignature" + mockSubscriptionEndpointService.confirmPurchaseResult = .failure(APIRequestV2.Error.invalidResponse) + mockOAuthClient.getTokensResponse = .success(OAuthTokensFactory.makeValidTokenContainer()) + do { + _ = try await subscriptionManager.confirmPurchase(signature: testSignature) + XCTFail("Error expected") + } catch { + XCTAssertEqual(error as? APIRequestV2.Error, APIRequestV2.Error.invalidResponse) } + } - // When - let completionCalled = expectation(description: "completion called") - subscriptionManager.refreshCachedSubscriptionAndEntitlements { isSubscriptionActive in - completionCalled.fulfill() - XCTAssertEqual(isSubscriptionActive, subscription.isActive) - } + // MARK: - Tests for save and loadEnvironmentFrom - // Then - await fulfillment(of: [completionCalled], timeout: 0.5) - XCTAssertTrue(subscriptionService.getSubscriptionCalled) - XCTAssertTrue(accountManager.fetchEntitlementsCalled) - } + var subscriptionEnvironment: SubscriptionEnvironment! - func testForRefreshCachedSubscriptionAndEntitlementsSignOutUserOn401() async throws { + func testLoadEnvironmentFromUserDefaults() async throws { + subscriptionEnvironment = SubscriptionEnvironment(serviceEnvironment: .production, + purchasePlatform: .appStore) + let userDefaultsSuiteName = "SubscriptionManagerTests" // Given - accountManager.accessToken = Constants.accessToken + let userDefaults = UserDefaults(suiteName: userDefaultsSuiteName)! + userDefaults.removePersistentDomain(forName: userDefaultsSuiteName) - subscriptionService.onGetSubscription = { _, cachePolicy in - XCTAssertEqual(cachePolicy, .reloadIgnoringLocalCacheData) - } - subscriptionService.getSubscriptionResult = .failure(.apiError(Constants.invalidTokenError)) + var loadedEnvironment = DefaultSubscriptionManager.loadEnvironmentFrom(userDefaults: userDefaults) + XCTAssertNil(loadedEnvironment) // When - let completionCalled = expectation(description: "completion called") - subscriptionManager.refreshCachedSubscriptionAndEntitlements { isSubscriptionActive in - completionCalled.fulfill() - XCTAssertFalse(isSubscriptionActive) - } + DefaultSubscriptionManager.save(subscriptionEnvironment: subscriptionEnvironment, + userDefaults: userDefaults) + loadedEnvironment = DefaultSubscriptionManager.loadEnvironmentFrom(userDefaults: userDefaults) // Then - await fulfillment(of: [completionCalled], timeout: 0.5) - XCTAssertTrue(accountManager.signOutCalled) - XCTAssertTrue(subscriptionService.getSubscriptionCalled) - XCTAssertFalse(accountManager.fetchEntitlementsCalled) + XCTAssertEqual(loadedEnvironment?.serviceEnvironment, subscriptionEnvironment.serviceEnvironment) + XCTAssertEqual(loadedEnvironment?.purchasePlatform, subscriptionEnvironment.purchasePlatform) } // MARK: - Tests for url @@ -204,13 +218,14 @@ final class SubscriptionManagerTests: XCTestCase { // Given let productionEnvironment = SubscriptionEnvironment(serviceEnvironment: .production, purchasePlatform: .appStore) - let productionSubscriptionManager = DefaultSubscriptionManager(storePurchaseManager: storePurchaseManager, - accountManager: accountManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService, - subscriptionFeatureMappingCache: subscriptionFeatureMappingCache, - subscriptionEnvironment: productionEnvironment, - subscriptionFeatureFlagger: subscriptionFeatureFlagger) + let productionSubscriptionManager = DefaultSubscriptionManager( + storePurchaseManager: mockStorePurchaseManager, + oAuthClient: mockOAuthClient, + subscriptionEndpointService: mockSubscriptionEndpointService, + subscriptionEnvironment: productionEnvironment, + subscriptionFeatureFlagger: subscriptionFeatureFlagger, + pixelHandler: { _ in } + ) // When let productionPurchaseURL = productionSubscriptionManager.url(for: .purchase) @@ -223,13 +238,14 @@ final class SubscriptionManagerTests: XCTestCase { // Given let stagingEnvironment = SubscriptionEnvironment(serviceEnvironment: .staging, purchasePlatform: .appStore) - let stagingSubscriptionManager = DefaultSubscriptionManager(storePurchaseManager: storePurchaseManager, - accountManager: accountManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService, - subscriptionFeatureMappingCache: subscriptionFeatureMappingCache, - subscriptionEnvironment: stagingEnvironment, - subscriptionFeatureFlagger: subscriptionFeatureFlagger) + let stagingSubscriptionManager = DefaultSubscriptionManager( + storePurchaseManager: mockStorePurchaseManager, + oAuthClient: mockOAuthClient, + subscriptionEndpointService: mockSubscriptionEndpointService, + subscriptionEnvironment: stagingEnvironment, + subscriptionFeatureFlagger: subscriptionFeatureFlagger, + pixelHandler: { _ in } + ) // When let stagingPurchaseURL = stagingSubscriptionManager.url(for: .purchase) diff --git a/Tests/SubscriptionTests/PrivacyProSubscriptionIntegrationTests.swift b/Tests/SubscriptionTests/PrivacyProSubscriptionIntegrationTests.swift new file mode 100644 index 000000000..33f139e82 --- /dev/null +++ b/Tests/SubscriptionTests/PrivacyProSubscriptionIntegrationTests.swift @@ -0,0 +1,147 @@ +// +// PrivacyProSubscriptionIntegrationTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +@testable import Subscription +@testable import Networking +import TestUtils +import SubscriptionTestingUtilities + +final class PrivacyProSubscriptionIntegrationTests: XCTestCase { + + var apiService: MockAPIService! + var tokenStorage: MockTokenStorage! + var legacyAccountStorage: MockLegacyTokenStorage! + var subscriptionManager: DefaultSubscriptionManager! + var appStorePurchaseFlow: DefaultAppStorePurchaseFlow! + var appStoreRestoreFlow: DefaultAppStoreRestoreFlow! + var stripePurchaseFlow: DefaultStripePurchaseFlow! + var storePurchaseManager: StorePurchaseManagerMock! + var subscriptionFeatureFlagger: FeatureFlaggerMapping! + + let subscriptionSelectionID = "ios.subscription.1month" + + override func setUpWithError() throws { + + let subscriptionEnvironment = SubscriptionEnvironment(serviceEnvironment: .staging, purchasePlatform: .appStore) + apiService = MockAPIService() + let authService = DefaultOAuthService(baseURL: OAuthEnvironment.staging.url, apiService: apiService) + + // keychain storage + tokenStorage = MockTokenStorage() + legacyAccountStorage = MockLegacyTokenStorage() + + let authClient = DefaultOAuthClient(tokensStorage: tokenStorage, + legacyTokenStorage: legacyAccountStorage, + authService: authService) + apiService.authorizationRefresherCallback = { _ in + return OAuthTokensFactory.makeValidTokenContainer().accessToken + } + storePurchaseManager = StorePurchaseManagerMock() + let subscriptionEndpointService = DefaultSubscriptionEndpointService(apiService: apiService, + baseURL: subscriptionEnvironment.serviceEnvironment.url) + let pixelHandler: SubscriptionManager.PixelHandler = { type in + print("Pixel fired: \(type)") + } + subscriptionFeatureFlagger = FeatureFlaggerMapping(mapping: { $0.defaultState }) + + subscriptionManager = DefaultSubscriptionManager(storePurchaseManager: storePurchaseManager, + oAuthClient: authClient, + subscriptionEndpointService: subscriptionEndpointService, + subscriptionEnvironment: subscriptionEnvironment, + subscriptionFeatureFlagger: subscriptionFeatureFlagger, + pixelHandler: pixelHandler) + + appStoreRestoreFlow = DefaultAppStoreRestoreFlow(subscriptionManager: subscriptionManager, + storePurchaseManager: storePurchaseManager) + appStorePurchaseFlow = DefaultAppStorePurchaseFlow(subscriptionManager: subscriptionManager, + storePurchaseManager: storePurchaseManager, + appStoreRestoreFlow: appStoreRestoreFlow) + stripePurchaseFlow = DefaultStripePurchaseFlow(subscriptionManager: subscriptionManager) + } + + override func tearDownWithError() throws { + apiService = nil + tokenStorage = nil + legacyAccountStorage = nil + subscriptionManager = nil + appStorePurchaseFlow = nil + appStoreRestoreFlow = nil + stripePurchaseFlow = nil + } + + // MARK: - Apple store + + func testAppStorePurchaseSuccess() async throws { + + // configure mock API responses + APIMockResponseFactory.mockAuthoriseResponse(destinationMockAPIService: apiService, success: true) + APIMockResponseFactory.mockCreateAccountResponse(destinationMockAPIService: apiService, success: true) + APIMockResponseFactory.mockGetAccessTokenResponse(destinationMockAPIService: apiService, success: true) + APIMockResponseFactory.mockGetJWKS(destinationMockAPIService: apiService, success: true) + APIMockResponseFactory.mockConfirmPurchase(destinationMockAPIService: apiService, success: true) + APIMockResponseFactory.mockGetProducts(destinationMockAPIService: apiService, success: true) + APIMockResponseFactory.mockGetFeatures(destinationMockAPIService: apiService, success: true, subscriptionID: "ios.subscription.1month") + + (subscriptionManager.oAuthClient as! DefaultOAuthClient).testingDecodedTokenContainer = OAuthTokensFactory.makeValidTokenContainerWithEntitlements() + + // configure mock store purchase manager responses + storePurchaseManager.purchaseSubscriptionResult = .success("purchaseTransactionJWS") + + // Buy subscription + + var purchaseTransactionJWS: String? + switch await appStorePurchaseFlow.purchaseSubscription(with: subscriptionSelectionID) { + case .success(let transactionJWS): + purchaseTransactionJWS = transactionJWS + case .failure(let error): + XCTFail("Purchase failed with error: \(error)") + } + XCTAssertNotNil(purchaseTransactionJWS) + + switch await appStorePurchaseFlow.completeSubscriptionPurchase(with: purchaseTransactionJWS!) { + case .success: + break + case .failure(let error): + XCTFail("Purchase failed with error: \(error)") + } + } + + // MARK: - Stripe + + func testStripePurchaseSuccess() async throws { + + // configure mock API responses + APIMockResponseFactory.mockAuthoriseResponse(destinationMockAPIService: apiService, success: true) + APIMockResponseFactory.mockCreateAccountResponse(destinationMockAPIService: apiService, success: true) + APIMockResponseFactory.mockGetAccessTokenResponse(destinationMockAPIService: apiService, success: true) + + (subscriptionManager.oAuthClient as! DefaultOAuthClient).testingDecodedTokenContainer = OAuthTokensFactory.makeValidTokenContainerWithEntitlements() + + // Buy subscription + let email = "test@duck.com" + let result = await stripePurchaseFlow.prepareSubscriptionPurchase(emailAccessToken: email) + switch result { + case .success(let success): + XCTAssertNotNil(success.type) + XCTAssertNotNil(success.token) + case .failure(let error): + XCTFail("Purchase failed with error: \(error)") + } + } +} diff --git a/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift b/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift index 2a6a9d3d8..05f1a29e1 100644 --- a/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift +++ b/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift @@ -20,41 +20,19 @@ import XCTest import Common @testable import Subscription import SubscriptionTestingUtilities +import TestUtils final class SubscriptionCookieManagerTests: XCTestCase { - - private struct Constants { - static let authToken = UUID().uuidString - static let accessToken = UUID().uuidString - } - - var accountManager: AccountManagerMock! - var subscriptionService: SubscriptionEndpointServiceMock! - var authService: AuthEndpointServiceMock! - var storePurchaseManager: StorePurchaseManagerMock! - var subscriptionEnvironment: SubscriptionEnvironment! - var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCacheMock! +// var subscriptionService: SubscriptionEndpointServiceMock! +// var storePurchaseManager: StorePurchaseManagerMock! +// var subscriptionEnvironment: SubscriptionEnvironment! var subscriptionManager: SubscriptionManagerMock! var cookieStore: HTTPCookieStore! var subscriptionCookieManager: SubscriptionCookieManager! override func setUp() async throws { - accountManager = AccountManagerMock() - subscriptionService = SubscriptionEndpointServiceMock() - authService = AuthEndpointServiceMock() - storePurchaseManager = StorePurchaseManagerMock() - subscriptionEnvironment = SubscriptionEnvironment(serviceEnvironment: .production, - purchasePlatform: .appStore) - subscriptionFeatureMappingCache = SubscriptionFeatureMappingCacheMock() - - subscriptionManager = SubscriptionManagerMock(accountManager: accountManager, - subscriptionEndpointService: subscriptionService, - authEndpointService: authService, - storePurchaseManager: storePurchaseManager, - currentEnvironment: subscriptionEnvironment, - canPurchase: true, - subscriptionFeatureMappingCache: subscriptionFeatureMappingCache) + subscriptionManager = SubscriptionManagerMock() cookieStore = MockHTTPCookieStore() subscriptionCookieManager = SubscriptionCookieManager(subscriptionManager: subscriptionManager, @@ -64,27 +42,22 @@ final class SubscriptionCookieManagerTests: XCTestCase { } override func tearDown() async throws { - accountManager = nil - subscriptionService = nil - authService = nil - storePurchaseManager = nil - subscriptionEnvironment = nil - subscriptionManager = nil + subscriptionCookieManager = nil } func testSubscriptionCookieIsAddedWhenSigningInToSubscription() async throws { // Given await ensureNoSubscriptionCookieInTheCookieStore() - accountManager.accessToken = Constants.accessToken + subscriptionManager.resultTokenContainer = OAuthTokensFactory.makeValidTokenContainer() // When subscriptionCookieManager.enableSettingSubscriptionCookie() NotificationCenter.default.post(name: .accountDidSignIn, object: self, userInfo: nil) - try await Task.sleep(seconds: 0.1) + try await Task.sleep(interval: 0.1) // Then - await checkSubscriptionCookieIsPresent() + await checkSubscriptionCookieIsPresent(token: subscriptionManager.resultTokenContainer!.accessToken) } func testSubscriptionCookieIsDeletedWhenSigningInToSubscription() async throws { @@ -94,7 +67,7 @@ final class SubscriptionCookieManagerTests: XCTestCase { // When subscriptionCookieManager.enableSettingSubscriptionCookie() NotificationCenter.default.post(name: .accountDidSignOut, object: self, userInfo: nil) - try await Task.sleep(seconds: 0.1) + try await Task.sleep(interval: 0.1) // Then await checkSubscriptionCookieIsHasEmptyValue() @@ -102,27 +75,27 @@ final class SubscriptionCookieManagerTests: XCTestCase { func testRefreshWhenSignedInButCookieIsMissing() async throws { // Given - accountManager.accessToken = Constants.accessToken + subscriptionManager.resultTokenContainer = OAuthTokensFactory.makeValidTokenContainer() await ensureNoSubscriptionCookieInTheCookieStore() // When subscriptionCookieManager.enableSettingSubscriptionCookie() await subscriptionCookieManager.refreshSubscriptionCookie() - try await Task.sleep(seconds: 0.1) + try await Task.sleep(interval: 0.1) // Then - await checkSubscriptionCookieIsPresent() + await checkSubscriptionCookieIsPresent(token: subscriptionManager.resultTokenContainer!.accessToken) } func testRefreshWhenSignedOutButCookieIsPresent() async throws { // Given - accountManager.accessToken = nil + subscriptionManager.resultTokenContainer = nil await ensureSubscriptionCookieIsInTheCookieStore() // When subscriptionCookieManager.enableSettingSubscriptionCookie() await subscriptionCookieManager.refreshSubscriptionCookie() - try await Task.sleep(seconds: 0.1) + try await Task.sleep(interval: 0.1) // Then await checkSubscriptionCookieIsHasEmptyValue() @@ -138,7 +111,7 @@ final class SubscriptionCookieManagerTests: XCTestCase { await subscriptionCookieManager.refreshSubscriptionCookie() firstRefreshDate = subscriptionCookieManager.lastRefreshDate - try await Task.sleep(seconds: 0.5) + try await Task.sleep(interval: 0.5) await subscriptionCookieManager.refreshSubscriptionCookie() secondRefreshDate = subscriptionCookieManager.lastRefreshDate @@ -157,7 +130,7 @@ final class SubscriptionCookieManagerTests: XCTestCase { await subscriptionCookieManager.refreshSubscriptionCookie() firstRefreshDate = subscriptionCookieManager.lastRefreshDate - try await Task.sleep(seconds: 1.1) + try await Task.sleep(interval: 1.1) await subscriptionCookieManager.refreshSubscriptionCookie() secondRefreshDate = subscriptionCookieManager.lastRefreshDate @@ -167,12 +140,13 @@ final class SubscriptionCookieManagerTests: XCTestCase { } private func ensureSubscriptionCookieIsInTheCookieStore() async { + let validTokenContainer = OAuthTokensFactory.makeValidTokenContainer() let subscriptionCookie = HTTPCookie(properties: [ .domain: SubscriptionCookieManager.cookieDomain, .path: "/", .expires: Date().addingTimeInterval(.days(365)), .name: SubscriptionCookieManager.cookieName, - .value: Constants.accessToken, + .value: validTokenContainer.accessToken, .secure: true, .init(rawValue: "HttpOnly"): true ])! @@ -187,12 +161,12 @@ final class SubscriptionCookieManagerTests: XCTestCase { XCTAssertTrue(cookieStoreCookies.isEmpty) } - private func checkSubscriptionCookieIsPresent() async { + private func checkSubscriptionCookieIsPresent(token: String) async { guard let subscriptionCookie = await cookieStore.fetchSubscriptionCookie() else { XCTFail("No subscription cookie in the store") return } - XCTAssertEqual(subscriptionCookie.value, Constants.accessToken) + XCTAssertEqual(subscriptionCookie.value, token) } private func checkSubscriptionCookieIsHasEmptyValue() async {