Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Davidmotson.code execution #186

Merged
merged 7 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .changes/common/angle-carpenter-beam-clock.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["add code execution tool"]}
1 change: 1 addition & 0 deletions .changes/generativeai/direction-bee-brass-aftermath.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["add code execution tool"]}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.common.client

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

@Serializable
data class GenerationConfig(
Expand All @@ -33,7 +34,11 @@ data class GenerationConfig(
@SerialName("response_schema") val responseSchema: Schema? = null,
)

@Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)
@Serializable
data class Tool(
val functionDeclarations: List<FunctionDeclaration>? = null,
val codeExecution: JsonObject? = null,
davidmotson marked this conversation as resolved.
Show resolved Hide resolved
)
davidmotson marked this conversation as resolved.
Show resolved Hide resolved

@Serializable
data class ToolConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List<Pa

@Serializable data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part

@Serializable data class ExecutableCodePart(val executableCode: ExecutableCode) : Part

@Serializable
data class CodeExecutionResultPart(val codeExecutionResult: CodeExecutionResult) : Part

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String?>)
Expand All @@ -71,6 +76,18 @@ data class FileData(

@Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64)

@Serializable data class ExecutableCode(val language: String, val code: String)

@Serializable data class CodeExecutionResult(val outcome: Outcome, val output: String)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rlazo output is an optional field in the proto. How about we decode this as "" if the field is omitted?


@Serializable
enum class Outcome {
@SerialName("OUTCOME_UNSPECIFIED") UNSPECIFIED,
OUTCOME_OK,
OUTCOME_FAILED,
OUTCOME_DEADLINE_EXCEEDED,
}

@Serializable
data class SafetySetting(
val category: HarmCategory,
Expand Down Expand Up @@ -101,8 +118,10 @@ object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::class) {
"text" in jsonObject -> TextPart.serializer()
"functionCall" in jsonObject -> FunctionCallPart.serializer()
"functionResponse" in jsonObject -> FunctionResponsePart.serializer()
"inline_data" in jsonObject -> BlobPart.serializer()
"file_data" in jsonObject -> FileDataPart.serializer()
"inlineData" in jsonObject -> BlobPart.serializer()
"fileData" in jsonObject -> FileDataPart.serializer()
"executableCode" in jsonObject -> ExecutableCodePart.serializer()
"codeExecutionResult" in jsonObject -> CodeExecutionResultPart.serializer()
else -> throw SerializationException("Unknown Part type")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.FunctionCallingConfig
import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.TextPart
Expand All @@ -43,6 +44,7 @@ import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.delay
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.JsonObject
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
Expand Down Expand Up @@ -259,6 +261,41 @@ internal class RequestFormatTests {

mockEngine.requestHistory.first().headers.contains("header1") shouldBe false
}

@Test
fun `code execution tool serialization contains correct keys`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) {
controller
.generateContentStream(
GenerateContentRequest(
model = "unused",
contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))),
tools = listOf(Tool(codeExecution = JsonObject(emptyMap()))),
)
)
.collect { channel.close() }
}

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text

requestBodyAsText shouldContainJsonKey "tools[0].codeExecution"
}
}

@RunWith(Parameterized::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ import com.google.ai.client.generativeai.common.server.BlockReason
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.server.HarmProbability
import com.google.ai.client.generativeai.common.server.HarmSeverity
import com.google.ai.client.generativeai.common.shared.CodeExecutionResult
import com.google.ai.client.generativeai.common.shared.CodeExecutionResultPart
import com.google.ai.client.generativeai.common.shared.ExecutableCode
import com.google.ai.client.generativeai.common.shared.ExecutableCodePart
import com.google.ai.client.generativeai.common.shared.FunctionCallPart
import com.google.ai.client.generativeai.common.shared.HarmCategory
import com.google.ai.client.generativeai.common.shared.Outcome
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.common.util.goldenUnaryFile
import com.google.ai.client.generativeai.common.util.shouldNotBeNullOrEmpty
Expand Down Expand Up @@ -331,4 +336,23 @@ internal class UnarySnapshotTests {
callPart.functionCall.args["current"] shouldBe "true"
}
}

@Test
fun `code execution parses correctly`() =
goldenUnaryFile("success-code-execution.json") {
withTimeout(testTimeout) {
val response = apiController.generateContent(textGenerateContentRequest("prompt"))
val content = response.candidates.shouldNotBeNullOrEmpty().first().content
content.shouldNotBeNull()
val executableCodePart = content.parts[0]
val codeExecutionResult = content.parts[1]

executableCodePart.shouldBe(
ExecutableCodePart(ExecutableCode("python", "print(\"Hello World\")"))
)
codeExecutionResult.shouldBe(
CodeExecutionResultPart(CodeExecutionResult(Outcome.OUTCOME_OK, "Hello World"))
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"candidates": [
{
"content": {
"parts": [
{
"executableCode": {
"language": "python",
davidmotson marked this conversation as resolved.
Show resolved Hide resolved
"code": "print(\"Hello World\")"
}
},
{
"codeExecutionResult": {
"outcome": "OUTCOME_OK",
"output": "Hello World"
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
}
]
}
],
"usageMetadata": {
"promptTokenCount": 774,
"candidatesTokenCount": 4176,
"totalTokenCount": 4950
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ import com.google.ai.client.generativeai.common.server.PromptFeedback
import com.google.ai.client.generativeai.common.server.SafetyRating
import com.google.ai.client.generativeai.common.shared.Blob
import com.google.ai.client.generativeai.common.shared.BlobPart
import com.google.ai.client.generativeai.common.shared.CodeExecutionResult
import com.google.ai.client.generativeai.common.shared.CodeExecutionResultPart
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.ExecutableCode
import com.google.ai.client.generativeai.common.shared.ExecutableCodePart
import com.google.ai.client.generativeai.common.shared.FileData
import com.google.ai.client.generativeai.common.shared.FileDataPart
import com.google.ai.client.generativeai.common.shared.FunctionCall
Expand All @@ -42,11 +46,13 @@ import com.google.ai.client.generativeai.common.shared.FunctionResponse
import com.google.ai.client.generativeai.common.shared.FunctionResponsePart
import com.google.ai.client.generativeai.common.shared.HarmBlockThreshold
import com.google.ai.client.generativeai.common.shared.HarmCategory
import com.google.ai.client.generativeai.common.shared.Outcome
import com.google.ai.client.generativeai.common.shared.Part
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.type.BlockThreshold
import com.google.ai.client.generativeai.type.CitationMetadata
import com.google.ai.client.generativeai.type.ExecutionOutcome
import com.google.ai.client.generativeai.type.FunctionCallingConfig
import com.google.ai.client.generativeai.type.FunctionDeclaration
import com.google.ai.client.generativeai.type.ImagePart
Expand Down Expand Up @@ -80,6 +86,10 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part {
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
is com.google.ai.client.generativeai.type.FileDataPart ->
FileDataPart(FileData(fileUri = uri, mimeType = mimeType))
is com.google.ai.client.generativeai.type.ExecutableCodePart ->
ExecutableCodePart(ExecutableCode(language, code))
is com.google.ai.client.generativeai.type.CodeExecutionResultPart ->
CodeExecutionResultPart(CodeExecutionResult(outcome.toInternal(), output))
else ->
throw SerializationException(
"The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet."
Expand Down Expand Up @@ -122,8 +132,19 @@ internal fun BlockThreshold.toInternal() =
BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED
}

internal fun ExecutionOutcome.toInternal() =
when (this) {
ExecutionOutcome.UNSPECIFIED -> Outcome.UNSPECIFIED
ExecutionOutcome.OK -> Outcome.OUTCOME_OK
ExecutionOutcome.FAILED -> Outcome.OUTCOME_FAILED
ExecutionOutcome.DEADLINE_EXCEEDED -> Outcome.OUTCOME_DEADLINE_EXCEEDED
}

internal fun Tool.toInternal() =
com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() })
com.google.ai.client.generativeai.common.client.Tool(
functionDeclarations?.map { it.toInternal() },
codeExecution = codeExecution?.toInternal(),
)

internal fun ToolConfig.toInternal() =
com.google.ai.client.generativeai.common.client.ToolConfig(
Expand Down Expand Up @@ -204,6 +225,16 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part {
)
is FileDataPart ->
com.google.ai.client.generativeai.type.FileDataPart(fileData.fileUri, fileData.mimeType)
is ExecutableCodePart ->
com.google.ai.client.generativeai.type.ExecutableCodePart(
executableCode.language,
executableCode.code,
)
is CodeExecutionResultPart ->
com.google.ai.client.generativeai.type.CodeExecutionResultPart(
codeExecutionResult.outcome.toPublic(),
codeExecutionResult.output,
)
else ->
throw SerializationException(
"Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK."
Expand Down Expand Up @@ -267,6 +298,14 @@ internal fun BlockReason.toPublic() =
BlockReason.UNKNOWN -> com.google.ai.client.generativeai.type.BlockReason.UNKNOWN
}

internal fun Outcome.toPublic() =
when (this) {
Outcome.UNSPECIFIED -> ExecutionOutcome.UNSPECIFIED
Outcome.OUTCOME_OK -> ExecutionOutcome.OK
Outcome.OUTCOME_FAILED -> ExecutionOutcome.FAILED
Outcome.OUTCOME_DEADLINE_EXCEEDED -> ExecutionOutcome.DEADLINE_EXCEEDED
}

internal fun GenerateContentResponse.toPublic() =
com.google.ai.client.generativeai.type.GenerateContentResponse(
candidates?.map { it.toPublic() }.orEmpty(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

enum class ExecutionOutcome {
UNSPECIFIED,
OK,
FAILED,
DEADLINE_EXCEEDED,
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,19 @@ class GenerateContentResponse(
) {
/** Convenience field representing all the text parts in the response, if they exists. */
val text: String? by lazy {
candidates.first().content.parts.filterIsInstance<TextPart>().joinToString(" ") { it.text }
candidates
.first()
.content
.parts
.filter { it is TextPart || it is ExecutableCodePart || it is CodeExecutionResultPart }
.joinToString(" ") {
when (it) {
is TextPart -> it.text
is ExecutableCodePart -> "\n```${it.language.lowercase()}\n${it.code}\n```"
is CodeExecutionResultPart -> "\n```\n${it.output}\n```"
else -> throw RuntimeException("unreachable")
}
}
}

/** Convenience field representing the first function call part in the request, if it exists */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ import org.json.JSONObject
* * [ImagePart] representing image data.
* * [BlobPart] representing MIME typed binary data.
* * [FileDataPart] representing MIME typed binary data.
* * [FunctionCallPart] representing a requested clientside function call by the model
* * [FunctionResponsePart] representing the result of a clientside function call
* * [ExecutableCodePart] representing code generated and executed by the model
* * [CodeExecutionResultPart] representing the result of running code generated by the model.
*/
interface Part

Expand Down Expand Up @@ -54,6 +58,12 @@ class FunctionCallPart(val name: String, val args: Map<String, String?>) : Part
/** Represents function call output to be returned to the model when it requests a function call */
class FunctionResponsePart(val name: String, val response: JSONObject) : Part

/** Represents an internal function call written by the model */
class ExecutableCodePart(val language: String, val code: String) : Part

/** Represents the results of an internal function call written by the model */
class CodeExecutionResultPart(val outcome: ExecutionOutcome, val output: String) : Part

/** @return The part as a [String] if it represents text, and null otherwise */
fun Part.asTextOrNull(): String? = (this as? TextPart)?.text

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,20 @@

package com.google.ai.client.generativeai.type

import org.json.JSONObject

/**
* Contains a set of function declarations that the model has access to. These can be used to gather
* information, or complete tasks
*
* @param functionDeclarations The set of functions that this tool allows the model access to
* @param codeExecution This is a flag value to enable Code Execution. Use [CODE_EXECUTION].
*/
class Tool(val functionDeclarations: List<FunctionDeclaration>)
class Tool(
val functionDeclarations: List<FunctionDeclaration>? = null,
val codeExecution: JSONObject? = null,
) {
companion object {
val CODE_EXECUTION = Tool(codeExecution = JSONObject())
}
}
Loading
Loading