diff --git a/firestore-palm-chatbot/functions/package-lock.json b/firestore-palm-chatbot/functions/package-lock.json index 4e13a3e6..4c2a6061 100644 --- a/firestore-palm-chatbot/functions/package-lock.json +++ b/firestore-palm-chatbot/functions/package-lock.json @@ -6,7 +6,7 @@ "": { "name": "firestore-palm-chatbot", "dependencies": { - "@google-ai/generativelanguage": "^0.1.0", + "@google-ai/generativelanguage": "^0.2.0", "@google-cloud/aiplatform": "^2.17.0", "firebase-admin": "^11.6.0", "firebase-functions": "^4.3.0", @@ -800,9 +800,9 @@ } }, "node_modules/@google-ai/generativelanguage": { - "version": "0.1.1", - "resolved": "https://registry.npmjs.org/@google-ai/generativelanguage/-/generativelanguage-0.1.1.tgz", - "integrity": "sha512-hwCLm/O9CdGURSbxreaecSsniyKgGxEdd5Uz1t9FMui9F6DqB9jTgzlzfbZVGjE1S3r5WZOpRNAKXlGS5vJmeg==", + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/@google-ai/generativelanguage/-/generativelanguage-0.2.1.tgz", + "integrity": "sha512-oqEQScnGO6UoEqdKMIGiRfLWNpc83RtLWcO/g/VH3+2PnqIwEqJThDAMCHmRZ9B3zUiiL2cd4FaHx3ZU93CXEA==", "dependencies": { "google-gax": "^3.5.8" }, @@ -8405,9 +8405,9 @@ } }, "@google-ai/generativelanguage": { - "version": "0.1.1", - "resolved": "https://registry.npmjs.org/@google-ai/generativelanguage/-/generativelanguage-0.1.1.tgz", - "integrity": "sha512-hwCLm/O9CdGURSbxreaecSsniyKgGxEdd5Uz1t9FMui9F6DqB9jTgzlzfbZVGjE1S3r5WZOpRNAKXlGS5vJmeg==", + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/@google-ai/generativelanguage/-/generativelanguage-0.2.1.tgz", + "integrity": "sha512-oqEQScnGO6UoEqdKMIGiRfLWNpc83RtLWcO/g/VH3+2PnqIwEqJThDAMCHmRZ9B3zUiiL2cd4FaHx3ZU93CXEA==", "requires": { "google-gax": "^3.5.8" } diff --git a/firestore-palm-chatbot/functions/package.json b/firestore-palm-chatbot/functions/package.json index c3269cba..d1294512 100644 --- a/firestore-palm-chatbot/functions/package.json +++ b/firestore-palm-chatbot/functions/package.json @@ -10,7 +10,7 @@ }, "main": "lib/index.js", "dependencies": { - "@google-ai/generativelanguage": "^0.1.0", + "@google-ai/generativelanguage": "^0.2.0", "@google-cloud/aiplatform": "^2.17.0", "firebase-admin": "^11.6.0", "firebase-functions": "^4.3.0", diff --git a/firestore-palm-chatbot/functions/src/discussion.ts b/firestore-palm-chatbot/functions/src/discussion.ts index d53bed36..8a8d31ae 100644 --- a/firestore-palm-chatbot/functions/src/discussion.ts +++ b/firestore-palm-chatbot/functions/src/discussion.ts @@ -189,6 +189,12 @@ export class Discussion { const [result] = await this.generativeClient.generateMessage(request); + if (result.filters && result.filters.length) { + throw new Error( + 'Chat prompt or response filtered by the PaLM API content filter.' + ); + } + if (!result.candidates || !result.candidates.length) { throw new Error('No candidates returned from server.'); } diff --git a/firestore-palm-gen-text/functions/src/generator.ts b/firestore-palm-gen-text/functions/src/generator.ts index 32ef2d31..eec1428f 100644 --- a/firestore-palm-gen-text/functions/src/generator.ts +++ b/firestore-palm-gen-text/functions/src/generator.ts @@ -41,17 +41,6 @@ export type TextGeneratorRequestOptions = Omit< 'prompt' | 'model' >; -type SafetyAttributes = { - blocked?: boolean; - scores?: number[]; - categories?: string[]; -}; - -export type TextGeneratorResponse = { - candidates: string[]; - safetyAttributes?: SafetyAttributes; -}; - export class TextGenerator { private generativeClient?: TextServiceClient; private vertexClient?: v1.PredictionServiceClient; @@ -116,14 +105,7 @@ export class TextGenerator { const request = this.createVertexRequest(promptText, options); const [result] = await this.vertexClient.predict(request); - const {content, safetyAttributes} = - this.extractVertexCandidateResponse(result); - - if (!content) { - return {candidates: [], safetyAttributes}; - } - - return {candidates: [content], safetyAttributes}; + return this.extractVertexCandidateResponse(result); } if (!this.generativeClient) { @@ -134,11 +116,7 @@ export class TextGenerator { const [result] = await this.generativeClient.generateText(request); - const candidates = this.extractGenerativeCandidationResponse(result); - - return { - candidates, - }; + return this.extractGenerativeCandidationResponse(result); } private createGenerativeRequest( @@ -166,16 +144,7 @@ export class TextGenerator { if (!result.candidates || !result.candidates.length) { throw new Error('No candidates returned from server.'); } - - //TODO: do we need to filter out empty strings? This seems to be a type issue with the API, why are they optional? - const candidates = result.candidates - .map(candidate => candidate.output) - .filter(output => !!output) as string[]; - - if (!candidates.length) { - throw new Error('No candidates returned from server.'); - } - return candidates; + return convertToTextGeneratorResponse(result as GenerativePrediction); } private createVertexRequest( @@ -224,20 +193,70 @@ export class TextGenerator { const predictionValue = result.predictions[0] as protobuf.common.IValue; - const prediction = helpers.fromValue(predictionValue); + const vertexPrediction = helpers.fromValue(predictionValue); - const {safetyAttributes, content} = prediction as { - safetyAttributes?: { - blocked: boolean; - categories: string[]; - scores: number[]; - }; - content?: string; - }; + return convertToTextGeneratorResponse(vertexPrediction as VertexPrediction); + } +} + +type VertexPrediction = { + safetyAttributes?: { + blocked: boolean; + categories: string[]; + scores: number[]; + }; + content?: string; +}; +type GenerativePrediction = { + candidates: {output: string}[]; + filters?: {reason: string}[]; + safetyFeedback?: { + rating: Record; + setting: Record; + }[]; +}; + +type TextGeneratorResponse = { + candidates: string[]; + safetyMetadata?: { + blocked: boolean; + [key: string]: any; + }; +}; + +function convertToTextGeneratorResponse( + prediction: VertexPrediction | GenerativePrediction +): TextGeneratorResponse { + // if it's generative language + if ('candidates' in prediction) { + const {candidates, filters, safetyFeedback} = prediction; + const blocked = !!filters && filters.length > 0; + const safetyMetadata = { + blocked, + safetyFeedback, + }; + if (!candidates.length && !blocked) { + throw new Error('No candidates returned from the Generative API.'); + } return { - content, + candidates: candidates.map(candidate => candidate.output), + safetyMetadata, + }; + } else { + // provider will be vertex + const {content, safetyAttributes} = prediction; + const blocked = !!safetyAttributes && !!safetyAttributes.blocked; + const safetyMetadata = { + blocked, safetyAttributes, }; + if (!content && !blocked) { + throw new Error('No content returned from the Vertex PaLM API.'); + } + return { + candidates: blocked ? [] : [content!], + safetyMetadata, + }; } } diff --git a/firestore-palm-gen-text/functions/src/index.ts b/firestore-palm-gen-text/functions/src/index.ts index 6561589f..2f2f5e1a 100644 --- a/firestore-palm-gen-text/functions/src/index.ts +++ b/firestore-palm-gen-text/functions/src/index.ts @@ -122,11 +122,11 @@ export const generateText = functions.firestore 'status.updateTime': FieldValue.serverTimestamp(), }; - if (result.safetyAttributes) { - metadata['safetyAttributes'] = result.safetyAttributes; + if (result.safetyMetadata) { + metadata['safetyMetadata'] = result.safetyMetadata; } - if (result.safetyAttributes?.blocked) { + if (result.safetyMetadata?.blocked) { return ref.update({ ...metadata, 'status.state': 'ERRORED', diff --git a/firestore-palm-summarize-text/functions/package-lock.json b/firestore-palm-summarize-text/functions/package-lock.json index c6216d01..30b30926 100644 --- a/firestore-palm-summarize-text/functions/package-lock.json +++ b/firestore-palm-summarize-text/functions/package-lock.json @@ -6,7 +6,7 @@ "": { "name": "firestore-palm-summarize-text", "dependencies": { - "@google-ai/generativelanguage": "^0.1.0", + "@google-ai/generativelanguage": "^0.2.0", "@google-cloud/aiplatform": "^2.17.0", "firebase-admin": "^11.5.0", "firebase-functions": "^4.2.0", @@ -800,9 +800,9 @@ } }, "node_modules/@google-ai/generativelanguage": { - "version": "0.1.1", - "resolved": "https://registry.npmjs.org/@google-ai/generativelanguage/-/generativelanguage-0.1.1.tgz", - "integrity": "sha512-hwCLm/O9CdGURSbxreaecSsniyKgGxEdd5Uz1t9FMui9F6DqB9jTgzlzfbZVGjE1S3r5WZOpRNAKXlGS5vJmeg==", + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/@google-ai/generativelanguage/-/generativelanguage-0.2.1.tgz", + "integrity": "sha512-oqEQScnGO6UoEqdKMIGiRfLWNpc83RtLWcO/g/VH3+2PnqIwEqJThDAMCHmRZ9B3zUiiL2cd4FaHx3ZU93CXEA==", "dependencies": { "google-gax": "^3.5.8" }, @@ -8456,9 +8456,9 @@ } }, "@google-ai/generativelanguage": { - "version": "0.1.1", - "resolved": "https://registry.npmjs.org/@google-ai/generativelanguage/-/generativelanguage-0.1.1.tgz", - "integrity": "sha512-hwCLm/O9CdGURSbxreaecSsniyKgGxEdd5Uz1t9FMui9F6DqB9jTgzlzfbZVGjE1S3r5WZOpRNAKXlGS5vJmeg==", + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/@google-ai/generativelanguage/-/generativelanguage-0.2.1.tgz", + "integrity": "sha512-oqEQScnGO6UoEqdKMIGiRfLWNpc83RtLWcO/g/VH3+2PnqIwEqJThDAMCHmRZ9B3zUiiL2cd4FaHx3ZU93CXEA==", "requires": { "google-gax": "^3.5.8" } diff --git a/firestore-palm-summarize-text/functions/package.json b/firestore-palm-summarize-text/functions/package.json index 8f5f8a00..d1c01cb3 100644 --- a/firestore-palm-summarize-text/functions/package.json +++ b/firestore-palm-summarize-text/functions/package.json @@ -10,7 +10,7 @@ }, "main": "lib/index.js", "dependencies": { - "@google-ai/generativelanguage": "^0.1.0", + "@google-ai/generativelanguage": "^0.2.0", "@google-cloud/aiplatform": "^2.17.0", "firebase-admin": "^11.5.0", "firebase-functions": "^4.2.0", diff --git a/firestore-palm-summarize-text/functions/src/generator.ts b/firestore-palm-summarize-text/functions/src/generator.ts index 19eda777..e6fcd86a 100644 --- a/firestore-palm-summarize-text/functions/src/generator.ts +++ b/firestore-palm-summarize-text/functions/src/generator.ts @@ -36,14 +36,6 @@ export type TextGeneratorRequestOptions = Omit< APIGenerateTextRequest, 'prompt' | 'model' >; -export type TextGeneratorResponse = { - candidates: string[]; - safetyAttributes?: { - blocked?: boolean; - scores?: number[]; - categories?: string[]; - }; -}; type VertexPredictResponse = protos.google.cloud.aiplatform.v1beta1.IPredictResponse; @@ -107,21 +99,9 @@ export class TextGenerator { const predictionValue = result.predictions[0] as protobuf.common.IValue; - const prediction = helpers.fromValue(predictionValue); + const vertexPrediction = helpers.fromValue(predictionValue); - const {safetyAttributes, content} = prediction as { - safetyAttributes?: { - blocked: boolean; - categories: string[]; - scores: number[]; - }; - content?: string; - }; - - return { - content, - safetyAttributes, - }; + return convertToTextGeneratorResponse(vertexPrediction as VertexPrediction); } async generate( @@ -165,13 +145,7 @@ export class TextGenerator { const [result] = await this.vertexClient.predict(request); - const {content, safetyAttributes} = - this.extractVertexCandidateResponse(result); - - if (!content) { - return {candidates: [], safetyAttributes}; - } - return {candidates: [content]}; + return this.extractVertexCandidateResponse(result); } const request = { @@ -188,21 +162,68 @@ export class TextGenerator { const [result] = await this.generativeClient.generateText(request); - if (!result.candidates || !result.candidates.length) { - throw new Error('No candidates returned from server.'); - } + return convertToTextGeneratorResponse(result as GenerativePrediction); + } +} - //TODO: do we need to filter out empty strings? This seems to be a type issue with the API, why are they optional? - const candidates = result.candidates - .map(candidate => candidate.output) - .filter(output => !!output) as string[]; +type VertexPrediction = { + safetyAttributes?: { + blocked: boolean; + categories: string[]; + scores: number[]; + }; + content?: string; +}; - if (!candidates.length) { - throw new Error('No candidates returned from server.'); - } +type GenerativePrediction = { + candidates: {output: string}[]; + filters?: {reason: string}[]; + safetyFeedback?: { + rating: Record; + setting: Record; + }[]; +}; + +type TextGeneratorResponse = { + candidates: string[]; + safetyMetadata?: { + blocked: boolean; + [key: string]: any; + }; +}; +function convertToTextGeneratorResponse( + prediction: VertexPrediction | GenerativePrediction +): TextGeneratorResponse { + // if it's generative language + if ('candidates' in prediction) { + const {candidates, filters, safetyFeedback} = prediction; + const blocked = !!filters && filters.length > 0; + const safetyMetadata = { + blocked, + safetyFeedback, + }; + if (!candidates.length && !blocked) { + throw new Error('No candidates returned from the Generative API.'); + } + return { + candidates: candidates.map(candidate => candidate.output), + safetyMetadata, + }; + } else { + // provider will be vertex + const {content, safetyAttributes} = prediction; + const blocked = !!safetyAttributes && !!safetyAttributes.blocked; + const safetyMetadata = { + blocked, + safetyAttributes, + }; + if (!content && !blocked) { + throw new Error('No content returned from the Vertex PaLM API.'); + } return { - candidates, + candidates: blocked ? [] : [content!], + safetyMetadata, }; } } diff --git a/firestore-palm-summarize-text/functions/src/index.ts b/firestore-palm-summarize-text/functions/src/index.ts index f15cf800..8437042e 100644 --- a/firestore-palm-summarize-text/functions/src/index.ts +++ b/firestore-palm-summarize-text/functions/src/index.ts @@ -79,16 +79,16 @@ export const generateSummary = functions.firestore 'status.updateTime': FieldValue.serverTimestamp(), }; - if (result.safetyAttributes) { - metadata['safetyAttributes'] = result.safetyAttributes; + if (result.safetyMetadata) { + metadata['safetyMetadata'] = result.safetyMetadata; } - if (result.safetyAttributes?.blocked) { + if (result.safetyMetadata?.blocked) { return ref.update({ ...metadata, 'status.state': 'ERRORED', 'status.error': - 'The text provided was blocked by the Vertex AI content filter.', + 'The prompt or summary was blocked by the PaLM content filter.', }); } return ref.update({