diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index 4e484266..4bcd4275 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -27,12 +27,12 @@ sealed interface Request @Serializable data class GenerateContentRequest( - val model: String, val contents: List, + val model: String? = null, @SerialName("safety_settings") val safetySettings: List? = null, @SerialName("generation_config") val generationConfig: GenerationConfig? = null, val tools: List? = null, ) : Request @Serializable -data class CountTokensRequest(val model: String, val contents: List) : Request +data class CountTokensRequest(val contents: List, val model: String? = null) : Request diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt index d753a9bf..efa2a87e 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt @@ -47,7 +47,7 @@ internal fun prepareResponse(response: GenerateContentResponse) = internal fun createRequest(vararg text: String): GenerateContentRequest { val contents = text.map { Content(parts = listOf(TextPart(it))) } - return GenerateContentRequest("gemini", contents) + return GenerateContentRequest(contents, "gemini") } internal fun createResponse(text: String) = createResponses(text).single() diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 690f5be4..7c75347f 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -173,14 +173,13 @@ internal constructor( private fun constructRequest(vararg prompt: Content) = GenerateContentRequest( - modelName, - prompt.map { it.toInternal() }, - safetySettings?.map { it.toInternal() }, - generationConfig?.toInternal() + contents = prompt.map { it.toInternal() }, + safetySettings = safetySettings?.map { it.toInternal() }, + generationConfig = generationConfig?.toInternal() ) private fun constructCountTokensRequest(vararg prompt: Content) = - CountTokensRequest(modelName, prompt.map { it.toInternal() }) + CountTokensRequest(prompt.map { it.toInternal() }) private fun GenerateContentResponse.validate() = apply { if (candidates.isEmpty() && promptFeedback == null) { diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index bfc0b24c..b23d9c46 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -52,7 +52,6 @@ internal class GenerativeModelTests { coEvery { mockApiController.generateContent( GenerateContentRequest_Common( - "gemini-pro-1.0", contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))) ) ) @@ -102,7 +101,6 @@ internal class GenerativeModelTests { coEvery { mockApiController.generateContent( GenerateContentRequest_Common( - "gemini-pro-1.0", contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))) ) ) @@ -117,7 +115,6 @@ internal class GenerativeModelTests { coEvery { mockApiController.generateContentStream( GenerateContentRequest_Common( - "gemini-pro-1.0", contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))) ) )