Skip to content

Commit

Permalink
Add test-only APIController constructor for Vertex AI unit tests (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanzimfh authored Jul 11, 2024
1 parent 0d951a5 commit fa9688a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 8 deletions.
1 change: 1 addition & 0 deletions .changes/common/crowd-company-bead-authority.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MINOR","changes":["Add test-only APIController constructor for Vertex AI unit tests"]}
1 change: 1 addition & 0 deletions common/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ dependencies {
implementation("com.google.guava:listenablefuture:1.0")
implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02")
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02")
compileOnly("io.ktor:ktor-client-mock:$ktorVersion")
testImplementation("junit:junit:4.13.2")
testImplementation("io.kotest:kotest-assertions-core:4.0.7")
testImplementation("io.kotest:kotest-assertions-jvm:4.0.7")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package com.google.ai.client.generativeai.common

import android.util.Log
import androidx.annotation.VisibleForTesting
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import com.google.ai.client.generativeai.common.util.fullModelName
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.client.engine.okhttp.OkHttp
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
Expand All @@ -35,9 +38,12 @@ import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.bodyAsChannel
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.headersOf
import io.ktor.serialization.kotlinx.json.json
import io.ktor.utils.io.ByteChannel
import kotlin.time.Duration
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.TimeoutCancellationException
Expand Down Expand Up @@ -85,6 +91,24 @@ internal constructor(
headerProvider: HeaderProvider? = null,
) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider)

@VisibleForTesting(otherwise = VisibleForTesting.NONE)
constructor(
key: String,
model: String,
requestOptions: RequestOptions,
apiClient: String,
headerProvider: HeaderProvider?,
channel: ByteChannel,
status: HttpStatusCode,
) : this(
key,
model,
requestOptions,
MockEngine { respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) },
apiClient,
headerProvider,
)

private val model = fullModelName(model)

private val client =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.TextPart
import io.kotest.matchers.collections.shouldNotBeEmpty
import io.kotest.matchers.nulls.shouldNotBeNull
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
Expand Down Expand Up @@ -105,17 +101,15 @@ internal fun commonTest(
block: CommonTest,
) = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val apiController =
APIController(
"super_cool_test_key",
"gemini-pro",
requestOptions,
mockEngine,
TEST_CLIENT_ID,
null,
channel,
status,
)
CommonTestScope(channel, apiController).block()
}
Expand Down

0 comments on commit fa9688a

Please sign in to comment.