Skip to content

Commit

Permalink
Merge pull request #34 from MacPaw/combine
Browse files Browse the repository at this point in the history
Add combine extensions
  • Loading branch information
Krivoblotsky authored Apr 4, 2023
2 parents 5f4723d + cd52799 commit e3f1a99
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ final public class OpenAI: OpenAIProtocol {
performRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
}

public func audioTransciptions(query: AudioTranscriptionQuery, completion: @escaping (Result<AudioTranscriptionResult, Error>) -> Void) {
public func audioTranscriptions(query: AudioTranscriptionQuery, completion: @escaping (Result<AudioTranscriptionResult, Error>) -> Void) {
performRequest(request: MultipartFormDataRequest<AudioTranscriptionResult>(body: query, url: buildURL(path: .audioTranscriptions)), completion: completion)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// OpenAIProtocol+Extensions.swift
// OpenAIProtocol+Async.swift
//
//
// Created by Maxime Maheo on 10/02/2023.
Expand Down Expand Up @@ -72,11 +72,11 @@ public extension OpenAIProtocol {
}
}

func audioTransciptions(
func audioTranscriptions(
query: AudioTranscriptionQuery
) async throws -> AudioTranscriptionResult {
try await withCheckedThrowingContinuation { continuation in
audioTransciptions(query: query) { result in
audioTranscriptions(query: query) { result in
switch result {
case let .success(success):
return continuation.resume(returning: success)
Expand Down
61 changes: 61 additions & 0 deletions Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//
// File.swift
//
//
// Created by Sergii Kryvoblotskyi on 03/04/2023.
//

#if canImport(Combine)

import Combine

@available(iOS 13.0, *)
@available(tvOS 13.0, *)
@available(macOS 10.15, *)
@available(watchOS 6.0, *)
public extension OpenAIProtocol {

func completions(query: CompletionsQuery) -> AnyPublisher<CompletionsResult, Error> {
Future<CompletionsResult, Error> {
completions(query: query, completion: $0)
}
.eraseToAnyPublisher()
}

func images(query: ImagesQuery) -> AnyPublisher<ImagesResult, Error> {
Future<ImagesResult, Error> {
images(query: query, completion: $0)
}
.eraseToAnyPublisher()
}

func embeddings(query: EmbeddingsQuery) -> AnyPublisher<EmbeddingsResult, Error> {
Future<EmbeddingsResult, Error> {
embeddings(query: query, completion: $0)
}
.eraseToAnyPublisher()
}

func chats(query: ChatQuery) -> AnyPublisher<ChatResult, Error> {
Future<ChatResult, Error> {
chats(query: query, completion: $0)
}
.eraseToAnyPublisher()
}

func audioTranscriptions(query: AudioTranscriptionQuery) -> AnyPublisher<AudioTranscriptionResult, Error> {
Future<AudioTranscriptionResult, Error> {
audioTranscriptions(query: query, completion: $0)
}
.eraseToAnyPublisher()
}

func audioTranslations(query: AudioTranslationQuery) -> AnyPublisher<AudioTranslationResult, Error> {
Future<AudioTranslationResult, Error> {
audioTranslations(query: query, completion: $0)
}
.eraseToAnyPublisher()
}
}

#endif
2 changes: 1 addition & 1 deletion Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public protocol OpenAIProtocol {
- Parameter completion: The completion handler to be executed upon completion of the transcription request.
Returns a `Result` of type `AudioTranscriptionResult` if successful, or an `Error` if an error occurs.
**/
func audioTransciptions(query: AudioTranscriptionQuery, completion: @escaping (Result<AudioTranscriptionResult, Error>) -> Void)
func audioTranscriptions(query: AudioTranscriptionQuery, completion: @escaping (Result<AudioTranscriptionResult, Error>) -> Void)

/**
Translates audio data using OpenAI's audio translation API and completes the operation asynchronously.
Expand Down
64 changes: 64 additions & 0 deletions Tests/OpenAITests/Extensions/XCTestCase+Extensions.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//
// XCTestCase+Extensions.swift
//
//
// Created by Sergii Kryvoblotskyi on 04/04/2023.
//

#if canImport(Combine)

import XCTest
import Combine

//Borrowed from here: https://www.swiftbysundell.com/articles/unit-testing-combine-based-swift-code/
extension XCTestCase {

func awaitPublisher<T: Publisher>(
_ publisher: T,
timeout: TimeInterval = 10,
file: StaticString = #file,
line: UInt = #line
) throws -> T.Output {
// This time, we use Swift's Result type to keep track
// of the result of our Combine pipeline:
var result: Result<T.Output, Error>?
let expectation = self.expectation(description: "Awaiting publisher")

let cancellable = publisher.sink(
receiveCompletion: { completion in
switch completion {
case .failure(let error):
result = .failure(error)
case .finished:
break
}

expectation.fulfill()
},
receiveValue: { value in
result = .success(value)
}
)

// Just like before, we await the expectation that we
// created at the top of our test, and once done, we
// also cancel our cancellable to avoid getting any
// unused variable warnings:
waitForExpectations(timeout: timeout)
cancellable.cancel()

// Here we pass the original file and line number that
// our utility was called at, to tell XCTest to report
// any encountered errors at that original call site:
let unwrappedResult = try XCTUnwrap(
result,
"Awaited publisher did not produce any output",
file: file,
line: line
)

return try unwrappedResult.get()
}
}

#endif
8 changes: 5 additions & 3 deletions Tests/OpenAITests/OpenAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import XCTest
@available(iOS 13.0, *)
@available(watchOS 6.0, *)
@available(tvOS 13.0, *)
final class OpenAITests: XCTestCase {
class OpenAITests: XCTestCase {

var openAI: OpenAIProtocol!
var urlSession: URLSessionMock!
Expand Down Expand Up @@ -111,7 +111,7 @@ final class OpenAITests: XCTestCase {
let transcriptionResult = AudioTranscriptionResult(text: "Hello, world!")
try self.stub(result: transcriptionResult)

let result = try await openAI.audioTransciptions(query: query)
let result = try await openAI.audioTranscriptions(query: query)
XCTAssertEqual(result, transcriptionResult)
}

Expand All @@ -121,7 +121,7 @@ final class OpenAITests: XCTestCase {
let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100")
self.stub(error: inError)

let apiError: APIError = try await XCTExpectError { try await openAI.audioTransciptions(query: query) }
let apiError: APIError = try await XCTExpectError { try await openAI.audioTranscriptions(query: query) }
XCTAssertEqual(inError, apiError)
}

Expand Down Expand Up @@ -197,6 +197,7 @@ final class OpenAITests: XCTestCase {
}
}

@available(iOS 13.0, *)
extension OpenAITests {

func stub(error: Error) {
Expand All @@ -213,6 +214,7 @@ extension OpenAITests {
}
}

@available(iOS 13.0, *)
extension OpenAITests {

enum TypeError: Error {
Expand Down
105 changes: 105 additions & 0 deletions Tests/OpenAITests/OpenAITestsCombine.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//
// OpenAITestsCombine.swift
//
//
// Created by Sergii Kryvoblotskyi on 04/04/2023.
//

#if canImport(Combine)

import XCTest
@testable import OpenAI

@available(iOS 13.0, *)
@available(watchOS 6.0, *)
@available(tvOS 13.0, *)
final class OpenAITestsCombine: XCTestCase {

var openAI: OpenAIProtocol!
var urlSession: URLSessionMock!

override func setUp() {
super.setUp()
self.urlSession = URLSessionMock()
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", timeoutInterval: 14)
self.openAI = OpenAI(configuration: configuration, session: self.urlSession)
}

func testCompletions() throws {
let query = CompletionsQuery(model: .textDavinci_003, prompt: "What is 42?", temperature: 0, maxTokens: 100, topP: 1, frequencyPenalty: 0, presencePenalty: 0, stop: ["\\n"])
let expectedResult = CompletionsResult(id: "foo", object: "bar", created: 100500, model: .babbage, choices: [
.init(text: "42 is the answer to everything", index: 0)
])
try self.stub(result: expectedResult)

let result = try awaitPublisher(self.openAI.completions(query: query))
XCTAssertEqual(result, expectedResult)
}

func testChats() throws {
let query = ChatQuery(model: .gpt4, messages: [
.init(role: .system, content: "You are Librarian-GPT. You know everything about the books."),
.init(role: .user, content: "Who wrote Harry Potter?")
])
let chatResult = ChatResult(id: "id-12312", object: "foo", created: 100, model: .gpt3_5Turbo, choices: [
.init(index: 0, message: .init(role: "foo", content: "bar"), finishReason: "baz"),
.init(index: 0, message: .init(role: "foo1", content: "bar1"), finishReason: "baz1"),
.init(index: 0, message: .init(role: "foo2", content: "bar2"), finishReason: "baz2")
], usage: .init(promptTokens: 100, completionTokens: 200, totalTokens: 300))
try self.stub(result: chatResult)
let result = try awaitPublisher(openAI.chats(query: query))
XCTAssertEqual(result, chatResult)
}

func testEmbeddings() throws {
let query = EmbeddingsQuery(model: .textSearchBabbadgeDoc, input: "The food was delicious and the waiter...")
let embeddingsResult = EmbeddingsResult(data: [
.init(object: "id-sdasd", embedding: [0.1, 0.2, 0.3, 0.4], index: 0),
.init(object: "id-sdasd1", embedding: [0.4, 0.1, 0.7, 0.1], index: 1),
.init(object: "id-sdasd2", embedding: [0.8, 0.1, 0.2, 0.8], index: 2)
])
try self.stub(result: embeddingsResult)

let result = try awaitPublisher(openAI.embeddings(query: query))
XCTAssertEqual(result, embeddingsResult)
}

func testAudioTransriptions() throws {
let data = Data()
let query = AudioTranscriptionQuery(file: data, fileName: "audio.m4a", model: .whisper_1)
let transcriptionResult = AudioTranscriptionResult(text: "Hello, world!")
try self.stub(result: transcriptionResult)

let result = try awaitPublisher(openAI.audioTranscriptions(query: query))
XCTAssertEqual(result, transcriptionResult)
}

func testAudioTranslations() throws {
let data = Data()
let query = AudioTranslationQuery(file: data, fileName: "audio.m4a", model: .whisper_1)
let transcriptionResult = AudioTranslationResult(text: "Hello, world!")
try self.stub(result: transcriptionResult)

let result = try awaitPublisher(openAI.audioTranslations(query: query))
XCTAssertEqual(result, transcriptionResult)
}
}

@available(iOS 13.0, *)
extension OpenAITestsCombine {

func stub(error: Error) {
let error = APIError(message: "foo", type: "bar", param: "baz", code: "100")
let task = DataTaskMock.failed(with: error)
self.urlSession.dataTask = task
}

func stub(result: Codable) throws {
let encoder = JSONEncoder()
let data = try encoder.encode(result)
let task = DataTaskMock.successful(with: data)
self.urlSession.dataTask = task
}
}

#endif

0 comments on commit e3f1a99

Please sign in to comment.