-
Notifications
You must be signed in to change notification settings - Fork 116
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ feat: AI Model Dialog adds gemini model support (#35)
- Loading branch information
Showing
3 changed files
with
251 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package client | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
) | ||
|
||
const ( | ||
openaiAPIURL = "https://api.openai.com/v1/chat/completions" | ||
) | ||
|
||
type ChatGPTClient struct { | ||
APIKey string | ||
Model string | ||
} | ||
|
||
type ChatGPTRequest struct { | ||
Model string `json:"model"` | ||
Messages []ChatGPTPrompt `json:"messages"` | ||
} | ||
|
||
type ChatGPTPrompt struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
} | ||
|
||
type ChatGPTResponse struct { | ||
Choices []struct { | ||
Message struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
} `json:"message"` | ||
} `json:"choices"` | ||
} | ||
|
||
func (c *ChatGPTClient) Name() string { | ||
return "ChatGPT" | ||
} | ||
|
||
func (c *ChatGPTClient) SendRequest(prompt string) (string, error) { | ||
requestData := ChatGPTRequest{ | ||
Model: c.Model, | ||
Messages: []ChatGPTPrompt{ | ||
{ | ||
Role: "system", | ||
Content: "You are a helpful assistant.", | ||
}, | ||
{ | ||
Role: "user", | ||
Content: prompt, | ||
}, | ||
}, | ||
} | ||
|
||
requestBody, err := json.Marshal(requestData) | ||
if err != nil { | ||
return "", fmt.Errorf("error encoding request: %v", err) | ||
} | ||
|
||
req, err := http.NewRequest("POST", openaiAPIURL, bytes.NewBuffer(requestBody)) | ||
if err != nil { | ||
return "", fmt.Errorf("error creating request: %v", err) | ||
} | ||
|
||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Authorization", "Bearer "+c.APIKey) | ||
|
||
client := &http.Client{} | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
return "", fmt.Errorf("error sending request: %v", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
body, err := ioutil.ReadAll(resp.Body) | ||
if err != nil { | ||
return "", fmt.Errorf("error reading response: %v", err) | ||
} | ||
|
||
var chatGPTResponse ChatGPTResponse | ||
err = json.Unmarshal(body, &chatGPTResponse) | ||
if err != nil { | ||
return "", fmt.Errorf("error decoding response: %v", err) | ||
} | ||
|
||
if len(chatGPTResponse.Choices) > 0 { | ||
return chatGPTResponse.Choices[0].Message.Content, nil | ||
} | ||
|
||
return "", fmt.Errorf("no valid response from ChatGPT") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
package client | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
) | ||
|
||
const geminiUrl = "https://generativelanguage.googleapis.com/v1beta/models" | ||
|
||
type GeminiClient struct { | ||
APIKey string | ||
Model string | ||
} | ||
|
||
type GeminiRequest struct { | ||
Contents []GeminiContent `json:"contents"` | ||
} | ||
|
||
type GeminiContent struct { | ||
Parts []GeminiPart `json:"parts"` | ||
} | ||
|
||
type GeminiPart struct { | ||
Text string `json:"text"` | ||
} | ||
|
||
type GeminiResponse struct { | ||
Candidates []struct { | ||
Content struct { | ||
Parts []struct { | ||
Text string `json:"text"` | ||
} `json:"parts"` | ||
Role string `json:"role"` | ||
} `json:"content"` | ||
FinishReason string `json:"finishReason"` | ||
SafetyRatings []struct { | ||
Category string `json:"category"` | ||
Probability string `json:"probability"` | ||
} `json:"safetyRatings"` | ||
} `json:"candidates"` | ||
UsageMetadata struct { | ||
PromptTokenCount int `json:"promptTokenCount"` | ||
CandidatesTokenCount int `json:"candidatesTokenCount"` | ||
TotalTokenCount int `json:"totalTokenCount"` | ||
} `json:"usageMetadata"` | ||
ModelVersion string `json:"modelVersion"` | ||
} | ||
|
||
func (g *GeminiClient) Name() string { | ||
return "Gemini" | ||
} | ||
|
||
func (g *GeminiClient) SendRequest(prompt string) (string, error) { | ||
requestData := GeminiRequest{ | ||
Contents: []GeminiContent{ | ||
{ | ||
Parts: []GeminiPart{ | ||
{ | ||
Text: prompt, | ||
}, | ||
}, | ||
}, | ||
}, | ||
} | ||
|
||
requestBody, err := json.Marshal(requestData) | ||
if err != nil { | ||
return "", fmt.Errorf("error encoding request: %v", err) | ||
} | ||
|
||
apiURL := fmt.Sprintf("%s/%s:generateContent?key=%s", geminiUrl, g.Model, g.APIKey) | ||
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(requestBody)) | ||
if err != nil { | ||
return "", fmt.Errorf("error creating request: %v", err) | ||
} | ||
|
||
req.Header.Set("Content-Type", "application/json") | ||
|
||
client := &http.Client{} | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
return "", fmt.Errorf("error sending request: %v", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
body, err := ioutil.ReadAll(resp.Body) | ||
if err != nil { | ||
return "", fmt.Errorf("error reading response: %v", err) | ||
} | ||
|
||
var geminiResponse GeminiResponse | ||
err = json.Unmarshal(body, &geminiResponse) | ||
if err != nil { | ||
return "", fmt.Errorf("error decoding response: %v", err) | ||
} | ||
|
||
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { | ||
return geminiResponse.Candidates[0].Content.Parts[0].Text, nil | ||
} | ||
|
||
return "", fmt.Errorf("no valid response from Gemini") | ||
} |
Oops, something went wrong.