Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Davidmotson.auto functions #50

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f95f85d
Intial implementation of automated function calling
Jan 30, 2024
34287a8
Add opt in annotation for v1beta endpoint
Feb 1, 2024
1d8ebf4
ktfmt and add licenses
Feb 1, 2024
61e3140
fix behavior for non-streaming chat and fix opt in propagation
Feb 1, 2024
8c5c798
ktfmt
Feb 1, 2024
bb2bed0
Merge branch 'main' into davidmotson.auto_functions
davidmotson Feb 1, 2024
7cad4ca
remove logging
Feb 1, 2024
1136610
added convinience methods for getting function call and response type…
Feb 7, 2024
f8d60e3
Alternative implementation (#68)
davidmotson Feb 22, 2024
b09c6ef
rename annotation and cleanup rlazo@'s merge
Feb 22, 2024
08309a7
re-add suspend functions
Feb 22, 2024
0516150
add flag to disable automated function calling
Feb 22, 2024
6bddff8
Merge branch 'main' into davidmotson.auto_functions
davidmotson Feb 22, 2024
393f918
move autofunction to request options and other merge conflicts
Feb 22, 2024
931cf00
add license to Type.kt
Feb 22, 2024
3ff8105
swap autoFunctionFlag to the negative
Feb 23, 2024
e9c3adc
fix dependant if statements
Feb 23, 2024
6c4eb81
swap output to JsonObject
Mar 5, 2024
6a50c76
swap public implementation to org.json
Mar 5, 2024
6ad8869
fix request options default
Mar 5, 2024
9362b10
Merge branch 'main' into davidmotson.auto_functions
davidmotson Mar 11, 2024
266b335
Update generativeai/src/main/java/com/google/ai/client/generativeai/C…
davidmotson Mar 11, 2024
6684478
Update generativeai/src/main/java/com/google/ai/client/generativeai/C…
davidmotson Mar 11, 2024
7c57e12
Add extra datatypes and fix serialization
Mar 12, 2024
571308f
update copyright
Mar 12, 2024
b752cda
added some kdocs
Mar 12, 2024
470594a
ktfmt
Mar 12, 2024
90f2863
clean up serialization
Mar 15, 2024
cddfedf
fix serialization & type check for role
Mar 18, 2024
0d567dd
Merge branch 'main' into davidmotson.auto_functions
davidmotson Mar 18, 2024
9d7a3eb
format
Mar 18, 2024
9d30869
add all fields including full recursive description
Mar 18, 2024
24f3a80
fix enum type
Mar 19, 2024
c57c46e
revert gradle.properties change
Mar 19, 2024
b1da12c
fix type declaration
Mar 19, 2024
dc83aea
revert copyright headers
Mar 20, 2024
b21f852
revert copyright headers
Mar 20, 2024
b53423e
update copyright header only on new files
Mar 20, 2024
c6a46f4
last cleanup push
Mar 20, 2024
e1abc62
remove unintended diffs
Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,21 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import androidx.lifecycle.viewmodel.CreationExtras
import com.google.ai.client.generativeai.GenerativeModel
import com.google.ai.client.generativeai.type.FunctionType
import com.google.ai.client.generativeai.type.GenerativeBeta
import com.google.ai.client.generativeai.type.ParameterDeclaration
import com.google.ai.client.generativeai.type.ParameterDeclaration.Companion.int
import com.google.ai.client.generativeai.type.ParameterDeclaration.Companion.str
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.defineFunction
import com.google.ai.client.generativeai.type.generationConfig
import com.google.ai.sample.feature.chat.ChatViewModel
import com.google.ai.sample.feature.multimodal.PhotoReasoningViewModel
import com.google.ai.sample.feature.text.SummarizeViewModel
import org.json.JSONObject

@OptIn(GenerativeBeta::class)
val GenerativeViewModelFactory = object : ViewModelProvider.Factory {
override fun <T : ViewModel> create(
viewModelClass: Class<T>,
Expand Down Expand Up @@ -63,7 +73,19 @@ val GenerativeViewModelFactory = object : ViewModelProvider.Factory {
val generativeModel = GenerativeModel(
modelName = "gemini-1.0-pro",
apiKey = BuildConfig.apiKey,
generationConfig = config
generationConfig = config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the inconsistent behaviour handling the functions, IMO we should have a separate entry in the sample app for function calling. This will prevent the regular chat example from failing due to an unrelated issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh whoops, this was not an intended part of the commit. I was doing this for my own testing, please ignore.

requestOptions = RequestOptions(apiVersion = "v1beta"),
tools = listOf(
Tool(
listOf(
defineFunction(
"getWeather",
"gets the weather at a given city",
str("city", "the city to get the weather in")
) { city -> JSONObject(mapOf("result" to "Sunny")) }
)
)
)
)
ChatViewModel(generativeModel)
}
Expand Down
132 changes: 96 additions & 36 deletions generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ package com.google.ai.client.generativeai
import android.graphics.Bitmap
import com.google.ai.client.generativeai.type.BlobPart
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.FunctionResponsePart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerativeBeta
import com.google.ai.client.generativeai.type.ImagePart
import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.TextPart
import com.google.ai.client.generativeai.type.content
import java.util.LinkedList
import java.util.concurrent.Semaphore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.transform

/**
* Representation of a back and forth interaction with a model.
Expand All @@ -42,8 +46,10 @@ import kotlinx.coroutines.flow.onEach
* @param model the model to use for the interaction
* @property history the previous interactions with the model
*/
@OptIn(GenerativeBeta::class)
class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) {
private var lock = Semaphore(1)
private val VALID_ROLES = listOf("user", "function")

/**
* Generates a response from the backend with the provided [Content], and any previous ones
Expand All @@ -53,13 +59,27 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()
suspend fun sendMessage(inputPrompt: Content): GenerateContentResponse {
inputPrompt.assertComesFromUser()
attemptLock()
var response: GenerateContentResponse
var prompt = inputPrompt
val tempHistory = LinkedList<Content>()
try {
val response = model.generateContent(*history.toTypedArray(), prompt)
history.add(prompt)
history.add(response.candidates.first().content)
while (true) {
response =
model.generateContent(*history.toTypedArray(), *tempHistory.toTypedArray(), prompt)
val responsePart = response.candidates.first().content.parts.first()

tempHistory.add(prompt)
tempHistory.add(response.candidates.first().content)
if (responsePart !is FunctionCallPart) break
if (model.requestOptions.disableAutoFunction) break

val output = model.executeFunction(responsePart)
prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output)))
}
history.addAll(tempHistory)
return response
} finally {
lock.release()
Expand Down Expand Up @@ -101,43 +121,19 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
attemptLock()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val bitmaps = LinkedList<Bitmap>()
val blobs = LinkedList<BlobPart>()
val text = StringBuilder()

val tempHistory = mutableListOf<Content>()
tempHistory.add(prompt)
/**
* TODO: revisit when images and blobs are returned. This will cause issues with how things are
* structured in the response. eg; a text/image/text response will be (incorrectly)
* represented as image/text
*/
return flow
.onEach {
for (part in it.candidates.first().content.parts) {
when (part) {
is TextPart -> text.append(part.text)
is ImagePart -> bitmaps.add(part.image)
is BlobPart -> blobs.add(part)
}
}
}
.transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) }
.onCompletion {
lock.release()
if (it == null) {
val content =
content("model") {
for (bitmap in bitmaps) {
image(bitmap)
}
for (blob in blobs) {
blob(blob.mimeType, blob.blob)
}
if (text.isNotBlank()) {
text(text.toString())
}
}

history.add(prompt)
history.add(content)
history.addAll(tempHistory)
}
}
}
Expand Down Expand Up @@ -167,9 +163,73 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}

private fun Content.assertComesFromUser() {
if (role != "user") {
throw InvalidStateException("Chat prompts should come from the 'user' role.")
if (!VALID_ROLES.contains(role)) {
throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.")
}
}

private suspend fun automaticFunctionExecutingTransform(
transformer: FlowCollector<GenerateContentResponse>,
tempHistory: MutableList<Content>,
response: GenerateContentResponse
) {
for (part in response.candidates.first().content.parts) {
when (part) {
is TextPart -> {
transformer.emit(response)
addTextToHistory(tempHistory, part)
}
is ImagePart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is BlobPart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is FunctionCallPart -> {
if (model.requestOptions.disableAutoFunction) {
tempHistory.add(response.candidates.first().content)
continue
}
val functionCall =
response.candidates.first().content.parts.first { it is FunctionCallPart }
as FunctionCallPart
val output = model.executeFunction(functionCall)
val functionResponse =
Content("function", listOf(FunctionResponsePart(functionCall.name, output)))
tempHistory.add(response.candidates.first().content)
tempHistory.add(functionResponse)
model
.generateContentStream(*history.toTypedArray(), *tempHistory.toTypedArray())
.collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) }
}
}
}
}

private fun addTextToHistory(tempHistory: MutableList<Content>, textPart: TextPart) {
val lastContent = tempHistory.lastOrNull()
if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) {
tempHistory.removeLast()
val editedContent =
Content(
"model",
lastContent.parts.map {
when (it) {
is TextPart -> {
TextPart(it.text + textPart.text)
}
else -> {
it
}
}
}
)
tempHistory.add(editedContent)
return
}
tempHistory.add(Content("model", listOf(textPart)))
}

private fun attemptLock() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,27 @@ import com.google.ai.client.generativeai.internal.util.toPublic
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.CountTokensResponse
import com.google.ai.client.generativeai.type.FinishReason
import com.google.ai.client.generativeai.type.FourParameterFunction
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GenerativeBeta
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.NoParameterFunction
import com.google.ai.client.generativeai.type.OneParameterFunction
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.ThreeParameterFunction
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.TwoParameterFunction
import com.google.ai.client.generativeai.type.content
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.map
import org.json.JSONObject

/**
* A facilitator for a given multimodal model (eg; Gemini).
Expand All @@ -48,12 +57,14 @@ import kotlinx.coroutines.flow.map
* generation
* @property requestOptions configuration options to utilize during backend communication
*/
@OptIn(GenerativeBeta::class)
class GenerativeModel
internal constructor(
val modelName: String,
val apiKey: String,
val generationConfig: GenerationConfig? = null,
val safetySettings: List<SafetySetting>? = null,
val tools: List<Tool>? = null,
val requestOptions: RequestOptions = RequestOptions(),
private val controller: APIController
) {
Expand All @@ -64,12 +75,14 @@ internal constructor(
apiKey: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
tools: List<Tool>? = null,
requestOptions: RequestOptions = RequestOptions(),
) : this(
modelName,
apiKey,
generationConfig,
safetySettings,
tools,
requestOptions,
APIController(apiKey, modelName, requestOptions)
)
Expand Down Expand Up @@ -171,12 +184,45 @@ internal constructor(
return countTokens(content { image(prompt) })
}

/**
* Executes a function request by the model.
*
* @param functionCallPart A [FunctionCallPart] from the model, containing a function call and
* parameters
* @return The output of the requested function call
*/
@GenerativeBeta
suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject {
if (tools == null) {
throw RuntimeException("No registered tools")
}
val tool = tools.first { it.functionDeclarations.any { it.name == functionCallPart.name } }
val callable =
tool.functionDeclarations.firstOrNull() { it.name == functionCallPart.name }
?: throw RuntimeException("No registered function named ${functionCallPart.name}")
return when (callable) {
is NoParameterFunction -> callable.execute()
is OneParameterFunction<*> ->
(callable as OneParameterFunction<Any?>).execute(functionCallPart)
is TwoParameterFunction<*, *> ->
(callable as TwoParameterFunction<Any?, Any?>).execute(functionCallPart)
is ThreeParameterFunction<*, *, *> ->
(callable as ThreeParameterFunction<Any?, Any?, Any?>).execute(functionCallPart)
is FourParameterFunction<*, *, *, *> ->
(callable as FourParameterFunction<Any?, Any?, Any?, Any?>).execute(functionCallPart)
else -> {
throw RuntimeException("UNREACHABLE")
}
}
}

private fun constructRequest(vararg prompt: Content) =
GenerateContentRequest(
modelName,
prompt.map { it.toInternal() },
safetySettings?.map { it.toInternal() },
generationConfig?.toInternal()
generationConfig?.toInternal(),
tools?.map { it.toInternal() }
)

private fun constructCountTokensRequest(vararg prompt: Content) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.ai.client.generativeai.internal.api

import com.google.ai.client.generativeai.internal.api.client.GenerationConfig
import com.google.ai.client.generativeai.internal.api.client.Tool
import com.google.ai.client.generativeai.internal.api.shared.Content
import com.google.ai.client.generativeai.internal.api.shared.SafetySetting
import kotlinx.serialization.SerialName
Expand All @@ -30,6 +31,7 @@ internal data class GenerateContentRequest(
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null
) : Request

@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,25 @@ internal data class GenerationConfig(
@SerialName("top_k") val topK: Int?,
@SerialName("candidate_count") val candidateCount: Int?,
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?
@SerialName("stop_sequences") val stopSequences: List<String>?,
)

@Serializable internal data class Tool(val functionDeclarations: List<FunctionDeclaration>)

@Serializable
internal data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: FunctionParameterProperties,
)

@Serializable
internal data class FunctionParameterProperties(
val type: String,
val description: String? = null,
val format: String? = null,
val enum: List<String>? = null,
val properties: Map<String, FunctionParameterProperties>? = null,
val required: List<String>? = null,
val items: FunctionParameterProperties? = null
)
Loading