Skip to content

Commit

Permalink
support spark and gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
cr7258 committed Aug 9, 2024
1 parent e25ce11 commit b03e18b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 61 deletions.
38 changes: 9 additions & 29 deletions plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,36 +214,10 @@ func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiN
}

func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
log.Infof("chunk body:%s", string(chunk))
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
// sample end event response:
// data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
var geminiResp geminiChatResponse
if err := json.Unmarshal([]byte(data), &geminiResp); err != nil {
log.Errorf("unable to unmarshal gemini response: %v", err)
continue
}
response := g.buildChatCompletionStreamResponse(ctx, &geminiResp)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
g.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, g.buildChatCompletionStreamResponse)
return modifiedResponseChunk, nil
}

func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down Expand Up @@ -553,7 +527,13 @@ func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCa
return toolCalls
}

func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, geminiResp *geminiChatResponse) *chatCompletionResponse {
func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse {
var geminiResp geminiChatResponse
if err := json.Unmarshal(chunk, &geminiResp); err != nil {
log.Errorf("unable to unmarshal gemini response: %v", err)
return nil
}

var choice chatCompletionChoice
if len(geminiResp.Candidates) > 0 && len(geminiResp.Candidates[0].Content.Parts) > 0 {
choice.Delta = &chatMessage{Content: geminiResp.Candidates[0].Content.Parts[0].Text}
Expand Down
41 changes: 9 additions & 32 deletions plugins/wasm-go/extensions/ai-proxy/provider/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,37 +133,8 @@ func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName,
}

func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
// The final response is `data: [DONE]`
if data == "[DONE]" {
continue
}
var sparkResponse sparkStreamResponse
if err := json.Unmarshal([]byte(data), &sparkResponse); err != nil {
log.Errorf("unable to unmarshal spark response: %v", err)
continue
}
response := p.streamResponseSpark2OpenAI(ctx, &sparkResponse)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
p.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
modifiedResponseChunk := processStreamEvent(ctx, chunk, isLastChunk, log, p.streamResponseSpark2OpenAI)
return modifiedResponseChunk, nil
}

func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkResponse) *chatCompletionResponse {
Expand All @@ -184,7 +155,13 @@ func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response *
}
}

func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkStreamResponse) *chatCompletionResponse {
func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) *chatCompletionResponse {
var response sparkStreamResponse
if err := json.Unmarshal(chunk, &response); err != nil {
log.Errorf("unable to unmarshal spark response: %v", err)
return nil
}

choices := make([]chatCompletionChoice, len(response.Choices))
for idx, c := range response.Choices {
choices[idx] = chatCompletionChoice{
Expand Down

0 comments on commit b03e18b

Please sign in to comment.