Skip to content

Commit

Permalink
test(firestore-multimodal-genai): add test for checking whether genki…
Browse files Browse the repository at this point in the history
…t is used
  • Loading branch information
cabljac committed Nov 29, 2024
1 parent c91b358 commit 25a53ea
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 7 deletions.
102 changes: 101 additions & 1 deletion firestore-multimodal-genai/functions/__tests__/genkit/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,11 @@ describe('GenkitGenerativeClient', () => {
'gemini-1.5-flash',
'google-ai'
);
expect(modelReference === null).toBe(false);

expect(modelReference.name).toBe('googleai/gemini-1.5-flash');
expect(modelReference).toHaveProperty('name');

expect(modelReference!.name).toBe('googleai/gemini-1.5-flash');
});

it('should call generate with correct options and return response', async () => {
Expand Down Expand Up @@ -215,3 +218,100 @@ describe('GenkitGenerativeClient', () => {
);
});
});

describe('GenkitGenerativeClient.shouldUseGenkitClient', () => {
const baseConfig: Config = {
vertex: {model: 'gemini-1.0-pro'},
googleAi: {model: 'gemini-1.5-flash', apiKey: 'test-api-key'},
model: 'gemini-1.5-flash',
location: 'us-central1',
projectId: 'test-project',
instanceId: 'test-instance',
prompt: 'Test prompt',
responseField: 'output',
collectionName: 'users/{uid}/discussions/{discussionId}/messages',
temperature: 0.7,
topP: 0.9,
topK: 50,
candidates: {
field: 'candidates',
count: 1,
shouldIncludeCandidatesField: false,
},
maxOutputTokens: 256,
maxOutputTokensVertex: 1024,
provider: 'google-ai',
apiKey: 'test-api-key',
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
],
bucketName: 'test-bucket',
imageField: 'image',
};

beforeEach(() => {
jest.clearAllMocks();
});

it('should return false if the model includes "pro-vision"', () => {
const config = {...baseConfig, model: 'gemini-pro-vision'};

const result = GenkitGenerativeClient.shouldUseGenkitClient(config);

expect(result).toBe(false);
});

it('should return false if multiple candidates are requested', () => {
const config = {
...baseConfig,
candidates: {
field: 'candidates',
count: 2,
shouldIncludeCandidatesField: true,
},
};

const result = GenkitGenerativeClient.shouldUseGenkitClient(config);

expect(result).toBe(false);
});

it('should return false if no model reference is found', () => {
const config = {...baseConfig, model: 'unknown-model'};

jest
.spyOn(GenkitGenerativeClient, 'createModelReference')
.mockReturnValueOnce(null);

const result = GenkitGenerativeClient.shouldUseGenkitClient(config);

expect(result).toBe(false);
});

it('should return true if conditions are met for Genkit client usage', () => {
const config = {...baseConfig, model: 'gemini-1.5-flash'};

jest
.spyOn(GenkitGenerativeClient, 'createModelReference')
.mockReturnValueOnce({
name: 'googleai/gemini-1.5-flash',
withVersion: jest.fn(),
withConfig: jest.fn(),
});

const result = GenkitGenerativeClient.shouldUseGenkitClient(config);

expect(result).toBe(true);
});

it('should call createModelReference with correct parameters', () => {
const spy = jest.spyOn(GenkitGenerativeClient, 'createModelReference');

GenkitGenerativeClient.shouldUseGenkitClient(baseConfig);

expect(spy).toHaveBeenCalledWith('gemini-1.5-flash', 'google-ai');
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ export class GenkitGenerativeClient extends GenerativeClient<
static createModelReference(
model: string,
provider: string
): ModelReference<any> {
): ModelReference<any> | null {
const modelReferences =
provider === 'google-ai'
? [gemini10ProGoogleAI, gemini15FlashGoogleAI, gemini15ProGoogleAI]
Expand All @@ -109,19 +109,25 @@ export class GenkitGenerativeClient extends GenerativeClient<
return modelReference.withVersion(model);
}
}
throw new Error('Model reference not found.');
return null;
}

private createGenerateOptions(config: Config): GenerateOptions {
if (!config.model) {
throw new Error('Model must be specified in the configuration.');
}

const modelRef = GenkitGenerativeClient.createModelReference(
config.model,
config.provider!
);

if (!modelRef) {
throw new Error('Model reference not found.');
}

return {
model: GenkitGenerativeClient.createModelReference(
config.model,
config.provider!
),
model: modelRef,
config: {
topP: config.topP,
topK: config.topK,
Expand Down

0 comments on commit 25a53ea

Please sign in to comment.