diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index 4e484266..7705e971 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -22,12 +22,13 @@ import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.SafetySetting import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import kotlinx.serialization.Transient sealed interface Request @Serializable data class GenerateContentRequest( - val model: String, + @Transient val model: String? = null, val contents: List, @SerialName("safety_settings") val safetySettings: List? = null, @SerialName("generation_config") val generationConfig: GenerationConfig? = null, diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index a05d6d67..e7ca4982 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -22,11 +22,14 @@ import com.google.ai.client.generativeai.common.util.commonTest import com.google.ai.client.generativeai.common.util.createResponses import com.google.ai.client.generativeai.common.util.doBlocking import com.google.ai.client.generativeai.common.util.prepareStreamingResponse +import io.kotest.assertions.json.shouldContainJsonKey +import io.kotest.assertions.json.shouldNotContainJsonKey import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respond +import io.ktor.content.TextContent import io.ktor.http.HttpHeaders import io.ktor.http.HttpStatusCode import io.ktor.http.headersOf @@ -69,7 +72,7 @@ internal class APIControllerTests { } } -internal class EndpointTests { +internal class RequestFormatTests { @Test fun `using default endpoint`() = doBlocking { val channel = ByteChannel(autoFlush = true) @@ -114,6 +117,28 @@ internal class EndpointTests { mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint" } + + @Test + fun `request doesn't include the model name`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + val controller = + APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine) + + withTimeout(5.seconds) { + controller.generateContentStream(textGenerateContentRequest("cats")).collect { + it.candidates?.isEmpty() shouldBe false + channel.close() + } + } + + val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text + requestBodyAsText shouldContainJsonKey "contents" + requestBodyAsText shouldNotContainJsonKey "model" + } } @RunWith(Parameterized::class)