diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 2da0ec1e1f..18b30912d8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -214,36 +214,10 @@ func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiN } func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - log.Infof("chunk body:%s", string(chunk)) - if isLastChunk || len(chunk) == 0 { - return nil, nil - } // sample end event response: // data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}} - responseBuilder := &strings.Builder{} - lines := strings.Split(string(chunk), "\n") - for _, data := range lines { - if len(data) < 6 { - // ignore blank line or wrong format - continue - } - data = data[6:] - var geminiResp geminiChatResponse - if err := json.Unmarshal([]byte(data), &geminiResp); err != nil { - log.Errorf("unable to unmarshal gemini response: %v", err) - continue - } - response := g.buildChatCompletionStreamResponse(ctx, &geminiResp) - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return nil, err - } - g.appendResponse(responseBuilder, string(responseBody)) - } - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, g.buildChatCompletionStreamResponse) + return modifiedResponseChunk, nil } func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -553,7 +527,13 @@ func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCa return toolCalls } -func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, geminiResp *geminiChatResponse) *chatCompletionResponse { +func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse { + var geminiResp geminiChatResponse + if err := json.Unmarshal(chunk, &geminiResp); err != nil { + log.Errorf("unable to unmarshal gemini response: %v", err) + return nil + } + var choice chatCompletionChoice if len(geminiResp.Candidates) > 0 && len(geminiResp.Candidates[0].Content.Parts) > 0 { choice.Delta = &chatMessage{Content: geminiResp.Candidates[0].Content.Parts[0].Text} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index fc266dfbaa..050cbdb12d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -133,37 +133,8 @@ func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, } func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if isLastChunk || len(chunk) == 0 { - return nil, nil - } - responseBuilder := &strings.Builder{} - lines := strings.Split(string(chunk), "\n") - for _, data := range lines { - if len(data) < 6 { - // ignore blank line or wrong format - continue - } - data = data[6:] - // The final response is `data: [DONE]` - if data == "[DONE]" { - continue - } - var sparkResponse sparkStreamResponse - if err := json.Unmarshal([]byte(data), &sparkResponse); err != nil { - log.Errorf("unable to unmarshal spark response: %v", err) - continue - } - response := p.streamResponseSpark2OpenAI(ctx, &sparkResponse) - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return nil, err - } - p.appendResponse(responseBuilder, string(responseBody)) - } - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, p.streamResponseSpark2OpenAI) + return modifiedResponseChunk, nil } func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkResponse) *chatCompletionResponse { @@ -184,7 +155,13 @@ func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response * } } -func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkStreamResponse) *chatCompletionResponse { +func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse { + var response sparkStreamResponse + if err := json.Unmarshal(chunk, &response); err != nil { + log.Errorf("unable to unmarshal spark response: %v", err) + return nil + } + choices := make([]chatCompletionChoice, len(response.Choices)) for idx, c := range response.Choices { choices[idx] = chatCompletionChoice{