From de1de0d9ecb096efc7a55f6e39a4fecc918d5f33 Mon Sep 17 00:00:00 2001 From: Liquid <116100070+Liquidwe@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:36:25 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20AI=20Model=20Dialog=20adds?= =?UTF-8?q?=20gemini=20model=20support=20(#35)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client/chatGPT_client.go | 94 ++++++++++++++++++++++++ client/gemini_client.go | 105 ++++++++++++++++++++++++++ commands/chat_manuscript.go | 142 +++++++++++++----------------------- 3 files changed, 251 insertions(+), 90 deletions(-) create mode 100644 client/chatGPT_client.go create mode 100644 client/gemini_client.go diff --git a/client/chatGPT_client.go b/client/chatGPT_client.go new file mode 100644 index 0000000..7e83e6f --- /dev/null +++ b/client/chatGPT_client.go @@ -0,0 +1,94 @@ +package client + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" +) + +const ( + openaiAPIURL = "https://api.openai.com/v1/chat/completions" +) + +type ChatGPTClient struct { + APIKey string + Model string +} + +type ChatGPTRequest struct { + Model string `json:"model"` + Messages []ChatGPTPrompt `json:"messages"` +} + +type ChatGPTPrompt struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatGPTResponse struct { + Choices []struct { + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` +} + +func (c *ChatGPTClient) Name() string { + return "ChatGPT" +} + +func (c *ChatGPTClient) SendRequest(prompt string) (string, error) { + requestData := ChatGPTRequest{ + Model: c.Model, + Messages: []ChatGPTPrompt{ + { + Role: "system", + Content: "You are a helpful assistant.", + }, + { + Role: "user", + Content: prompt, + }, + }, + } + + requestBody, err := json.Marshal(requestData) + if err != nil { + return "", fmt.Errorf("error encoding request: %v", err) + } + + req, err := http.NewRequest("POST", openaiAPIURL, bytes.NewBuffer(requestBody)) + if err != nil { + return "", fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("error sending request: %v", err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("error reading response: %v", err) + } + + var chatGPTResponse ChatGPTResponse + err = json.Unmarshal(body, &chatGPTResponse) + if err != nil { + return "", fmt.Errorf("error decoding response: %v", err) + } + + if len(chatGPTResponse.Choices) > 0 { + return chatGPTResponse.Choices[0].Message.Content, nil + } + + return "", fmt.Errorf("no valid response from ChatGPT") +} diff --git a/client/gemini_client.go b/client/gemini_client.go new file mode 100644 index 0000000..65fa4e5 --- /dev/null +++ b/client/gemini_client.go @@ -0,0 +1,105 @@ +package client + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" +) + +const geminiUrl = "https://generativelanguage.googleapis.com/v1beta/models" + +type GeminiClient struct { + APIKey string + Model string +} + +type GeminiRequest struct { + Contents []GeminiContent `json:"contents"` +} + +type GeminiContent struct { + Parts []GeminiPart `json:"parts"` +} + +type GeminiPart struct { + Text string `json:"text"` +} + +type GeminiResponse struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + Role string `json:"role"` + } `json:"content"` + FinishReason string `json:"finishReason"` + SafetyRatings []struct { + Category string `json:"category"` + Probability string `json:"probability"` + } `json:"safetyRatings"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` + ModelVersion string `json:"modelVersion"` +} + +func (g *GeminiClient) Name() string { + return "Gemini" +} + +func (g *GeminiClient) SendRequest(prompt string) (string, error) { + requestData := GeminiRequest{ + Contents: []GeminiContent{ + { + Parts: []GeminiPart{ + { + Text: prompt, + }, + }, + }, + }, + } + + requestBody, err := json.Marshal(requestData) + if err != nil { + return "", fmt.Errorf("error encoding request: %v", err) + } + + apiURL := fmt.Sprintf("%s/%s:generateContent?key=%s", geminiUrl, g.Model, g.APIKey) + req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return "", fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("error sending request: %v", err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("error reading response: %v", err) + } + + var geminiResponse GeminiResponse + err = json.Unmarshal(body, &geminiResponse) + if err != nil { + return "", fmt.Errorf("error decoding response: %v", err) + } + + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + return geminiResponse.Candidates[0].Content.Parts[0].Text, nil + } + + return "", fmt.Errorf("no valid response from Gemini") +} diff --git a/commands/chat_manuscript.go b/commands/chat_manuscript.go index 698a920..d7f0378 100644 --- a/commands/chat_manuscript.go +++ b/commands/chat_manuscript.go @@ -2,15 +2,12 @@ package commands import ( "bufio" - "bytes" "context" - "encoding/json" "fmt" "github.com/jackc/pgx/v4/pgxpool" - "io/ioutil" "log" + "manuscript-core/client" "manuscript-core/pkg" - "net/http" "os" "os/signal" "regexp" @@ -19,27 +16,18 @@ import ( ) const ( - openaiAPIURL = "https://api.openai.com/v1/chat/completions" - openaiMode = "gpt-4o-mini" + ChatGPT int = iota + 1 + Gemini ) -type ChatGPTRequest struct { - Model string `json:"model"` - Messages []ChatGPTPrompt `json:"messages"` -} - -type ChatGPTPrompt struct { - Role string `json:"role"` - Content string `json:"content"` -} +const ( + chatGPTModel = "gpt-4o-mini" + geminiModel = "gemini-1.5-flash" +) -type ChatGPTResponse struct { - Choices []struct { - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` +type LLMClient interface { + Name() string + SendRequest(prompt string) (string, error) } type TextToSQLRequest struct { @@ -54,10 +42,12 @@ func Chat(manuscript string) { jobName := fmt.Sprintf("%s-postgres-1", manuscript) dockers, err := pkg.RunDockerPs() if err != nil { - log.Fatalf("Error: Failed to get postgres ps: %v", err) + fmt.Printf("Error: Failed to get postgres ps: %v", err) + return } if len(dockers) == 0 { - log.Fatalf("Error: No manuscript postgres found") + fmt.Println("No manuscript postgres found") + return } for _, d := range dockers { if d.Name == jobName { @@ -67,7 +57,15 @@ func Chat(manuscript string) { } for _, m := range manuscripts.Manuscripts { if m.Name == manuscript { - ChatWithLLM(m) + model := promptInput("Manuscript currently offers the following two types of model integration:\n1. ChatGPT\n2. Gemini\nSelect model to use(default ChatGPT): ", "1") + chat, err := newChatClient(model) + if err != nil { + log.Printf("Failed to create chat client: %v", err) + return + } + + ChatWithLLM(m, chat) + break } } break @@ -75,16 +73,32 @@ func Chat(manuscript string) { } } -func ChatWithLLM(job pkg.Manuscript) { - apiKey := os.Getenv("OPENAI_API_KEY") - if apiKey == "" { - log.Fatalf("Error: Please set the environment variable OPENAI_API_KEY") - } - model := os.Getenv("OPENAI_MODEL") - if model == "" { - model = openaiMode +func newChatClient(model string) (LLMClient, error) { + switch model { + case "1": + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return nil, fmt.Errorf("OPENAI_API_KEY environment variable not set, please set it to your OpenAI API key. You can obtain an API key from https://platform.openai.com") + } + return &client.ChatGPTClient{ + APIKey: apiKey, + Model: chatGPTModel, + }, nil + case "2": + apiKey := os.Getenv("GEMINI_API_KEY") + if apiKey == "" { + return nil, fmt.Errorf("GEMINI_API_KEY environment variable not set, please set it to your Gemini API key. You can obtain an API key from https://ai.google.dev/gemini-api/docs/models/gemini") + } + return &client.GeminiClient{ + APIKey: apiKey, + Model: geminiModel, + }, nil + default: + return nil, fmt.Errorf("error: unknown model %s", model) } +} +func ChatWithLLM(job pkg.Manuscript, client LLMClient) { signalChannel := make(chan os.Signal, 1) signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM) @@ -104,7 +118,7 @@ func ChatWithLLM(job pkg.Manuscript) { go func() { for { - fmt.Print("You: ") + fmt.Print("🏄🏼‍You: ") userInput, _ := reader.ReadString('\n') userInput = strings.TrimSpace(userInput) inputChannel <- userInput @@ -126,19 +140,19 @@ func ChatWithLLM(job pkg.Manuscript) { fmt.Printf("Processing your question...\n") - response, err := sendToChatGPT(prompt, apiKey, model) + response, err := client.SendRequest(prompt) if err != nil { fmt.Printf("Error getting response: %v\n", err) continue } - fmt.Printf("AI: %s\n", response) - sqlQuery, err := extractSQL(response) if err != nil { fmt.Printf("Error extracting SQL: %v\n", err) continue } + fmt.Printf("🔎%s: \u001B[32m%s\u001B[0m\nExecuting SQL......\n", client.Name(), sqlQuery) + executeSQL(pool, sqlQuery) case <-signalChannel: @@ -160,59 +174,6 @@ func extractSQL(response string) (string, error) { return strings.TrimSpace(response), nil } -func sendToChatGPT(prompt string, apiKey, model string) (string, error) { - requestData := ChatGPTRequest{ - Model: model, - Messages: []ChatGPTPrompt{ - { - Role: "system", - Content: "You are a helpful assistant.", - }, - { - Role: "user", - Content: prompt, - }, - }, - } - - requestBody, err := json.Marshal(requestData) - if err != nil { - return "", fmt.Errorf("error encoding request: %v", err) - } - - req, err := http.NewRequest("POST", openaiAPIURL, bytes.NewBuffer(requestBody)) - if err != nil { - return "", fmt.Errorf("error creating request: %v", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("error sending request: %v", err) - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("error reading response: %v", err) - } - - var chatGPTResponse ChatGPTResponse - err = json.Unmarshal(body, &chatGPTResponse) - if err != nil { - return "", fmt.Errorf("error decoding response: %v", err) - } - - if len(chatGPTResponse.Choices) > 0 { - return chatGPTResponse.Choices[0].Message.Content, nil - } - - return "", fmt.Errorf("no valid response from ChatGPT") -} - func connectToDB(ms pkg.Manuscript) (*pgxpool.Pool, error) { dbUrl := fmt.Sprintf("postgres://%s:%s@localhost:%d/%s", ms.DbUser, ms.DbPassword, ms.DbPort, ms.Database) pool, err := pgxpool.Connect(context.Background(), dbUrl) @@ -266,4 +227,5 @@ func executeSQL(pool *pgxpool.Pool, sqlQuery string) { } fmt.Println() } + fmt.Println("Do you have any other questions? Type 'exit' to quit.") }