Skip to content

Commit

Permalink
Intial implementation of automated function calling
Browse files Browse the repository at this point in the history
includes a builder to make it easier to add functions
DO NOT MERGE
  • Loading branch information
David Motsonashvili committed Jan 30, 2024
1 parent e70a134 commit f95f85d
Show file tree
Hide file tree
Showing 11 changed files with 453 additions and 35 deletions.
118 changes: 86 additions & 32 deletions generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.google.ai.client.generativeai
import android.graphics.Bitmap
import com.google.ai.client.generativeai.type.BlobPart
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.FunctionResponsePart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.ImagePart
import com.google.ai.client.generativeai.type.InvalidStateException
Expand All @@ -27,8 +29,10 @@ import com.google.ai.client.generativeai.type.content
import java.util.LinkedList
import java.util.concurrent.Semaphore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.transform

/**
* Representation of a back and forth interaction with a model.
Expand All @@ -53,13 +57,25 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()
suspend fun sendMessage(inputPrompt: Content): GenerateContentResponse {
inputPrompt.assertComesFromUser()
attemptLock()
var response: GenerateContentResponse
var prompt = inputPrompt
try {
val response = model.generateContent(*history.toTypedArray(), prompt)
history.add(prompt)
history.add(response.candidates.first().content)
while (true) {
response = model.generateContent(*history.toTypedArray(), prompt)
val responsePart = response.candidates.first().content.parts.first()

history.add(prompt)
history.add(response.candidates.first().content)
if (responsePart is FunctionCallPart) {
val output = model.executeFunction(responsePart)
prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output)))
} else {
break
}
}
return response
} finally {
lock.release()
Expand Down Expand Up @@ -101,43 +117,21 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
attemptLock()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val bitmaps = LinkedList<Bitmap>()
val blobs = LinkedList<BlobPart>()
val text = StringBuilder()

val tempHistory = LinkedList<Content>()
tempHistory.add(prompt)
/**
* TODO: revisit when images and blobs are returned. This will cause issues with how things are
* structured in the response. eg; a text/image/text response will be (incorrectly)
* represented as image/text
*/
return flow
.onEach {
for (part in it.candidates.first().content.parts) {
when (part) {
is TextPart -> text.append(part.text)
is ImagePart -> bitmaps.add(part.image)
is BlobPart -> blobs.add(part)
}
}
.transform { response ->
automaticFunctionExecutingTransform(this, tempHistory, response)
}
.onCompletion {
lock.release()
if (it == null) {
val content =
content("model") {
for (bitmap in bitmaps) {
image(bitmap)
}
for (blob in blobs) {
blob(blob.mimeType, blob.blob)
}
if (text.isNotBlank()) {
text(text.toString())
}
}

history.add(prompt)
history.add(content)
history.addAll(tempHistory)
}
}
}
Expand Down Expand Up @@ -172,6 +166,66 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}
}

private suspend fun automaticFunctionExecutingTransform(
transformer: FlowCollector<GenerateContentResponse>,
tempHistory: LinkedList<Content>,
response: GenerateContentResponse
) {
for (part in response.candidates.first().content.parts) {
when (part) {
is TextPart -> {
transformer.emit(response)
addTextToHistory(tempHistory, part)
}
is ImagePart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is BlobPart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is FunctionCallPart -> {
val functionCall =
response.candidates.first().content.parts.first { it is FunctionCallPart }
as FunctionCallPart
val output = model.executeFunction(functionCall)
val functionResponse =
Content("function", listOf(FunctionResponsePart(functionCall.name, output)))
tempHistory.add(response.candidates.first().content)
tempHistory.add(functionResponse)
model
.generateContentStream(*history.toTypedArray(), *tempHistory.toTypedArray())
.collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) }
}
}
}
}

private fun addTextToHistory(tempHistory: LinkedList<Content>, textPart: TextPart) {
val lastContent = tempHistory.lastOrNull()
if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) {
tempHistory.removeLast()
val editedContent =
Content(
"model",
lastContent.parts.map {
when (it) {
is TextPart -> {
TextPart(it.text + textPart.text)
}
else -> {
it
}
}
}
)
tempHistory.add(editedContent)
return
}
tempHistory.add(Content("model", listOf(textPart)))
}

private fun attemptLock() {
if (!lock.tryAcquire()) {
throw InvalidStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,20 @@ import com.google.ai.client.generativeai.internal.util.toPublic
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.CountTokensResponse
import com.google.ai.client.generativeai.type.FinishReason
import com.google.ai.client.generativeai.type.FourParameterFunction
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.NoParameterFunction
import com.google.ai.client.generativeai.type.OneParameterFunction
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.ThreeParameterFunction
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.TwoParameterFunction
import com.google.ai.client.generativeai.type.content
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
Expand All @@ -52,6 +59,7 @@ internal constructor(
val apiKey: String,
val generationConfig: GenerationConfig? = null,
val safetySettings: List<SafetySetting>? = null,
val tools: List<Tool>? = null,
private val controller: APIController
) {

Expand All @@ -61,7 +69,15 @@ internal constructor(
apiKey: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
) : this(modelName, apiKey, generationConfig, safetySettings, APIController(apiKey, modelName))
tools: List<Tool>? = null
) : this(
modelName,
apiKey,
generationConfig,
safetySettings,
tools,
APIController(apiKey, modelName)
)

/**
* Generates a response from the backend with the provided [Content]s.
Expand Down Expand Up @@ -160,12 +176,64 @@ internal constructor(
return countTokens(content { image(prompt) })
}

/**
* Executes a function request by the model.
*
* @param call A [FunctionCallPart] from the model, containing a function call and parameters
* @return The output of the requested function call
*/
suspend fun executeFunction(call: FunctionCallPart): String {
if (tools == null) {
throw RuntimeException("No registered tools")
}
val tool = tools.first { it.functionDeclarations.any { it.name == call.name } }
val declaration =
tool.functionDeclarations.firstOrNull() { it.name == call.name }
?: throw RuntimeException("No registered function named ${call.name}")
return when (declaration) {
is NoParameterFunction -> {
declaration.function.invoke()
}
is OneParameterFunction -> {
val param1 = getParamOrThrow(declaration.param.name, call)
declaration.function.invoke(param1)
}
is TwoParameterFunction -> {
val param1 = getParamOrThrow(declaration.param1.name, call)
val param2 = getParamOrThrow(declaration.param2.name, call)
declaration.function.invoke(param1, param2)
}
is ThreeParameterFunction -> {
val param1 = getParamOrThrow(declaration.param1.name, call)
val param2 = getParamOrThrow(declaration.param2.name, call)
val param3 = getParamOrThrow(declaration.param3.name, call)
declaration.function.invoke(param1, param2, param3)
}
is FourParameterFunction -> {
val param1 = getParamOrThrow(declaration.param1.name, call)
val param2 = getParamOrThrow(declaration.param2.name, call)
val param3 = getParamOrThrow(declaration.param3.name, call)
val param4 = getParamOrThrow(declaration.param4.name, call)
declaration.function.invoke(param1, param2, param3, param4)
}
else -> {
throw RuntimeException("UNREACHABLE")
}
}
}

private fun getParamOrThrow(paramName: String, part: FunctionCallPart): String {
return part.args[paramName]
?: throw RuntimeException("Missing parameter named $paramName for function ${part.name}")
}

private fun constructRequest(vararg prompt: Content) =
GenerateContentRequest(
modelName,
prompt.map { it.toInternal() },
safetySettings?.map { it.toInternal() },
generationConfig?.toInternal()
generationConfig?.toInternal(),
tools?.map { it.toInternal() }
)

private fun constructCountTokensRequest(vararg prompt: Content) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.engine.okhttp.OkHttp
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.plugins.logging.ANDROID
import io.ktor.client.plugins.logging.LogLevel
import io.ktor.client.plugins.logging.Logger
import io.ktor.client.plugins.logging.Logging
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.post
Expand All @@ -44,7 +48,7 @@ import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json

// TODO: Should these stay here or be moved elsewhere?
internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1"
internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1beta"

internal val JSON = Json {
ignoreUnknownKeys = true
Expand Down Expand Up @@ -74,6 +78,10 @@ internal class APIController(
requestTimeoutMillis = HttpTimeout.INFINITE_TIMEOUT_MS
socketTimeoutMillis = 80_000
}
install(Logging) {
logger = Logger.ANDROID
level = LogLevel.BODY
}
install(ContentNegotiation) { json(JSON) }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.ai.client.generativeai.internal.api

import com.google.ai.client.generativeai.internal.api.client.GenerationConfig
import com.google.ai.client.generativeai.internal.api.client.Tool
import com.google.ai.client.generativeai.internal.api.shared.Content
import com.google.ai.client.generativeai.internal.api.shared.SafetySetting
import kotlinx.serialization.SerialName
Expand All @@ -30,6 +31,7 @@ internal data class GenerateContentRequest(
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null
) : Request

@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.internal.api.client

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

@Serializable
internal data class GenerationConfig(
Expand All @@ -28,3 +29,19 @@ internal data class GenerationConfig(
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?
)

@Serializable internal data class Tool(val functionDeclarations: List<FunctionDeclaration>)

@Serializable
internal data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: FunctionParameters
)

@Serializable
internal data class FunctionParameters(
val properties: JsonObject,
val required: List<String>,
val type: String,
)
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,24 @@ typealias Base64 = String

@Serializable internal data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part

@Serializable internal data class FunctionCallPart(val functionCall: FunctionCall) : Part

@Serializable
internal data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part

@Serializable
internal data class Blob(
@SerialName("mime_type") val mimeType: String,
val data: Base64,
)

@Serializable
internal data class FunctionResponse(val name: String, val response: FunctionResponseData)

@Serializable internal data class FunctionCall(val name: String, val args: Map<String, String>)

@Serializable internal data class FunctionResponseData(val name: String, val content: String)

@Serializable
internal data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold)

Expand All @@ -72,6 +84,8 @@ internal object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::cl
return when {
"text" in jsonObject -> TextPart.serializer()
"inlineData" in jsonObject -> BlobPart.serializer()
"functionCall" in jsonObject -> FunctionCallPart.serializer()
"functionResponse" in jsonObject -> FunctionResponsePart.serializer()
else -> throw SerializationException("Unknown Part type")
}
}
Expand Down
Loading

0 comments on commit f95f85d

Please sign in to comment.