Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: process stream events #1184

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 13 additions & 35 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
Expand Down Expand Up @@ -166,38 +165,13 @@ func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiNa
}

func (b *baiduProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
// sample event response:
// data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089502,"sentence_id":0,"is_end":false,"is_truncated":false,"result":"当然可以,","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}

// sample end event response:
// data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089531,"sentence_id":20,"is_end":true,"is_truncated":false,"result":"","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":420,"total_tokens":425}}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
var baiduResponse baiduTextGenStreamResponse
if err := json.Unmarshal([]byte(data), &baiduResponse); err != nil {
log.Errorf("unable to unmarshal baidu response: %v", err)
continue
}
response := b.streamResponseBaidu2OpenAI(ctx, &baiduResponse)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
b.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, b.streamResponseBaidu2OpenAI)
return modifiedResponseChunk, nil
}

func (b *baiduProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down Expand Up @@ -313,15 +287,21 @@ func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response *
}
}

func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenStreamResponse) *chatCompletionResponse {
func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse {
var response baiduTextGenStreamResponse
if err := json.Unmarshal(chunk, &response); err != nil {
log.Errorf("unable to unmarshal baidu response: %v", err)
return nil
}

choice := chatCompletionChoice{
Index: 0,
Message: &chatMessage{Role: roleAssistant, Content: response.Result},
Index: 0,
Delta: &chatMessage{Role: roleAssistant, Content: response.Result},
}
if response.IsEnd {
choice.FinishReason = finishReasonStop
}
return &chatCompletionResponse{
openAIResponse := &chatCompletionResponse{
Id: response.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Expand All @@ -334,8 +314,6 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp
TotalTokens: response.Usage.TotalTokens,
},
}
}

func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
return openAIResponse
}
72 changes: 23 additions & 49 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
Expand Down Expand Up @@ -227,36 +226,11 @@ func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiN
}

func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}

responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
// only process the line starting with "data:"
if strings.HasPrefix(data, "data:") {
// extract json data from the line
jsonData := strings.TrimPrefix(data, "data:")
var claudeResponse claudeTextGenStreamResponse
if err := json.Unmarshal([]byte(jsonData), &claudeResponse); err != nil {
log.Errorf("unable to unmarshal claude response: %v", err)
continue
}
response := c.streamResponseClaude2OpenAI(ctx, &claudeResponse, log)
if response != nil {
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
c.appendResponse(responseBuilder, string(responseBody))
}
}
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
// claude 的流式返回
// event: content_block_delta
// data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"司开发的"} }
modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, c.streamResponseClaude2OpenAI)
return modifiedResponseChunk, nil
}

func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRequest) *claudeTextGenRequest {
Expand Down Expand Up @@ -324,36 +298,40 @@ func stopReasonClaude2OpenAI(reason *string) string {
}
}

func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse, log wrapper.Log) *chatCompletionResponse {
switch origResponse.Type {
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse {
var response claudeTextGenStreamResponse
if err := json.Unmarshal(chunk, &response); err != nil {
log.Errorf("unable to unmarshal claude response: %v", err)
return nil
}

var choice chatCompletionChoice
switch response.Type {
case "message_start":
choice := chatCompletionChoice{
choice = chatCompletionChoice{
Index: 0,
Delta: &chatMessage{Role: roleAssistant, Content: ""},
}
return createChatCompletionResponse(ctx, origResponse, choice)

case "content_block_delta":
choice := chatCompletionChoice{
choice = chatCompletionChoice{
Index: 0,
Delta: &chatMessage{Content: origResponse.Delta.Text},
Delta: &chatMessage{Content: response.Delta.Text},
}
return createChatCompletionResponse(ctx, origResponse, choice)

case "message_delta":
choice := chatCompletionChoice{
choice = chatCompletionChoice{
Index: 0,
Delta: &chatMessage{},
FinishReason: stopReasonClaude2OpenAI(origResponse.Delta.StopReason),
FinishReason: stopReasonClaude2OpenAI(response.Delta.StopReason),
}
return createChatCompletionResponse(ctx, origResponse, choice)
case "content_block_stop", "message_stop":
log.Debugf("skip processing response type: %s", origResponse.Type)
log.Debugf("skip processing response type: %s", response.Type)
return nil
default:
log.Errorf("Unexpected response type: %s", origResponse.Type)
log.Errorf("Unexpected response type: %s", response.Type)
return nil
}

return createChatCompletionResponse(ctx, &response, choice)
}

func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextGenStreamResponse, choice chatCompletionChoice) *chatCompletionResponse {
Expand All @@ -365,7 +343,3 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG
Choices: []chatCompletionChoice{choice},
}
}

func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
38 changes: 9 additions & 29 deletions plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,36 +214,10 @@ func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiN
}

func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
log.Infof("chunk body:%s", string(chunk))
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
// sample end event response:
// data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
var geminiResp geminiChatResponse
if err := json.Unmarshal([]byte(data), &geminiResp); err != nil {
log.Errorf("unable to unmarshal gemini response: %v", err)
continue
}
response := g.buildChatCompletionStreamResponse(ctx, &geminiResp)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
g.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, g.buildChatCompletionStreamResponse)
return modifiedResponseChunk, nil
}

func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down Expand Up @@ -553,7 +527,13 @@ func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCa
return toolCalls
}

func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, geminiResp *geminiChatResponse) *chatCompletionResponse {
func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse {
var geminiResp geminiChatResponse
if err := json.Unmarshal(chunk, &geminiResp); err != nil {
log.Errorf("unable to unmarshal gemini response: %v", err)
return nil
}

var choice chatCompletionChoice
if len(geminiResp.Candidates) > 0 && len(geminiResp.Candidates[0].Content.Parts) > 0 {
choice.Delta = &chatMessage{Content: geminiResp.Candidates[0].Content.Parts[0].Text}
Expand Down
Loading
Loading