From fdfe18974d0f9627670ff3b22aeabde3413a0825 Mon Sep 17 00:00:00 2001
From: Rodrigo Lazo Paz <rlazo@google.com>
Date: Thu, 28 Mar 2024 10:10:36 -0400
Subject: [PATCH] Don't encode the model name in the outgoing request

---
 .../ai/client/generativeai/common/Request.kt  |  3 ++-
 .../generativeai/common/APIControllerTests.kt | 27 ++++++++++++++++++-
 2 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt
index 4e484266..7705e971 100644
--- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt
+++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt
@@ -22,12 +22,13 @@ import com.google.ai.client.generativeai.common.shared.Content
 import com.google.ai.client.generativeai.common.shared.SafetySetting
 import kotlinx.serialization.SerialName
 import kotlinx.serialization.Serializable
+import kotlinx.serialization.Transient
 
 sealed interface Request
 
 @Serializable
 data class GenerateContentRequest(
-  val model: String,
+  @Transient val model: String? = null,
   val contents: List<Content>,
   @SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
   @SerialName("generation_config") val generationConfig: GenerationConfig? = null,
diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt
index a05d6d67..e7ca4982 100644
--- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt
+++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt
@@ -22,11 +22,14 @@ import com.google.ai.client.generativeai.common.util.commonTest
 import com.google.ai.client.generativeai.common.util.createResponses
 import com.google.ai.client.generativeai.common.util.doBlocking
 import com.google.ai.client.generativeai.common.util.prepareStreamingResponse
+import io.kotest.assertions.json.shouldContainJsonKey
+import io.kotest.assertions.json.shouldNotContainJsonKey
 import io.kotest.assertions.throwables.shouldThrow
 import io.kotest.matchers.shouldBe
 import io.kotest.matchers.string.shouldContain
 import io.ktor.client.engine.mock.MockEngine
 import io.ktor.client.engine.mock.respond
+import io.ktor.content.TextContent
 import io.ktor.http.HttpHeaders
 import io.ktor.http.HttpStatusCode
 import io.ktor.http.headersOf
@@ -69,7 +72,7 @@ internal class APIControllerTests {
     }
 }
 
-internal class EndpointTests {
+internal class RequestFormatTests {
   @Test
   fun `using default endpoint`() = doBlocking {
     val channel = ByteChannel(autoFlush = true)
@@ -114,6 +117,28 @@ internal class EndpointTests {
 
     mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
   }
+
+  @Test
+  fun `request doesn't include the model name`() = doBlocking {
+    val channel = ByteChannel(autoFlush = true)
+    val mockEngine = MockEngine {
+      respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
+    }
+    prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
+    val controller =
+      APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)
+
+    withTimeout(5.seconds) {
+      controller.generateContentStream(textGenerateContentRequest("cats")).collect {
+        it.candidates?.isEmpty() shouldBe false
+        channel.close()
+      }
+    }
+
+    val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
+    requestBodyAsText shouldContainJsonKey "contents"
+    requestBodyAsText shouldNotContainJsonKey "model"
+  }
 }
 
 @RunWith(Parameterized::class)