Skip to content

Commit

Permalink
Pass the requestOptions object to APIController as is.
Browse files Browse the repository at this point in the history
This will make code changes that add fields to the `requestOptions`
object easier to do in the future.
  • Loading branch information
rlazo committed Feb 28, 2024
1 parent a8ac99e commit 69ab2ab
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ internal constructor(
generationConfig,
safetySettings,
requestOptions,
APIController(apiKey, modelName, requestOptions.apiVersion, requestOptions.timeout)
APIController(apiKey, modelName, requestOptions)
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.internal.api

import com.google.ai.client.generativeai.BuildConfig
import com.google.ai.client.generativeai.internal.util.decodeToFlow
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ServerException
import io.ktor.client.HttpClient
import io.ktor.client.call.body
Expand All @@ -37,7 +38,6 @@ import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlin.time.Duration
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
Expand Down Expand Up @@ -67,38 +67,41 @@ internal val JSON = Json {
internal class APIController(
private val key: String,
model: String,
private val apiVersion: String,
private val timeout: Duration,
private val requestOptions: RequestOptions,
httpEngine: HttpClientEngine = OkHttp.create(),
) {
private val model = fullModelName(model)

private val client =
HttpClient(httpEngine) {
install(HttpTimeout) {
requestTimeoutMillis = timeout.inWholeMilliseconds
requestTimeoutMillis = requestOptions.timeout.inWholeMilliseconds
socketTimeoutMillis = 80_000
}
install(ContentNegotiation) { json(JSON) }
}

suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
client
.post("$DOMAIN/$apiVersion/$model:generateContent") { applyCommonConfiguration(request) }
.post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()

fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> {
return client.postStream<GenerateContentResponse>(
"$DOMAIN/$apiVersion/$model:streamGenerateContent?alt=sse"
"$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
}
}

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
client
.post("$DOMAIN/$apiVersion/$model:countTokens") { applyCommonConfiguration(request) }
.post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,7 @@ internal fun createGenerativeModel(
GenerativeModel(
name,
apikey,
controller =
APIController(
"super_cool_test_key",
name,
requestOptions.apiVersion,
requestOptions.timeout,
engine
)
controller = APIController("super_cool_test_key", name, requestOptions, engine)
)

/**
Expand Down

0 comments on commit 69ab2ab

Please sign in to comment.