From 69ab2ab4969e8ca4a0fd761f7845c5816266e924 Mon Sep 17 00:00:00 2001 From: Rodrigo Lazo Paz Date: Wed, 28 Feb 2024 14:24:10 -0500 Subject: [PATCH] Pass the requestOptions object to APIController as is. This will make code changes that add fields to the `requestOptions` object easier to do in the future. --- .../ai/client/generativeai/GenerativeModel.kt | 2 +- .../generativeai/internal/api/APIController.kt | 17 ++++++++++------- .../google/ai/client/generativeai/util/tests.kt | 9 +-------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 94f4f9c5..77c86cce 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -71,7 +71,7 @@ internal constructor( generationConfig, safetySettings, requestOptions, - APIController(apiKey, modelName, requestOptions.apiVersion, requestOptions.timeout) + APIController(apiKey, modelName, requestOptions) ) /** diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt index fd750111..8358c8e6 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.internal.api import com.google.ai.client.generativeai.BuildConfig import com.google.ai.client.generativeai.internal.util.decodeToFlow +import com.google.ai.client.generativeai.type.RequestOptions import com.google.ai.client.generativeai.type.ServerException import io.ktor.client.HttpClient import io.ktor.client.call.body @@ -37,7 +38,6 @@ import io.ktor.http.ContentType import io.ktor.http.HttpStatusCode import io.ktor.http.contentType import io.ktor.serialization.kotlinx.json.json -import kotlin.time.Duration import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow @@ -67,8 +67,7 @@ internal val JSON = Json { internal class APIController( private val key: String, model: String, - private val apiVersion: String, - private val timeout: Duration, + private val requestOptions: RequestOptions, httpEngine: HttpClientEngine = OkHttp.create(), ) { private val model = fullModelName(model) @@ -76,7 +75,7 @@ internal class APIController( private val client = HttpClient(httpEngine) { install(HttpTimeout) { - requestTimeoutMillis = timeout.inWholeMilliseconds + requestTimeoutMillis = requestOptions.timeout.inWholeMilliseconds socketTimeoutMillis = 80_000 } install(ContentNegotiation) { json(JSON) } @@ -84,13 +83,15 @@ internal class APIController( suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse = client - .post("$DOMAIN/$apiVersion/$model:generateContent") { applyCommonConfiguration(request) } + .post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") { + applyCommonConfiguration(request) + } .also { validateResponse(it) } .body() fun generateContentStream(request: GenerateContentRequest): Flow { return client.postStream( - "$DOMAIN/$apiVersion/$model:streamGenerateContent?alt=sse" + "$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" ) { applyCommonConfiguration(request) } @@ -98,7 +99,9 @@ internal class APIController( suspend fun countTokens(request: CountTokensRequest): CountTokensResponse = client - .post("$DOMAIN/$apiVersion/$model:countTokens") { applyCommonConfiguration(request) } + .post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") { + applyCommonConfiguration(request) + } .also { validateResponse(it) } .body() diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt index 092b141c..90707336 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt @@ -121,14 +121,7 @@ internal fun createGenerativeModel( GenerativeModel( name, apikey, - controller = - APIController( - "super_cool_test_key", - name, - requestOptions.apiVersion, - requestOptions.timeout, - engine - ) + controller = APIController("super_cool_test_key", name, requestOptions, engine) ) /**