Skip to content

Commit

Permalink
factory methods and removing fullModelName duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
David Motsonashvili committed Jun 7, 2024
1 parent 12550d7 commit efdda97
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.ai.client.generativeai.common
import android.util.Log
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import com.google.ai.client.generativeai.common.util.fullModelName
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
Expand Down Expand Up @@ -213,12 +214,6 @@ interface HeaderProvider {
suspend fun generateHeaders(): Map<String, String>
}

/**
* Ensures the model name provided has a `models/` prefix
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

private suspend fun validateResponse(response: HttpResponse) {
if (response.status == HttpStatusCode.OK) return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.util.fullModelName
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

Expand All @@ -39,9 +40,27 @@ data class GenerateContentRequest(

@Serializable
data class CountTokensRequest(
val generateContentRequest: GenerateContentRequest,
val generateContentRequest: GenerateContentRequest? = null,
val model: String? = null,
val contents: List<Content>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
val tools: List<Tool>? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
) : Request
) : Request {
companion object {
fun forGenAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
generateContentRequest =
generateContentRequest.model?.let {
generateContentRequest.copy(model = fullModelName(it))
} ?: generateContentRequest
)

fun forVertexAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
model = generateContentRequest.model?.let { fullModelName(it) },
contents = generateContentRequest.contents,
tools = generateContentRequest.tools,
systemInstruction = generateContentRequest.systemInstruction,
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.google.ai.client.generativeai.common.util

/**
* Ensures the model name provided has a `models/` prefix
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ internal constructor(
)

private fun constructCountTokensRequest(vararg prompt: Content) =
CountTokensRequest(generateContentRequest = constructRequest(*prompt))
CountTokensRequest.forGenAI(constructRequest(*prompt))

private fun GenerateContentResponse.validate() = apply {
if (candidates.isEmpty() && promptFeedback == null) {
Expand Down

0 comments on commit efdda97

Please sign in to comment.