diff --git a/Sources/OpenAIKit/API.swift b/Sources/OpenAIKit/API.swift index 0d0385a..90373a8 100644 --- a/Sources/OpenAIKit/API.swift +++ b/Sources/OpenAIKit/API.swift @@ -3,15 +3,18 @@ import Foundation public struct API { public let scheme: Scheme public let host: String + public let port: Int? public let path: String? public init( scheme: API.Scheme, host: String, + port: Int? = nil, pathPrefix path: String? = nil ) { self.scheme = scheme self.host = host + self.port = port self.path = path } } diff --git a/Sources/OpenAIKit/Chat/Chat.swift b/Sources/OpenAIKit/Chat/Chat.swift index f0dde58..b181a92 100644 --- a/Sources/OpenAIKit/Chat/Chat.swift +++ b/Sources/OpenAIKit/Chat/Chat.swift @@ -20,6 +20,14 @@ extension Chat { public let message: Message public let finishReason: FinishReason? } + + public struct ImageUrl: Codable { + public let url: String + + public init(url: String) { + self.url = url + } + } } extension Chat.Choice: Codable {} @@ -30,6 +38,17 @@ extension Chat { case user(content: String) case assistant(content: String) } + + public enum MessageWithImage { + case system(content: String) + case user(content: [Content]) + case assistant(content: String) + } + + public enum Content { + case text(String) + case imageUrl(ImageUrl) + } } extension Chat.Message: Codable { @@ -87,3 +106,91 @@ extension Chat.Message { } } } + +extension Chat.MessageWithImage: Codable { + private enum CodingKeys: String, CodingKey { + case role + case content + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let role = try container.decode(String.self, forKey: .role) + switch role { + case "system": + let content = try container.decode(String.self, forKey: .content) + self = .system(content: content) + case "user": + let content = try container.decode([Chat.Content].self, forKey: .content) + self = .user(content: content) + case "assistant": + let content = try container.decode(String.self, forKey: .content) + self = .assistant(content: content) + default: + throw DecodingError.dataCorruptedError(forKey: .role, in: container, debugDescription: "Invalid role") + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case .system(content: let content): + try container.encode("system", forKey: .role) + try container.encode(content, forKey: .content) + case .assistant(content: let content): + try container.encode("assistant", forKey: .role) + try container.encode(content, forKey: .content) + case .user(content: let content): + try container.encode("user", forKey: .role) + try container.encode(content, forKey: .content) + } + } +} + +extension Chat.Content: Codable { + private enum CodingKeys: String, CodingKey { + case type + case text + case imageUrl = "image_url" + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(String.self, forKey: .type) + switch type { + case "text": + let text = try container.decode(String.self, forKey: .text) + self = .text(text) + case "image_url": + let imageUrl = try container.decode(Chat.ImageUrl.self, forKey: .imageUrl) + self = .imageUrl(imageUrl) + default: + throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Invalid type") + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case .text(let text): + try container.encode("text", forKey: .type) + try container.encode(text, forKey: .text) + case .imageUrl(let imageUrl): + try container.encode("image_url", forKey: .type) + try container.encode(imageUrl, forKey: .imageUrl) + } + } +} + +extension Chat.Content: Equatable { + public static func ==(lhs: Chat.Content, rhs: Chat.Content) -> Bool { + switch (lhs, rhs) { + case (.text(let lhsText), .text(let rhsText)): + return lhsText == rhsText + case (.imageUrl(let lhsUrl), .imageUrl(let rhsUrl)): + return lhsUrl.url == rhsUrl.url + default: + return false + } + } +} diff --git a/Sources/OpenAIKit/Chat/ChatProvider.swift b/Sources/OpenAIKit/Chat/ChatProvider.swift index 5a6a236..7e65c70 100644 --- a/Sources/OpenAIKit/Chat/ChatProvider.swift +++ b/Sources/OpenAIKit/Chat/ChatProvider.swift @@ -44,7 +44,37 @@ public struct ChatProvider { ) return try await requestHandler.perform(request: request) - + } + + public func createWithImage( + model: ModelID, + message: [Chat.MessageWithImage] = [], + temperature: Double = 1.0, + topP: Double = 1.0, + n: Int = 1, + stops: [String] = [], + maxTokens: Int? = nil, + presencePenalty: Double = 0.0, + frequencyPenalty: Double = 0.0, + logitBias: [String : Int] = [:], + user: String? = nil + ) async throws -> Chat { + let request = try CreateChatWithImageRequest( + model: model.id, + messages: message, + temperature: temperature, + topP: topP, + n: n, + stream: false, + stops: stops, + maxTokens: maxTokens, + presencePenalty: presencePenalty, + frequencyPenalty: frequencyPenalty, + logitBias: logitBias, + user: user + ) + + return try await requestHandler.perform(request: request) } /** diff --git a/Sources/OpenAIKit/Chat/CreateChatWithImageRequest.swift b/Sources/OpenAIKit/Chat/CreateChatWithImageRequest.swift new file mode 100644 index 0000000..06b5e6c --- /dev/null +++ b/Sources/OpenAIKit/Chat/CreateChatWithImageRequest.swift @@ -0,0 +1,104 @@ +import AsyncHTTPClient +import NIOHTTP1 +import Foundation + +struct CreateChatWithImageRequest: Request { + let method: HTTPMethod = .POST + let path: String = "/v1/chat/completions" + let body: Data? + + init( + model: String, + messages: [Chat.MessageWithImage], + temperature: Double, + topP: Double, + n: Int, + stream: Bool, + stops: [String], + maxTokens: Int?, + presencePenalty: Double, + frequencyPenalty: Double, + logitBias: [String: Int], + user: String? + ) throws { + let body = Body( + model: model, + messages: messages, + temperature: temperature, + topP: topP, + n: n, + stream: stream, + stops: stops, + maxTokens: maxTokens, + presencePenalty: presencePenalty, + frequencyPenalty: frequencyPenalty, + logitBias: logitBias, + user: user + ) + + self.body = try Self.encoder.encode(body) + } +} + +extension CreateChatWithImageRequest { + struct Body: Encodable { + let model: String + let messages: [Chat.MessageWithImage] + let temperature: Double + let topP: Double + let n: Int + let stream: Bool + let stops: [String] + let maxTokens: Int? + let presencePenalty: Double + let frequencyPenalty: Double + let logitBias: [String: Int] + let user: String? + + enum CodingKeys: CodingKey { + case model + case messages + case temperature + case topP + case n + case stream + case stop + case maxTokens + case presencePenalty + case frequencyPenalty + case logitBias + case user + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(model, forKey: .model) + + if !messages.isEmpty { + try container.encode(messages, forKey: .messages) + } + + try container.encode(temperature, forKey: .temperature) + try container.encode(topP, forKey: .topP) + try container.encode(n, forKey: .n) + try container.encode(stream, forKey: .stream) + + if !stops.isEmpty { + try container.encode(stops, forKey: .stop) + } + + if let maxTokens { + try container.encode(maxTokens, forKey: .maxTokens) + } + + try container.encode(presencePenalty, forKey: .presencePenalty) + try container.encode(frequencyPenalty, forKey: .frequencyPenalty) + + if !logitBias.isEmpty { + try container.encode(logitBias, forKey: .logitBias) + } + + try container.encodeIfPresent(user, forKey: .user) + } + } +} diff --git a/Sources/OpenAIKit/Completion/CompletionProvider.swift b/Sources/OpenAIKit/Completion/CompletionProvider.swift index b332536..ce51abf 100644 --- a/Sources/OpenAIKit/Completion/CompletionProvider.swift +++ b/Sources/OpenAIKit/Completion/CompletionProvider.swift @@ -10,7 +10,7 @@ public struct CompletionProvider { Create completion POST - https://api.openai.com/v1/completions + https://api.openai.com/chat/completions Creates a completion for the provided prompt and parameters */ diff --git a/Sources/OpenAIKit/Completion/CreateCompletionRequest.swift b/Sources/OpenAIKit/Completion/CreateCompletionRequest.swift index decd04e..b1f39fc 100644 --- a/Sources/OpenAIKit/Completion/CreateCompletionRequest.swift +++ b/Sources/OpenAIKit/Completion/CreateCompletionRequest.swift @@ -4,7 +4,7 @@ import Foundation struct CreateCompletionRequest: Request { let method: HTTPMethod = .POST - let path = "/v1/completions" + let path = "/chat/completions" let body: Data? init( diff --git a/Sources/OpenAIKit/Model/Model.swift b/Sources/OpenAIKit/Model/Model.swift index f7cc45c..25e5e34 100644 --- a/Sources/OpenAIKit/Model/Model.swift +++ b/Sources/OpenAIKit/Model/Model.swift @@ -40,6 +40,8 @@ extension Model { case gpt40314 = "gpt-4-0314" case gpt4_32k = "gpt-4-32k" case gpt4_32k0314 = "gpt-4-32k-0314" + case gpt4_Turbo = "gpt-4-1106-preview" + case gpt4_vision = "gpt-4-vision-preview" } public enum GPT3: String, ModelID { diff --git a/Sources/OpenAIKit/RequestHandler/RequestHandler.swift b/Sources/OpenAIKit/RequestHandler/RequestHandler.swift index 7f43859..9d5f8a2 100644 --- a/Sources/OpenAIKit/RequestHandler/RequestHandler.swift +++ b/Sources/OpenAIKit/RequestHandler/RequestHandler.swift @@ -15,7 +15,8 @@ extension RequestHandler { components.path = [configuration.api?.path, request.path] .compactMap { $0 } .joined() - + components.port = configuration.api?.port + guard let url = components.url else { throw RequestHandlerError.invalidURLGenerated } diff --git a/Tests/OpenAIKitTests/MessageTests.swift b/Tests/OpenAIKitTests/MessageTests.swift index 230c018..51fcb22 100644 --- a/Tests/OpenAIKitTests/MessageTests.swift +++ b/Tests/OpenAIKitTests/MessageTests.swift @@ -70,6 +70,51 @@ final class MessageTests: XCTestCase { XCTFail("incorrect role") } } + + func testDecodingProvidedExample() throws { + let json = """ + [ + { + "role": "system", + "content": "You are Malcolm Tucker from The Thick of It, an unfriendly assistant for writing mail and explaining science and history. You write text in your voice for me." + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What’s in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + } + ] + } + ] + """.data(using: .utf8)! + + let messages = try JSONDecoder().decode([Chat.MessageWithImage].self, from: json) + + XCTAssertEqual(messages.count, 2) + + if case .system(let content) = messages[0] { + XCTAssertEqual(content, "You are Malcolm Tucker from The Thick of It, an unfriendly assistant for writing mail and explaining science and history. You write text in your voice for me.") + } else { + XCTFail("First Message is not a System Message") + } + + if case .user(let content) = messages[1] { + XCTAssertEqual(content.count, 2) + XCTAssertEqual(content[0], .text("What’s in this image?")) + XCTAssertEqual(content[1], .imageUrl(Chat.ImageUrl(url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"))) + } else { + XCTFail("Second Message is not a User Message") + } + } + func testMessageRoundtrip() throws { let message = Chat.Message.system(content: "You are a helpful assistant that translates English to French.") @@ -128,26 +173,4 @@ final class MessageTests: XCTestCase { XCTFail() } } - - func testChatRequest() throws { - let request = try CreateChatRequest( - model: "gpt-3.5-turbo", //.gpt3_5Turbo, - messages: [ - .system(content: "You are Malcolm Tucker from The Thick of It, an unfriendly assistant for writing mail and explaining science and history. You write text in your voice for me."), - .user(content: "tell me a joke"), - ], - temperature: 1.0, - topP: 1.0, - n: 1, - stream: false, - stops: [], - maxTokens: nil, - presencePenalty: 0.0, - frequencyPenalty: 0.0, - logitBias: [:], - user: nil - ) - - print(request.body) - } } diff --git a/Tests/OpenAIKitTests/RequestHandlerTests.swift b/Tests/OpenAIKitTests/RequestHandlerTests.swift index a4803e5..a4dbbd0 100644 --- a/Tests/OpenAIKitTests/RequestHandlerTests.swift +++ b/Tests/OpenAIKitTests/RequestHandlerTests.swift @@ -52,6 +52,21 @@ final class RequestHandlerTests: XCTestCase { XCTAssertEqual(url, "http://chat.openai.com/v1/test") } + func test_generateURL_configWithPort() throws { + let api = API(scheme: .http, host: "test.com", port: 8200) + let configuration = Configuration(apiKey: "TEST", api: api) + + let request = TestRequest( + scheme: .http, + host: "some-host", + path: "/v1/test" + ) + + let url = try requestHandler(configuration: configuration).generateURL(for: request) + + XCTAssertEqual(url, "http://test.com:8200/v1/test") + } + } private struct TestRequest: Request {