diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 4bfffc37..27d2341d 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -49,6 +49,7 @@ import kotlinx.coroutines.flow.transform @OptIn(GenerativeBeta::class) class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) { private var lock = Semaphore(1) + private val VALID_ROLES = listOf("user","function") /** * Generates a response from the backend with the provided [Content], and any previous ones @@ -162,8 +163,8 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content> } private fun Content.assertComesFromUser() { - if (role != "user") { - throw InvalidStateException("Chat prompts should come from the 'user' role.") + if (!VALID_ROLES.contains(role)) { + throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.") } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt index 0d2b6366..05f0c9bd 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt @@ -41,7 +41,7 @@ internal enum class HarmCategory { typealias Base64 = String -@Serializable internal data class Content(val role: String? = "user", val parts: List<Part>) +@Serializable internal data class Content(val role: String?, val parts: List<Part>) @Serializable(PartSerializer::class) internal sealed interface Part