diff --git a/firestore-multimodal-genai/functions/__tests__/genkit/client.test.ts b/firestore-multimodal-genai/functions/__tests__/genkit/client.test.ts index 3d9df12b..aa65790e 100644 --- a/firestore-multimodal-genai/functions/__tests__/genkit/client.test.ts +++ b/firestore-multimodal-genai/functions/__tests__/genkit/client.test.ts @@ -131,8 +131,11 @@ describe('GenkitGenerativeClient', () => { 'gemini-1.5-flash', 'google-ai' ); + expect(modelReference === null).toBe(false); - expect(modelReference.name).toBe('googleai/gemini-1.5-flash'); + expect(modelReference).toHaveProperty('name'); + + expect(modelReference!.name).toBe('googleai/gemini-1.5-flash'); }); it('should call generate with correct options and return response', async () => { @@ -215,3 +218,100 @@ describe('GenkitGenerativeClient', () => { ); }); }); + +describe('GenkitGenerativeClient.shouldUseGenkitClient', () => { + const baseConfig: Config = { + vertex: {model: 'gemini-1.0-pro'}, + googleAi: {model: 'gemini-1.5-flash', apiKey: 'test-api-key'}, + model: 'gemini-1.5-flash', + location: 'us-central1', + projectId: 'test-project', + instanceId: 'test-instance', + prompt: 'Test prompt', + responseField: 'output', + collectionName: 'users/{uid}/discussions/{discussionId}/messages', + temperature: 0.7, + topP: 0.9, + topK: 50, + candidates: { + field: 'candidates', + count: 1, + shouldIncludeCandidatesField: false, + }, + maxOutputTokens: 256, + maxOutputTokensVertex: 1024, + provider: 'google-ai', + apiKey: 'test-api-key', + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + }, + ], + bucketName: 'test-bucket', + imageField: 'image', + }; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should return false if the model includes "pro-vision"', () => { + const config = {...baseConfig, model: 'gemini-pro-vision'}; + + const result = GenkitGenerativeClient.shouldUseGenkitClient(config); + + expect(result).toBe(false); + }); + + it('should return false if multiple candidates are requested', () => { + const config = { + ...baseConfig, + candidates: { + field: 'candidates', + count: 2, + shouldIncludeCandidatesField: true, + }, + }; + + const result = GenkitGenerativeClient.shouldUseGenkitClient(config); + + expect(result).toBe(false); + }); + + it('should return false if no model reference is found', () => { + const config = {...baseConfig, model: 'unknown-model'}; + + jest + .spyOn(GenkitGenerativeClient, 'createModelReference') + .mockReturnValueOnce(null); + + const result = GenkitGenerativeClient.shouldUseGenkitClient(config); + + expect(result).toBe(false); + }); + + it('should return true if conditions are met for Genkit client usage', () => { + const config = {...baseConfig, model: 'gemini-1.5-flash'}; + + jest + .spyOn(GenkitGenerativeClient, 'createModelReference') + .mockReturnValueOnce({ + name: 'googleai/gemini-1.5-flash', + withVersion: jest.fn(), + withConfig: jest.fn(), + }); + + const result = GenkitGenerativeClient.shouldUseGenkitClient(config); + + expect(result).toBe(true); + }); + + it('should call createModelReference with correct parameters', () => { + const spy = jest.spyOn(GenkitGenerativeClient, 'createModelReference'); + + GenkitGenerativeClient.shouldUseGenkitClient(baseConfig); + + expect(spy).toHaveBeenCalledWith('gemini-1.5-flash', 'google-ai'); + }); +}); diff --git a/firestore-multimodal-genai/functions/src/generative-client/genkit.ts b/firestore-multimodal-genai/functions/src/generative-client/genkit.ts index ef2a1f8d..ffd16240 100644 --- a/firestore-multimodal-genai/functions/src/generative-client/genkit.ts +++ b/firestore-multimodal-genai/functions/src/generative-client/genkit.ts @@ -93,7 +93,7 @@ export class GenkitGenerativeClient extends GenerativeClient< static createModelReference( model: string, provider: string - ): ModelReference { + ): ModelReference | null { const modelReferences = provider === 'google-ai' ? [gemini10ProGoogleAI, gemini15FlashGoogleAI, gemini15ProGoogleAI] @@ -109,7 +109,7 @@ export class GenkitGenerativeClient extends GenerativeClient< return modelReference.withVersion(model); } } - throw new Error('Model reference not found.'); + return null; } private createGenerateOptions(config: Config): GenerateOptions { @@ -117,11 +117,17 @@ export class GenkitGenerativeClient extends GenerativeClient< throw new Error('Model must be specified in the configuration.'); } + const modelRef = GenkitGenerativeClient.createModelReference( + config.model, + config.provider! + ); + + if (!modelRef) { + throw new Error('Model reference not found.'); + } + return { - model: GenkitGenerativeClient.createModelReference( - config.model, - config.provider! - ), + model: modelRef, config: { topP: config.topP, topK: config.topK,