Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand RequestOptions class in common #89

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ import kotlinx.coroutines.flow.timeout
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json

const val DOMAIN = "https://generativelanguage.googleapis.com"

val JSON = Json {
ignoreUnknownKeys = true
prettyPrint = false
Expand Down Expand Up @@ -90,7 +88,7 @@ internal constructor(
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") {
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
Expand All @@ -103,7 +101,7 @@ internal constructor(
fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> =
client
.postStream<GenerateContentResponse>(
"$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
"${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
}
Expand All @@ -113,7 +111,7 @@ internal constructor(
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") {
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@ import kotlin.time.toDuration
* first response.
* @property apiVersion the api endpoint to call.
*/
class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") {
class RequestOptions(
val timeout: Duration,
val apiVersion: String = "v1",
val endpoint: String = "https://generativelanguage.googleapis.com"
) {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1"
apiVersion: String = "v1",
endpoint: String = "https://generativelanguage.googleapis.com"
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion
apiVersion,
endpoint
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

internal class GenerativeModelTests {
internal class APIControllerTests {
private val testTimeout = 5.seconds

@Test
Expand Down Expand Up @@ -69,6 +69,53 @@ internal class GenerativeModelTests {
}
}

internal class EndpointTests {
@Test
fun `using default endpoint`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.host shouldBe "generativelanguage.googleapis.com"
}

@Test
fun `using custom endpoint`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(endpoint = "https://my.custom.endpoint"),
mockEngine
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
}
}

@RunWith(Parameterized::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {

Expand Down
Loading