diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index cf96ab5ede..71e8deccaf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "strings" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -166,38 +165,13 @@ func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (b *baiduProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if isLastChunk || len(chunk) == 0 { - return nil, nil - } // sample event response: // data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089502,"sentence_id":0,"is_end":false,"is_truncated":false,"result":"当然可以,","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}} // sample end event response: // data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089531,"sentence_id":20,"is_end":true,"is_truncated":false,"result":"","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":420,"total_tokens":425}} - responseBuilder := &strings.Builder{} - lines := strings.Split(string(chunk), "\n") - for _, data := range lines { - if len(data) < 6 { - // ignore blank line or wrong format - continue - } - data = data[6:] - var baiduResponse baiduTextGenStreamResponse - if err := json.Unmarshal([]byte(data), &baiduResponse); err != nil { - log.Errorf("unable to unmarshal baidu response: %v", err) - continue - } - response := b.streamResponseBaidu2OpenAI(ctx, &baiduResponse) - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return nil, err - } - b.appendResponse(responseBuilder, string(responseBody)) - } - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, b.streamResponseBaidu2OpenAI) + return modifiedResponseChunk, nil } func (b *baiduProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -313,15 +287,21 @@ func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response * } } -func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenStreamResponse) *chatCompletionResponse { +func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse { + var response baiduTextGenStreamResponse + if err := json.Unmarshal(chunk, &response); err != nil { + log.Errorf("unable to unmarshal baidu response: %v", err) + return nil + } + choice := chatCompletionChoice{ - Index: 0, - Message: &chatMessage{Role: roleAssistant, Content: response.Result}, + Index: 0, + Delta: &chatMessage{Role: roleAssistant, Content: response.Result}, } if response.IsEnd { choice.FinishReason = finishReasonStop } - return &chatCompletionResponse{ + openAIResponse := &chatCompletionResponse{ Id: response.Id, Created: time.Now().UnixMilli() / 1000, Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), @@ -334,8 +314,6 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp TotalTokens: response.Usage.TotalTokens, }, } -} -func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { - responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) + return openAIResponse } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 2f233cc95e..f9c97aa0b0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "strings" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -227,36 +226,11 @@ func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiN } func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if isLastChunk || len(chunk) == 0 { - return nil, nil - } - - responseBuilder := &strings.Builder{} - lines := strings.Split(string(chunk), "\n") - for _, data := range lines { - // only process the line starting with "data:" - if strings.HasPrefix(data, "data:") { - // extract json data from the line - jsonData := strings.TrimPrefix(data, "data:") - var claudeResponse claudeTextGenStreamResponse - if err := json.Unmarshal([]byte(jsonData), &claudeResponse); err != nil { - log.Errorf("unable to unmarshal claude response: %v", err) - continue - } - response := c.streamResponseClaude2OpenAI(ctx, &claudeResponse, log) - if response != nil { - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return nil, err - } - c.appendResponse(responseBuilder, string(responseBody)) - } - } - } - modifiedResponseChunk := responseBuilder.String() - log.Debugf("modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + // claude 的流式返回 + // event: content_block_delta + // data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"司开发的"} } + modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, c.streamResponseClaude2OpenAI) + return modifiedResponseChunk, nil } func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRequest) *claudeTextGenRequest { @@ -324,36 +298,40 @@ func stopReasonClaude2OpenAI(reason *string) string { } } -func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse, log wrapper.Log) *chatCompletionResponse { - switch origResponse.Type { +func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse { + var response claudeTextGenStreamResponse + if err := json.Unmarshal(chunk, &response); err != nil { + log.Errorf("unable to unmarshal claude response: %v", err) + return nil + } + + var choice chatCompletionChoice + switch response.Type { case "message_start": - choice := chatCompletionChoice{ + choice = chatCompletionChoice{ Index: 0, Delta: &chatMessage{Role: roleAssistant, Content: ""}, } - return createChatCompletionResponse(ctx, origResponse, choice) - case "content_block_delta": - choice := chatCompletionChoice{ + choice = chatCompletionChoice{ Index: 0, - Delta: &chatMessage{Content: origResponse.Delta.Text}, + Delta: &chatMessage{Content: response.Delta.Text}, } - return createChatCompletionResponse(ctx, origResponse, choice) - case "message_delta": - choice := chatCompletionChoice{ + choice = chatCompletionChoice{ Index: 0, Delta: &chatMessage{}, - FinishReason: stopReasonClaude2OpenAI(origResponse.Delta.StopReason), + FinishReason: stopReasonClaude2OpenAI(response.Delta.StopReason), } - return createChatCompletionResponse(ctx, origResponse, choice) case "content_block_stop", "message_stop": - log.Debugf("skip processing response type: %s", origResponse.Type) + log.Debugf("skip processing response type: %s", response.Type) return nil default: - log.Errorf("Unexpected response type: %s", origResponse.Type) + log.Errorf("Unexpected response type: %s", response.Type) return nil } + + return createChatCompletionResponse(ctx, &response, choice) } func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextGenStreamResponse, choice chatCompletionChoice) *chatCompletionResponse { @@ -365,7 +343,3 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG Choices: []chatCompletionChoice{choice}, } } - -func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { - responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 2da0ec1e1f..18b30912d8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -214,36 +214,10 @@ func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiN } func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - log.Infof("chunk body:%s", string(chunk)) - if isLastChunk || len(chunk) == 0 { - return nil, nil - } // sample end event response: // data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}} - responseBuilder := &strings.Builder{} - lines := strings.Split(string(chunk), "\n") - for _, data := range lines { - if len(data) < 6 { - // ignore blank line or wrong format - continue - } - data = data[6:] - var geminiResp geminiChatResponse - if err := json.Unmarshal([]byte(data), &geminiResp); err != nil { - log.Errorf("unable to unmarshal gemini response: %v", err) - continue - } - response := g.buildChatCompletionStreamResponse(ctx, &geminiResp) - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return nil, err - } - g.appendResponse(responseBuilder, string(responseBody)) - } - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, g.buildChatCompletionStreamResponse) + return modifiedResponseChunk, nil } func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -553,7 +527,13 @@ func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCa return toolCalls } -func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, geminiResp *geminiChatResponse) *chatCompletionResponse { +func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse { + var geminiResp geminiChatResponse + if err := json.Unmarshal(chunk, &geminiResp); err != nil { + log.Errorf("unable to unmarshal gemini response: %v", err) + return nil + } + var choice chatCompletionChoice if len(geminiResp.Candidates) > 0 && len(geminiResp.Candidates[0].Content.Parts) > 0 { choice.Delta = &chatMessage{Content: geminiResp.Candidates[0].Content.Parts[0].Text} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index ae342f31b5..a6f15637e4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -1,7 +1,6 @@ package provider import ( - "bytes" "crypto/hmac" "crypto/sha256" "encoding/hex" @@ -278,15 +277,16 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName } func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if m.config.protocol == protocolOriginal { + ctx.DontReadResponseBody() + return types.ActionContinue, nil + } + _ = proxywasm.RemoveHttpResponseHeader("Content-Length") return types.ActionContinue, nil } func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if m.config.protocol == protocolOriginal { - return chunk, nil - } - // hunyuan的流式返回: //data: {"Note":"以上内容为AI生成,不代表开发者立场,请勿删除或修改本标记","Choices":[{"Delta":{"Role":"assistant","Content":"有助于"},"FinishReason":""}],"Created":1716359713,"Id":"086b6b19-8b2c-4def-a65c-db6a7bc86acd","Usage":{"PromptTokens":7,"CompletionTokens":145,"TotalTokens":152}} @@ -295,107 +295,58 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name // log.Debugf("#debug nash5# [OnStreamingResponseBody] chunk is: %s", string(chunk)) - // 从上下文获取现有缓冲区数据 - newBufferedBody := chunk - if bufferedBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { - newBufferedBody = append(bufferedBody, chunk...) - } - - // 初始化处理下标,以及将要返回的处理过的chunks - var newEventPivot = -1 - var outputBuffer []byte - - // 从buffer区取出若干完整的chunk,将其转为openAI格式后返回 - // 处理可能包含多个事件的缓冲区 - for { - eventStartIndex := bytes.Index(newBufferedBody, []byte(ssePrefix)) - if eventStartIndex == -1 { - break // 没有找到新事件,跳出循环 - } - - // 移除缓冲区前面非事件部分 - newBufferedBody = newBufferedBody[eventStartIndex+len(ssePrefix):] - - // 查找事件结束的位置(即下一个事件的开始) - newEventPivot = bytes.Index(newBufferedBody, []byte("\n\n")) - if newEventPivot == -1 && !isLastChunk { - // 未找到事件结束标识,跳出循环等待更多数据,若是最后一个chunk,不一定有2个换行符 - break - } - - // 提取并处理一个完整的事件 - eventData := newBufferedBody[:newEventPivot] - // log.Debugf("@@@ <<< ori chun is: %s", string(newBufferedBody[:newEventPivot])) - newBufferedBody = newBufferedBody[newEventPivot+2:] // 跳过结束标识 - - // 转换并追加到输出缓冲区 - convertedData, _ := m.convertChunkFromHunyuanToOpenAI(ctx, eventData, log) - // log.Debugf("@@@ >>> converted one chunk: %s", string(convertedData)) - outputBuffer = append(outputBuffer, convertedData...) - } - - // 刷新剩余的不完整事件回到上下文缓冲区以便下次继续处理 - ctx.SetContext(ctxKeyStreamingBody, newBufferedBody) - - log.Debugf("=== modified response chunk: %s", string(outputBuffer)) - return outputBuffer, nil + modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, m.convertChunkFromHunyuanToOpenAI) + return modifiedResponseChunk, nil } -func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContext, hunyuanChunk []byte, log wrapper.Log) ([]byte, error) { +func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse { // 将hunyuan的chunk转为openai的chunk - hunyuanFormattedChunk := &hunyuanTextGenDetailedResponseNonStreaming{} - if err := json.Unmarshal(hunyuanChunk, hunyuanFormattedChunk); err != nil { - return []byte(""), nil + response := &hunyuanTextGenDetailedResponseNonStreaming{} + if err := json.Unmarshal(chunk, response); err != nil { + return nil } - openAIFormattedChunk := &chatCompletionResponse{ - Id: hunyuanFormattedChunk.Id, + openAIResponse := &chatCompletionResponse{ + Id: response.Id, Created: time.Now().UnixMilli() / 1000, Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), SystemFingerprint: "", Object: objectChatCompletionChunk, Usage: usage{ - PromptTokens: hunyuanFormattedChunk.Usage.PromptTokens, - CompletionTokens: hunyuanFormattedChunk.Usage.CompletionTokens, - TotalTokens: hunyuanFormattedChunk.Usage.TotalTokens, + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, }, } // tmpStr3, _ := json.Marshal(hunyuanFormattedChunk) // log.Debugf("@@@ --- 源数据是:: %s", tmpStr3) // 是否为最后一个chunk? - if hunyuanFormattedChunk.Choices[0].FinishReason == hunyuanStreamEndMark { + if response.Choices[0].FinishReason == hunyuanStreamEndMark { // log.Debugf("@@@ --- 最后chunk: ") - openAIFormattedChunk.Choices = append(openAIFormattedChunk.Choices, chatCompletionChoice{ - FinishReason: hunyuanFormattedChunk.Choices[0].FinishReason, + openAIResponse.Choices = append(openAIResponse.Choices, chatCompletionChoice{ + FinishReason: response.Choices[0].FinishReason, }) } else { deltaMsg := chatMessage{ Name: "", - Role: hunyuanFormattedChunk.Choices[0].Delta.Role, - Content: hunyuanFormattedChunk.Choices[0].Delta.Content, + Role: response.Choices[0].Delta.Role, + Content: response.Choices[0].Delta.Content, ToolCalls: []toolCall{}, } // tmpStr2, _ := json.Marshal(deltaMsg) // log.Debugf("@@@ --- 中间chunk: choices.chatMsg 是: %s", tmpStr2) - openAIFormattedChunk.Choices = append( - openAIFormattedChunk.Choices, + openAIResponse.Choices = append( + openAIResponse.Choices, chatCompletionChoice{Delta: &deltaMsg}, ) // tmpStr, _ := json.Marshal(openAIFormattedChunk.Choices) // log.Debugf("@@@ --- 中间chunk: choices 是: %s", tmpStr) } - // 返回的格式 - openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk) - var openAIChunk strings.Builder - openAIChunk.WriteString(ssePrefix) - openAIChunk.WriteString(string(openAIFormattedChunkBytes)) - openAIChunk.WriteString("\n\n") - - return []byte(openAIChunk.String()), nil + return openAIResponse } func (m *hunyuanProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 03a1d85a02..0ac294f3fe 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -4,8 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "strings" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -240,51 +238,19 @@ func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName Api // OnStreamingResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应 func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if isLastChunk || len(chunk) == 0 { - return nil, nil - } // sample event response: // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false} // sample end event response: // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"I am from China.","choices":[{"finish_reason":"stop","messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"I am from China."}]}],"usage":{"total_tokens":187},"input_sensitive":false,"output_sensitive":false,"id":"0106b3bc9fd844a9f3de1aa06004e2ab","base_resp":{"status_code":0,"status_msg":""}} - responseBuilder := &strings.Builder{} - lines := strings.Split(string(chunk), "\n") - for _, data := range lines { - if len(data) < 6 { - // ignore blank line or wrong format - continue - } - data = data[6:] - var minimaxResp minimaxChatCompletionV2Resp - if err := json.Unmarshal([]byte(data), &minimaxResp); err != nil { - log.Errorf("unable to unmarshal minimax response: %v", err) - continue - } - response := m.responseV2ToOpenAI(&minimaxResp) - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return nil, err - } - m.appendResponse(responseBuilder, string(responseBody)) - } - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, m.responseV2ToOpenAI) + return modifiedResponseChunk, nil } // OnResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应 func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - minimaxResp := &minimaxChatCompletionV2Resp{} - if err := json.Unmarshal(body, minimaxResp); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal minimax response: %v", err) - } - if minimaxResp.BaseResp.StatusCode != 0 { - return types.ActionContinue, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg) - } - response := m.responseV2ToOpenAI(minimaxResp) - return types.ActionContinue, replaceJsonResponseBody(response, log) + modifiedResponseChunk := m.responseV2ToOpenAI(ctx, body, log) + return types.ActionContinue, replaceJsonResponseBody(modifiedResponseChunk, log) } // minimaxChatCompletionV2Request 表示ChatCompletion V2请求的结构体 @@ -441,7 +407,13 @@ func (m *minimaxProvider) buildMinimaxChatCompletionV2Request(request *chatCompl return result } -func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Resp) *chatCompletionResponse { +func (m *minimaxProvider) responseV2ToOpenAI(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse { + var response minimaxChatCompletionV2Resp + if err := json.Unmarshal(chunk, &response); err != nil { + log.Errorf("unable to unmarshal minimax response: %v", err) + return nil + } + var choices []chatCompletionChoice messageIndex := 0 for _, choice := range response.Choices { @@ -470,7 +442,3 @@ func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Re }, } } - -func (m *minimaxProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { - responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 1366716599..c2e3dad509 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -1,18 +1,7 @@ package provider -import "strings" - const ( - streamEventIdItemKey = "id:" - streamEventNameItemKey = "event:" - streamBuiltInItemKey = ":" - streamHttpStatusValuePrefix = "HTTP_STATUS/" - streamDataItemKey = "data:" - streamEndDataValue = "[DONE]" - - eventResult = "result" - - httpStatus200 = "200" + streamDataItemKey = "data:" ) type chatCompletionRequest struct { @@ -119,28 +108,6 @@ func (m *functionCall) IsEmpty() bool { return m.Name == "" && m.Arguments == "" } -type streamEvent struct { - Id string `json:"id"` - Event string `json:"event"` - Data string `json:"data"` - HttpStatus string `json:"http_status"` -} - -func (e *streamEvent) setValue(key, value string) { - switch key { - case streamEventIdItemKey: - e.Id = value - case streamEventNameItemKey: - e.Event = value - case streamDataItemKey: - e.Data = value - case streamBuiltInItemKey: - if strings.HasPrefix(value, streamHttpStatusValuePrefix) { - e.HttpStatus = value[len(streamHttpStatusValuePrefix):] - } - } -} - type embeddingsRequest struct { Input interface{} `json:"input"` Model string `json:"model"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 65b3015266..3cf2661fe1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -224,24 +224,6 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api return chunk, nil } - receivedBody := chunk - if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { - receivedBody = append(bufferedStreamingBody, chunk...) - } - - incrementalStreaming := ctx.GetBoolContext(ctxKeyIncrementalStreaming, false) - - eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1 - - defer func() { - if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) { - // Just in case the received chunk is not a complete event. - ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:]) - } else { - ctx.SetContext(ctxKeyStreamingBody, nil) - } - }() - // Sample Qwen event response: // // event:result @@ -252,56 +234,8 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api // :HTTP_STATUS/400 // data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"} // - - var responseBuilder strings.Builder - currentKey := "" - currentEvent := &streamEvent{} - i, length := 0, len(receivedBody) - for i = 0; i < length; i++ { - ch := receivedBody[i] - if ch != '\n' { - if lineStartIndex == -1 { - if eventStartIndex == -1 { - eventStartIndex = i - } - lineStartIndex = i - valueStartIndex = -1 - } - if valueStartIndex == -1 { - if ch == ':' { - valueStartIndex = i + 1 - currentKey = string(receivedBody[lineStartIndex:valueStartIndex]) - } - } else if valueStartIndex == i && ch == ' ' { - // Skip leading spaces in data. - valueStartIndex = i + 1 - } - continue - } - - if lineStartIndex != -1 { - value := string(receivedBody[valueStartIndex:i]) - currentEvent.setValue(currentKey, value) - } else { - // Extra new line. The current event is complete. - log.Debugf("processing event: %v", currentEvent) - if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, incrementalStreaming, log); err != nil { - return nil, err - } - // Reset event parsing state. - eventStartIndex = -1 - currentEvent = &streamEvent{} - } - - // Reset line parsing state. - lineStartIndex = -1 - valueStartIndex = -1 - currentKey = "" - } - - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, m.buildChatCompletionStreamingResponse) + return modifiedResponseChunk, nil } func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -396,7 +330,15 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen } } -func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool, log wrapper.Log) []*chatCompletionResponse { +func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []*chatCompletionResponse { + qwenResponse := &qwenTextGenResponse{} + if err := json.Unmarshal(chunk, qwenResponse); err != nil { + log.Errorf("unable to unmarshal Qwen response: %v", err) + return nil + } + + incrementalStreaming := ctx.GetBoolContext(ctxKeyIncrementalStreaming, false) + baseMessage := chatCompletionResponse{ Id: qwenResponse.RequestId, Created: time.Now().UnixMilli() / 1000, @@ -472,39 +414,6 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont return responses } -func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, incrementalStreaming bool, log wrapper.Log) error { - if event.Data == streamEndDataValue { - m.appendStreamEvent(responseBuilder, event) - return nil - } - - if event.Event != eventResult || event.HttpStatus != httpStatus200 { - // Something goes wrong. Just pass through the event. - m.appendStreamEvent(responseBuilder, event) - return nil - } - - qwenResponse := &qwenTextGenResponse{} - if err := json.Unmarshal([]byte(event.Data), qwenResponse); err != nil { - log.Errorf("unable to unmarshal Qwen response: %v", err) - return fmt.Errorf("unable to unmarshal Qwen response: %v", err) - } - - responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log) - for _, response := range responses { - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return fmt.Errorf("unable to marshal response: %v", err) - } - modifiedEvent := &*event - modifiedEvent.Data = string(responseBody) - m.appendStreamEvent(responseBuilder, modifiedEvent) - } - - return nil -} - func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string, onlyOneSystemBeforeFile bool) int { fileMessage := qwenMessage{ Role: roleSystem, @@ -539,12 +448,6 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content } } -func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) { - responseBuilder.WriteString(streamDataItemKey) - responseBuilder.WriteString(event.Data) - responseBuilder.WriteString("\n\n") -} - func (m *qwenProvider) buildQwenTextEmbeddingRequest(request *embeddingsRequest) (*qwenTextEmbeddingRequest, error) { var texts []string if str, isString := request.Input.(string); isString { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go index 19060849ac..be24e83054 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go @@ -1,8 +1,10 @@ package provider import ( + "bytes" "encoding/json" "fmt" + "strings" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -62,3 +64,98 @@ func replaceJsonResponseBody(response interface{}, log wrapper.Log) error { } return err } + +type chatCompletionResponseConverter interface{} + +// processStreamEvent 从上下文中取出缓冲区,将新 chunk 追加到缓冲区,然后处理缓冲区中的完整事件 +func processStreamEvent( + ctx wrapper.HttpContext, + chunk []byte, isLastChunk bool, + log wrapper.Log, + streamResponseCovertFunc chatCompletionResponseConverter) []byte { + + if isLastChunk || len(chunk) == 0 { + return nil + } + // 从上下文中取出缓冲区,将新 chunk 追加到缓冲区 + newBufferedBody := chunk + if bufferedBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { + newBufferedBody = append(bufferedBody, chunk...) + } + + // 初始化处理下标,以及将要返回的处理过的 chunk + var newEventPivot = -1 + var outputBuffer []byte + + // 从缓冲区取出若干完整的 chunk,将其转为 openAI 格式后返回 + // 处理可能包含多个事件的缓冲区 + for { + eventStartIndex := bytes.Index(newBufferedBody, []byte(streamDataItemKey)) + if eventStartIndex == -1 { + break // 没有找到新事件,跳出循环 + } + + // 移除缓冲区前面非事件部分 + newBufferedBody = newBufferedBody[eventStartIndex+len(streamDataItemKey):] + + // 查找事件结束的位置(即下一个事件的开始) + newEventPivot = bytes.Index(newBufferedBody, []byte("\n\n")) + if newEventPivot == -1 { + // 未找到事件结束标识,跳出循环等待更多数据,若是最后一个 chunk,不一定有 2 个换行符 + break + } + + // 提取并处理一个完整的事件 + eventData := newBufferedBody[:newEventPivot] + newBufferedBody = newBufferedBody[newEventPivot+2:] // 跳过结束标识 + + // 转换并追加到输出缓冲区 + switch fn := streamResponseCovertFunc.(type) { + case func(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse: + if openAIResponse := fn(ctx, eventData, log); openAIResponse != nil { + convertedData, err := appendOpenAIChunk(openAIResponse, log) + if err != nil { + log.Errorf("failed to append openAI chunk: %v", err) + } + outputBuffer = append(outputBuffer, convertedData...) + } + // qwen 的 chunk 中可能包含多个事件 + case func(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []*chatCompletionResponse: + if openAIResponses := fn(ctx, eventData, log); openAIResponses != nil { + for _, response := range openAIResponses { + convertedData, err := appendOpenAIChunk(response, log) + if err != nil { + log.Errorf("failed to append openAI chunk: %v", err) + } + outputBuffer = append(outputBuffer, convertedData...) + } + } + default: + log.Errorf("unsupported streamResponseCovertFunc type") + return nil + } + } + + // 刷新剩余的不完整事件回到上下文缓冲区以便下次继续处理 + ctx.SetContext(ctxKeyStreamingBody, newBufferedBody) + log.Debugf("=== modified response chunk: %s", string(outputBuffer)) + + return outputBuffer +} + +func appendOpenAIChunk(openAIResponse *chatCompletionResponse, log wrapper.Log) ([]byte, error) { + openAIFormattedChunk, err := json.Marshal(openAIResponse) + if err != nil { + log.Errorf("unable to marshal response: %v", err) + return nil, err + } + + var responseBuilder strings.Builder + appendResponse(&responseBuilder, string(openAIFormattedChunk)) + + return []byte(responseBuilder.String()), nil +} + +func appendResponse(responseBuilder *strings.Builder, responseBody string) { + responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index fc266dfbaa..050cbdb12d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -133,37 +133,8 @@ func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, } func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if isLastChunk || len(chunk) == 0 { - return nil, nil - } - responseBuilder := &strings.Builder{} - lines := strings.Split(string(chunk), "\n") - for _, data := range lines { - if len(data) < 6 { - // ignore blank line or wrong format - continue - } - data = data[6:] - // The final response is `data: [DONE]` - if data == "[DONE]" { - continue - } - var sparkResponse sparkStreamResponse - if err := json.Unmarshal([]byte(data), &sparkResponse); err != nil { - log.Errorf("unable to unmarshal spark response: %v", err) - continue - } - response := p.streamResponseSpark2OpenAI(ctx, &sparkResponse) - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return nil, err - } - p.appendResponse(responseBuilder, string(responseBody)) - } - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, p.streamResponseSpark2OpenAI) + return modifiedResponseChunk, nil } func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkResponse) *chatCompletionResponse { @@ -184,7 +155,13 @@ func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response * } } -func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkStreamResponse) *chatCompletionResponse { +func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse { + var response sparkStreamResponse + if err := json.Unmarshal(chunk, &response); err != nil { + log.Errorf("unable to unmarshal spark response: %v", err) + return nil + } + choices := make([]chatCompletionChoice, len(response.Choices)) for idx, c := range response.Choices { choices[idx] = chatCompletionChoice{