From 1a88a235f2711394874239ac18a131046c63351d Mon Sep 17 00:00:00 2001 From: Daymon Date: Tue, 2 Jul 2024 14:37:03 -0500 Subject: [PATCH 1/4] Align proto primitives --- .../common/breath-committee-dad-curtain.json | 1 + .../carpenter-beggar-creator-celery.json | 1 + .../breath-brush-achiever-boat.json | 1 + .../generativeai/common/APIController.kt | 8 +- .../client/generativeai/common/Exceptions.kt | 2 +- .../ai/client/generativeai/common/Request.kt | 44 ++++++----- .../ai/client/generativeai/common/Response.kt | 10 +-- .../generativeai/common/client/Types.kt | 46 ++++++------ .../generativeai/common/server/Types.kt | 50 ++++++------- .../generativeai/common/shared/Types.kt | 19 ++--- .../generativeai/common/APIControllerTests.kt | 16 ++-- .../common/StreamingSnapshotTests.kt | 31 ++++---- .../generativeai/common/UnarySnapshotTests.kt | 73 +++++++++---------- .../ai/client/generativeai/GenerativeModel.kt | 8 +- .../generativeai/internal/util/conversions.kt | 24 +++--- .../ai/client/generativeai/type/Candidate.kt | 2 +- .../ai/client/generativeai/type/Content.kt | 9 +-- .../generativeai/type/FunctionDeclarations.kt | 10 +-- .../generativeai/type/FunctionParameter.kt | 2 +- .../generativeai/type/GenerationConfig.kt | 28 +++---- .../generativeai/type/PromptFeedback.kt | 5 +- .../client/generativeai/type/SafetySetting.kt | 2 +- .../generativeai/GenerativeModelTests.kt | 25 +++++-- 23 files changed, 214 insertions(+), 203 deletions(-) create mode 100644 .changes/common/breath-committee-dad-curtain.json create mode 100644 .changes/common/carpenter-beggar-creator-celery.json create mode 100644 .changes/generativeai/breath-brush-achiever-boat.json diff --git a/.changes/common/breath-committee-dad-curtain.json b/.changes/common/breath-committee-dad-curtain.json new file mode 100644 index 00000000..2c60c795 --- /dev/null +++ b/.changes/common/breath-committee-dad-curtain.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]} diff --git a/.changes/common/carpenter-beggar-creator-celery.json b/.changes/common/carpenter-beggar-creator-celery.json new file mode 100644 index 00000000..2c60c795 --- /dev/null +++ b/.changes/common/carpenter-beggar-creator-celery.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]} diff --git a/.changes/generativeai/breath-brush-achiever-boat.json b/.changes/generativeai/breath-brush-achiever-boat.json new file mode 100644 index 00000000..2c60c795 --- /dev/null +++ b/.changes/generativeai/breath-brush-achiever-boat.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt index d815f44a..adaf00f8 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -235,19 +235,19 @@ private suspend fun validateResponse(response: HttpResponse) { if (message.contains("quota")) { throw QuotaExceededException(message) } - if (error.details?.any { "SERVICE_DISABLED" == it.reason } == true) { + if (error.details.any { "SERVICE_DISABLED" == it.reason }) { throw ServiceDisabledException(message) } throw ServerException(message) } private fun GenerateContentResponse.validate() = apply { - if ((candidates?.isEmpty() != false) && promptFeedback == null) { + if (candidates.isEmpty() && promptFeedback == null) { throw SerializationException("Error deserializing response, found no valid fields") } promptFeedback?.blockReason?.let { throw PromptBlockedException(this) } candidates - ?.mapNotNull { it.finishReason } - ?.firstOrNull { it != FinishReason.STOP } + .map { it.finishReason } + .firstOrNull { it != FinishReason.STOP } ?.let { throw ResponseStoppedException(this) } } diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt index 15cd25d9..05960d2f 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt @@ -96,7 +96,7 @@ class InvalidStateException(message: String, cause: Throwable? = null) : */ class ResponseStoppedException(val response: GenerateContentResponse, cause: Throwable? = null) : GoogleGenerativeAIException( - "Content generation stopped. Reason: ${response.candidates?.first()?.finishReason?.name}", + "Content generation stopped. Reason: ${response.candidates.first().finishReason?.name}", cause, ) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index b252b2e1..b1a4f36a 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.client.GenerationConfig @@ -22,45 +24,41 @@ import com.google.ai.client.generativeai.common.client.ToolConfig import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.SafetySetting import com.google.ai.client.generativeai.common.util.fullModelName -import kotlinx.serialization.SerialName +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable sealed interface Request @Serializable data class GenerateContentRequest( - val model: String? = null, + val model: String, val contents: List, - @SerialName("safety_settings") val safetySettings: List? = null, - @SerialName("generation_config") val generationConfig: GenerationConfig? = null, - val tools: List? = null, - @SerialName("tool_config") var toolConfig: ToolConfig? = null, - @SerialName("system_instruction") val systemInstruction: Content? = null, + val safetySettings: List = emptyList(), + val generationConfig: GenerationConfig? = null, + val tools: List = emptyList(), + val toolConfig: ToolConfig? = null, + val systemInstruction: Content? = null, ) : Request @Serializable data class CountTokensRequest( + val model: String, + val contents: List = emptyList(), + val tools: List = emptyList(), val generateContentRequest: GenerateContentRequest? = null, - val model: String? = null, - val contents: List? = null, - val tools: List? = null, - @SerialName("system_instruction") val systemInstruction: Content? = null, + val systemInstruction: Content? = null, ) : Request { companion object { - fun forGenAI(generateContentRequest: GenerateContentRequest) = - CountTokensRequest( - generateContentRequest = - generateContentRequest.model?.let { - generateContentRequest.copy(model = fullModelName(it)) - } ?: generateContentRequest - ) + fun forGenAI(request: GenerateContentRequest) = + CountTokensRequest(fullModelName(request.model), request.contents, emptyList(), request) - fun forVertexAI(generateContentRequest: GenerateContentRequest) = + fun forVertexAI(request: GenerateContentRequest) = CountTokensRequest( - model = generateContentRequest.model?.let { fullModelName(it) }, - contents = generateContentRequest.contents, - tools = generateContentRequest.tools, - systemInstruction = generateContentRequest.systemInstruction, + fullModelName(request.model), + request.contents, + request.tools, + null, + request.systemInstruction, ) } } diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt index 8b5f1ba7..4194ae14 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt @@ -25,20 +25,20 @@ sealed interface Response @Serializable data class GenerateContentResponse( - val candidates: List? = null, + val candidates: List = emptyList(), val promptFeedback: PromptFeedback? = null, val usageMetadata: UsageMetadata? = null, ) : Response @Serializable -data class CountTokensResponse(val totalTokens: Int, val totalBillableCharacters: Int? = null) : +data class CountTokensResponse(val totalTokens: Int = 0, val totalBillableCharacters: Int = 0) : Response @Serializable data class GRpcErrorResponse(val error: GRpcError) : Response @Serializable data class UsageMetadata( - val promptTokenCount: Int? = null, - val candidatesTokenCount: Int? = null, - val totalTokenCount: Int? = null, + val promptTokenCount: Int = 0, + val candidatesTokenCount: Int = 0, + val totalTokenCount: Int = 0, ) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index e32b9196..2763702c 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -22,32 +22,30 @@ import kotlinx.serialization.json.JsonObject @Serializable data class GenerationConfig( - val temperature: Float?, - @SerialName("top_p") val topP: Float?, - @SerialName("top_k") val topK: Int?, - @SerialName("candidate_count") val candidateCount: Int?, - @SerialName("max_output_tokens") val maxOutputTokens: Int?, - @SerialName("stop_sequences") val stopSequences: List?, - @SerialName("response_mime_type") val responseMimeType: String? = null, - @SerialName("presence_penalty") val presencePenalty: Float? = null, - @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, - @SerialName("response_schema") val responseSchema: Schema? = null, + val temperature: Float = 0f, + val topP: Float = 0f, + val topK: Int = 0, + val candidateCount: Int = 0, + val maxOutputTokens: Int = 0, + val stopSequences: List = emptyList(), + val responseMimeType: String = "", + val presencePenalty: Float = 0f, + val frequencyPenalty: Float = 0f, + val responseSchema: Schema? = null, ) @Serializable data class Tool( - val functionDeclarations: List? = null, + val functionDeclarations: List = emptyList(), // This is a json object because it is not possible to make a data class with no parameters. val codeExecution: JsonObject? = null, ) @Serializable -data class ToolConfig( - @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig -) +data class ToolConfig(val functionCallingConfig: FunctionCallingConfig = FunctionCallingConfig()) @Serializable -data class FunctionCallingConfig(val mode: Mode) { +data class FunctionCallingConfig(val mode: Mode? = null) { @Serializable enum class Mode { @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, @@ -58,16 +56,20 @@ data class FunctionCallingConfig(val mode: Mode) { } @Serializable -data class FunctionDeclaration(val name: String, val description: String, val parameters: Schema) +data class FunctionDeclaration( + val name: String, + val description: String, + val parameters: Schema? = null, +) @Serializable data class Schema( val type: String, - val description: String? = null, - val format: String? = null, - val nullable: Boolean? = false, - val enum: List? = null, - val properties: Map? = null, - val required: List? = null, + val description: String = "", + val format: String = "", + val nullable: Boolean = false, + val enum: List = emptyList(), + val properties: Map = emptyMap(), + val required: List = emptyList(), val items: Schema? = null, ) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt index 16b25a78..db672fd3 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common.server import com.google.ai.client.generativeai.common.shared.Content @@ -37,7 +39,7 @@ object FinishReasonSerializer : @Serializable data class PromptFeedback( val blockReason: BlockReason? = null, - val safetyRatings: List? = null, + val safetyRatings: List = emptyList(), ) @Serializable(BlockReasonSerializer::class) @@ -52,59 +54,51 @@ enum class BlockReason { data class Candidate( val content: Content? = null, val finishReason: FinishReason? = null, - val safetyRatings: List? = null, + val safetyRatings: List = emptyList(), val citationMetadata: CitationMetadata? = null, val groundingMetadata: GroundingMetadata? = null, ) @Serializable -data class CitationMetadata -@OptIn(ExperimentalSerializationApi::class) -constructor(@JsonNames("citations") val citationSources: List) +data class CitationMetadata( + @JsonNames("citations") val citationSources: List = emptyList() +) @Serializable data class CitationSources( val startIndex: Int = 0, - val endIndex: Int, - val uri: String, - val license: String? = null, + val endIndex: Int = 0, + val uri: String = "", + val license: String = "", ) @Serializable data class SafetyRating( val category: HarmCategory, val probability: HarmProbability, - val blocked: Boolean? = null, // TODO(): any reason not to default to false? - val probabilityScore: Float? = null, + val blocked: Boolean = false, + val probabilityScore: Float = 0f, val severity: HarmSeverity? = null, - val severityScore: Float? = null, + val severityScore: Float = 0f, ) @Serializable data class GroundingMetadata( - @SerialName("web_search_queries") val webSearchQueries: List?, - @SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?, - @SerialName("retrieval_queries") val retrievalQueries: List?, - @SerialName("grounding_attribution") val groundingAttribution: List?, + val webSearchQueries: List = emptyList(), + val searchEntryPoint: SearchEntryPoint? = null, + val retrievalQueries: List = emptyList(), + val groundingAttribution: List = emptyList(), ) @Serializable -data class SearchEntryPoint( - @SerialName("rendered_content") val renderedContent: String?, - @SerialName("sdk_blob") val sdkBlob: String?, -) +data class SearchEntryPoint(val renderedContent: String = "", val sdkBlob: String = "") +// TODO() Has a different definition for labs vs vertex. May need to split into diff types in future +// (when labs supports it) @Serializable -data class GroundingAttribution( - val segment: Segment, - @SerialName("confidence_score") val confidenceScore: Float?, -) +data class GroundingAttribution(val segment: Segment, val confidenceScore: Float = 0f) -@Serializable -data class Segment( - @SerialName("start_index") val startIndex: Int, - @SerialName("end_index") val endIndex: Int, -) +@Serializable data class Segment(val startIndex: Int = 0, val endIndex: Int = 0) @Serializable(HarmProbabilitySerializer::class) enum class HarmProbability { diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt index 7dc31858..ac01acb8 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -45,13 +45,13 @@ typealias Base64 = String @ExperimentalSerializationApi @Serializable -data class Content(@EncodeDefault val role: String? = "user", val parts: List) +data class Content(@EncodeDefault val role: String = "", val parts: List) @Serializable(PartSerializer::class) sealed interface Part -@Serializable data class TextPart(val text: String) : Part +@Serializable data class TextPart(val text: String = "") : Part -@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part +@Serializable data class BlobPart(val inlineData: Blob) : Part @Serializable data class FunctionCallPart(val functionCall: FunctionCall) : Part @@ -64,17 +64,14 @@ data class CodeExecutionResultPart(val codeExecutionResult: CodeExecutionResult) @Serializable data class FunctionResponse(val name: String, val response: JsonObject) -@Serializable data class FunctionCall(val name: String, val args: Map) +@Serializable +data class FunctionCall(val name: String, val args: Map = emptyMap()) -@Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part +@Serializable data class FileDataPart(val fileData: FileData) : Part -@Serializable -data class FileData( - @SerialName("mime_type") val mimeType: String, - @SerialName("file_uri") val fileUri: String, -) +@Serializable data class FileData(val mimeType: String, val fileUri: String) -@Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64) +@Serializable data class Blob(val mimeType: String, val data: Base64) @Serializable data class ExecutableCode(val language: String, val code: String) diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index aa3c402b..d748ff51 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class, ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.client.FunctionCallingConfig @@ -43,6 +45,7 @@ import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.delay import kotlinx.coroutines.withTimeout +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString import kotlinx.serialization.json.JsonObject import org.junit.Test @@ -64,7 +67,7 @@ internal class APIControllerTests { withTimeout(testTimeout) { responses.collect { - it.candidates?.isEmpty() shouldBe false + it.candidates.isEmpty() shouldBe false channel.close() } } @@ -101,7 +104,7 @@ internal class RequestFormatTests { withTimeout(5.seconds) { controller.generateContentStream(textGenerateContentRequest("cats")).collect { - it.candidates?.isEmpty() shouldBe false + it.candidates.isEmpty() shouldBe false channel.close() } } @@ -128,7 +131,7 @@ internal class RequestFormatTests { withTimeout(5.seconds) { controller.generateContentStream(textGenerateContentRequest("cats")).collect { - it.candidates?.isEmpty() shouldBe false + it.candidates.isEmpty() shouldBe false channel.close() } } @@ -193,8 +196,7 @@ internal class RequestFormatTests { } val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text - - requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode" + requestBodyAsText shouldContainJsonKey "toolConfig.functionCallingConfig.mode" } @Test @@ -320,7 +322,7 @@ internal class ModelNamingTests(private val modelName: String, private val actua withTimeout(5.seconds) { controller.generateContentStream(textGenerateContentRequest("cats")).collect { - it.candidates?.isEmpty() shouldBe false + it.candidates.isEmpty() shouldBe false channel.close() } } @@ -349,4 +351,4 @@ fun textGenerateContentRequest(prompt: String) = ) fun textCountTokenRequest(prompt: String) = - CountTokensRequest(generateContentRequest = textGenerateContentRequest(prompt)) + CountTokensRequest("", generateContentRequest = textGenerateContentRequest(prompt)) diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt index 7f151320..b56bd098 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.server.BlockReason @@ -31,6 +33,7 @@ import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.toList import kotlinx.coroutines.withTimeout +import kotlinx.serialization.ExperimentalSerializationApi import org.junit.Test internal class StreamingSnapshotTests { @@ -44,9 +47,9 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val responseList = responses.toList() responseList.isEmpty() shouldBe false - responseList.first().candidates?.first()?.finishReason shouldBe FinishReason.STOP - responseList.first().candidates?.first()?.content?.parts?.isEmpty() shouldBe false - responseList.first().candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + responseList.first().candidates.first().finishReason shouldBe FinishReason.STOP + responseList.first().candidates.first().content?.parts?.isEmpty() shouldBe false + responseList.first().candidates.first().safetyRatings.isEmpty() shouldBe false } } @@ -59,9 +62,9 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false responseList.forEach { - it.candidates?.first()?.finishReason shouldBe FinishReason.STOP - it.candidates?.first()?.content?.parts?.isEmpty() shouldBe false - it.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + it.candidates.first().finishReason shouldBe FinishReason.STOP + it.candidates.first().content?.parts?.isEmpty() shouldBe false + it.candidates.first().safetyRatings.isEmpty() shouldBe false } } } @@ -73,9 +76,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { responses.first { - it.candidates?.any { - it.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } ?: false - } ?: false + it.candidates.any { it.safetyRatings.any { it.category == HarmCategory.UNKNOWN } } } } } @@ -89,7 +90,7 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false - val part = responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart + val part = responseList.first().candidates.first().content?.parts?.first() as? TextPart part.shouldNotBeNull() part.text shouldContain "\"" } @@ -129,7 +130,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY + exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY } } @@ -141,8 +142,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val responseList = responses.toList() responseList.any { - it.candidates?.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } - ?: false + it.candidates.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } } shouldBe true } } @@ -155,8 +155,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val responseList = responses.toList() responseList.any { - it.candidates?.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } - ?: false + it.candidates.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } } shouldBe true } } @@ -168,7 +167,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.RECITATION + exception.response.candidates.first().finishReason shouldBe FinishReason.RECITATION } } diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt index ad5c634a..cc40256e 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.server.BlockReason @@ -36,10 +38,12 @@ import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.should import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe +import io.kotest.matchers.string.shouldBeEmpty import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.http.HttpStatusCode import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.withTimeout +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable import org.junit.Test @@ -54,10 +58,10 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP - response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.candidates.first().content?.parts?.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false } } @@ -67,10 +71,10 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP - response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.candidates.first().content?.parts?.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false } } @@ -80,9 +84,7 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.first { - it.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } ?: false - } + response.candidates.first { it.safetyRatings.any { it.category == HarmCategory.UNKNOWN } } } } @@ -92,17 +94,16 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.all { + response.candidates.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false + response.candidates.first().safetyRatings.all { it.probability == HarmProbability.NEGLIGIBLE } shouldBe true - response.candidates?.first()?.safetyRatings?.all { it.probabilityScore != null } shouldBe - true - response.candidates?.first()?.safetyRatings?.all { + response.candidates.first().safetyRatings.all { it.probabilityScore != 0f } shouldBe true + response.candidates.first().safetyRatings.all { it.severity == HarmSeverity.NEGLIGIBLE } shouldBe true - response.candidates?.first()?.safetyRatings?.all { it.severityScore != null } shouldBe true + response.candidates.first().safetyRatings.all { it.severityScore != 0f } shouldBe true } } @@ -154,7 +155,7 @@ internal class UnarySnapshotTests { shouldThrow { apiController.generateContent(textGenerateContentRequest("prompt")) } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY + exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY } } @@ -164,8 +165,8 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true + response.candidates.isEmpty() shouldBe false + response.candidates.first().citationMetadata?.citationSources?.isNotEmpty() shouldBe true } } @@ -175,11 +176,11 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true + response.candidates.isEmpty() shouldBe false + response.candidates.first().citationMetadata?.citationSources?.isNotEmpty() shouldBe true // Verify the values in the citation source - with(response.candidates?.first()?.citationMetadata?.citationSources?.first()!!) { - license shouldBe null + with(response.candidates.first().citationMetadata?.citationSources?.first()!!) { + license.shouldBeEmpty() startIndex shouldBe 0 } } @@ -191,8 +192,8 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.totalTokenCount shouldBe 363 } @@ -204,11 +205,11 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.promptTokenCount shouldBe 6 - response.usageMetadata?.totalTokenCount shouldBe null + response.usageMetadata?.totalTokenCount shouldBe 0 } } @@ -218,8 +219,8 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true + response.candidates.isEmpty() shouldBe false + response.candidates.first().citationMetadata?.citationSources?.isNotEmpty() shouldBe true } } @@ -229,10 +230,8 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - with( - response.candidates?.first()?.content?.parts?.first()?.shouldBeInstanceOf() - ) { + response.candidates.isEmpty() shouldBe false + with(response.candidates.first().content?.parts?.first()?.shouldBeInstanceOf()) { shouldNotBeNull() JSON.decodeFromString>(text).shouldNotBeEmpty() } @@ -314,7 +313,7 @@ internal class UnarySnapshotTests { goldenUnaryFile("success-function-call-null.json") { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart) + val callPart = (response.candidates.first().content!!.parts.first() as FunctionCallPart) callPart.functionCall.args["season"] shouldBe null } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index c68b49a5..c16c7875 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -195,9 +195,9 @@ internal constructor( GenerateContentRequest( modelName, prompt.map { it.toInternal() }, - safetySettings?.map { it.toInternal() }, + safetySettings?.map { it.toInternal() }.orEmpty(), generationConfig?.toInternal(), - tools?.map { it.toInternal() }, + tools?.map { it.toInternal() }.orEmpty(), toolConfig?.toInternal(), systemInstruction?.toInternal(), ) @@ -206,12 +206,12 @@ internal constructor( CountTokensRequest.forGenAI(constructRequest(*prompt)) private fun GenerateContentResponse.validate() = apply { - if (candidates.isEmpty() && promptFeedback == null) { + if (candidates.isEmpty() && promptFeedback?.blockReason == null) { throw SerializationException("Error deserializing response, found no valid fields") } promptFeedback?.blockReason?.let { throw PromptBlockedException(this) } candidates - .mapNotNull { it.finishReason } + .map { it.finishReason } .firstOrNull { it != FinishReason.STOP } ?.let { throw ResponseStoppedException(this) } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index c886a6dc..d38647eb 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.internal.util import android.graphics.Bitmap @@ -62,6 +64,7 @@ import com.google.ai.client.generativeai.type.ToolConfig import com.google.ai.client.generativeai.type.UsageMetadata import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject import org.json.JSONObject @@ -142,7 +145,7 @@ internal fun ExecutionOutcome.toInternal() = internal fun Tool.toInternal() = com.google.ai.client.generativeai.common.client.Tool( - functionDeclarations?.map { it.toInternal() }, + functionDeclarations?.map { it.toInternal() }.orEmpty(), codeExecution = codeExecution?.toInternal(), ) @@ -161,7 +164,7 @@ internal fun ToolConfig.toInternal() = ) internal fun com.google.ai.client.generativeai.common.UsageMetadata.toPublic(): UsageMetadata = - UsageMetadata(promptTokenCount ?: 0, candidatesTokenCount ?: 0, totalTokenCount ?: 0) + UsageMetadata(promptTokenCount, candidatesTokenCount, totalTokenCount) internal fun FunctionDeclaration.toInternal() = com.google.ai.client.generativeai.common.client.FunctionDeclaration( @@ -182,7 +185,7 @@ internal fun com.google.ai.client.generativeai.type.Schema.toInternal(): format, nullable, enum, - properties?.mapValues { it.value.toInternal() }, + properties.mapValues { it.value.toInternal() }, required, items?.toInternal(), ) @@ -190,9 +193,9 @@ internal fun com.google.ai.client.generativeai.type.Schema.toInternal(): internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candidate { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + val safetyRatings = safetyRatings.map { it.toPublic() } val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty() - val finishReason = finishReason.toPublic() + val finishReason = finishReason?.toPublic() return com.google.ai.client.generativeai.type.Candidate( this.content?.toPublic() ?: content("model") {}, @@ -249,16 +252,15 @@ internal fun SafetyRating.toPublic() = com.google.ai.client.generativeai.type.SafetyRating(category.toPublic(), probability.toPublic()) internal fun PromptFeedback.toPublic(): com.google.ai.client.generativeai.type.PromptFeedback { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + val safetyRatings = safetyRatings.map { it.toPublic() } return com.google.ai.client.generativeai.type.PromptFeedback( blockReason?.toPublic(), safetyRatings, ) } -internal fun FinishReason?.toPublic() = +internal fun FinishReason.toPublic() = when (this) { - null -> null FinishReason.MAX_TOKENS -> com.google.ai.client.generativeai.type.FinishReason.MAX_TOKENS FinishReason.RECITATION -> com.google.ai.client.generativeai.type.FinishReason.RECITATION FinishReason.SAFETY -> com.google.ai.client.generativeai.type.FinishReason.SAFETY @@ -308,9 +310,9 @@ internal fun Outcome.toPublic() = internal fun GenerateContentResponse.toPublic() = com.google.ai.client.generativeai.type.GenerateContentResponse( - candidates?.map { it.toPublic() }.orEmpty(), - promptFeedback?.toPublic(), - usageMetadata?.toPublic(), + candidates.map { it.toPublic() }, + promptFeedback?.toPublic() ?: com.google.ai.client.generativeai.type.PromptFeedback(), + usageMetadata?.toPublic() ?: UsageMetadata(0, 0, 0), ) internal fun CountTokensResponse.toPublic() = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Candidate.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Candidate.kt index 1f63a734..243fb6b3 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Candidate.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Candidate.kt @@ -44,7 +44,7 @@ class CitationMetadata( val startIndex: Int = 0, val endIndex: Int, val uri: String, - val license: String? = null, + val license: String, ) /** The reason for content finishing. */ diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Content.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Content.kt index 885df96e..e39ee18e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Content.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Content.kt @@ -25,12 +25,11 @@ import android.graphics.Bitmap * * @see content */ -class Content @JvmOverloads constructor(val role: String? = "user", val parts: List) { +class Content @JvmOverloads constructor(val role: String = "user", val parts: List) { class Builder { - var role: String? = "user" - - var parts: MutableList = arrayListOf() + var role: String = "user" + val parts: MutableList = arrayListOf() @JvmName("addPart") fun part(data: T) = apply { parts.add(data) } @@ -59,7 +58,7 @@ class Content @JvmOverloads constructor(val role: String? = "user", val parts: L * ) * ``` */ -fun content(role: String? = "user", init: Content.Builder.() -> Unit): Content { +fun content(role: String = "user", init: Content.Builder.() -> Unit): Content { val builder = Content.Builder() builder.role = role builder.init() diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 88247f3d..f3c5a2bb 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -52,11 +52,11 @@ class FunctionDeclaration( class Schema( val name: String, val description: String, - val format: String? = null, - val nullable: Boolean? = null, - val enum: List? = null, - val properties: Map>? = null, - val required: List? = null, + val format: String = "", + val nullable: Boolean = false, + val enum: List = emptyList(), + val properties: Map> = emptyMap(), + val required: List = emptyList(), val items: Schema? = null, val type: FunctionType, ) { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt index cb9ccbe1..25a9ba4f 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt @@ -16,4 +16,4 @@ package com.google.ai.client.generativeai.type -class FunctionParameter(val name: String, val description: String, val type: FunctionType) {} +class FunctionParameter(val name: String, val description: String, val type: FunctionType) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt index 79c2ef11..e27c5d40 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt @@ -31,24 +31,24 @@ package com.google.ai.client.generativeai.type */ class GenerationConfig private constructor( - val temperature: Float?, - val topK: Int?, - val topP: Float?, - val candidateCount: Int?, - val maxOutputTokens: Int?, - val stopSequences: List?, - val responseMimeType: String?, + val temperature: Float, + val topK: Int, + val topP: Float, + val candidateCount: Int, + val maxOutputTokens: Int, + val stopSequences: List = emptyList(), + val responseMimeType: String, val responseSchema: Schema<*>?, ) { class Builder { - @JvmField var temperature: Float? = null - @JvmField var topK: Int? = null - @JvmField var topP: Float? = null - @JvmField var candidateCount: Int? = null - @JvmField var maxOutputTokens: Int? = null - @JvmField var stopSequences: List? = null - @JvmField var responseMimeType: String? = null + @JvmField var temperature: Float = 0f + @JvmField var topK: Int = 0 + @JvmField var topP: Float = 0f + @JvmField var candidateCount: Int = 1 + @JvmField var maxOutputTokens: Int = 0 + @JvmField var stopSequences: List = emptyList() + @JvmField var responseMimeType: String = "" @JvmField var responseSchema: Schema<*>? = null fun build() = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/PromptFeedback.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/PromptFeedback.kt index d03026a7..bd171337 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/PromptFeedback.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/PromptFeedback.kt @@ -22,7 +22,10 @@ package com.google.ai.client.generativeai.type * @param blockReason The reason that content was blocked, if at all. * @param safetyRatings A list of relevant [SafetyRating]s. */ -class PromptFeedback(val blockReason: BlockReason?, val safetyRatings: List) +class PromptFeedback( + val blockReason: BlockReason? = null, + val safetyRatings: List = emptyList(), +) /** Describes why content was blocked. */ enum class BlockReason { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/SafetySetting.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/SafetySetting.kt index 8f4b62c9..9a4d1026 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/SafetySetting.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/SafetySetting.kt @@ -22,4 +22,4 @@ package com.google.ai.client.generativeai.type * @param harmCategory The relevant [HarmCategory]. * @param threshold The threshold form harm allowable. */ -class SafetySetting(val harmCategory: HarmCategory, val threshold: BlockThreshold) {} +class SafetySetting(val harmCategory: HarmCategory, val threshold: BlockThreshold) diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index 904130c0..0e743cdf 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai import com.google.ai.client.generativeai.common.APIController @@ -54,6 +56,7 @@ import io.mockk.mockk import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.flow import kotlinx.coroutines.runBlocking +import kotlinx.serialization.ExperimentalSerializationApi import org.json.JSONObject import org.junit.Test @@ -69,7 +72,10 @@ internal class GenerativeModelTests { mockApiController.generateContent( GenerateContentRequest_Common( "gemini-pro-1.0", - contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))), + contents = + listOf( + Content_Common(role = "user", parts = listOf(TextPart_Common("Why's the sky blue?"))) + ), ) ) } returns @@ -78,7 +84,8 @@ internal class GenerativeModelTests { Candidate_Common( content = Content_Common( - parts = listOf(TextPart_Common("I'm still learning how to answer this question")) + role = "user", + parts = listOf(TextPart_Common("I'm still learning how to answer this question")), ), finishReason = null, safetyRatings = listOf(), @@ -103,7 +110,7 @@ internal class GenerativeModelTests { startIndex = 0, endIndex = 100, uri = "http://www.example.com", - license = null, + license = "", ) ), finishReason = null, @@ -139,8 +146,11 @@ internal class GenerativeModelTests { coEvery { mockApiController.generateContent( GenerateContentRequest_Common( - "gemini-pro-1.0", - contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))), + model = "gemini-pro-1.0", + contents = + listOf( + Content_Common(role = "user", parts = listOf(TextPart_Common("Why's the sky blue?"))) + ), ) ) } throws InvalidAPIKeyException_Common("exception message") @@ -155,7 +165,10 @@ internal class GenerativeModelTests { mockApiController.generateContentStream( GenerateContentRequest_Common( "gemini-pro-1.0", - contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))), + contents = + listOf( + Content_Common(role = "user", parts = listOf(TextPart_Common("Why's the sky blue?"))) + ), ) ) } returns flow { throw UnsupportedUserLocationException_Common() } From dcb55dbb85be8f12c2596cb0083d101ca4cfe7f3 Mon Sep 17 00:00:00 2001 From: Daymon Date: Tue, 2 Jul 2024 14:43:34 -0500 Subject: [PATCH 2/4] Remove duplicate change entry --- .changes/common/breath-committee-dad-curtain.json | 1 - 1 file changed, 1 deletion(-) delete mode 100644 .changes/common/breath-committee-dad-curtain.json diff --git a/.changes/common/breath-committee-dad-curtain.json b/.changes/common/breath-committee-dad-curtain.json deleted file mode 100644 index 2c60c795..00000000 --- a/.changes/common/breath-committee-dad-curtain.json +++ /dev/null @@ -1 +0,0 @@ -{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]} From 84b2b21ef2fb1cc2f48fe5ab72b67ab2e2c58d35 Mon Sep 17 00:00:00 2001 From: Daymon Date: Wed, 3 Jul 2024 16:59:09 -0500 Subject: [PATCH 3/4] Add readme concerning common protos --- common/README.md | 71 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 common/README.md diff --git a/common/README.md b/common/README.md new file mode 100644 index 00000000..3566cb2a --- /dev/null +++ b/common/README.md @@ -0,0 +1,71 @@ +# Protos + +> [!NOTE] +> Some code and documentation may refer to "Generative AI" as "labs". These two names are used +> interchangeably, and you should just register them as different names for the same service. + +Protos are derived from a combination of the [Generative AI proto files](https://github.com/googleapis/googleapis/tree/master/google/ai/generativelanguage/v1beta) +and the [Vertex AI proto files](https://github.com/googleapis/googleapis/tree/master/google/cloud/aiplatform/v1beta1). + +The goal is to maintain a sort of overlap between the two protos- representing their "common" +definitions. + +## Organization + +Within this SDK, the protos are defined under the following three categories. + +### [Client](#client-protos) + +You can find these types [here](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt). + +These are types that can only be sent _to_ the server; meaning the server will never respond +with them. + +You can classify them as "client" only types, or "request" types. + +### [Server](#server-protos) + +You can find these types [here](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt). + +These are types that can only be sent _from_ the server; meaning the client will never create them +on their own. + +You can classify them as "server" only types, or "response" types. + +### [Shared](#shared-protos) + +You can find these types [here](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt). + +These are types that can both be sent _to_ and received _from_ the server; meaning the client can +create them, and the server can also respond with them. + +You can classify them as "shared" types, or "common" types. + +## Alignment efforts + +In aligning with the proto, you should be mindful of the following practices: + +### Field presence + +Additional Context: [Presence in Proto3 APIs](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt) + +- `optional` types should be nullable. +- non `optional` primitive types (including enums) should default to their [respective default](https://protobuf.dev/programming-guides/proto3/#default). +- `repeated` fields that are not marked with a `google.api.field_behavior` of `REQUIRED` should +default to an empty list or map. +- message fields that are marked with a `google.api.field_behavior` of `OPTIONAL` should be nullable. +- fields that are marked with a `google.api.field_behavior` of `REQUIRED` should *NOT* have a +default value, but *ONLY* when it's a [client](#client-protos) or [shared](#shared-protos) type. +- if a field is marked with both `optional` and a `google.api.field_behavior` of `REQUIRED`, then it +should be a nullable field that does _not_ default to null (ie; it needs to be explicitly set). + +### Serial names + +> [!NOTE] +> The exception to this rule is ENUM fields, which DO use `snake_case` serial names. + +While the proto is defined in `snake_case`, it will respect and respond in `camelCase` if you send +the request in `camelCase`. As such, our protos do not have `@SerialName` annotations denoting their +`snake_case` alternative. + +So all your fields should be defined in `camelCase` format. From 97ee570d6e982e1fd7423ad43ee01ae8d8ea4152 Mon Sep 17 00:00:00 2001 From: Daymon Date: Wed, 3 Jul 2024 16:59:45 -0500 Subject: [PATCH 4/4] Fix types according to field presence --- .../generativeai/common/client/Types.kt | 21 +++++++++---------- .../generativeai/common/server/Types.kt | 20 +++++++++++------- .../generativeai/common/shared/Types.kt | 9 +++++--- .../generativeai/common/UnarySnapshotTests.kt | 8 +++---- .../generativeai/internal/util/conversions.kt | 9 +++++++- .../client/generativeai/type/HarmCategory.kt | 3 +++ 6 files changed, 43 insertions(+), 27 deletions(-) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index 2763702c..256140e8 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -22,15 +22,15 @@ import kotlinx.serialization.json.JsonObject @Serializable data class GenerationConfig( - val temperature: Float = 0f, - val topP: Float = 0f, - val topK: Int = 0, - val candidateCount: Int = 0, - val maxOutputTokens: Int = 0, + val temperature: Float? = null, + val topP: Float? = null, + val topK: Int? = null, + val candidateCount: Int? = null, + val maxOutputTokens: Int? = null, val stopSequences: List = emptyList(), - val responseMimeType: String = "", - val presencePenalty: Float = 0f, - val frequencyPenalty: Float = 0f, + val responseMimeType: String? = null, + val presencePenalty: Float? = null, + val frequencyPenalty: Float? = null, val responseSchema: Schema? = null, ) @@ -41,11 +41,10 @@ data class Tool( val codeExecution: JsonObject? = null, ) -@Serializable -data class ToolConfig(val functionCallingConfig: FunctionCallingConfig = FunctionCallingConfig()) +@Serializable data class ToolConfig(val functionCallingConfig: FunctionCallingConfig? = null) @Serializable -data class FunctionCallingConfig(val mode: Mode? = null) { +data class FunctionCallingConfig(val mode: Mode = Mode.UNSPECIFIED) { @Serializable enum class Mode { @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt index db672fd3..310c047d 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt @@ -38,6 +38,8 @@ object FinishReasonSerializer : @Serializable data class PromptFeedback( + // TODO() should default to UNSPECIFIED, but that would be an unexpected change for consumers null + // checking block reason to see if their prompt was blocked val blockReason: BlockReason? = null, val safetyRatings: List = emptyList(), ) @@ -53,6 +55,8 @@ enum class BlockReason { @Serializable data class Candidate( val content: Content? = null, + // TODO() should default to UNSPECIFIED, but that would be an unexpected change for consumers + // checking if their finish reason is anything other than STOP val finishReason: FinishReason? = null, val safetyRatings: List = emptyList(), val citationMetadata: CitationMetadata? = null, @@ -66,19 +70,19 @@ data class CitationMetadata( @Serializable data class CitationSources( - val startIndex: Int = 0, - val endIndex: Int = 0, - val uri: String = "", - val license: String = "", + val startIndex: Int? = null, + val endIndex: Int? = null, + val uri: String? = null, + val license: String? = null, ) @Serializable data class SafetyRating( - val category: HarmCategory, - val probability: HarmProbability, + val category: HarmCategory = HarmCategory.UNSPECIFIED, + val probability: HarmProbability = HarmProbability.UNSPECIFIED, val blocked: Boolean = false, val probabilityScore: Float = 0f, - val severity: HarmSeverity? = null, + val severity: HarmSeverity = HarmSeverity.UNSPECIFIED, val severityScore: Float = 0f, ) @@ -96,7 +100,7 @@ data class SearchEntryPoint(val renderedContent: String = "", val sdkBlob: Strin // TODO() Has a different definition for labs vs vertex. May need to split into diff types in future // (when labs supports it) @Serializable -data class GroundingAttribution(val segment: Segment, val confidenceScore: Float = 0f) +data class GroundingAttribution(val segment: Segment? = null, val confidenceScore: Float? = null) @Serializable data class Segment(val startIndex: Int = 0, val endIndex: Int = 0) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt index ac01acb8..c3c3d5ea 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -35,6 +35,7 @@ object HarmCategorySerializer : @Serializable(HarmCategorySerializer::class) enum class HarmCategory { UNKNOWN, + @SerialName("HARM_CATEGORY_UNSPECIFIED") UNSPECIFIED, @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, @SerialName("HARM_CATEGORY_SEXUALLY_EXPLICIT") SEXUALLY_EXPLICIT, @@ -45,7 +46,7 @@ typealias Base64 = String @ExperimentalSerializationApi @Serializable -data class Content(@EncodeDefault val role: String = "", val parts: List) +data class Content(@EncodeDefault val role: String = "user", val parts: List) @Serializable(PartSerializer::class) sealed interface Part @@ -75,7 +76,7 @@ data class FunctionCall(val name: String, val args: Map = empty @Serializable data class ExecutableCode(val language: String, val code: String) -@Serializable data class CodeExecutionResult(val outcome: Outcome, val output: String) +@Serializable data class CodeExecutionResult(val outcome: Outcome, val output: String = "") @Serializable enum class Outcome { @@ -85,11 +86,13 @@ enum class Outcome { OUTCOME_DEADLINE_EXCEEDED, } +// TODO() Move SafetySettings, HarmBlockThreshold and HarmBlockMethod to `client` as they're client +// only types @Serializable data class SafetySetting( val category: HarmCategory, val threshold: HarmBlockThreshold, - val method: HarmBlockMethod? = null, + val method: HarmBlockMethod = HarmBlockMethod.UNSPECIFIED, ) @Serializable diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt index cc40256e..9dd28580 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt @@ -34,11 +34,11 @@ import com.google.ai.client.generativeai.common.util.goldenUnaryFile import com.google.ai.client.generativeai.common.util.shouldNotBeNullOrEmpty import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.collections.shouldNotBeEmpty +import io.kotest.matchers.nulls.shouldBeNull import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.should import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe -import io.kotest.matchers.string.shouldBeEmpty import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.http.HttpStatusCode import kotlin.time.Duration.Companion.seconds @@ -179,9 +179,9 @@ internal class UnarySnapshotTests { response.candidates.isEmpty() shouldBe false response.candidates.first().citationMetadata?.citationSources?.isNotEmpty() shouldBe true // Verify the values in the citation source - with(response.candidates.first().citationMetadata?.citationSources?.first()!!) { - license.shouldBeEmpty() - startIndex shouldBe 0 + response.candidates.first().citationMetadata?.citationSources?.first().let { + it.shouldNotBeNull() + it.startIndex.shouldBeNull() } } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index d38647eb..0008dcce 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -117,6 +117,7 @@ internal fun com.google.ai.client.generativeai.type.GenerationConfig.toInternal( internal fun com.google.ai.client.generativeai.type.HarmCategory.toInternal() = when (this) { + com.google.ai.client.generativeai.type.HarmCategory.UNSPECIFIED -> HarmCategory.UNSPECIFIED com.google.ai.client.generativeai.type.HarmCategory.HARASSMENT -> HarmCategory.HARASSMENT com.google.ai.client.generativeai.type.HarmCategory.HATE_SPEECH -> HarmCategory.HATE_SPEECH com.google.ai.client.generativeai.type.HarmCategory.SEXUALLY_EXPLICIT -> @@ -246,7 +247,12 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { } internal fun CitationSources.toPublic() = - CitationMetadata(startIndex = startIndex, endIndex = endIndex, uri = uri, license = license) + CitationMetadata( + startIndex = startIndex ?: 0, + endIndex = endIndex ?: 0, + uri = uri ?: "", + license = license ?: "", + ) internal fun SafetyRating.toPublic() = com.google.ai.client.generativeai.type.SafetyRating(category.toPublic(), probability.toPublic()) @@ -272,6 +278,7 @@ internal fun FinishReason.toPublic() = internal fun HarmCategory.toPublic() = when (this) { + HarmCategory.UNSPECIFIED -> com.google.ai.client.generativeai.type.HarmCategory.UNSPECIFIED HarmCategory.HARASSMENT -> com.google.ai.client.generativeai.type.HarmCategory.HARASSMENT HarmCategory.HATE_SPEECH -> com.google.ai.client.generativeai.type.HarmCategory.HATE_SPEECH HarmCategory.SEXUALLY_EXPLICIT -> diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/HarmCategory.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/HarmCategory.kt index d4e9781f..78411903 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/HarmCategory.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/HarmCategory.kt @@ -21,6 +21,9 @@ enum class HarmCategory { /** A new and not yet supported value. */ UNKNOWN, + /** A HarmCategory was not specified. */ + UNSPECIFIED, + /** Harassment content. */ HARASSMENT,