Skip to content

Commit

Permalink
Add cohere text embedding
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Jan 24, 2025
1 parent 16cbdfb commit afe7892
Show file tree
Hide file tree
Showing 8 changed files with 687 additions and 49 deletions.
147 changes: 147 additions & 0 deletions internal/util/function/cohere_embedding_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/

package function

import (
"fmt"
"os"
"strings"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/cohere"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

type CohereEmbeddingProvider struct {
fieldDim int64

client *cohere.CohereEmbedding
modelName string
truncate string

maxBatch int
timeoutSec int64
}

func createCohereEmbeddingClient(apiKey string, url string) (*cohere.CohereEmbedding, error) {
if apiKey == "" {
apiKey = os.Getenv(cohereAIAKEnvStr)
}

Check warning on line 45 in internal/util/function/cohere_embedding_provider.go

View check run for this annotation

Codecov / codecov/patch

internal/util/function/cohere_embedding_provider.go#L44-L45

Added lines #L44 - L45 were not covered by tests
if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", cohereAIAKEnvStr)
}

Check warning on line 48 in internal/util/function/cohere_embedding_provider.go

View check run for this annotation

Codecov / codecov/patch

internal/util/function/cohere_embedding_provider.go#L47-L48

Added lines #L47 - L48 were not covered by tests

if url == "" {
url = "https://api.cohere.com/v2/embed"
}

c := cohere.NewCohereEmbeddingClient(apiKey, url)
return c, nil
}

func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*CohereEmbeddingProvider, error) {
fieldDim, err := typeutil.GetDim(fieldSchema)
if err != nil {
return nil, err
}

Check warning on line 62 in internal/util/function/cohere_embedding_provider.go

View check run for this annotation

Codecov / codecov/patch

internal/util/function/cohere_embedding_provider.go#L61-L62

Added lines #L61 - L62 were not covered by tests
var apiKey, url, modelName string
truncate := "END"
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case modelNameParamKey:
modelName = param.Value
case apiKeyParamKey:
apiKey = param.Value
case embeddingURLParamKey:
url = param.Value
case truncateParamKey:
if param.Value != "NONE" && param.Value != "START" && param.Value != "END" {
return nil, fmt.Errorf("")
}
truncate = param.Value
default:
}
}

if modelName != embedEnglishV30 && modelName != embedMultilingualV30 && modelName != embedEnglishLightV30 && modelName != embedMultilingualLightV30 && modelName != embedEnglishV20 && modelName != embedEnglishLightV20 && modelName != embedMultilingualV20 {
return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]",
modelName, embedEnglishV30, embedMultilingualV30, embedEnglishLightV30, embedMultilingualLightV30, embedEnglishV20, embedEnglishLightV20, embedMultilingualV20)
}

c, err := createCohereEmbeddingClient(apiKey, url)
if err != nil {
return nil, err
}

Check warning on line 90 in internal/util/function/cohere_embedding_provider.go

View check run for this annotation

Codecov / codecov/patch

internal/util/function/cohere_embedding_provider.go#L89-L90

Added lines #L89 - L90 were not covered by tests

provider := CohereEmbeddingProvider{
client: c,
fieldDim: fieldDim,
modelName: modelName,
truncate: truncate,
maxBatch: 96,
timeoutSec: 30,
}
return &provider, nil
}

func (provider *CohereEmbeddingProvider) MaxBatch() int {
return 5 * provider.maxBatch

Check warning on line 104 in internal/util/function/cohere_embedding_provider.go

View check run for this annotation

Codecov / codecov/patch

internal/util/function/cohere_embedding_provider.go#L103-L104

Added lines #L103 - L104 were not covered by tests
}

func (provider *CohereEmbeddingProvider) FieldDim() int64 {
return provider.fieldDim

Check warning on line 108 in internal/util/function/cohere_embedding_provider.go

View check run for this annotation

Codecov / codecov/patch

internal/util/function/cohere_embedding_provider.go#L107-L108

Added lines #L107 - L108 were not covered by tests
}

// Specifies the type of input passed to the model. Required for embedding models v3 and higher.
func (provider *CohereEmbeddingProvider) getInputType(mode TextEmbeddingMode) string {
if provider.modelName == embedEnglishV20 || provider.modelName == embedEnglishLightV20 || provider.modelName == embedMultilingualV20 {
return ""
}
if mode == InsertMode {
return "search_document" // Used for embeddings stored in a vector database for search use-cases.
}
return "search_query" // Used for embeddings of search queries run against a vector DB to find relevant documents.
}

func (provider *CohereEmbeddingProvider) CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error) {
numRows := len(texts)
inputType := provider.getInputType(mode)
data := make([][]float32, 0, numRows)
for i := 0; i < numRows; i += provider.maxBatch {
end := i + provider.maxBatch
if end > numRows {
end = numRows
}
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], inputType, "float", provider.truncate, provider.timeoutSec)
if err != nil {
return nil, err
}

Check warning on line 134 in internal/util/function/cohere_embedding_provider.go

View check run for this annotation

Codecov / codecov/patch

internal/util/function/cohere_embedding_provider.go#L133-L134

Added lines #L133 - L134 were not covered by tests
if end-i != len(resp.Embeddings.Float) {
return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Embeddings.Float))
}
for _, embedding := range resp.Embeddings.Float {
if len(embedding) != int(provider.fieldDim) {
return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]",
provider.fieldDim, len(embedding))
}

Check warning on line 142 in internal/util/function/cohere_embedding_provider.go

View check run for this annotation

Codecov / codecov/patch

internal/util/function/cohere_embedding_provider.go#L140-L142

Added lines #L140 - L142 were not covered by tests
}
data = append(data, resp.Embeddings.Float...)
}
return data, nil
}
226 changes: 226 additions & 0 deletions internal/util/function/cohere_embedding_provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/

package function

import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/suite"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models/ali"
)

func TestCohereTextEmbeddingProvider(t *testing.T) {
suite.Run(t, new(CohereTextEmbeddingProviderSuite))
}

type CohereTextEmbeddingProviderSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
providers []string
}

func (s *CohereTextEmbeddingProviderSuite) SetupTest() {
s.schema = &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
s.providers = []string{cohereProvider}
}

func createCohereProvider(url string, schema *schemapb.FieldSchema, providerName string) (textEmbeddingProvider, error) {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: embedEnglishLightV20},
{Key: apiKeyParamKey, Value: "mock"},
{Key: embeddingURLParamKey, Value: url},
},
}
switch providerName {
case cohereProvider:
return NewCohereEmbeddingProvider(schema, functionSchema)
default:
return nil, fmt.Errorf("Unknow provider")
}
}

func (s *CohereTextEmbeddingProviderSuite) TestEmbedding() {
ts := CreateCohereEmbeddingServer()

defer ts.Close()
for _, provderName := range s.providers {
provder, err := createCohereProvider(ts.URL, s.schema.Fields[2], provderName)
s.NoError(err)
{
data := []string{"sentence"}
ret, err2 := provder.CallEmbedding(data, InsertMode)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(4, len(ret[0]))
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0])
}
{
data := []string{"sentence 1", "sentence 2", "sentence 3"}
ret, _ := provder.CallEmbedding(data, SearchMode)
s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret)
}
}
}

func (s *CohereTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res ali.EmbeddingResponse
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
TextIndex: 0,
})

res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
Embedding: []float32{1.0, 1.0},
TextIndex: 1,
})
res.Usage = ali.Usage{
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))

defer ts.Close()
for _, providerName := range s.providers {
provder, err := createCohereProvider(ts.URL, s.schema.Fields[2], providerName)
s.NoError(err)

// embedding dim not match
data := []string{"sentence", "sentence"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
}

func (s *CohereTextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var res ali.EmbeddingResponse
res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{
Embedding: []float32{1.0, 1.0, 1.0, 1.0},
TextIndex: 0,
})
res.Usage = ali.Usage{
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))

defer ts.Close()
for _, provderName := range s.providers {
provder, err := createCohereProvider(ts.URL, s.schema.Fields[2], provderName)

s.NoError(err)

// embedding dim not match
data := []string{"sentence", "sentence2"}
_, err2 := provder.CallEmbedding(data, InsertMode)
s.Error(err2)
}
}

func (s *CohereTextEmbeddingProviderSuite) TestNewCohereProvider() {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: embedEnglishLightV20},
{Key: apiKeyParamKey, Value: "mock"},
},
}

provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.NoError(err)
s.Equal(provider.truncate, "END")

functionSchema.Params = append(functionSchema.Params, &commonpb.KeyValuePair{Key: truncateParamKey, Value: "START"})
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.NoError(err)
s.Equal(provider.truncate, "START")

// Invalid truncateParam
functionSchema.Params[2].Value = "Unknow"
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)

// Invalid ModelName
functionSchema.Params[2].Value = "END"
functionSchema.Params[0].Value = "Unknow"
_, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.Error(err)
}

func (s *CohereTextEmbeddingProviderSuite) TestGetInputType() {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: embedEnglishLightV20},
{Key: apiKeyParamKey, Value: "mock"},
},
}

provider, err := NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.NoError(err)
s.Equal(provider.getInputType(InsertMode), "")
s.Equal(provider.getInputType(SearchMode), "")

functionSchema.Params[0].Value = embedEnglishLightV30
provider, err = NewCohereEmbeddingProvider(s.schema.Fields[2], functionSchema)
s.NoError(err)
s.Equal(provider.getInputType(InsertMode), "search_document")
s.Equal(provider.getInputType(SearchMode), "search_query")
}
16 changes: 16 additions & 0 deletions internal/util/function/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,22 @@ const (
voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY"
)

// cohere

const (
embedEnglishV30 string = "embed-english-v3.0"
embedMultilingualV30 string = "embed-multilingual-v3.0"
embedEnglishLightV30 string = "embed-english-light-v3.0"
embedMultilingualLightV30 string = "embed-multilingual-light-v3.0"
embedEnglishV20 string = "embed-english-v2.0"
embedEnglishLightV20 string = "embed-english-light-v2.0"
embedMultilingualV20 string = "embed-multilingual-v2.0"

truncateParamKey string = "truncate"

cohereAIAKEnvStr string = "MILVUSAI_COHERE_API_KEY"
)

func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) {
dim, err := strconv.ParseInt(dimStr, 10, 64)
if err != nil {
Expand Down
Loading

0 comments on commit afe7892

Please sign in to comment.