Skip to content

Commit

Permalink
Don't encode the model name in the outgoing request
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo committed Mar 28, 2024
1 parent a2b71df commit fdfe189
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient

sealed interface Request

@Serializable
data class GenerateContentRequest(
val model: String,
@Transient val model: String? = null,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ import com.google.ai.client.generativeai.common.util.commonTest
import com.google.ai.client.generativeai.common.util.createResponses
import com.google.ai.client.generativeai.common.util.doBlocking
import com.google.ai.client.generativeai.common.util.prepareStreamingResponse
import io.kotest.assertions.json.shouldContainJsonKey
import io.kotest.assertions.json.shouldNotContainJsonKey
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.content.TextContent
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
Expand Down Expand Up @@ -69,7 +72,7 @@ internal class APIControllerTests {
}
}

internal class EndpointTests {
internal class RequestFormatTests {
@Test
fun `using default endpoint`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
Expand Down Expand Up @@ -114,6 +117,28 @@ internal class EndpointTests {

mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
}

@Test
fun `request doesn't include the model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
requestBodyAsText shouldContainJsonKey "contents"
requestBodyAsText shouldNotContainJsonKey "model"
}
}

@RunWith(Parameterized::class)
Expand Down

0 comments on commit fdfe189

Please sign in to comment.