diff --git a/.changes/common/crowd-company-bead-authority.json b/.changes/common/crowd-company-bead-authority.json new file mode 100644 index 00000000..8868b7e1 --- /dev/null +++ b/.changes/common/crowd-company-bead-authority.json @@ -0,0 +1 @@ +{"type":"MINOR","changes":["Add test-only APIController constructor for Vertex AI unit tests"]} diff --git a/common/build.gradle.kts b/common/build.gradle.kts index bd02f5fa..842e9301 100644 --- a/common/build.gradle.kts +++ b/common/build.gradle.kts @@ -90,6 +90,7 @@ dependencies { implementation("com.google.guava:listenablefuture:1.0") implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02") implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02") + compileOnly("io.ktor:ktor-client-mock:$ktorVersion") testImplementation("junit:junit:4.13.2") testImplementation("io.kotest:kotest-assertions-core:4.0.7") testImplementation("io.kotest:kotest-assertions-jvm:4.0.7") diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt index d815f44a..aaf79634 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -17,12 +17,15 @@ package com.google.ai.client.generativeai.common import android.util.Log +import androidx.annotation.VisibleForTesting import com.google.ai.client.generativeai.common.server.FinishReason import com.google.ai.client.generativeai.common.util.decodeToFlow import com.google.ai.client.generativeai.common.util.fullModelName import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.engine.HttpClientEngine +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.plugins.HttpTimeout import io.ktor.client.plugins.contentnegotiation.ContentNegotiation @@ -35,9 +38,12 @@ import io.ktor.client.statement.HttpResponse import io.ktor.client.statement.bodyAsChannel import io.ktor.client.statement.bodyAsText import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders import io.ktor.http.HttpStatusCode import io.ktor.http.contentType +import io.ktor.http.headersOf import io.ktor.serialization.kotlinx.json.json +import io.ktor.utils.io.ByteChannel import kotlin.time.Duration import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.TimeoutCancellationException @@ -85,6 +91,24 @@ internal constructor( headerProvider: HeaderProvider? = null, ) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider) + @VisibleForTesting(otherwise = VisibleForTesting.NONE) + constructor( + key: String, + model: String, + requestOptions: RequestOptions, + apiClient: String, + headerProvider: HeaderProvider?, + channel: ByteChannel, + status: HttpStatusCode, + ) : this( + key, + model, + requestOptions, + MockEngine { respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) }, + apiClient, + headerProvider, + ) + private val model = fullModelName(model) private val client = diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt index fa84d258..0726d959 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt @@ -28,11 +28,7 @@ import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.TextPart import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldNotBeNull -import io.ktor.client.engine.mock.MockEngine -import io.ktor.client.engine.mock.respond -import io.ktor.http.HttpHeaders import io.ktor.http.HttpStatusCode -import io.ktor.http.headersOf import io.ktor.utils.io.ByteChannel import io.ktor.utils.io.close import io.ktor.utils.io.writeFully @@ -105,17 +101,15 @@ internal fun commonTest( block: CommonTest, ) = doBlocking { val channel = ByteChannel(autoFlush = true) - val mockEngine = MockEngine { - respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) - } val apiController = APIController( "super_cool_test_key", "gemini-pro", requestOptions, - mockEngine, TEST_CLIENT_ID, null, + channel, + status, ) CommonTestScope(channel, apiController).block() }