diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt index 13cfccc2..757554ce 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -45,8 +45,6 @@ import kotlinx.coroutines.flow.timeout import kotlinx.coroutines.launch import kotlinx.serialization.json.Json -const val DOMAIN = "https://generativelanguage.googleapis.com" - val JSON = Json { ignoreUnknownKeys = true prettyPrint = false @@ -90,7 +88,7 @@ internal constructor( suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse = try { client - .post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") { + .post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") { applyCommonConfiguration(request) } .also { validateResponse(it) } @@ -103,7 +101,7 @@ internal constructor( fun generateContentStream(request: GenerateContentRequest): Flow = client .postStream( - "$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" + "${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" ) { applyCommonConfiguration(request) } @@ -113,7 +111,7 @@ internal constructor( suspend fun countTokens(request: CountTokensRequest): CountTokensResponse = try { client - .post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") { + .post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") { applyCommonConfiguration(request) } .also { validateResponse(it) } diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt index 82ea8842..4decb5c1 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt @@ -28,13 +28,19 @@ import kotlin.time.toDuration * first response. * @property apiVersion the api endpoint to call. */ -class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") { +class RequestOptions( + val timeout: Duration, + val apiVersion: String = "v1", + val endpoint: String = "https://generativelanguage.googleapis.com" +) { @JvmOverloads constructor( timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, - apiVersion: String = "v1" + apiVersion: String = "v1", + endpoint: String = "https://generativelanguage.googleapis.com" ) : this( (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), - apiVersion + apiVersion, + endpoint ) } diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/GenerativeModelTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt similarity index 70% rename from common/src/test/java/com/google/ai/client/generativeai/common/GenerativeModelTests.kt rename to common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index 6269b673..a05d6d67 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/GenerativeModelTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -39,7 +39,7 @@ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized -internal class GenerativeModelTests { +internal class APIControllerTests { private val testTimeout = 5.seconds @Test @@ -69,6 +69,53 @@ internal class GenerativeModelTests { } } +internal class EndpointTests { + @Test + fun `using default endpoint`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + val controller = + APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine) + + withTimeout(5.seconds) { + controller.generateContentStream(textGenerateContentRequest("cats")).collect { + it.candidates?.isEmpty() shouldBe false + channel.close() + } + } + + mockEngine.requestHistory.first().url.host shouldBe "generativelanguage.googleapis.com" + } + + @Test + fun `using custom endpoint`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + val controller = + APIController( + "super_cool_test_key", + "gemini-pro-1.0", + RequestOptions(endpoint = "https://my.custom.endpoint"), + mockEngine + ) + + withTimeout(5.seconds) { + controller.generateContentStream(textGenerateContentRequest("cats")).collect { + it.candidates?.isEmpty() shouldBe false + channel.close() + } + } + + mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint" + } +} + @RunWith(Parameterized::class) internal class ModelNamingTests(private val modelName: String, private val actualName: String) { diff --git a/generativeai-android-sample/app/build.gradle.kts b/generativeai-android-sample/app/build.gradle.kts index 9e9fd99f..af5e5f82 100644 --- a/generativeai-android-sample/app/build.gradle.kts +++ b/generativeai-android-sample/app/build.gradle.kts @@ -75,5 +75,5 @@ dependencies { debugImplementation("androidx.compose.ui:ui-tooling") debugImplementation("androidx.compose.ui:ui-test-manifest") - implementation("com.google.ai.client.generativeai:generativeai:0.2.2") + implementation("com.google.ai.client.generativeai:generativeai:0.2.20") } diff --git a/generativeai/gradle.properties b/generativeai/gradle.properties index f4f46191..fe4af4a0 100644 --- a/generativeai/gradle.properties +++ b/generativeai/gradle.properties @@ -1 +1 @@ -version=0.2.2 +version=0.2.20