Skip to content

Commit

Permalink
Improve model name support (#47)
Browse files Browse the repository at this point in the history
Testing is based on examining the request data.
  • Loading branch information
rlazo authored Jan 30, 2024
1 parent 2fe1dfb commit 5a24830
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
1 change: 1 addition & 0 deletions generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ dependencies {
testImplementation("junit:junit:4.13.2")
testImplementation("io.kotest:kotest-assertions-core:4.0.7")
testImplementation("io.kotest:kotest-assertions-jvm:4.0.7")
testImplementation("io.kotest:kotest-assertions-json:4.0.7")
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ internal class APIController(
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
private fun fullModelName(name: String): String =
name.takeIf { it.startsWith("models/") } ?: "models/$name"
name.takeIf { it.startsWith("models/") || it.startsWith("tunedModels/") } ?: "models/$name"

/**
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,25 @@
package com.google.ai.client.generativeai

import com.google.ai.client.generativeai.util.commonTest
import com.google.ai.client.generativeai.util.createGenerativeModel
import com.google.ai.client.generativeai.util.createResponses
import com.google.ai.client.generativeai.util.doBlocking
import com.google.ai.client.generativeai.util.prepareStreamingResponse
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.withTimeout
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

internal class GenerativeModelTests {
private val testTimeout = 5.seconds
Expand All @@ -46,3 +56,39 @@ internal class GenerativeModelTests {
}
}
}

@RunWith(Parameterized::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {

@Test
fun `request should include right model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val model = createGenerativeModel(modelName, "super_cool_test_key", mockEngine)

withTimeout(5.seconds) {
model.generateContentStream().collect {
it.candidates.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
}

companion object {
@JvmStatic
@Parameterized.Parameters
fun data() =
listOf(
arrayOf("gemini-pro", "models/gemini-pro"),
arrayOf("x/gemini-pro", "models/x/gemini-pro"),
arrayOf("models/gemini-pro", "models/gemini-pro"),
arrayOf("tunedModels/mymodel", "tunedModels/mymodel"),
arrayOf("tuneModels/mymodel", "models/tuneModels/mymodel"),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,14 @@ internal fun commonTest(status: HttpStatusCode = HttpStatusCode.OK, block: Commo
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val controller = APIController("super_cool_test_key", "gemini-pro", mockEngine)
val model = GenerativeModel("gemini-pro", "super_cool_test_key", controller = controller)
val model = createGenerativeModel("gemini-pro", "super_cool_test_key", mockEngine)
CommonTestScope(channel, model).block()
}

/** Simple wrapper that guarantees the model and APIController are created using the same data */
internal fun createGenerativeModel(name: String, apikey: String, engine: MockEngine) =
GenerativeModel(name, apikey, controller = APIController(apikey, name, engine))

/**
* A variant of [commonTest] for performing *streaming-based* snapshot tests.
*
Expand Down

0 comments on commit 5a24830

Please sign in to comment.