diff --git a/.changeset/pretty-rabbits-smell.md b/.changeset/pretty-rabbits-smell.md new file mode 100644 index 000000000000..f2e54085e3ec --- /dev/null +++ b/.changeset/pretty-rabbits-smell.md @@ -0,0 +1,6 @@ +--- +'@ai-sdk/google-vertex': patch +'@ai-sdk/google': patch +--- + +feat (provider/google): Include safety ratings response detail. diff --git a/content/providers/01-ai-sdk-providers/11-google-vertex.mdx b/content/providers/01-ai-sdk-providers/11-google-vertex.mdx index 91a7a33f6ec7..499c2084fd46 100644 --- a/content/providers/01-ai-sdk-providers/11-google-vertex.mdx +++ b/content/providers/01-ai-sdk-providers/11-google-vertex.mdx @@ -296,7 +296,7 @@ The following optional settings are available for Google Vertex models: You can use Google Vertex language models to generate text with the `generateText` function: -```ts highlight="1,5" +```ts highlight="1,4" import { vertex } from '@ai-sdk/google-vertex'; import { generateText } from 'ai'; @@ -352,7 +352,7 @@ With [search grounding](https://cloud.google.com/vertex-ai/generative-ai/docs/gr the model has access to the latest information using Google search. Search grounding can be used to provide answers around current events: -```ts highlight="6,14-17" +```ts highlight="7,14-20" import { vertex } from '@ai-sdk/google-vertex'; import { GoogleGenerativeAIProviderMetadata } from '@ai-sdk/google'; import { generateText } from 'ai'; @@ -372,6 +372,7 @@ const metadata = experimental_providerMetadata?.google as | GoogleGenerativeAIProviderMetadata | undefined; const groundingMetadata = metadata?.groundingMetadata; +const safetyRatings = metadata?.safetyRatings; ``` The grounding metadata includes detailed information about how search results were used to ground the model's response. Here are the available fields: @@ -396,7 +397,7 @@ The grounding metadata includes detailed information about how search results we - **`groundingChunkIndices`**: References to supporting search result chunks - **`confidenceScores`**: Confidence scores (0-1) for each supporting chunk -Example response: +Example response excerpt: ```json { @@ -420,6 +421,46 @@ Example response: } ``` +The safety ratings provide insight into how the model's response was grounded to search results. See [Google Vertex AI documentation on configuring safety filters](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters). + +Example response excerpt: + +```json +{ + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.11027937, + "severity": "HARM_SEVERITY_LOW", + "severityScore": 0.28487435 + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "HIGH", + "blocked": true, + "probabilityScore": 0.95422274, + "severity": "HARM_SEVERITY_MEDIUM", + "severityScore": 0.43398145 + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.11085559, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.19027223 + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.22901751, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.09089675 + } + ] +} +``` + For more details, see the [Google Vertex AI documentation on grounding with Google Search](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/ground-gemini#ground-to-search). ### Troubleshooting diff --git a/examples/ai-core/src/e2e/google-vertex.test.ts b/examples/ai-core/src/e2e/google-vertex.test.ts index 6f7a17396472..2ba27c6c3375 100644 --- a/examples/ai-core/src/e2e/google-vertex.test.ts +++ b/examples/ai-core/src/e2e/google-vertex.test.ts @@ -116,6 +116,20 @@ describe.each(Object.values(RUNTIME_VARIANTS))( expect(Array.isArray(support?.groundingChunkIndices)).toBe(true); expect(Array.isArray(support?.confidenceScores)).toBe(true); + // Verify safety ratings + const safetyRatings = metadata?.safetyRatings; + expect(Array.isArray(safetyRatings)).toBe(true); + expect(safetyRatings?.length).toBeGreaterThan(0); + + // Verify each safety rating has required properties + safetyRatings?.forEach(rating => { + expect(rating.category).toBeDefined(); + expect(rating.probability).toBeDefined(); + expect(typeof rating.probabilityScore).toBe('number'); + expect(rating.severity).toBeDefined(); + expect(typeof rating.severityScore).toBe('number'); + }); + // Basic response checks expect(result.text).toBeTruthy(); expect(result.usage?.totalTokens).toBeGreaterThan(0); @@ -261,6 +275,20 @@ describe.each(Object.values(RUNTIME_VARIANTS))( expect(Array.isArray(support?.groundingChunkIndices)).toBe(true); expect(Array.isArray(support?.confidenceScores)).toBe(true); + // Verify safety ratings + const safetyRatings = metadata?.safetyRatings; + expect(Array.isArray(safetyRatings)).toBe(true); + expect(safetyRatings?.length).toBeGreaterThan(0); + + // Verify each safety rating has required properties + safetyRatings?.forEach(rating => { + expect(rating.category).toBeDefined(); + expect(rating.probability).toBeDefined(); + expect(typeof rating.probabilityScore).toBe('number'); + expect(rating.severity).toBeDefined(); + expect(typeof rating.severityScore).toBe('number'); + }); + // Basic response checks expect(chunks.join('')).toBeTruthy(); expect((await result.usage)?.totalTokens).toBeGreaterThan(0); diff --git a/packages/google/src/google-generative-ai-language-model.test.ts b/packages/google/src/google-generative-ai-language-model.test.ts index 7857d3fb3d10..fbe9794ed4e5 100644 --- a/packages/google/src/google-generative-ai-language-model.test.ts +++ b/packages/google/src/google-generative-ai-language-model.test.ts @@ -632,6 +632,57 @@ describe('doGenerate', () => { }), ); }); + + it( + 'should expose safety ratings in provider metadata', + withTestServer( + { + url: 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent', + type: 'json-value', + content: { + candidates: [ + { + content: { + parts: [{ text: 'test response' }], + role: 'model', + }, + finishReason: 'STOP', + index: 0, + safetyRatings: [ + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.1, + severity: 'LOW', + severityScore: 0.2, + blocked: false, + }, + ], + }, + ], + promptFeedback: { safetyRatings: SAFETY_RATINGS }, + }, + }, + async () => { + const { providerMetadata } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(providerMetadata?.google.safetyRatings).toStrictEqual([ + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.1, + severity: 'LOW', + severityScore: 0.2, + blocked: false, + }, + ]); + }, + ), + ); }); describe('doStream', () => { @@ -678,6 +729,24 @@ describe('doStream', () => { providerMetadata: { google: { groundingMetadata: null, + safetyRatings: [ + { + category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + probability: 'NEGLIGIBLE', + }, + { + category: 'HARM_CATEGORY_HATE_SPEECH', + probability: 'NEGLIGIBLE', + }, + { + category: 'HARM_CATEGORY_HARASSMENT', + probability: 'NEGLIGIBLE', + }, + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + probability: 'NEGLIGIBLE', + }, + ], }, }, usage: { promptTokens: 294, completionTokens: 233 }, @@ -834,6 +903,24 @@ describe('doStream', () => { providerMetadata: { google: { groundingMetadata: null, + safetyRatings: [ + { + category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + probability: 'NEGLIGIBLE', + }, + { + category: 'HARM_CATEGORY_HATE_SPEECH', + probability: 'NEGLIGIBLE', + }, + { + category: 'HARM_CATEGORY_HARASSMENT', + probability: 'NEGLIGIBLE', + }, + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + probability: 'NEGLIGIBLE', + }, + ], }, }, usage: { promptTokens: 294, completionTokens: 233 }, @@ -842,4 +929,44 @@ describe('doStream', () => { }, ), ); + + it( + 'should expose safety ratings in provider metadata on finish', + withTestServer( + { + url: 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent', + type: 'stream-values', + content: [ + `data: {"candidates": [{"content": {"parts": [{"text": "test"}],"role": "model"},` + + `"finishReason": "STOP","index": 0,"safetyRatings": [` + + `{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE",` + + `"probabilityScore": 0.1,"severity": "LOW","severityScore": 0.2,"blocked": false}]}]}\n\n`, + ], + }, + async () => { + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + const events = await convertReadableStreamToArray(stream); + const finishEvent = events.find(event => event.type === 'finish'); + + expect( + finishEvent?.type === 'finish' && + finishEvent.providerMetadata?.google.safetyRatings, + ).toStrictEqual([ + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.1, + severity: 'LOW', + severityScore: 0.2, + blocked: false, + }, + ]); + }, + ), + ); }); diff --git a/packages/google/src/google-generative-ai-language-model.ts b/packages/google/src/google-generative-ai-language-model.ts index a5678f8008cc..c67e3c786772 100644 --- a/packages/google/src/google-generative-ai-language-model.ts +++ b/packages/google/src/google-generative-ai-language-model.ts @@ -245,6 +245,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 { providerMetadata: { google: { groundingMetadata: candidate.groundingMetadata ?? null, + safetyRatings: candidate.safetyRatings ?? null, }, }, request: { body }, @@ -326,6 +327,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 { providerMetadata = { google: { groundingMetadata: candidate.groundingMetadata ?? null, + safetyRatings: candidate.safetyRatings ?? null, }, }; } @@ -473,11 +475,22 @@ export const groundingMetadataSchema = z.object({ .nullish(), }); +// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters +export const safetyRatingSchema = z.object({ + category: z.string(), + probability: z.string(), + probabilityScore: z.number().nullish(), + severity: z.string().nullish(), + severityScore: z.number().nullish(), + blocked: z.boolean().nullish(), +}); + const responseSchema = z.object({ candidates: z.array( z.object({ content: contentSchema.nullish(), finishReason: z.string().nullish(), + safetyRatings: z.array(safetyRatingSchema).nullish(), groundingMetadata: groundingMetadataSchema.nullish(), }), ), @@ -498,6 +511,7 @@ const chunkSchema = z.object({ z.object({ content: contentSchema.nullish(), finishReason: z.string().nullish(), + safetyRatings: z.array(safetyRatingSchema).nullish(), groundingMetadata: groundingMetadataSchema.nullish(), }), ) diff --git a/packages/google/src/google-generative-ai-prompt.ts b/packages/google/src/google-generative-ai-prompt.ts index fd6fbab20b5b..7fc4c145872e 100644 --- a/packages/google/src/google-generative-ai-prompt.ts +++ b/packages/google/src/google-generative-ai-prompt.ts @@ -1,4 +1,7 @@ -import { groundingMetadataSchema } from './google-generative-ai-language-model'; +import { + groundingMetadataSchema, + safetyRatingSchema, +} from './google-generative-ai-language-model'; import { z } from 'zod'; export type GoogleGenerativeAIPrompt = { @@ -46,6 +49,9 @@ export type GoogleGenerativeAIGroundingMetadata = z.infer< typeof groundingMetadataSchema >; +export type GoogleGenerativeAISafetyRating = z.infer; + export interface GoogleGenerativeAIProviderMetadata { groundingMetadata: GoogleGenerativeAIGroundingMetadata | null; + safetyRatings: GoogleGenerativeAISafetyRating[] | null; }