Skip to content

Commit

Permalink
feat (provider/google): Include safety ratings response detail. (#4051)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaper authored Dec 12, 2024
1 parent 6cc182b commit e07439a
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .changeset/pretty-rabbits-smell.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@ai-sdk/google-vertex': patch
'@ai-sdk/google': patch
---

feat (provider/google): Include safety ratings response detail.
47 changes: 44 additions & 3 deletions content/providers/01-ai-sdk-providers/11-google-vertex.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ The following optional settings are available for Google Vertex models:

You can use Google Vertex language models to generate text with the `generateText` function:

```ts highlight="1,5"
```ts highlight="1,4"
import { vertex } from '@ai-sdk/google-vertex';
import { generateText } from 'ai';

Expand Down Expand Up @@ -352,7 +352,7 @@ With [search grounding](https://cloud.google.com/vertex-ai/generative-ai/docs/gr
the model has access to the latest information using Google search.
Search grounding can be used to provide answers around current events:

```ts highlight="6,14-17"
```ts highlight="7,14-20"
import { vertex } from '@ai-sdk/google-vertex';
import { GoogleGenerativeAIProviderMetadata } from '@ai-sdk/google';
import { generateText } from 'ai';
Expand All @@ -372,6 +372,7 @@ const metadata = experimental_providerMetadata?.google as
| GoogleGenerativeAIProviderMetadata
| undefined;
const groundingMetadata = metadata?.groundingMetadata;
const safetyRatings = metadata?.safetyRatings;
```

The grounding metadata includes detailed information about how search results were used to ground the model's response. Here are the available fields:
Expand All @@ -396,7 +397,7 @@ The grounding metadata includes detailed information about how search results we
- **`groundingChunkIndices`**: References to supporting search result chunks
- **`confidenceScores`**: Confidence scores (0-1) for each supporting chunk

Example response:
Example response excerpt:

```json
{
Expand All @@ -420,6 +421,46 @@ Example response:
}
```

The safety ratings provide insight into how the model's response was grounded to search results. See [Google Vertex AI documentation on configuring safety filters](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters).

Example response excerpt:

```json
{
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.11027937,
"severity": "HARM_SEVERITY_LOW",
"severityScore": 0.28487435
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "HIGH",
"blocked": true,
"probabilityScore": 0.95422274,
"severity": "HARM_SEVERITY_MEDIUM",
"severityScore": 0.43398145
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.11085559,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.19027223
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.22901751,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.09089675
}
]
}
```

For more details, see the [Google Vertex AI documentation on grounding with Google Search](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/ground-gemini#ground-to-search).

### Troubleshooting
Expand Down
28 changes: 28 additions & 0 deletions examples/ai-core/src/e2e/google-vertex.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ describe.each(Object.values(RUNTIME_VARIANTS))(
expect(Array.isArray(support?.groundingChunkIndices)).toBe(true);
expect(Array.isArray(support?.confidenceScores)).toBe(true);

// Verify safety ratings
const safetyRatings = metadata?.safetyRatings;
expect(Array.isArray(safetyRatings)).toBe(true);
expect(safetyRatings?.length).toBeGreaterThan(0);

// Verify each safety rating has required properties
safetyRatings?.forEach(rating => {
expect(rating.category).toBeDefined();
expect(rating.probability).toBeDefined();
expect(typeof rating.probabilityScore).toBe('number');
expect(rating.severity).toBeDefined();
expect(typeof rating.severityScore).toBe('number');
});

// Basic response checks
expect(result.text).toBeTruthy();
expect(result.usage?.totalTokens).toBeGreaterThan(0);
Expand Down Expand Up @@ -261,6 +275,20 @@ describe.each(Object.values(RUNTIME_VARIANTS))(
expect(Array.isArray(support?.groundingChunkIndices)).toBe(true);
expect(Array.isArray(support?.confidenceScores)).toBe(true);

// Verify safety ratings
const safetyRatings = metadata?.safetyRatings;
expect(Array.isArray(safetyRatings)).toBe(true);
expect(safetyRatings?.length).toBeGreaterThan(0);

// Verify each safety rating has required properties
safetyRatings?.forEach(rating => {
expect(rating.category).toBeDefined();
expect(rating.probability).toBeDefined();
expect(typeof rating.probabilityScore).toBe('number');
expect(rating.severity).toBeDefined();
expect(typeof rating.severityScore).toBe('number');
});

// Basic response checks
expect(chunks.join('')).toBeTruthy();
expect((await result.usage)?.totalTokens).toBeGreaterThan(0);
Expand Down
127 changes: 127 additions & 0 deletions packages/google/src/google-generative-ai-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,57 @@ describe('doGenerate', () => {
}),
);
});

it(
'should expose safety ratings in provider metadata',
withTestServer(
{
url: 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent',
type: 'json-value',
content: {
candidates: [
{
content: {
parts: [{ text: 'test response' }],
role: 'model',
},
finishReason: 'STOP',
index: 0,
safetyRatings: [
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
probability: 'NEGLIGIBLE',
probabilityScore: 0.1,
severity: 'LOW',
severityScore: 0.2,
blocked: false,
},
],
},
],
promptFeedback: { safetyRatings: SAFETY_RATINGS },
},
},
async () => {
const { providerMetadata } = await model.doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(providerMetadata?.google.safetyRatings).toStrictEqual([
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
probability: 'NEGLIGIBLE',
probabilityScore: 0.1,
severity: 'LOW',
severityScore: 0.2,
blocked: false,
},
]);
},
),
);
});

describe('doStream', () => {
Expand Down Expand Up @@ -678,6 +729,24 @@ describe('doStream', () => {
providerMetadata: {
google: {
groundingMetadata: null,
safetyRatings: [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
probability: 'NEGLIGIBLE',
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
probability: 'NEGLIGIBLE',
},
{
category: 'HARM_CATEGORY_HARASSMENT',
probability: 'NEGLIGIBLE',
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
probability: 'NEGLIGIBLE',
},
],
},
},
usage: { promptTokens: 294, completionTokens: 233 },
Expand Down Expand Up @@ -834,6 +903,24 @@ describe('doStream', () => {
providerMetadata: {
google: {
groundingMetadata: null,
safetyRatings: [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
probability: 'NEGLIGIBLE',
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
probability: 'NEGLIGIBLE',
},
{
category: 'HARM_CATEGORY_HARASSMENT',
probability: 'NEGLIGIBLE',
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
probability: 'NEGLIGIBLE',
},
],
},
},
usage: { promptTokens: 294, completionTokens: 233 },
Expand All @@ -842,4 +929,44 @@ describe('doStream', () => {
},
),
);

it(
'should expose safety ratings in provider metadata on finish',
withTestServer(
{
url: 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent',
type: 'stream-values',
content: [
`data: {"candidates": [{"content": {"parts": [{"text": "test"}],"role": "model"},` +
`"finishReason": "STOP","index": 0,"safetyRatings": [` +
`{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE",` +
`"probabilityScore": 0.1,"severity": "LOW","severityScore": 0.2,"blocked": false}]}]}\n\n`,
],
},
async () => {
const { stream } = await model.doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

const events = await convertReadableStreamToArray(stream);
const finishEvent = events.find(event => event.type === 'finish');

expect(
finishEvent?.type === 'finish' &&
finishEvent.providerMetadata?.google.safetyRatings,
).toStrictEqual([
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
probability: 'NEGLIGIBLE',
probabilityScore: 0.1,
severity: 'LOW',
severityScore: 0.2,
blocked: false,
},
]);
},
),
);
});
14 changes: 14 additions & 0 deletions packages/google/src/google-generative-ai-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 {
providerMetadata: {
google: {
groundingMetadata: candidate.groundingMetadata ?? null,
safetyRatings: candidate.safetyRatings ?? null,
},
},
request: { body },
Expand Down Expand Up @@ -326,6 +327,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 {
providerMetadata = {
google: {
groundingMetadata: candidate.groundingMetadata ?? null,
safetyRatings: candidate.safetyRatings ?? null,
},
};
}
Expand Down Expand Up @@ -473,11 +475,22 @@ export const groundingMetadataSchema = z.object({
.nullish(),
});

// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters
export const safetyRatingSchema = z.object({
category: z.string(),
probability: z.string(),
probabilityScore: z.number().nullish(),
severity: z.string().nullish(),
severityScore: z.number().nullish(),
blocked: z.boolean().nullish(),
});

const responseSchema = z.object({
candidates: z.array(
z.object({
content: contentSchema.nullish(),
finishReason: z.string().nullish(),
safetyRatings: z.array(safetyRatingSchema).nullish(),
groundingMetadata: groundingMetadataSchema.nullish(),
}),
),
Expand All @@ -498,6 +511,7 @@ const chunkSchema = z.object({
z.object({
content: contentSchema.nullish(),
finishReason: z.string().nullish(),
safetyRatings: z.array(safetyRatingSchema).nullish(),
groundingMetadata: groundingMetadataSchema.nullish(),
}),
)
Expand Down
8 changes: 7 additions & 1 deletion packages/google/src/google-generative-ai-prompt.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { groundingMetadataSchema } from './google-generative-ai-language-model';
import {
groundingMetadataSchema,
safetyRatingSchema,
} from './google-generative-ai-language-model';
import { z } from 'zod';

export type GoogleGenerativeAIPrompt = {
Expand Down Expand Up @@ -46,6 +49,9 @@ export type GoogleGenerativeAIGroundingMetadata = z.infer<
typeof groundingMetadataSchema
>;

export type GoogleGenerativeAISafetyRating = z.infer<typeof safetyRatingSchema>;

export interface GoogleGenerativeAIProviderMetadata {
groundingMetadata: GoogleGenerativeAIGroundingMetadata | null;
safetyRatings: GoogleGenerativeAISafetyRating[] | null;
}

0 comments on commit e07439a

Please sign in to comment.