From 5a24830fcec5f355d372fc4bd41ac37d10bd3e6b Mon Sep 17 00:00:00 2001 From: Rodrigo Lazo Date: Tue, 30 Jan 2024 12:03:07 -0500 Subject: [PATCH] Improve model name support (#47) Testing is based on examining the request data. --- generativeai/build.gradle.kts | 1 + .../internal/api/APIController.kt | 2 +- .../generativeai/GenerativeModelTests.kt | 48 ++++++++++++++++++- .../ai/client/generativeai/util/tests.kt | 7 ++- 4 files changed, 54 insertions(+), 4 deletions(-) diff --git a/generativeai/build.gradle.kts b/generativeai/build.gradle.kts index d78d3079..1628d785 100644 --- a/generativeai/build.gradle.kts +++ b/generativeai/build.gradle.kts @@ -94,6 +94,7 @@ dependencies { testImplementation("junit:junit:4.13.2") testImplementation("io.kotest:kotest-assertions-core:4.0.7") testImplementation("io.kotest:kotest-assertions-jvm:4.0.7") + testImplementation("io.kotest:kotest-assertions-json:4.0.7") testImplementation("io.ktor:ktor-client-mock:$ktorVersion") androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt index e46016ca..c947ec44 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt @@ -114,7 +114,7 @@ internal class APIController( * Models must be prepended with the `models/` prefix when communicating with the backend. */ private fun fullModelName(name: String): String = - name.takeIf { it.startsWith("models/") } ?: "models/$name" + name.takeIf { it.startsWith("models/") || it.startsWith("tunedModels/") } ?: "models/$name" /** * Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index da0e90b6..af48bfc4 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -17,15 +17,25 @@ package com.google.ai.client.generativeai import com.google.ai.client.generativeai.util.commonTest +import com.google.ai.client.generativeai.util.createGenerativeModel import com.google.ai.client.generativeai.util.createResponses +import com.google.ai.client.generativeai.util.doBlocking import com.google.ai.client.generativeai.util.prepareStreamingResponse import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.headersOf +import io.ktor.utils.io.ByteChannel import io.ktor.utils.io.close import io.ktor.utils.io.writeFully import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.flow.collect import kotlinx.coroutines.withTimeout import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized internal class GenerativeModelTests { private val testTimeout = 5.seconds @@ -46,3 +56,39 @@ internal class GenerativeModelTests { } } } + +@RunWith(Parameterized::class) +internal class ModelNamingTests(private val modelName: String, private val actualName: String) { + + @Test + fun `request should include right model name`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + val model = createGenerativeModel(modelName, "super_cool_test_key", mockEngine) + + withTimeout(5.seconds) { + model.generateContentStream().collect { + it.candidates.isEmpty() shouldBe false + channel.close() + } + } + + mockEngine.requestHistory.first().url.encodedPath shouldContain actualName + } + + companion object { + @JvmStatic + @Parameterized.Parameters + fun data() = + listOf( + arrayOf("gemini-pro", "models/gemini-pro"), + arrayOf("x/gemini-pro", "models/x/gemini-pro"), + arrayOf("models/gemini-pro", "models/gemini-pro"), + arrayOf("tunedModels/mymodel", "tunedModels/mymodel"), + arrayOf("tuneModels/mymodel", "models/tuneModels/mymodel"), + ) + } +} diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt index fe498754..ddf272f4 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt @@ -102,11 +102,14 @@ internal fun commonTest(status: HttpStatusCode = HttpStatusCode.OK, block: Commo val mockEngine = MockEngine { respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) } - val controller = APIController("super_cool_test_key", "gemini-pro", mockEngine) - val model = GenerativeModel("gemini-pro", "super_cool_test_key", controller = controller) + val model = createGenerativeModel("gemini-pro", "super_cool_test_key", mockEngine) CommonTestScope(channel, model).block() } +/** Simple wrapper that guarantees the model and APIController are created using the same data */ +internal fun createGenerativeModel(name: String, apikey: String, engine: MockEngine) = + GenerativeModel(name, apikey, controller = APIController(apikey, name, engine)) + /** * A variant of [commonTest] for performing *streaming-based* snapshot tests. *