Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split the auto function calling changes between common and generativeai #90

Merged
merged 11 commits into from
Apr 2, 2024
Merged
1 change: 1 addition & 0 deletions .changes/cloud-camp-bait-calculator.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Add function calling"]}
1 change: 1 addition & 0 deletions generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ dependencies {
implementation("org.slf4j:slf4j-nop:2.0.9")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
davidmotson marked this conversation as resolved.
Show resolved Hide resolved
implementation("org.reactivestreams:reactive-streams:1.0.3")

implementation("com.google.guava:listenablefuture:1.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}

private fun Content.assertComesFromUser() {
if (role != "user") {
throw InvalidStateException("Chat prompts should come from the 'user' role.")
if (role !in listOf("user", "function")) {
throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,29 @@ import com.google.ai.client.generativeai.internal.util.toPublic
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.CountTokensResponse
import com.google.ai.client.generativeai.type.FinishReason
import com.google.ai.client.generativeai.type.FourParameterFunction
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GenerativeBeta
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.NoParameterFunction
import com.google.ai.client.generativeai.type.OneParameterFunction
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.ThreeParameterFunction
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.TwoParameterFunction
import com.google.ai.client.generativeai.type.content
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.map
import kotlinx.serialization.ExperimentalSerializationApi
import org.json.JSONObject

/**
* A facilitator for a given multimodal model (eg; Gemini).
Expand All @@ -48,14 +59,16 @@ import kotlinx.coroutines.flow.map
* generation
* @property requestOptions configuration options to utilize during backend communication
*/
@OptIn(ExperimentalSerializationApi::class)
class GenerativeModel
internal constructor(
val modelName: String,
val apiKey: String,
val generationConfig: GenerationConfig? = null,
val safetySettings: List<SafetySetting>? = null,
val tools: List<Tool>? = null,
val requestOptions: RequestOptions = RequestOptions(),
private val controller: APIController
private val controller: APIController,
) {

@JvmOverloads
Expand All @@ -64,14 +77,16 @@ internal constructor(
apiKey: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
tools: List<Tool>? = null,
requestOptions: RequestOptions = RequestOptions(),
) : this(
modelName,
apiKey,
generationConfig,
safetySettings,
tools,
requestOptions,
APIController(apiKey, modelName, requestOptions.toInternal())
APIController(apiKey, modelName, requestOptions.toInternal()),
)

/**
Expand Down Expand Up @@ -171,12 +186,45 @@ internal constructor(
return countTokens(content { image(prompt) })
}

/**
* Executes a function requested by the model.
*
* @param functionCallPart A [FunctionCallPart] from the model, containing a function call and
* parameters
* @return The output of the requested function call
*/
@OptIn(GenerativeBeta::class)
suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject {
if (tools == null) {
throw InvalidStateException("No registered tools")
}
val callable =
tools.flatMap { it.functionDeclarations }.firstOrNull { it.name == functionCallPart.name }
?: throw InvalidStateException("No registered function named ${functionCallPart.name}")
return when (callable) {
is NoParameterFunction -> callable.execute()
is OneParameterFunction<*> ->
(callable as OneParameterFunction<Any?>).execute(functionCallPart)
is TwoParameterFunction<*, *> ->
(callable as TwoParameterFunction<Any?, Any?>).execute(functionCallPart)
is ThreeParameterFunction<*, *, *> ->
(callable as ThreeParameterFunction<Any?, Any?, Any?>).execute(functionCallPart)
is FourParameterFunction<*, *, *, *> ->
(callable as FourParameterFunction<Any?, Any?, Any?, Any?>).execute(functionCallPart)
else -> {
throw RuntimeException("UNREACHABLE")
}
}
}

@OptIn(GenerativeBeta::class)
private fun constructRequest(vararg prompt: Content) =
GenerateContentRequest(
modelName,
prompt.map { it.toInternal() },
safetySettings?.map { it.toInternal() },
generationConfig?.toInternal()
generationConfig?.toInternal(),
tools?.map { it.toInternal() },
)

private fun constructCountTokensRequest(vararg prompt: Content) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.google.ai.client.generativeai.common.CountTokensResponse
import com.google.ai.client.generativeai.common.GenerateContentResponse
import com.google.ai.client.generativeai.common.RequestOptions
import com.google.ai.client.generativeai.common.client.GenerationConfig
import com.google.ai.client.generativeai.common.client.Schema
import com.google.ai.client.generativeai.common.server.BlockReason
import com.google.ai.client.generativeai.common.server.Candidate
import com.google.ai.client.generativeai.common.server.CitationSources
Expand All @@ -33,17 +34,27 @@ import com.google.ai.client.generativeai.common.server.SafetyRating
import com.google.ai.client.generativeai.common.shared.Blob
import com.google.ai.client.generativeai.common.shared.BlobPart
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.FunctionCall
import com.google.ai.client.generativeai.common.shared.FunctionCallPart
import com.google.ai.client.generativeai.common.shared.FunctionResponse
import com.google.ai.client.generativeai.common.shared.FunctionResponsePart
import com.google.ai.client.generativeai.common.shared.HarmBlockThreshold
import com.google.ai.client.generativeai.common.shared.HarmCategory
import com.google.ai.client.generativeai.common.shared.Part
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.type.BlockThreshold
import com.google.ai.client.generativeai.type.CitationMetadata
import com.google.ai.client.generativeai.type.FunctionDeclaration
import com.google.ai.client.generativeai.type.GenerativeBeta
import com.google.ai.client.generativeai.type.ImagePart
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.content
import java.io.ByteArrayOutputStream
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import org.json.JSONObject

private const val BASE_64_FLAGS = Base64.NO_WRAP

Expand All @@ -59,6 +70,10 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part {
is ImagePart -> BlobPart(Blob("image/jpeg", encodeBitmapToBase64Png(image)))
is com.google.ai.client.generativeai.type.BlobPart ->
BlobPart(Blob(mimeType, Base64.encodeToString(blob, BASE_64_FLAGS)))
is com.google.ai.client.generativeai.type.FunctionCallPart ->
FunctionCallPart(FunctionCall(name, args.orEmpty()))
is com.google.ai.client.generativeai.type.FunctionResponsePart ->
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
else ->
throw SerializationException(
"The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet."
Expand All @@ -76,7 +91,7 @@ internal fun com.google.ai.client.generativeai.type.GenerationConfig.toInternal(
topK = topK,
candidateCount = candidateCount,
maxOutputTokens = maxOutputTokens,
stopSequences = stopSequences
stopSequences = stopSequences,
)

internal fun com.google.ai.client.generativeai.type.HarmCategory.toInternal() =
Expand All @@ -99,6 +114,35 @@ internal fun BlockThreshold.toInternal() =
BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED
}

@GenerativeBeta
internal fun Tool.toInternal() =
com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() })

@GenerativeBeta
internal fun FunctionDeclaration.toInternal() =
com.google.ai.client.generativeai.common.client.FunctionDeclaration(
name,
description,
Schema(
properties = getParameters().associate { it.name to it.toInternal() },
required = getParameters().map { it.name },
type = "OBJECT",
),
)

internal fun <T> com.google.ai.client.generativeai.type.Schema<T>.toInternal(): Schema =
Schema(
type.name,
description,
format,
enum,
properties?.mapValues { it.value.toInternal() },
required,
items?.toInternal(),
)

internal fun JSONObject.toInternal() = Json.decodeFromString<JsonObject>(toString())

internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candidate {
val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty()
val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty()
Expand All @@ -108,7 +152,7 @@ internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candid
this.content?.toPublic() ?: content("model") {},
safetyRatings,
citations,
finishReason
finishReason,
)
}

Expand All @@ -126,6 +170,16 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part {
com.google.ai.client.generativeai.type.BlobPart(inlineData.mimeType, data)
}
}
is FunctionCallPart ->
com.google.ai.client.generativeai.type.FunctionCallPart(
functionCall.name,
functionCall.args.orEmpty(),
)
is FunctionResponsePart ->
com.google.ai.client.generativeai.type.FunctionResponsePart(
functionResponse.name,
functionResponse.response.toPublic(),
)
else ->
throw SerializationException(
"Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK."
Expand Down Expand Up @@ -192,12 +246,14 @@ internal fun BlockReason.toPublic() =
internal fun GenerateContentResponse.toPublic() =
com.google.ai.client.generativeai.type.GenerateContentResponse(
candidates?.map { it.toPublic() }.orEmpty(),
promptFeedback?.toPublic()
promptFeedback?.toPublic(),
)

internal fun CountTokensResponse.toPublic() =
com.google.ai.client.generativeai.type.CountTokensResponse(totalTokens)

internal fun JsonObject.toPublic() = JSONObject(toString())

private fun encodeBitmapToBase64Png(input: Bitmap): String {
ByteArrayOutputStream().let {
input.compress(Bitmap.CompressFormat.JPEG, 80, it)
Expand Down
Loading
Loading