diff --git a/.changes/common/carpenter-beggar-creator-celery.json b/.changes/common/carpenter-beggar-creator-celery.json new file mode 100644 index 00000000..2c60c795 --- /dev/null +++ b/.changes/common/carpenter-beggar-creator-celery.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]} diff --git a/.changes/generativeai/breath-brush-achiever-boat.json b/.changes/generativeai/breath-brush-achiever-boat.json new file mode 100644 index 00000000..2c60c795 --- /dev/null +++ b/.changes/generativeai/breath-brush-achiever-boat.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]} diff --git a/common/README.md b/common/README.md new file mode 100644 index 00000000..3566cb2a --- /dev/null +++ b/common/README.md @@ -0,0 +1,71 @@ +# Protos + +> [!NOTE] +> Some code and documentation may refer to "Generative AI" as "labs". These two names are used +> interchangeably, and you should just register them as different names for the same service. + +Protos are derived from a combination of the [Generative AI proto files](https://github.com/googleapis/googleapis/tree/master/google/ai/generativelanguage/v1beta) +and the [Vertex AI proto files](https://github.com/googleapis/googleapis/tree/master/google/cloud/aiplatform/v1beta1). + +The goal is to maintain a sort of overlap between the two protos- representing their "common" +definitions. + +## Organization + +Within this SDK, the protos are defined under the following three categories. + +### [Client](#client-protos) + +You can find these types [here](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt). + +These are types that can only be sent _to_ the server; meaning the server will never respond +with them. + +You can classify them as "client" only types, or "request" types. + +### [Server](#server-protos) + +You can find these types [here](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt). + +These are types that can only be sent _from_ the server; meaning the client will never create them +on their own. + +You can classify them as "server" only types, or "response" types. + +### [Shared](#shared-protos) + +You can find these types [here](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt). + +These are types that can both be sent _to_ and received _from_ the server; meaning the client can +create them, and the server can also respond with them. + +You can classify them as "shared" types, or "common" types. + +## Alignment efforts + +In aligning with the proto, you should be mindful of the following practices: + +### Field presence + +Additional Context: [Presence in Proto3 APIs](https://github.com/google-gemini/generative-ai-android/blob/main/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt) + +- `optional` types should be nullable. +- non `optional` primitive types (including enums) should default to their [respective default](https://protobuf.dev/programming-guides/proto3/#default). +- `repeated` fields that are not marked with a `google.api.field_behavior` of `REQUIRED` should +default to an empty list or map. +- message fields that are marked with a `google.api.field_behavior` of `OPTIONAL` should be nullable. +- fields that are marked with a `google.api.field_behavior` of `REQUIRED` should *NOT* have a +default value, but *ONLY* when it's a [client](#client-protos) or [shared](#shared-protos) type. +- if a field is marked with both `optional` and a `google.api.field_behavior` of `REQUIRED`, then it +should be a nullable field that does _not_ default to null (ie; it needs to be explicitly set). + +### Serial names + +> [!NOTE] +> The exception to this rule is ENUM fields, which DO use `snake_case` serial names. + +While the proto is defined in `snake_case`, it will respect and respond in `camelCase` if you send +the request in `camelCase`. As such, our protos do not have `@SerialName` annotations denoting their +`snake_case` alternative. + +So all your fields should be defined in `camelCase` format. diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt index d815f44a..adaf00f8 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -235,19 +235,19 @@ private suspend fun validateResponse(response: HttpResponse) { if (message.contains("quota")) { throw QuotaExceededException(message) } - if (error.details?.any { "SERVICE_DISABLED" == it.reason } == true) { + if (error.details.any { "SERVICE_DISABLED" == it.reason }) { throw ServiceDisabledException(message) } throw ServerException(message) } private fun GenerateContentResponse.validate() = apply { - if ((candidates?.isEmpty() != false) && promptFeedback == null) { + if (candidates.isEmpty() && promptFeedback == null) { throw SerializationException("Error deserializing response, found no valid fields") } promptFeedback?.blockReason?.let { throw PromptBlockedException(this) } candidates - ?.mapNotNull { it.finishReason } - ?.firstOrNull { it != FinishReason.STOP } + .map { it.finishReason } + .firstOrNull { it != FinishReason.STOP } ?.let { throw ResponseStoppedException(this) } } diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt index 15cd25d9..05960d2f 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt @@ -96,7 +96,7 @@ class InvalidStateException(message: String, cause: Throwable? = null) : */ class ResponseStoppedException(val response: GenerateContentResponse, cause: Throwable? = null) : GoogleGenerativeAIException( - "Content generation stopped. Reason: ${response.candidates?.first()?.finishReason?.name}", + "Content generation stopped. Reason: ${response.candidates.first().finishReason?.name}", cause, ) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index b252b2e1..b1a4f36a 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.client.GenerationConfig @@ -22,45 +24,41 @@ import com.google.ai.client.generativeai.common.client.ToolConfig import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.SafetySetting import com.google.ai.client.generativeai.common.util.fullModelName -import kotlinx.serialization.SerialName +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable sealed interface Request @Serializable data class GenerateContentRequest( - val model: String? = null, + val model: String, val contents: List, - @SerialName("safety_settings") val safetySettings: List? = null, - @SerialName("generation_config") val generationConfig: GenerationConfig? = null, - val tools: List? = null, - @SerialName("tool_config") var toolConfig: ToolConfig? = null, - @SerialName("system_instruction") val systemInstruction: Content? = null, + val safetySettings: List = emptyList(), + val generationConfig: GenerationConfig? = null, + val tools: List = emptyList(), + val toolConfig: ToolConfig? = null, + val systemInstruction: Content? = null, ) : Request @Serializable data class CountTokensRequest( + val model: String, + val contents: List = emptyList(), + val tools: List = emptyList(), val generateContentRequest: GenerateContentRequest? = null, - val model: String? = null, - val contents: List? = null, - val tools: List? = null, - @SerialName("system_instruction") val systemInstruction: Content? = null, + val systemInstruction: Content? = null, ) : Request { companion object { - fun forGenAI(generateContentRequest: GenerateContentRequest) = - CountTokensRequest( - generateContentRequest = - generateContentRequest.model?.let { - generateContentRequest.copy(model = fullModelName(it)) - } ?: generateContentRequest - ) + fun forGenAI(request: GenerateContentRequest) = + CountTokensRequest(fullModelName(request.model), request.contents, emptyList(), request) - fun forVertexAI(generateContentRequest: GenerateContentRequest) = + fun forVertexAI(request: GenerateContentRequest) = CountTokensRequest( - model = generateContentRequest.model?.let { fullModelName(it) }, - contents = generateContentRequest.contents, - tools = generateContentRequest.tools, - systemInstruction = generateContentRequest.systemInstruction, + fullModelName(request.model), + request.contents, + request.tools, + null, + request.systemInstruction, ) } } diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt index 8b5f1ba7..4194ae14 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt @@ -25,20 +25,20 @@ sealed interface Response @Serializable data class GenerateContentResponse( - val candidates: List? = null, + val candidates: List = emptyList(), val promptFeedback: PromptFeedback? = null, val usageMetadata: UsageMetadata? = null, ) : Response @Serializable -data class CountTokensResponse(val totalTokens: Int, val totalBillableCharacters: Int? = null) : +data class CountTokensResponse(val totalTokens: Int = 0, val totalBillableCharacters: Int = 0) : Response @Serializable data class GRpcErrorResponse(val error: GRpcError) : Response @Serializable data class UsageMetadata( - val promptTokenCount: Int? = null, - val candidatesTokenCount: Int? = null, - val totalTokenCount: Int? = null, + val promptTokenCount: Int = 0, + val candidatesTokenCount: Int = 0, + val totalTokenCount: Int = 0, ) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index e32b9196..256140e8 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -22,32 +22,29 @@ import kotlinx.serialization.json.JsonObject @Serializable data class GenerationConfig( - val temperature: Float?, - @SerialName("top_p") val topP: Float?, - @SerialName("top_k") val topK: Int?, - @SerialName("candidate_count") val candidateCount: Int?, - @SerialName("max_output_tokens") val maxOutputTokens: Int?, - @SerialName("stop_sequences") val stopSequences: List?, - @SerialName("response_mime_type") val responseMimeType: String? = null, - @SerialName("presence_penalty") val presencePenalty: Float? = null, - @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, - @SerialName("response_schema") val responseSchema: Schema? = null, + val temperature: Float? = null, + val topP: Float? = null, + val topK: Int? = null, + val candidateCount: Int? = null, + val maxOutputTokens: Int? = null, + val stopSequences: List = emptyList(), + val responseMimeType: String? = null, + val presencePenalty: Float? = null, + val frequencyPenalty: Float? = null, + val responseSchema: Schema? = null, ) @Serializable data class Tool( - val functionDeclarations: List? = null, + val functionDeclarations: List = emptyList(), // This is a json object because it is not possible to make a data class with no parameters. val codeExecution: JsonObject? = null, ) -@Serializable -data class ToolConfig( - @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig -) +@Serializable data class ToolConfig(val functionCallingConfig: FunctionCallingConfig? = null) @Serializable -data class FunctionCallingConfig(val mode: Mode) { +data class FunctionCallingConfig(val mode: Mode = Mode.UNSPECIFIED) { @Serializable enum class Mode { @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, @@ -58,16 +55,20 @@ data class FunctionCallingConfig(val mode: Mode) { } @Serializable -data class FunctionDeclaration(val name: String, val description: String, val parameters: Schema) +data class FunctionDeclaration( + val name: String, + val description: String, + val parameters: Schema? = null, +) @Serializable data class Schema( val type: String, - val description: String? = null, - val format: String? = null, - val nullable: Boolean? = false, - val enum: List? = null, - val properties: Map? = null, - val required: List? = null, + val description: String = "", + val format: String = "", + val nullable: Boolean = false, + val enum: List = emptyList(), + val properties: Map = emptyMap(), + val required: List = emptyList(), val items: Schema? = null, ) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt index 16b25a78..310c047d 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common.server import com.google.ai.client.generativeai.common.shared.Content @@ -36,8 +38,10 @@ object FinishReasonSerializer : @Serializable data class PromptFeedback( + // TODO() should default to UNSPECIFIED, but that would be an unexpected change for consumers null + // checking block reason to see if their prompt was blocked val blockReason: BlockReason? = null, - val safetyRatings: List? = null, + val safetyRatings: List = emptyList(), ) @Serializable(BlockReasonSerializer::class) @@ -51,60 +55,54 @@ enum class BlockReason { @Serializable data class Candidate( val content: Content? = null, + // TODO() should default to UNSPECIFIED, but that would be an unexpected change for consumers + // checking if their finish reason is anything other than STOP val finishReason: FinishReason? = null, - val safetyRatings: List? = null, + val safetyRatings: List = emptyList(), val citationMetadata: CitationMetadata? = null, val groundingMetadata: GroundingMetadata? = null, ) @Serializable -data class CitationMetadata -@OptIn(ExperimentalSerializationApi::class) -constructor(@JsonNames("citations") val citationSources: List) +data class CitationMetadata( + @JsonNames("citations") val citationSources: List = emptyList() +) @Serializable data class CitationSources( - val startIndex: Int = 0, - val endIndex: Int, - val uri: String, + val startIndex: Int? = null, + val endIndex: Int? = null, + val uri: String? = null, val license: String? = null, ) @Serializable data class SafetyRating( - val category: HarmCategory, - val probability: HarmProbability, - val blocked: Boolean? = null, // TODO(): any reason not to default to false? - val probabilityScore: Float? = null, - val severity: HarmSeverity? = null, - val severityScore: Float? = null, + val category: HarmCategory = HarmCategory.UNSPECIFIED, + val probability: HarmProbability = HarmProbability.UNSPECIFIED, + val blocked: Boolean = false, + val probabilityScore: Float = 0f, + val severity: HarmSeverity = HarmSeverity.UNSPECIFIED, + val severityScore: Float = 0f, ) @Serializable data class GroundingMetadata( - @SerialName("web_search_queries") val webSearchQueries: List?, - @SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?, - @SerialName("retrieval_queries") val retrievalQueries: List?, - @SerialName("grounding_attribution") val groundingAttribution: List?, + val webSearchQueries: List = emptyList(), + val searchEntryPoint: SearchEntryPoint? = null, + val retrievalQueries: List = emptyList(), + val groundingAttribution: List = emptyList(), ) @Serializable -data class SearchEntryPoint( - @SerialName("rendered_content") val renderedContent: String?, - @SerialName("sdk_blob") val sdkBlob: String?, -) +data class SearchEntryPoint(val renderedContent: String = "", val sdkBlob: String = "") +// TODO() Has a different definition for labs vs vertex. May need to split into diff types in future +// (when labs supports it) @Serializable -data class GroundingAttribution( - val segment: Segment, - @SerialName("confidence_score") val confidenceScore: Float?, -) +data class GroundingAttribution(val segment: Segment? = null, val confidenceScore: Float? = null) -@Serializable -data class Segment( - @SerialName("start_index") val startIndex: Int, - @SerialName("end_index") val endIndex: Int, -) +@Serializable data class Segment(val startIndex: Int = 0, val endIndex: Int = 0) @Serializable(HarmProbabilitySerializer::class) enum class HarmProbability { diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt index 7dc31858..c3c3d5ea 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -35,6 +35,7 @@ object HarmCategorySerializer : @Serializable(HarmCategorySerializer::class) enum class HarmCategory { UNKNOWN, + @SerialName("HARM_CATEGORY_UNSPECIFIED") UNSPECIFIED, @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, @SerialName("HARM_CATEGORY_SEXUALLY_EXPLICIT") SEXUALLY_EXPLICIT, @@ -45,13 +46,13 @@ typealias Base64 = String @ExperimentalSerializationApi @Serializable -data class Content(@EncodeDefault val role: String? = "user", val parts: List) +data class Content(@EncodeDefault val role: String = "user", val parts: List) @Serializable(PartSerializer::class) sealed interface Part -@Serializable data class TextPart(val text: String) : Part +@Serializable data class TextPart(val text: String = "") : Part -@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part +@Serializable data class BlobPart(val inlineData: Blob) : Part @Serializable data class FunctionCallPart(val functionCall: FunctionCall) : Part @@ -64,21 +65,18 @@ data class CodeExecutionResultPart(val codeExecutionResult: CodeExecutionResult) @Serializable data class FunctionResponse(val name: String, val response: JsonObject) -@Serializable data class FunctionCall(val name: String, val args: Map) +@Serializable +data class FunctionCall(val name: String, val args: Map = emptyMap()) -@Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part +@Serializable data class FileDataPart(val fileData: FileData) : Part -@Serializable -data class FileData( - @SerialName("mime_type") val mimeType: String, - @SerialName("file_uri") val fileUri: String, -) +@Serializable data class FileData(val mimeType: String, val fileUri: String) -@Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64) +@Serializable data class Blob(val mimeType: String, val data: Base64) @Serializable data class ExecutableCode(val language: String, val code: String) -@Serializable data class CodeExecutionResult(val outcome: Outcome, val output: String) +@Serializable data class CodeExecutionResult(val outcome: Outcome, val output: String = "") @Serializable enum class Outcome { @@ -88,11 +86,13 @@ enum class Outcome { OUTCOME_DEADLINE_EXCEEDED, } +// TODO() Move SafetySettings, HarmBlockThreshold and HarmBlockMethod to `client` as they're client +// only types @Serializable data class SafetySetting( val category: HarmCategory, val threshold: HarmBlockThreshold, - val method: HarmBlockMethod? = null, + val method: HarmBlockMethod = HarmBlockMethod.UNSPECIFIED, ) @Serializable diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index aa3c402b..d748ff51 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class, ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.client.FunctionCallingConfig @@ -43,6 +45,7 @@ import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.delay import kotlinx.coroutines.withTimeout +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString import kotlinx.serialization.json.JsonObject import org.junit.Test @@ -64,7 +67,7 @@ internal class APIControllerTests { withTimeout(testTimeout) { responses.collect { - it.candidates?.isEmpty() shouldBe false + it.candidates.isEmpty() shouldBe false channel.close() } } @@ -101,7 +104,7 @@ internal class RequestFormatTests { withTimeout(5.seconds) { controller.generateContentStream(textGenerateContentRequest("cats")).collect { - it.candidates?.isEmpty() shouldBe false + it.candidates.isEmpty() shouldBe false channel.close() } } @@ -128,7 +131,7 @@ internal class RequestFormatTests { withTimeout(5.seconds) { controller.generateContentStream(textGenerateContentRequest("cats")).collect { - it.candidates?.isEmpty() shouldBe false + it.candidates.isEmpty() shouldBe false channel.close() } } @@ -193,8 +196,7 @@ internal class RequestFormatTests { } val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text - - requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode" + requestBodyAsText shouldContainJsonKey "toolConfig.functionCallingConfig.mode" } @Test @@ -320,7 +322,7 @@ internal class ModelNamingTests(private val modelName: String, private val actua withTimeout(5.seconds) { controller.generateContentStream(textGenerateContentRequest("cats")).collect { - it.candidates?.isEmpty() shouldBe false + it.candidates.isEmpty() shouldBe false channel.close() } } @@ -349,4 +351,4 @@ fun textGenerateContentRequest(prompt: String) = ) fun textCountTokenRequest(prompt: String) = - CountTokensRequest(generateContentRequest = textGenerateContentRequest(prompt)) + CountTokensRequest("", generateContentRequest = textGenerateContentRequest(prompt)) diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt index 7f151320..b56bd098 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.server.BlockReason @@ -31,6 +33,7 @@ import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.toList import kotlinx.coroutines.withTimeout +import kotlinx.serialization.ExperimentalSerializationApi import org.junit.Test internal class StreamingSnapshotTests { @@ -44,9 +47,9 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val responseList = responses.toList() responseList.isEmpty() shouldBe false - responseList.first().candidates?.first()?.finishReason shouldBe FinishReason.STOP - responseList.first().candidates?.first()?.content?.parts?.isEmpty() shouldBe false - responseList.first().candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + responseList.first().candidates.first().finishReason shouldBe FinishReason.STOP + responseList.first().candidates.first().content?.parts?.isEmpty() shouldBe false + responseList.first().candidates.first().safetyRatings.isEmpty() shouldBe false } } @@ -59,9 +62,9 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false responseList.forEach { - it.candidates?.first()?.finishReason shouldBe FinishReason.STOP - it.candidates?.first()?.content?.parts?.isEmpty() shouldBe false - it.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + it.candidates.first().finishReason shouldBe FinishReason.STOP + it.candidates.first().content?.parts?.isEmpty() shouldBe false + it.candidates.first().safetyRatings.isEmpty() shouldBe false } } } @@ -73,9 +76,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { responses.first { - it.candidates?.any { - it.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } ?: false - } ?: false + it.candidates.any { it.safetyRatings.any { it.category == HarmCategory.UNKNOWN } } } } } @@ -89,7 +90,7 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false - val part = responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart + val part = responseList.first().candidates.first().content?.parts?.first() as? TextPart part.shouldNotBeNull() part.text shouldContain "\"" } @@ -129,7 +130,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY + exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY } } @@ -141,8 +142,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val responseList = responses.toList() responseList.any { - it.candidates?.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } - ?: false + it.candidates.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } } shouldBe true } } @@ -155,8 +155,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val responseList = responses.toList() responseList.any { - it.candidates?.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } - ?: false + it.candidates.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } } shouldBe true } } @@ -168,7 +167,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.RECITATION + exception.response.candidates.first().finishReason shouldBe FinishReason.RECITATION } } diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt index dcb37a93..c00d053a 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.server.BlockReason @@ -32,6 +34,7 @@ import com.google.ai.client.generativeai.common.util.goldenUnaryFile import com.google.ai.client.generativeai.common.util.shouldNotBeNullOrEmpty import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.collections.shouldNotBeEmpty +import io.kotest.matchers.nulls.shouldBeNull import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.should import io.kotest.matchers.shouldBe @@ -40,6 +43,7 @@ import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.http.HttpStatusCode import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.withTimeout +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable import org.junit.Test @@ -54,10 +58,10 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP - response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.candidates.first().content?.parts?.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false } } @@ -67,10 +71,10 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP - response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.candidates.first().content?.parts?.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false } } @@ -80,9 +84,13 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) +<<<<<<< daymon-align-proto + response.candidates.first { it.safetyRatings.any { it.category == HarmCategory.UNKNOWN } } +======= response.candidates?.isNullOrEmpty() shouldBe false val candidate = response.candidates?.first() candidate?.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } shouldBe true +>>>>>>> main } } @@ -92,17 +100,16 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.all { + response.candidates.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false + response.candidates.first().safetyRatings.all { it.probability == HarmProbability.NEGLIGIBLE } shouldBe true - response.candidates?.first()?.safetyRatings?.all { it.probabilityScore != null } shouldBe - true - response.candidates?.first()?.safetyRatings?.all { + response.candidates.first().safetyRatings.all { it.probabilityScore != 0f } shouldBe true + response.candidates.first().safetyRatings.all { it.severity == HarmSeverity.NEGLIGIBLE } shouldBe true - response.candidates?.first()?.safetyRatings?.all { it.severityScore != null } shouldBe true + response.candidates.first().safetyRatings.all { it.severityScore != 0f } shouldBe true } } @@ -154,7 +161,7 @@ internal class UnarySnapshotTests { shouldThrow { apiController.generateContent(textGenerateContentRequest("prompt")) } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY + exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY } } @@ -164,8 +171,8 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true + response.candidates.isEmpty() shouldBe false + response.candidates.first().citationMetadata?.citationSources?.isNotEmpty() shouldBe true } } @@ -175,12 +182,12 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true + response.candidates.isEmpty() shouldBe false + response.candidates.first().citationMetadata?.citationSources?.isNotEmpty() shouldBe true // Verify the values in the citation source - with(response.candidates?.first()?.citationMetadata?.citationSources?.first()!!) { - license shouldBe null - startIndex shouldBe 0 + response.candidates.first().citationMetadata?.citationSources?.first().let { + it.shouldNotBeNull() + it.startIndex.shouldBeNull() } } } @@ -191,8 +198,8 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.totalTokenCount shouldBe 363 } @@ -204,11 +211,11 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.promptTokenCount shouldBe 6 - response.usageMetadata?.totalTokenCount shouldBe null + response.usageMetadata?.totalTokenCount shouldBe 0 } } @@ -218,8 +225,8 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true + response.candidates.isEmpty() shouldBe false + response.candidates.first().citationMetadata?.citationSources?.isNotEmpty() shouldBe true } } @@ -229,10 +236,8 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - response.candidates?.isEmpty() shouldBe false - with( - response.candidates?.first()?.content?.parts?.first()?.shouldBeInstanceOf() - ) { + response.candidates.isEmpty() shouldBe false + with(response.candidates.first().content?.parts?.first()?.shouldBeInstanceOf()) { shouldNotBeNull() JSON.decodeFromString>(text).shouldNotBeEmpty() } @@ -314,7 +319,7 @@ internal class UnarySnapshotTests { goldenUnaryFile("success-function-call-null.json") { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart) + val callPart = (response.candidates.first().content!!.parts.first() as FunctionCallPart) callPart.functionCall.args["season"] shouldBe null } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index c68b49a5..c16c7875 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -195,9 +195,9 @@ internal constructor( GenerateContentRequest( modelName, prompt.map { it.toInternal() }, - safetySettings?.map { it.toInternal() }, + safetySettings?.map { it.toInternal() }.orEmpty(), generationConfig?.toInternal(), - tools?.map { it.toInternal() }, + tools?.map { it.toInternal() }.orEmpty(), toolConfig?.toInternal(), systemInstruction?.toInternal(), ) @@ -206,12 +206,12 @@ internal constructor( CountTokensRequest.forGenAI(constructRequest(*prompt)) private fun GenerateContentResponse.validate() = apply { - if (candidates.isEmpty() && promptFeedback == null) { + if (candidates.isEmpty() && promptFeedback?.blockReason == null) { throw SerializationException("Error deserializing response, found no valid fields") } promptFeedback?.blockReason?.let { throw PromptBlockedException(this) } candidates - .mapNotNull { it.finishReason } + .map { it.finishReason } .firstOrNull { it != FinishReason.STOP } ?.let { throw ResponseStoppedException(this) } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index c886a6dc..0008dcce 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai.internal.util import android.graphics.Bitmap @@ -62,6 +64,7 @@ import com.google.ai.client.generativeai.type.ToolConfig import com.google.ai.client.generativeai.type.UsageMetadata import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject import org.json.JSONObject @@ -114,6 +117,7 @@ internal fun com.google.ai.client.generativeai.type.GenerationConfig.toInternal( internal fun com.google.ai.client.generativeai.type.HarmCategory.toInternal() = when (this) { + com.google.ai.client.generativeai.type.HarmCategory.UNSPECIFIED -> HarmCategory.UNSPECIFIED com.google.ai.client.generativeai.type.HarmCategory.HARASSMENT -> HarmCategory.HARASSMENT com.google.ai.client.generativeai.type.HarmCategory.HATE_SPEECH -> HarmCategory.HATE_SPEECH com.google.ai.client.generativeai.type.HarmCategory.SEXUALLY_EXPLICIT -> @@ -142,7 +146,7 @@ internal fun ExecutionOutcome.toInternal() = internal fun Tool.toInternal() = com.google.ai.client.generativeai.common.client.Tool( - functionDeclarations?.map { it.toInternal() }, + functionDeclarations?.map { it.toInternal() }.orEmpty(), codeExecution = codeExecution?.toInternal(), ) @@ -161,7 +165,7 @@ internal fun ToolConfig.toInternal() = ) internal fun com.google.ai.client.generativeai.common.UsageMetadata.toPublic(): UsageMetadata = - UsageMetadata(promptTokenCount ?: 0, candidatesTokenCount ?: 0, totalTokenCount ?: 0) + UsageMetadata(promptTokenCount, candidatesTokenCount, totalTokenCount) internal fun FunctionDeclaration.toInternal() = com.google.ai.client.generativeai.common.client.FunctionDeclaration( @@ -182,7 +186,7 @@ internal fun com.google.ai.client.generativeai.type.Schema.toInternal(): format, nullable, enum, - properties?.mapValues { it.value.toInternal() }, + properties.mapValues { it.value.toInternal() }, required, items?.toInternal(), ) @@ -190,9 +194,9 @@ internal fun com.google.ai.client.generativeai.type.Schema.toInternal(): internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candidate { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + val safetyRatings = safetyRatings.map { it.toPublic() } val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty() - val finishReason = finishReason.toPublic() + val finishReason = finishReason?.toPublic() return com.google.ai.client.generativeai.type.Candidate( this.content?.toPublic() ?: content("model") {}, @@ -243,22 +247,26 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { } internal fun CitationSources.toPublic() = - CitationMetadata(startIndex = startIndex, endIndex = endIndex, uri = uri, license = license) + CitationMetadata( + startIndex = startIndex ?: 0, + endIndex = endIndex ?: 0, + uri = uri ?: "", + license = license ?: "", + ) internal fun SafetyRating.toPublic() = com.google.ai.client.generativeai.type.SafetyRating(category.toPublic(), probability.toPublic()) internal fun PromptFeedback.toPublic(): com.google.ai.client.generativeai.type.PromptFeedback { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + val safetyRatings = safetyRatings.map { it.toPublic() } return com.google.ai.client.generativeai.type.PromptFeedback( blockReason?.toPublic(), safetyRatings, ) } -internal fun FinishReason?.toPublic() = +internal fun FinishReason.toPublic() = when (this) { - null -> null FinishReason.MAX_TOKENS -> com.google.ai.client.generativeai.type.FinishReason.MAX_TOKENS FinishReason.RECITATION -> com.google.ai.client.generativeai.type.FinishReason.RECITATION FinishReason.SAFETY -> com.google.ai.client.generativeai.type.FinishReason.SAFETY @@ -270,6 +278,7 @@ internal fun FinishReason?.toPublic() = internal fun HarmCategory.toPublic() = when (this) { + HarmCategory.UNSPECIFIED -> com.google.ai.client.generativeai.type.HarmCategory.UNSPECIFIED HarmCategory.HARASSMENT -> com.google.ai.client.generativeai.type.HarmCategory.HARASSMENT HarmCategory.HATE_SPEECH -> com.google.ai.client.generativeai.type.HarmCategory.HATE_SPEECH HarmCategory.SEXUALLY_EXPLICIT -> @@ -308,9 +317,9 @@ internal fun Outcome.toPublic() = internal fun GenerateContentResponse.toPublic() = com.google.ai.client.generativeai.type.GenerateContentResponse( - candidates?.map { it.toPublic() }.orEmpty(), - promptFeedback?.toPublic(), - usageMetadata?.toPublic(), + candidates.map { it.toPublic() }, + promptFeedback?.toPublic() ?: com.google.ai.client.generativeai.type.PromptFeedback(), + usageMetadata?.toPublic() ?: UsageMetadata(0, 0, 0), ) internal fun CountTokensResponse.toPublic() = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Candidate.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Candidate.kt index 1f63a734..243fb6b3 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Candidate.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Candidate.kt @@ -44,7 +44,7 @@ class CitationMetadata( val startIndex: Int = 0, val endIndex: Int, val uri: String, - val license: String? = null, + val license: String, ) /** The reason for content finishing. */ diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Content.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Content.kt index 885df96e..e39ee18e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Content.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Content.kt @@ -25,12 +25,11 @@ import android.graphics.Bitmap * * @see content */ -class Content @JvmOverloads constructor(val role: String? = "user", val parts: List) { +class Content @JvmOverloads constructor(val role: String = "user", val parts: List) { class Builder { - var role: String? = "user" - - var parts: MutableList = arrayListOf() + var role: String = "user" + val parts: MutableList = arrayListOf() @JvmName("addPart") fun part(data: T) = apply { parts.add(data) } @@ -59,7 +58,7 @@ class Content @JvmOverloads constructor(val role: String? = "user", val parts: L * ) * ``` */ -fun content(role: String? = "user", init: Content.Builder.() -> Unit): Content { +fun content(role: String = "user", init: Content.Builder.() -> Unit): Content { val builder = Content.Builder() builder.role = role builder.init() diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 88247f3d..f3c5a2bb 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -52,11 +52,11 @@ class FunctionDeclaration( class Schema( val name: String, val description: String, - val format: String? = null, - val nullable: Boolean? = null, - val enum: List? = null, - val properties: Map>? = null, - val required: List? = null, + val format: String = "", + val nullable: Boolean = false, + val enum: List = emptyList(), + val properties: Map> = emptyMap(), + val required: List = emptyList(), val items: Schema? = null, val type: FunctionType, ) { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt index cb9ccbe1..25a9ba4f 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt @@ -16,4 +16,4 @@ package com.google.ai.client.generativeai.type -class FunctionParameter(val name: String, val description: String, val type: FunctionType) {} +class FunctionParameter(val name: String, val description: String, val type: FunctionType) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt index 79c2ef11..e27c5d40 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt @@ -31,24 +31,24 @@ package com.google.ai.client.generativeai.type */ class GenerationConfig private constructor( - val temperature: Float?, - val topK: Int?, - val topP: Float?, - val candidateCount: Int?, - val maxOutputTokens: Int?, - val stopSequences: List?, - val responseMimeType: String?, + val temperature: Float, + val topK: Int, + val topP: Float, + val candidateCount: Int, + val maxOutputTokens: Int, + val stopSequences: List = emptyList(), + val responseMimeType: String, val responseSchema: Schema<*>?, ) { class Builder { - @JvmField var temperature: Float? = null - @JvmField var topK: Int? = null - @JvmField var topP: Float? = null - @JvmField var candidateCount: Int? = null - @JvmField var maxOutputTokens: Int? = null - @JvmField var stopSequences: List? = null - @JvmField var responseMimeType: String? = null + @JvmField var temperature: Float = 0f + @JvmField var topK: Int = 0 + @JvmField var topP: Float = 0f + @JvmField var candidateCount: Int = 1 + @JvmField var maxOutputTokens: Int = 0 + @JvmField var stopSequences: List = emptyList() + @JvmField var responseMimeType: String = "" @JvmField var responseSchema: Schema<*>? = null fun build() = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/HarmCategory.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/HarmCategory.kt index d4e9781f..78411903 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/HarmCategory.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/HarmCategory.kt @@ -21,6 +21,9 @@ enum class HarmCategory { /** A new and not yet supported value. */ UNKNOWN, + /** A HarmCategory was not specified. */ + UNSPECIFIED, + /** Harassment content. */ HARASSMENT, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/PromptFeedback.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/PromptFeedback.kt index d03026a7..bd171337 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/PromptFeedback.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/PromptFeedback.kt @@ -22,7 +22,10 @@ package com.google.ai.client.generativeai.type * @param blockReason The reason that content was blocked, if at all. * @param safetyRatings A list of relevant [SafetyRating]s. */ -class PromptFeedback(val blockReason: BlockReason?, val safetyRatings: List) +class PromptFeedback( + val blockReason: BlockReason? = null, + val safetyRatings: List = emptyList(), +) /** Describes why content was blocked. */ enum class BlockReason { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/SafetySetting.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/SafetySetting.kt index 8f4b62c9..9a4d1026 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/SafetySetting.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/SafetySetting.kt @@ -22,4 +22,4 @@ package com.google.ai.client.generativeai.type * @param harmCategory The relevant [HarmCategory]. * @param threshold The threshold form harm allowable. */ -class SafetySetting(val harmCategory: HarmCategory, val threshold: BlockThreshold) {} +class SafetySetting(val harmCategory: HarmCategory, val threshold: BlockThreshold) diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index 904130c0..0e743cdf 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -14,6 +14,8 @@ * limitations under the License. */ +@file:OptIn(ExperimentalSerializationApi::class) + package com.google.ai.client.generativeai import com.google.ai.client.generativeai.common.APIController @@ -54,6 +56,7 @@ import io.mockk.mockk import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.flow import kotlinx.coroutines.runBlocking +import kotlinx.serialization.ExperimentalSerializationApi import org.json.JSONObject import org.junit.Test @@ -69,7 +72,10 @@ internal class GenerativeModelTests { mockApiController.generateContent( GenerateContentRequest_Common( "gemini-pro-1.0", - contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))), + contents = + listOf( + Content_Common(role = "user", parts = listOf(TextPart_Common("Why's the sky blue?"))) + ), ) ) } returns @@ -78,7 +84,8 @@ internal class GenerativeModelTests { Candidate_Common( content = Content_Common( - parts = listOf(TextPart_Common("I'm still learning how to answer this question")) + role = "user", + parts = listOf(TextPart_Common("I'm still learning how to answer this question")), ), finishReason = null, safetyRatings = listOf(), @@ -103,7 +110,7 @@ internal class GenerativeModelTests { startIndex = 0, endIndex = 100, uri = "http://www.example.com", - license = null, + license = "", ) ), finishReason = null, @@ -139,8 +146,11 @@ internal class GenerativeModelTests { coEvery { mockApiController.generateContent( GenerateContentRequest_Common( - "gemini-pro-1.0", - contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))), + model = "gemini-pro-1.0", + contents = + listOf( + Content_Common(role = "user", parts = listOf(TextPart_Common("Why's the sky blue?"))) + ), ) ) } throws InvalidAPIKeyException_Common("exception message") @@ -155,7 +165,10 @@ internal class GenerativeModelTests { mockApiController.generateContentStream( GenerateContentRequest_Common( "gemini-pro-1.0", - contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))), + contents = + listOf( + Content_Common(role = "user", parts = listOf(TextPart_Common("Why's the sky blue?"))) + ), ) ) } returns flow { throw UnsupportedUserLocationException_Common() }