-
Notifications
You must be signed in to change notification settings - Fork 368
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #34 from MacPaw/combine
Add combine extensions
- Loading branch information
Showing
7 changed files
with
240 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// | ||
// File.swift | ||
// | ||
// | ||
// Created by Sergii Kryvoblotskyi on 03/04/2023. | ||
// | ||
|
||
#if canImport(Combine) | ||
|
||
import Combine | ||
|
||
@available(iOS 13.0, *) | ||
@available(tvOS 13.0, *) | ||
@available(macOS 10.15, *) | ||
@available(watchOS 6.0, *) | ||
public extension OpenAIProtocol { | ||
|
||
func completions(query: CompletionsQuery) -> AnyPublisher<CompletionsResult, Error> { | ||
Future<CompletionsResult, Error> { | ||
completions(query: query, completion: $0) | ||
} | ||
.eraseToAnyPublisher() | ||
} | ||
|
||
func images(query: ImagesQuery) -> AnyPublisher<ImagesResult, Error> { | ||
Future<ImagesResult, Error> { | ||
images(query: query, completion: $0) | ||
} | ||
.eraseToAnyPublisher() | ||
} | ||
|
||
func embeddings(query: EmbeddingsQuery) -> AnyPublisher<EmbeddingsResult, Error> { | ||
Future<EmbeddingsResult, Error> { | ||
embeddings(query: query, completion: $0) | ||
} | ||
.eraseToAnyPublisher() | ||
} | ||
|
||
func chats(query: ChatQuery) -> AnyPublisher<ChatResult, Error> { | ||
Future<ChatResult, Error> { | ||
chats(query: query, completion: $0) | ||
} | ||
.eraseToAnyPublisher() | ||
} | ||
|
||
func audioTranscriptions(query: AudioTranscriptionQuery) -> AnyPublisher<AudioTranscriptionResult, Error> { | ||
Future<AudioTranscriptionResult, Error> { | ||
audioTranscriptions(query: query, completion: $0) | ||
} | ||
.eraseToAnyPublisher() | ||
} | ||
|
||
func audioTranslations(query: AudioTranslationQuery) -> AnyPublisher<AudioTranslationResult, Error> { | ||
Future<AudioTranslationResult, Error> { | ||
audioTranslations(query: query, completion: $0) | ||
} | ||
.eraseToAnyPublisher() | ||
} | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// | ||
// XCTestCase+Extensions.swift | ||
// | ||
// | ||
// Created by Sergii Kryvoblotskyi on 04/04/2023. | ||
// | ||
|
||
#if canImport(Combine) | ||
|
||
import XCTest | ||
import Combine | ||
|
||
//Borrowed from here: https://www.swiftbysundell.com/articles/unit-testing-combine-based-swift-code/ | ||
extension XCTestCase { | ||
|
||
func awaitPublisher<T: Publisher>( | ||
_ publisher: T, | ||
timeout: TimeInterval = 10, | ||
file: StaticString = #file, | ||
line: UInt = #line | ||
) throws -> T.Output { | ||
// This time, we use Swift's Result type to keep track | ||
// of the result of our Combine pipeline: | ||
var result: Result<T.Output, Error>? | ||
let expectation = self.expectation(description: "Awaiting publisher") | ||
|
||
let cancellable = publisher.sink( | ||
receiveCompletion: { completion in | ||
switch completion { | ||
case .failure(let error): | ||
result = .failure(error) | ||
case .finished: | ||
break | ||
} | ||
|
||
expectation.fulfill() | ||
}, | ||
receiveValue: { value in | ||
result = .success(value) | ||
} | ||
) | ||
|
||
// Just like before, we await the expectation that we | ||
// created at the top of our test, and once done, we | ||
// also cancel our cancellable to avoid getting any | ||
// unused variable warnings: | ||
waitForExpectations(timeout: timeout) | ||
cancellable.cancel() | ||
|
||
// Here we pass the original file and line number that | ||
// our utility was called at, to tell XCTest to report | ||
// any encountered errors at that original call site: | ||
let unwrappedResult = try XCTUnwrap( | ||
result, | ||
"Awaited publisher did not produce any output", | ||
file: file, | ||
line: line | ||
) | ||
|
||
return try unwrappedResult.get() | ||
} | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// | ||
// OpenAITestsCombine.swift | ||
// | ||
// | ||
// Created by Sergii Kryvoblotskyi on 04/04/2023. | ||
// | ||
|
||
#if canImport(Combine) | ||
|
||
import XCTest | ||
@testable import OpenAI | ||
|
||
@available(iOS 13.0, *) | ||
@available(watchOS 6.0, *) | ||
@available(tvOS 13.0, *) | ||
final class OpenAITestsCombine: XCTestCase { | ||
|
||
var openAI: OpenAIProtocol! | ||
var urlSession: URLSessionMock! | ||
|
||
override func setUp() { | ||
super.setUp() | ||
self.urlSession = URLSessionMock() | ||
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14) | ||
self.openAI = OpenAI(configuration: configuration, session: self.urlSession) | ||
} | ||
|
||
func testCompletions() throws { | ||
let query = CompletionsQuery(model: .textDavinci_003, prompt: "What is 42?", temperature: 0, maxTokens: 100, topP: 1, frequencyPenalty: 0, presencePenalty: 0, stop: ["\\n"]) | ||
let expectedResult = CompletionsResult(id: "foo", object: "bar", created: 100500, model: .babbage, choices: [ | ||
.init(text: "42 is the answer to everything", index: 0) | ||
]) | ||
try self.stub(result: expectedResult) | ||
|
||
let result = try awaitPublisher(self.openAI.completions(query: query)) | ||
XCTAssertEqual(result, expectedResult) | ||
} | ||
|
||
func testChats() throws { | ||
let query = ChatQuery(model: .gpt4, messages: [ | ||
.init(role: .system, content: "You are Librarian-GPT. You know everything about the books."), | ||
.init(role: .user, content: "Who wrote Harry Potter?") | ||
]) | ||
let chatResult = ChatResult(id: "id-12312", object: "foo", created: 100, model: .gpt3_5Turbo, choices: [ | ||
.init(index: 0, message: .init(role: "foo", content: "bar"), finishReason: "baz"), | ||
.init(index: 0, message: .init(role: "foo1", content: "bar1"), finishReason: "baz1"), | ||
.init(index: 0, message: .init(role: "foo2", content: "bar2"), finishReason: "baz2") | ||
], usage: .init(promptTokens: 100, completionTokens: 200, totalTokens: 300)) | ||
try self.stub(result: chatResult) | ||
let result = try awaitPublisher(openAI.chats(query: query)) | ||
XCTAssertEqual(result, chatResult) | ||
} | ||
|
||
func testEmbeddings() throws { | ||
let query = EmbeddingsQuery(model: .textSearchBabbadgeDoc, input: "The food was delicious and the waiter...") | ||
let embeddingsResult = EmbeddingsResult(data: [ | ||
.init(object: "id-sdasd", embedding: [0.1, 0.2, 0.3, 0.4], index: 0), | ||
.init(object: "id-sdasd1", embedding: [0.4, 0.1, 0.7, 0.1], index: 1), | ||
.init(object: "id-sdasd2", embedding: [0.8, 0.1, 0.2, 0.8], index: 2) | ||
]) | ||
try self.stub(result: embeddingsResult) | ||
|
||
let result = try awaitPublisher(openAI.embeddings(query: query)) | ||
XCTAssertEqual(result, embeddingsResult) | ||
} | ||
|
||
func testAudioTransriptions() throws { | ||
let data = Data() | ||
let query = AudioTranscriptionQuery(file: data, fileName: "audio.m4a", model: .whisper_1) | ||
let transcriptionResult = AudioTranscriptionResult(text: "Hello, world!") | ||
try self.stub(result: transcriptionResult) | ||
|
||
let result = try awaitPublisher(openAI.audioTranscriptions(query: query)) | ||
XCTAssertEqual(result, transcriptionResult) | ||
} | ||
|
||
func testAudioTranslations() throws { | ||
let data = Data() | ||
let query = AudioTranslationQuery(file: data, fileName: "audio.m4a", model: .whisper_1) | ||
let transcriptionResult = AudioTranslationResult(text: "Hello, world!") | ||
try self.stub(result: transcriptionResult) | ||
|
||
let result = try awaitPublisher(openAI.audioTranslations(query: query)) | ||
XCTAssertEqual(result, transcriptionResult) | ||
} | ||
} | ||
|
||
@available(iOS 13.0, *) | ||
extension OpenAITestsCombine { | ||
|
||
func stub(error: Error) { | ||
let error = APIError(message: "foo", type: "bar", param: "baz", code: "100") | ||
let task = DataTaskMock.failed(with: error) | ||
self.urlSession.dataTask = task | ||
} | ||
|
||
func stub(result: Codable) throws { | ||
let encoder = JSONEncoder() | ||
let data = try encoder.encode(result) | ||
let task = DataTaskMock.successful(with: data) | ||
self.urlSession.dataTask = task | ||
} | ||
} | ||
|
||
#endif |