Skip to content

Commit

Permalink
feat(*): handle content filters for both APIs (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac authored Aug 14, 2023
1 parent f51daf7 commit 3d0bd46
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 107 deletions.
14 changes: 7 additions & 7 deletions firestore-palm-chatbot/functions/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion firestore-palm-chatbot/functions/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
"main": "lib/index.js",
"dependencies": {
"@google-ai/generativelanguage": "^0.1.0",
"@google-ai/generativelanguage": "^0.2.0",
"@google-cloud/aiplatform": "^2.17.0",
"firebase-admin": "^11.6.0",
"firebase-functions": "^4.3.0",
Expand Down
6 changes: 6 additions & 0 deletions firestore-palm-chatbot/functions/src/discussion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ export class Discussion {

const [result] = await this.generativeClient.generateMessage(request);

if (result.filters && result.filters.length) {
throw new Error(
'Chat prompt or response filtered by the PaLM API content filter.'
);
}

if (!result.candidates || !result.candidates.length) {
throw new Error('No candidates returned from server.');
}
Expand Down
107 changes: 63 additions & 44 deletions firestore-palm-gen-text/functions/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,6 @@ export type TextGeneratorRequestOptions = Omit<
'prompt' | 'model'
>;

type SafetyAttributes = {
blocked?: boolean;
scores?: number[];
categories?: string[];
};

export type TextGeneratorResponse = {
candidates: string[];
safetyAttributes?: SafetyAttributes;
};

export class TextGenerator {
private generativeClient?: TextServiceClient;
private vertexClient?: v1.PredictionServiceClient;
Expand Down Expand Up @@ -116,14 +105,7 @@ export class TextGenerator {
const request = this.createVertexRequest(promptText, options);
const [result] = await this.vertexClient.predict(request);

const {content, safetyAttributes} =
this.extractVertexCandidateResponse(result);

if (!content) {
return {candidates: [], safetyAttributes};
}

return {candidates: [content], safetyAttributes};
return this.extractVertexCandidateResponse(result);
}

if (!this.generativeClient) {
Expand All @@ -134,11 +116,7 @@ export class TextGenerator {

const [result] = await this.generativeClient.generateText(request);

const candidates = this.extractGenerativeCandidationResponse(result);

return {
candidates,
};
return this.extractGenerativeCandidationResponse(result);
}

private createGenerativeRequest(
Expand Down Expand Up @@ -166,16 +144,7 @@ export class TextGenerator {
if (!result.candidates || !result.candidates.length) {
throw new Error('No candidates returned from server.');
}

//TODO: do we need to filter out empty strings? This seems to be a type issue with the API, why are they optional?
const candidates = result.candidates
.map(candidate => candidate.output)
.filter(output => !!output) as string[];

if (!candidates.length) {
throw new Error('No candidates returned from server.');
}
return candidates;
return convertToTextGeneratorResponse(result as GenerativePrediction);
}

private createVertexRequest(
Expand Down Expand Up @@ -224,20 +193,70 @@ export class TextGenerator {

const predictionValue = result.predictions[0] as protobuf.common.IValue;

const prediction = helpers.fromValue(predictionValue);
const vertexPrediction = helpers.fromValue(predictionValue);

const {safetyAttributes, content} = prediction as {
safetyAttributes?: {
blocked: boolean;
categories: string[];
scores: number[];
};
content?: string;
};
return convertToTextGeneratorResponse(vertexPrediction as VertexPrediction);
}
}

type VertexPrediction = {
safetyAttributes?: {
blocked: boolean;
categories: string[];
scores: number[];
};
content?: string;
};

type GenerativePrediction = {
candidates: {output: string}[];
filters?: {reason: string}[];
safetyFeedback?: {
rating: Record<string, any>;
setting: Record<string, any>;
}[];
};

type TextGeneratorResponse = {
candidates: string[];
safetyMetadata?: {
blocked: boolean;
[key: string]: any;
};
};

function convertToTextGeneratorResponse(
prediction: VertexPrediction | GenerativePrediction
): TextGeneratorResponse {
// if it's generative language
if ('candidates' in prediction) {
const {candidates, filters, safetyFeedback} = prediction;
const blocked = !!filters && filters.length > 0;
const safetyMetadata = {
blocked,
safetyFeedback,
};
if (!candidates.length && !blocked) {
throw new Error('No candidates returned from the Generative API.');
}
return {
content,
candidates: candidates.map(candidate => candidate.output),
safetyMetadata,
};
} else {
// provider will be vertex
const {content, safetyAttributes} = prediction;
const blocked = !!safetyAttributes && !!safetyAttributes.blocked;
const safetyMetadata = {
blocked,
safetyAttributes,
};
if (!content && !blocked) {
throw new Error('No content returned from the Vertex PaLM API.');
}
return {
candidates: blocked ? [] : [content!],
safetyMetadata,
};
}
}
6 changes: 3 additions & 3 deletions firestore-palm-gen-text/functions/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ export const generateText = functions.firestore
'status.updateTime': FieldValue.serverTimestamp(),
};

if (result.safetyAttributes) {
metadata['safetyAttributes'] = result.safetyAttributes;
if (result.safetyMetadata) {
metadata['safetyMetadata'] = result.safetyMetadata;
}

if (result.safetyAttributes?.blocked) {
if (result.safetyMetadata?.blocked) {
return ref.update({
...metadata,
'status.state': 'ERRORED',
Expand Down
14 changes: 7 additions & 7 deletions firestore-palm-summarize-text/functions/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion firestore-palm-summarize-text/functions/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
"main": "lib/index.js",
"dependencies": {
"@google-ai/generativelanguage": "^0.1.0",
"@google-ai/generativelanguage": "^0.2.0",
"@google-cloud/aiplatform": "^2.17.0",
"firebase-admin": "^11.5.0",
"firebase-functions": "^4.2.0",
Expand Down
101 changes: 61 additions & 40 deletions firestore-palm-summarize-text/functions/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ export type TextGeneratorRequestOptions = Omit<
APIGenerateTextRequest,
'prompt' | 'model'
>;
export type TextGeneratorResponse = {
candidates: string[];
safetyAttributes?: {
blocked?: boolean;
scores?: number[];
categories?: string[];
};
};

type VertexPredictResponse =
protos.google.cloud.aiplatform.v1beta1.IPredictResponse;
Expand Down Expand Up @@ -107,21 +99,9 @@ export class TextGenerator {

const predictionValue = result.predictions[0] as protobuf.common.IValue;

const prediction = helpers.fromValue(predictionValue);
const vertexPrediction = helpers.fromValue(predictionValue);

const {safetyAttributes, content} = prediction as {
safetyAttributes?: {
blocked: boolean;
categories: string[];
scores: number[];
};
content?: string;
};

return {
content,
safetyAttributes,
};
return convertToTextGeneratorResponse(vertexPrediction as VertexPrediction);
}

async generate(
Expand Down Expand Up @@ -165,13 +145,7 @@ export class TextGenerator {

const [result] = await this.vertexClient.predict(request);

const {content, safetyAttributes} =
this.extractVertexCandidateResponse(result);

if (!content) {
return {candidates: [], safetyAttributes};
}
return {candidates: [content]};
return this.extractVertexCandidateResponse(result);
}

const request = {
Expand All @@ -188,21 +162,68 @@ export class TextGenerator {

const [result] = await this.generativeClient.generateText(request);

if (!result.candidates || !result.candidates.length) {
throw new Error('No candidates returned from server.');
}
return convertToTextGeneratorResponse(result as GenerativePrediction);
}
}

//TODO: do we need to filter out empty strings? This seems to be a type issue with the API, why are they optional?
const candidates = result.candidates
.map(candidate => candidate.output)
.filter(output => !!output) as string[];
type VertexPrediction = {
safetyAttributes?: {
blocked: boolean;
categories: string[];
scores: number[];
};
content?: string;
};

if (!candidates.length) {
throw new Error('No candidates returned from server.');
}
type GenerativePrediction = {
candidates: {output: string}[];
filters?: {reason: string}[];
safetyFeedback?: {
rating: Record<string, any>;
setting: Record<string, any>;
}[];
};

type TextGeneratorResponse = {
candidates: string[];
safetyMetadata?: {
blocked: boolean;
[key: string]: any;
};
};

function convertToTextGeneratorResponse(
prediction: VertexPrediction | GenerativePrediction
): TextGeneratorResponse {
// if it's generative language
if ('candidates' in prediction) {
const {candidates, filters, safetyFeedback} = prediction;
const blocked = !!filters && filters.length > 0;
const safetyMetadata = {
blocked,
safetyFeedback,
};
if (!candidates.length && !blocked) {
throw new Error('No candidates returned from the Generative API.');
}
return {
candidates: candidates.map(candidate => candidate.output),
safetyMetadata,
};
} else {
// provider will be vertex
const {content, safetyAttributes} = prediction;
const blocked = !!safetyAttributes && !!safetyAttributes.blocked;
const safetyMetadata = {
blocked,
safetyAttributes,
};
if (!content && !blocked) {
throw new Error('No content returned from the Vertex PaLM API.');
}
return {
candidates,
candidates: blocked ? [] : [content!],
safetyMetadata,
};
}
}
Loading

0 comments on commit 3d0bd46

Please sign in to comment.