diff --git a/common/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/common/Request.kt b/common/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/common/Request.kt index 53141321..084b428c 100644 --- a/common/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/common/Request.kt +++ b/common/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/common/Request.kt @@ -20,23 +20,46 @@ import dev.shreyaspatil.ai.client.generativeai.common.client.Tool import dev.shreyaspatil.ai.client.generativeai.common.client.ToolConfig import dev.shreyaspatil.ai.client.generativeai.common.shared.Content import dev.shreyaspatil.ai.client.generativeai.common.shared.SafetySetting +import dev.shreyaspatil.ai.client.generativeai.common.util.fullModelName import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable -import kotlinx.serialization.Transient sealed interface Request @Serializable data class GenerateContentRequest( - @Transient val model: String? = null, - val contents: List, - @SerialName("safety_settings") val safetySettings: List? = null, - @SerialName("generation_config") val generationConfig: GenerationConfig? = null, - val tools: List? = null, - @SerialName("tool_config") var toolConfig: ToolConfig? = null, - @SerialName("system_instruction") val systemInstruction: Content? = null, + val model: String? = null, + val contents: List, + @SerialName("safety_settings") val safetySettings: List? = null, + @SerialName("generation_config") val generationConfig: GenerationConfig? = null, + val tools: List? = null, + @SerialName("tool_config") var toolConfig: ToolConfig? = null, + @SerialName("system_instruction") val systemInstruction: Content? = null, ) : Request @Serializable -data class CountTokensRequest(@Transient val model: String? = null, val contents: List) : - Request +data class CountTokensRequest( + val generateContentRequest: GenerateContentRequest? = null, + val model: String? = null, + val contents: List? = null, + val tools: List? = null, + @SerialName("system_instruction") val systemInstruction: Content? = null, +) : Request { + companion object { + fun forGenAI(generateContentRequest: GenerateContentRequest) = + CountTokensRequest( + generateContentRequest = + generateContentRequest.model?.let { + generateContentRequest.copy(model = fullModelName(it)) + } ?: generateContentRequest + ) + + fun forVertexAI(generateContentRequest: GenerateContentRequest) = + CountTokensRequest( + model = generateContentRequest.model?.let { fullModelName(it) }, + contents = generateContentRequest.contents, + tools = generateContentRequest.tools, + systemInstruction = generateContentRequest.systemInstruction, + ) + } +} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt index ebac13fd..b23faa76 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -19,6 +19,7 @@ package com.google.ai.client.generativeai.common import android.util.Log import com.google.ai.client.generativeai.common.server.FinishReason import com.google.ai.client.generativeai.common.util.decodeToFlow +import com.google.ai.client.generativeai.common.util.fullModelName import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.engine.HttpClientEngine @@ -213,13 +214,6 @@ interface HeaderProvider { suspend fun generateHeaders(): Map } -/** - * Ensures the model name provided has a `models/` prefix - * - * Models must be prepended with the `models/` prefix when communicating with the backend. - */ -private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name" - private suspend fun validateResponse(response: HttpResponse) { if (response.status == HttpStatusCode.OK) return val text = response.bodyAsText() diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/util.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/util.kt new file mode 100644 index 00000000..56a8a97d --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/util.kt @@ -0,0 +1,24 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.util + +/** + * Ensures the model name provided has a `models/` prefix + * + * Models must be prepended with the `models/` prefix when communicating with the backend. + */ +fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name" diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index 31171ece..776e0d0c 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -25,7 +25,6 @@ import com.google.ai.client.generativeai.common.util.createResponses import com.google.ai.client.generativeai.common.util.doBlocking import com.google.ai.client.generativeai.common.util.prepareStreamingResponse import io.kotest.assertions.json.shouldContainJsonKey -import io.kotest.assertions.json.shouldNotContainJsonKey import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain @@ -135,60 +134,6 @@ internal class RequestFormatTests { mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint" } - @Test - fun `generateContentRequest doesn't include the model name`() = doBlocking { - val channel = ByteChannel(autoFlush = true) - val mockEngine = MockEngine { - respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) - } - prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } - val controller = - APIController( - "super_cool_test_key", - "gemini-pro-1.0", - RequestOptions(), - mockEngine, - TEST_CLIENT_ID, - null, - ) - - withTimeout(5.seconds) { - controller.generateContentStream(textGenerateContentRequest("cats")).collect { - it.candidates?.isEmpty() shouldBe false - channel.close() - } - } - - val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text - requestBodyAsText shouldContainJsonKey "contents" - requestBodyAsText shouldNotContainJsonKey "model" - } - - @Test - fun `countTokenRequest doesn't include the model name`() = doBlocking { - val response = - JSON.encodeToString(CountTokensResponse(totalTokens = 10, totalBillableCharacters = 10)) - val mockEngine = MockEngine { - respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) - } - - val controller = - APIController( - "super_cool_test_key", - "gemini-pro-1.0", - RequestOptions(), - mockEngine, - TEST_CLIENT_ID, - null, - ) - - withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) } - - val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text - requestBodyAsText shouldContainJsonKey "contents" - requestBodyAsText shouldNotContainJsonKey "model" - } - @Test fun `client id header is set correctly in the request`() = doBlocking { val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10)) @@ -367,4 +312,4 @@ fun textGenerateContentRequest(prompt: String) = ) fun textCountTokenRequest(prompt: String) = - CountTokensRequest(model = "unused", contents = listOf(Content(parts = listOf(TextPart(prompt))))) + CountTokensRequest(generateContentRequest = textGenerateContentRequest(prompt)) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt new file mode 100644 index 00000000..6ea9df87 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -0,0 +1,256 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai + +import android.graphics.Bitmap +import com.google.ai.client.generativeai.common.APIController +import com.google.ai.client.generativeai.common.CountTokensRequest +import com.google.ai.client.generativeai.common.GenerateContentRequest +import com.google.ai.client.generativeai.common.util.fullModelName +import com.google.ai.client.generativeai.internal.util.toInternal +import com.google.ai.client.generativeai.internal.util.toPublic +import com.google.ai.client.generativeai.type.Content +import com.google.ai.client.generativeai.type.CountTokensResponse +import com.google.ai.client.generativeai.type.FinishReason +import com.google.ai.client.generativeai.type.FourParameterFunction +import com.google.ai.client.generativeai.type.FunctionCallPart +import com.google.ai.client.generativeai.type.GenerateContentResponse +import com.google.ai.client.generativeai.type.GenerationConfig +import com.google.ai.client.generativeai.type.GoogleGenerativeAIException +import com.google.ai.client.generativeai.type.InvalidStateException +import com.google.ai.client.generativeai.type.NoParameterFunction +import com.google.ai.client.generativeai.type.OneParameterFunction +import com.google.ai.client.generativeai.type.PromptBlockedException +import com.google.ai.client.generativeai.type.RequestOptions +import com.google.ai.client.generativeai.type.ResponseStoppedException +import com.google.ai.client.generativeai.type.SafetySetting +import com.google.ai.client.generativeai.type.SerializationException +import com.google.ai.client.generativeai.type.ThreeParameterFunction +import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.ToolConfig +import com.google.ai.client.generativeai.type.TwoParameterFunction +import com.google.ai.client.generativeai.type.content +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.map +import kotlinx.serialization.ExperimentalSerializationApi +import org.json.JSONObject + +/** + * A facilitator for a given multimodal model (eg; Gemini). + * + * @property modelName name of the model in the backend + * @property apiKey authentication key for interacting with the backend + * @property generationConfig configuration parameters to use for content generation + * @property safetySettings the safety bounds to use during alongside prompts during content + * generation + * @property systemInstruction contains a [Content] that directs the model to behave a certain way + * @property requestOptions configuration options to utilize during backend communication + */ +@OptIn(ExperimentalSerializationApi::class) +class GenerativeModel +internal constructor( + val modelName: String, + val apiKey: String, + val generationConfig: GenerationConfig? = null, + val safetySettings: List? = null, + val tools: List? = null, + val toolConfig: ToolConfig? = null, + val systemInstruction: Content? = null, + val requestOptions: RequestOptions = RequestOptions(), + private val controller: APIController, +) { + + @JvmOverloads + constructor( + modelName: String, + apiKey: String, + generationConfig: GenerationConfig? = null, + safetySettings: List? = null, + requestOptions: RequestOptions = RequestOptions(), + tools: List? = null, + toolConfig: ToolConfig? = null, + systemInstruction: Content? = null, + ) : this( + fullModelName(modelName), + apiKey, + generationConfig, + safetySettings, + tools, + toolConfig, + systemInstruction?.let { Content("system", it.parts) }, + requestOptions, + APIController( + apiKey, + modelName, + requestOptions.toInternal(), + "genai-android/${BuildConfig.VERSION_NAME}", + ), + ) + + /** + * Generates a response from the backend with the provided [Content]s. + * + * @param prompt A group of [Content]s to send to the model. + * @return A [GenerateContentResponse] after some delay. Function should be called within a + * suspend context to properly manage concurrency. + */ + suspend fun generateContent(vararg prompt: Content): GenerateContentResponse = + try { + controller.generateContent(constructRequest(*prompt)).toPublic().validate() + } catch (e: Throwable) { + throw GoogleGenerativeAIException.from(e) + } + + /** + * Generates a streaming response from the backend with the provided [Content]s. + * + * @param prompt A group of [Content]s to send to the model. + * @return A [Flow] which will emit responses as they are returned from the model. + */ + fun generateContentStream(vararg prompt: Content): Flow = + controller + .generateContentStream(constructRequest(*prompt)) + .catch { throw GoogleGenerativeAIException.from(it) } + .map { it.toPublic().validate() } + + /** + * Generates a response from the backend with the provided text represented [Content]. + * + * @param prompt The text to be converted into a single piece of [Content] to send to the model. + * @return A [GenerateContentResponse] after some delay. Function should be called within a + * suspend context to properly manage concurrency. + */ + suspend fun generateContent(prompt: String): GenerateContentResponse = + generateContent(content { text(prompt) }) + + /** + * Generates a streaming response from the backend with the provided text represented [Content]. + * + * @param prompt The text to be converted into a single piece of [Content] to send to the model. + * @return A [Flow] which will emit responses as they are returned from the model. + */ + fun generateContentStream(prompt: String): Flow = + generateContentStream(content { text(prompt) }) + + /** + * Generates a response from the backend with the provided bitmap represented [Content]. + * + * @param prompt The bitmap to be converted into a single piece of [Content] to send to the model. + * @return A [GenerateContentResponse] after some delay. Function should be called within a + * suspend context to properly manage concurrency. + */ + suspend fun generateContent(prompt: Bitmap): GenerateContentResponse = + generateContent(content { image(prompt) }) + + /** + * Generates a streaming response from the backend with the provided bitmap represented [Content]. + * + * @param prompt The bitmap to be converted into a single piece of [Content] to send to the model. + * @return A [Flow] which will emit responses as they are returned from the model. + */ + fun generateContentStream(prompt: Bitmap): Flow = + generateContentStream(content { image(prompt) }) + + /** Creates a chat instance which internally tracks the ongoing conversation with the model */ + fun startChat(history: List = emptyList()): Chat = Chat(this, history.toMutableList()) + + /** + * Counts the number of tokens used in a prompt. + * + * @param prompt A group of [Content]s to count tokens of. + * @return A [CountTokensResponse] containing the number of tokens in the prompt. + */ + suspend fun countTokens(vararg prompt: Content): CountTokensResponse { + return controller.countTokens(constructCountTokensRequest(*prompt)).toPublic() + } + + /** + * Counts the number of tokens used in a prompt. + * + * @param prompt The text to be converted to a single piece of [Content] to count the tokens of. + * @return A [CountTokensResponse] containing the number of tokens in the prompt. + */ + suspend fun countTokens(prompt: String): CountTokensResponse { + return countTokens(content { text(prompt) }) + } + + /** + * Counts the number of tokens used in a prompt. + * + * @param prompt The image to be converted to a single piece of [Content] to count the tokens of. + * @return A [CountTokensResponse] containing the number of tokens in the prompt. + */ + suspend fun countTokens(prompt: Bitmap): CountTokensResponse { + return countTokens(content { image(prompt) }) + } + + /** + * Executes a function requested by the model. + * + * @param functionCallPart A [FunctionCallPart] from the model, containing a function call and + * parameters + * @return The output of the requested function call + */ + suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject { + if (tools == null) { + throw InvalidStateException("No registered tools") + } + val callable = + tools.flatMap { it.functionDeclarations }.firstOrNull { it.name == functionCallPart.name } + ?: throw InvalidStateException("No registered function named ${functionCallPart.name}") + return when (callable) { + is NoParameterFunction -> callable.execute() + is OneParameterFunction<*> -> + (callable as OneParameterFunction).execute(functionCallPart) + is TwoParameterFunction<*, *> -> + (callable as TwoParameterFunction).execute(functionCallPart) + is ThreeParameterFunction<*, *, *> -> + (callable as ThreeParameterFunction).execute(functionCallPart) + is FourParameterFunction<*, *, *, *> -> + (callable as FourParameterFunction).execute(functionCallPart) + else -> { + throw RuntimeException("UNREACHABLE") + } + } + } + + private fun constructRequest(vararg prompt: Content) = + GenerateContentRequest( + modelName, + prompt.map { it.toInternal() }, + safetySettings?.map { it.toInternal() }, + generationConfig?.toInternal(), + tools?.map { it.toInternal() }, + toolConfig?.toInternal(), + systemInstruction?.toInternal(), + ) + + private fun constructCountTokensRequest(vararg prompt: Content) = + CountTokensRequest.forGenAI(constructRequest(*prompt)) + + private fun GenerateContentResponse.validate() = apply { + if (candidates.isEmpty() && promptFeedback == null) { + throw SerializationException("Error deserializing response, found no valid fields") + } + promptFeedback?.blockReason?.let { throw PromptBlockedException(this) } + candidates + .mapNotNull { it.finishReason } + .firstOrNull { it != FinishReason.STOP } + ?.let { throw ResponseStoppedException(this) } + } +}