Skip to content

Commit

Permalink
Expand RequestOptions class in common
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo committed Mar 20, 2024
1 parent 39396f0 commit 1b5d701
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ import kotlinx.coroutines.flow.timeout
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json

const val DOMAIN = "https://generativelanguage.googleapis.com"

val JSON = Json {
ignoreUnknownKeys = true
prettyPrint = false
Expand Down Expand Up @@ -90,7 +88,7 @@ internal constructor(
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") {
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
Expand All @@ -103,7 +101,7 @@ internal constructor(
fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> =
client
.postStream<GenerateContentResponse>(
"$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
"${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
}
Expand All @@ -113,7 +111,7 @@ internal constructor(
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") {
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@ import kotlin.time.toDuration
* first response.
* @property apiVersion the api endpoint to call.
*/
class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") {
class RequestOptions(
val timeout: Duration,
val apiVersion: String = "v1",
val endpoint: String = "https://generativelanguage.googleapis.com"
) {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1"
apiVersion: String = "v1",
endpoint: String = "https://generativelanguage.googleapis.com"
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion
apiVersion,
endpoint
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

internal class GenerativeModelTests {
internal class APIControllerTests {
private val testTimeout = 5.seconds

@Test
Expand Down Expand Up @@ -69,6 +69,53 @@ internal class GenerativeModelTests {
}
}

internal class EndpointTests {
@Test
fun `using default endpoint`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.host shouldBe "generativelanguage.googleapis.com"
}

@Test
fun `using custom endpoint`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(endpoint = "https://my.custom.endpoint"),
mockEngine
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
}
}

@RunWith(Parameterized::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {

Expand Down
2 changes: 1 addition & 1 deletion generativeai-android-sample/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ dependencies {
debugImplementation("androidx.compose.ui:ui-tooling")
debugImplementation("androidx.compose.ui:ui-test-manifest")

implementation("com.google.ai.client.generativeai:generativeai:0.2.2")
implementation("com.google.ai.client.generativeai:generativeai:0.2.20")
}
2 changes: 1 addition & 1 deletion generativeai/gradle.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.2.2
version=0.2.20

0 comments on commit 1b5d701

Please sign in to comment.