Skip to content

Commit

Permalink
Merge pull request #46 from tak-bro/feature/add-groq
Browse files Browse the repository at this point in the history
Feature/add groq
  • Loading branch information
tak-bro authored Jun 10, 2024
2 parents 9e9f8cf + a3b5612 commit 5db6817
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 3 deletions.
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ _aicommit2_ streamlines interactions with various AI, enabling users to request
- [Gemini](https://gemini.google.com/)
- [Mistral AI](https://mistral.ai/)
- [Cohere](https://cohere.com/)
- [Groq](https://groq.com/)
- [Huggingface **(Unofficial)**](https://huggingface.co/chat/)
- [Clova X **(Unofficial)**](https://clova-x.naver.com/)

Expand Down Expand Up @@ -78,6 +79,11 @@ aicommit2 config set MISTRAL_KEY=<your key>
aicommit2 config set COHERE_KEY=<your key>
```

- [Groq](https://console.groq.com)
```sh
aicommit2 config set GROQ_KEY=<your key>
```

- [Huggingface **(Unofficial)**](https://github.com/tak-bro/aicommit2?tab=readme-ov-file#how-to-get-cookieunofficial-api)
```shell
# Please be cautious of Escape characters(\", \') in browser cookie string
Expand Down Expand Up @@ -306,6 +312,8 @@ aicommit2 config set OPENAI_KEY=<your-api-key> generate=3 locale=en
| `MISTRAL_MODEL` | `mistral-tiny` | The Mistral Model to use |
| `COHERE_KEY` | N/A | The Cohere API Key |
| `COHERE_MODEL` | `command` | The identifier of the Cohere model |
| `GROQ_KEY` | N/A | The Groq API Key |
| `GROQ_MODEL` | `gemma-7b-it` | The Groq model name to use |
| `HUGGING_COOKIE` | N/A | The HuggingFace Cookie string |
| `HUGGING_MODEL` | `mistralai/Mixtral-8x7B-Instruct-v0.1` | The HuggingFace Model to use |
| `CLOVAX_COOKIE` | N/A | The Clova X Cookie string |
Expand Down Expand Up @@ -334,6 +342,7 @@ aicommit2 config set OPENAI_KEY=<your-api-key> generate=3 locale=en
| **Gemini** |||| | |||||
| **Mistral AI** |||| ||||||
| **Cohere** |||| | |||||
| **Groq** |||| ||| | ||
| **Huggingface** |||| ||| | ||
| **Clova X** |||| ||| | ||
| **Ollama** |||| | ⚠<br/>(OLLAMA_TIMEOUT) || |||
Expand Down Expand Up @@ -452,7 +461,7 @@ aicommit2 log removeAll
The Ollama Model. Please see [a list of models available](https://ollama.com/library)

```sh
aicommit2 config set OLLAMA_MODEL=llama3
aicommit2 config set OLLAMA_MODEL="llama3"
aicommit2 config set OLLAMA_MODEL="llama3,codellama" # for multiple models
```

Expand All @@ -466,7 +475,6 @@ The Ollama host
aicommit2 config set OLLAMA_HOST=<host>
```


##### OLLAMA_TIMEOUT

Default: `100_000` (100 seconds)
Expand Down Expand Up @@ -601,6 +609,24 @@ Supported:

> The models mentioned above are subject to change.
### Groq

##### GROQ_KEY

The Groq API key. If you don't have one, please sign up and get the API key in [Groq Console](https://console.groq.com).

##### GROQ_MODEL

Default: `gemma-7b-it`

Supported:
- `llama3-8b-8192`
- 'llama3-70b-8192'
- `mixtral-8x7b-32768`
- `gemma-7b-it`

> The models mentioned above are subject to change.
### HuggingFace Chat

##### HUGGING_COOKIE
Expand Down
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
"llama3",
"llm",
"chatgpt",
"cohere"
"cohere",
"groq"
],
"license": "MIT",
"repository": "tak-bro/aicommit2",
Expand Down Expand Up @@ -74,6 +75,7 @@
"copy-paste": "^1.5.3",
"figlet": "^1.7.0",
"formdata-node": "^6.0.3",
"groq-sdk": "^0.4.0",
"inquirer": "^9.0.3",
"inquirer-reactive-list-prompt": "^1.0.5",
"ollama": "^0.5.0",
Expand Down
18 changes: 18 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/managers/ai-request.manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { AnthropicService } from '../services/ai/anthropic.service.js';
import { ClovaXService } from '../services/ai/clova-x.service.js';
import { CohereService } from '../services/ai/cohere.service.js';
import { GeminiService } from '../services/ai/gemini.service.js';
import { GroqService } from '../services/ai/groq.service.js';
import { HuggingService } from '../services/ai/hugging.service.js';
import { MistralService } from '../services/ai/mistral.service.js';
import { OllamaService } from '../services/ai/ollama.service.js';
Expand Down Expand Up @@ -54,6 +55,8 @@ export class AIRequestManager {
);
case AIType.COHERE:
return AIServiceFactory.create(CohereService, params).generateCommitMessage$();
case AIType.GROQ:
return AIServiceFactory.create(GroqService, params).generateCommitMessage$();
default:
const prefixError = chalk.red.bold(`[${ai}]`);
return of({
Expand Down
1 change: 1 addition & 0 deletions src/services/ai/ai.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export const AIType = {
MISTRAL: 'MISTRAL_KEY',
OLLAMA: 'OLLAMA_MODEL',
COHERE: 'COHERE_KEY',
GROQ: 'GROQ_KEY',
} as const;
export type ApiKeyName = (typeof AIType)[keyof typeof AIType];
export const ApiKeyNames: ApiKeyName[] = Object.values(AIType).map(value => value);
Expand Down
88 changes: 88 additions & 0 deletions src/services/ai/groq.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import chalk from 'chalk';
import Groq from 'groq-sdk';
import { GroqError } from 'groq-sdk/error';
import { ReactiveListChoice } from 'inquirer-reactive-list-prompt';
import { Observable, catchError, concatMap, from, map, of } from 'rxjs';
import { fromPromise } from 'rxjs/internal/observable/innerFrom';

import { AIService, AIServiceParams } from './ai.service.js';
import { createLogResponse } from '../../utils/log.js';
import { deduplicateMessages } from '../../utils/openai.js';
import { extraPrompt, generateDefaultPrompt } from '../../utils/prompt.js';

export class GroqService extends AIService {
private groq: Groq;

constructor(private readonly params: AIServiceParams) {
super(params);
this.colors = {
primary: '#f55036',
secondary: '#fff',
};
this.serviceName = chalk.bgHex(this.colors.primary).hex(this.colors.secondary).bold('[Groq]');
this.errorPrefix = chalk.red.bold(`[Groq]`);
this.groq = new Groq({ apiKey: this.params.config.GROQ_KEY });
}

generateCommitMessage$(): Observable<ReactiveListChoice> {
return fromPromise(this.generateMessage()).pipe(
concatMap(messages => from(messages)),
map(message => ({
name: `${this.serviceName} ${message}`,
value: message,
isError: false,
})),
catchError(this.handleError$)
);
}

private async generateMessage(): Promise<string[]> {
try {
const diff = this.params.stagedDiff.diff;
const { locale, generate, type, prompt: userPrompt, logging } = this.params.config;
const maxLength = this.params.config['max-length'];
const defaultPrompt = generateDefaultPrompt(locale, maxLength, type, userPrompt);
const systemPrompt = `${defaultPrompt}\n${extraPrompt(generate)}`;

const chatCompletion = await this.groq.chat.completions.create(
{
messages: [
{ role: 'system', content: systemPrompt },
{
role: 'user',
content: `Here are diff: ${diff}`,
},
],
model: this.params.config.GROQ_MODEL,
},
{
timeout: this.params.config.timeout,
}
);

const result = chatCompletion.choices[0].message.content || '';
logging && createLogResponse('Groq', diff, systemPrompt, result);
return deduplicateMessages(this.sanitizeMessage(result, this.params.config.type, generate));
} catch (error) {
throw error as any;
}
}

handleError$ = (error: GroqError) => {
let simpleMessage = 'An error occurred';
const regex = /"message":\s*"([^"]*)"/;
const match = error.message.match(regex);
if (match && match[1]) {
simpleMessage = match[1];
}
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
const message = `${error['status']} ${simpleMessage}`;
return of({
name: `${this.errorPrefix} ${message}`,
value: simpleMessage,
isError: true,
disabled: true,
});
};
}
14 changes: 14 additions & 0 deletions src/utils/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,20 @@ const configParsers = {
parseAssert('COHERE_MODEL', supportModels.includes(model), 'Invalid model type of Cohere');
return model;
},
GROQ_KEY(key?: string) {
if (!key) {
return '';
}
return key;
},
GROQ_MODEL(model?: string) {
if (!model || model.length === 0) {
return 'gemma-7b-it';
}
const supportModels = [`llama3-8b-8192`, 'llama3-70b-8192', `mixtral-8x7b-32768`, `gemma-7b-it`];
parseAssert('GROQ_MODEL', supportModels.includes(model), 'Invalid model type of Groq');
return model;
},
} as const;

type ConfigKeys = keyof typeof generalConfigParsers | keyof typeof configParsers;
Expand Down

0 comments on commit 5db6817

Please sign in to comment.