From 7c6ef7f41afae85f929c8348abf81e4e1d477821 Mon Sep 17 00:00:00 2001 From: bjwswang Date: Mon, 15 Jan 2024 07:30:41 +0000 Subject: [PATCH] feat: generate evaluation test dataset Signed-off-by: bjwswang --- apiserver/pkg/chat/chat.go | 10 +- apiserver/pkg/chat/chat_type.go | 2 +- apiserver/pkg/common/schema.go | 6 + cmd/arctl/main.go | 18 +- ...dia_v1alpha1_worker_bge-large-zh-v1.5.yaml | 6 +- .../app_run.go => appruntime/app_runtime.go} | 25 +- .../base/context.go | 6 +- pkg/{application => appruntime}/base/input.go | 0 pkg/{application => appruntime}/base/node.go | 0 .../base/output.go | 0 .../chain/common.go | 0 .../chain/llmchain.go | 21 +- .../chain/retrievalqachain.go | 4 +- .../knowledgebase/knowledgebase.go | 2 +- pkg/{application => appruntime}/llm/llm.go | 2 +- .../prompt/prompt.go | 2 +- .../retriever/knowledgebaseretriever.go | 11 +- pkg/arctl/chat.go | 221 ------- pkg/arctl/dataset.go | 559 ------------------ pkg/arctl/datasource.go | 51 +- pkg/arctl/eval.go | 179 ++++++ pkg/evaluation/evaluation.go | 138 ++++- pkg/evaluation/output.go | 46 ++ 23 files changed, 463 insertions(+), 846 deletions(-) rename pkg/{application/app_run.go => appruntime/app_runtime.go} (90%) rename pkg/{application => appruntime}/base/context.go (90%) rename pkg/{application => appruntime}/base/input.go (100%) rename pkg/{application => appruntime}/base/node.go (100%) rename pkg/{application => appruntime}/base/output.go (100%) rename pkg/{application => appruntime}/chain/common.go (100%) rename pkg/{application => appruntime}/chain/llmchain.go (86%) rename pkg/{application => appruntime}/chain/retrievalqachain.go (97%) rename pkg/{application => appruntime}/knowledgebase/knowledgebase.go (94%) rename pkg/{application => appruntime}/llm/llm.go (97%) rename pkg/{application => appruntime}/prompt/prompt.go (97%) rename pkg/{application => appruntime}/retriever/knowledgebaseretriever.go (97%) delete mode 100644 pkg/arctl/chat.go delete mode 100644 pkg/arctl/dataset.go create mode 100644 pkg/arctl/eval.go create mode 100644 pkg/evaluation/output.go diff --git a/apiserver/pkg/chat/chat.go b/apiserver/pkg/chat/chat.go index 8771572e9..68112a649 100644 --- a/apiserver/pkg/chat/chat.go +++ b/apiserver/pkg/chat/chat.go @@ -33,9 +33,9 @@ import ( "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/apiserver/pkg/auth" "github.com/kubeagi/arcadia/apiserver/pkg/client" - "github.com/kubeagi/arcadia/pkg/application" - "github.com/kubeagi/arcadia/pkg/application/base" - "github.com/kubeagi/arcadia/pkg/application/retriever" + "github.com/kubeagi/arcadia/pkg/appruntime" + "github.com/kubeagi/arcadia/pkg/appruntime/base" + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) var ( @@ -99,12 +99,12 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*Chat Answer: "", }) ctx = base.SetAppNamespace(ctx, req.AppNamespace) - appRun, err := application.NewAppOrGetFromCache(ctx, app, c) + appRun, err := appruntime.NewAppOrGetFromCache(ctx, c, app) if err != nil { return nil, err } klog.FromContext(ctx).Info("begin to run application", "appName", req.APPName, "appNamespace", req.AppNamespace) - out, err := appRun.Run(ctx, c, respStream, application.Input{Question: req.Query, NeedStream: req.ResponseMode.IsStreaming(), History: conversation.History}) + out, err := appRun.Run(ctx, c, respStream, appruntime.Input{Question: req.Query, NeedStream: req.ResponseMode.IsStreaming(), History: conversation.History}) if err != nil { return nil, err } diff --git a/apiserver/pkg/chat/chat_type.go b/apiserver/pkg/chat/chat_type.go index a710e6cab..348c1ee73 100644 --- a/apiserver/pkg/chat/chat_type.go +++ b/apiserver/pkg/chat/chat_type.go @@ -21,7 +21,7 @@ import ( "github.com/tmc/langchaingo/memory" - "github.com/kubeagi/arcadia/pkg/application/retriever" + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) type ResponseMode string diff --git a/apiserver/pkg/common/schema.go b/apiserver/pkg/common/schema.go index 77514915c..8355d9b57 100644 --- a/apiserver/pkg/common/schema.go +++ b/apiserver/pkg/common/schema.go @@ -29,6 +29,7 @@ import ( apiprompt "github.com/kubeagi/arcadia/api/app-node/prompt/v1alpha1" apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" "github.com/kubeagi/arcadia/api/base/v1alpha1" + evalv1alpha1 "github.com/kubeagi/arcadia/api/evaluation/v1alpha1" ) var ( @@ -102,6 +103,11 @@ var ( Version: v1alpha1.GroupVersion.Version, Resource: "embedders", }, + "rag": { + Group: evalv1alpha1.GroupVersion.Group, + Version: evalv1alpha1.GroupVersion.Version, + Resource: "rags", + }, } ) diff --git a/cmd/arctl/main.go b/cmd/arctl/main.go index 2cad132af..323f556f6 100644 --- a/cmd/arctl/main.go +++ b/cmd/arctl/main.go @@ -21,9 +21,7 @@ import ( "path/filepath" "github.com/spf13/cobra" - "k8s.io/client-go/dynamic" - "github.com/kubeagi/arcadia/apiserver/pkg/client" arctlPkg "github.com/kubeagi/arcadia/pkg/arctl" ) @@ -32,8 +30,6 @@ var ( home string namespace string - - kubeClient dynamic.Interface ) func NewCLI() *cobra.Command { @@ -41,18 +37,11 @@ func NewCLI() *cobra.Command { Use: "arctl [usage]", Short: "Command line tools for Arcadia", PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if _, err := os.Stat(home); os.IsNotExist(err) { + if _, err = os.Stat(home); os.IsNotExist(err) { if err := os.MkdirAll(home, 0700); err != nil { return err } } - - // initialize a kube client - kubeClient, err = client.GetClient(nil) - if err != nil { - return err - } - return nil }, } @@ -60,9 +49,8 @@ func NewCLI() *cobra.Command { arctl.PersistentFlags().StringVar(&home, "home", filepath.Join(os.Getenv("HOME"), ".arcadia"), "home directory to use") arctl.PersistentFlags().StringVarP(&namespace, "namespace", "n", "default", "namespace to use") - arctl.AddCommand(arctlPkg.NewDatasourceCmd(kubeClient, namespace)) - arctl.AddCommand(arctlPkg.NewDatasetCmd(home)) - arctl.AddCommand(arctlPkg.NewChatCmd(home)) + arctl.AddCommand(arctlPkg.NewDatasourceCmd(&namespace)) + arctl.AddCommand(arctlPkg.NewEvalCmd(&home, &namespace)) return arctl } diff --git a/config/samples/arcadia_v1alpha1_worker_bge-large-zh-v1.5.yaml b/config/samples/arcadia_v1alpha1_worker_bge-large-zh-v1.5.yaml index 8cf01d91e..845da0f13 100644 --- a/config/samples/arcadia_v1alpha1_worker_bge-large-zh-v1.5.yaml +++ b/config/samples/arcadia_v1alpha1_worker_bge-large-zh-v1.5.yaml @@ -1,8 +1,8 @@ -\apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: Worker metadata: name: bge-large-zh - namespace: arcadia + namespace: kubeagi-system spec: displayName: BGE模型服务 description: "这是一个Embedding模型服务,由BGE提供" @@ -10,4 +10,4 @@ spec: replicas: 1 model: kind: "Models" - name: "bge-large-zh-v1.5" + name: "bge-large-zh-v1.5" \ No newline at end of file diff --git a/pkg/application/app_run.go b/pkg/appruntime/app_runtime.go similarity index 90% rename from pkg/application/app_run.go rename to pkg/appruntime/app_runtime.go index 4189aee3a..e5a2e14a2 100644 --- a/pkg/application/app_run.go +++ b/pkg/appruntime/app_runtime.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package application +package appruntime import ( "container/list" @@ -28,12 +28,12 @@ import ( "k8s.io/utils/strings/slices" arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" - "github.com/kubeagi/arcadia/pkg/application/chain" - "github.com/kubeagi/arcadia/pkg/application/knowledgebase" - "github.com/kubeagi/arcadia/pkg/application/llm" - "github.com/kubeagi/arcadia/pkg/application/prompt" - "github.com/kubeagi/arcadia/pkg/application/retriever" + "github.com/kubeagi/arcadia/pkg/appruntime/base" + "github.com/kubeagi/arcadia/pkg/appruntime/chain" + "github.com/kubeagi/arcadia/pkg/appruntime/knowledgebase" + "github.com/kubeagi/arcadia/pkg/appruntime/llm" + "github.com/kubeagi/arcadia/pkg/appruntime/prompt" + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) type Input struct { @@ -62,10 +62,14 @@ type Application struct { // return app.Namespace + "/" + app.Name //} -func NewAppOrGetFromCache(ctx context.Context, app *arcadiav1alpha1.Application, cli dynamic.Interface) (*Application, error) { +func NewAppOrGetFromCache(ctx context.Context, cli dynamic.Interface, app *arcadiav1alpha1.Application) (*Application, error) { if app == nil || app.Name == "" || app.Namespace == "" { return nil, errors.New("app has no name or namespace") } + // make sure namespace value exists in context + if base.GetAppNamespace(ctx) == "" { + ctx = base.SetAppNamespace(ctx, app.Namespace) + } // TODO: disable cache for now. // https://github.com/kubeagi/arcadia/issues/391 // a, ok := cache[cacheKey(app)] @@ -150,6 +154,11 @@ func (a *Application) Init(ctx context.Context, cli dynamic.Interface) (err erro } func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream chan string, input Input) (output Output, err error) { + // make sure ns value set + if base.GetAppNamespace(ctx) == "" { + ctx = base.SetAppNamespace(ctx, a.Namespace) + } + out := map[string]any{ "question": input.Question, "_answer_stream": respStream, diff --git a/pkg/application/base/context.go b/pkg/appruntime/base/context.go similarity index 90% rename from pkg/application/base/context.go rename to pkg/appruntime/base/context.go index f37780714..54a488326 100644 --- a/pkg/application/base/context.go +++ b/pkg/appruntime/base/context.go @@ -25,7 +25,11 @@ const ( ) func GetAppNamespace(ctx context.Context) string { - return ctx.Value(AppNamespaceContextKey).(string) + ns := ctx.Value(AppNamespaceContextKey) + if ns == nil { + return "" + } + return ns.(string) } func SetAppNamespace(ctx context.Context, ns string) context.Context { diff --git a/pkg/application/base/input.go b/pkg/appruntime/base/input.go similarity index 100% rename from pkg/application/base/input.go rename to pkg/appruntime/base/input.go diff --git a/pkg/application/base/node.go b/pkg/appruntime/base/node.go similarity index 100% rename from pkg/application/base/node.go rename to pkg/appruntime/base/node.go diff --git a/pkg/application/base/output.go b/pkg/appruntime/base/output.go similarity index 100% rename from pkg/application/base/output.go rename to pkg/appruntime/base/output.go diff --git a/pkg/application/chain/common.go b/pkg/appruntime/chain/common.go similarity index 100% rename from pkg/application/chain/common.go rename to pkg/appruntime/chain/common.go diff --git a/pkg/application/chain/llmchain.go b/pkg/appruntime/chain/llmchain.go similarity index 86% rename from pkg/application/chain/llmchain.go rename to pkg/appruntime/chain/llmchain.go index 3ece671e6..3bc2d91cc 100644 --- a/pkg/application/chain/llmchain.go +++ b/pkg/appruntime/chain/llmchain.go @@ -32,7 +32,7 @@ import ( "k8s.io/klog/v2" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" ) type LLMChain struct { @@ -64,13 +64,14 @@ func (l *LLMChain) Run(ctx context.Context, cli dynamic.Interface, args map[stri if !ok { return args, errors.New("prompt not prompts.FormatPrompter") } - v3, ok := args["_history"] - if !ok { - return args, errors.New("no history") - } - history, ok := v3.(langchaingoschema.ChatMessageHistory) - if !ok { - return args, errors.New("history not memory.ChatMessageHistory") + // _history is optional + // if set ,only ChatMessageHistory allowed + var history langchaingoschema.ChatMessageHistory + if v3, ok := args["_history"]; ok && v3 != nil { + history, ok = v3.(langchaingoschema.ChatMessageHistory) + if !ok { + return args, errors.New("history not memory.ChatMessageHistory") + } } ns := base.GetAppNamespace(ctx) @@ -87,7 +88,9 @@ func (l *LLMChain) Run(ctx context.Context, cli dynamic.Interface, args map[stri options := getChainOptions(instance.Spec.CommonChainConfig) chain := chains.NewLLMChain(llm, prompt) - chain.Memory = getMemory(llm, instance.Spec.Memory, history) + if history != nil { + chain.Memory = getMemory(llm, instance.Spec.Memory, history) + } l.LLMChain = *chain var out string if needStream, ok := args["_need_stream"].(bool); ok && needStream { diff --git a/pkg/application/chain/retrievalqachain.go b/pkg/appruntime/chain/retrievalqachain.go similarity index 97% rename from pkg/application/chain/retrievalqachain.go rename to pkg/appruntime/chain/retrievalqachain.go index f94dd3ad5..d846f1954 100644 --- a/pkg/application/chain/retrievalqachain.go +++ b/pkg/appruntime/chain/retrievalqachain.go @@ -32,8 +32,8 @@ import ( "k8s.io/klog/v2" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" - appretriever "github.com/kubeagi/arcadia/pkg/application/retriever" + "github.com/kubeagi/arcadia/pkg/appruntime/base" + appretriever "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) type RetrievalQAChain struct { diff --git a/pkg/application/knowledgebase/knowledgebase.go b/pkg/appruntime/knowledgebase/knowledgebase.go similarity index 94% rename from pkg/application/knowledgebase/knowledgebase.go rename to pkg/appruntime/knowledgebase/knowledgebase.go index 8074e337d..38d8290d0 100644 --- a/pkg/application/knowledgebase/knowledgebase.go +++ b/pkg/appruntime/knowledgebase/knowledgebase.go @@ -21,7 +21,7 @@ import ( "k8s.io/client-go/dynamic" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" ) type Knowledgebase struct { diff --git a/pkg/application/llm/llm.go b/pkg/appruntime/llm/llm.go similarity index 97% rename from pkg/application/llm/llm.go rename to pkg/appruntime/llm/llm.go index 3846658b9..c6ce4f158 100644 --- a/pkg/application/llm/llm.go +++ b/pkg/appruntime/llm/llm.go @@ -27,7 +27,7 @@ import ( "k8s.io/client-go/dynamic" "github.com/kubeagi/arcadia/api/base/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/langchainwrap" ) diff --git a/pkg/application/prompt/prompt.go b/pkg/appruntime/prompt/prompt.go similarity index 97% rename from pkg/application/prompt/prompt.go rename to pkg/appruntime/prompt/prompt.go index a5644ed6b..11a9b9c3a 100644 --- a/pkg/application/prompt/prompt.go +++ b/pkg/appruntime/prompt/prompt.go @@ -27,7 +27,7 @@ import ( "k8s.io/client-go/dynamic" "github.com/kubeagi/arcadia/api/app-node/prompt/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" ) type Prompt struct { diff --git a/pkg/application/retriever/knowledgebaseretriever.go b/pkg/appruntime/retriever/knowledgebaseretriever.go similarity index 97% rename from pkg/application/retriever/knowledgebaseretriever.go rename to pkg/appruntime/retriever/knowledgebaseretriever.go index 0bc3b1a43..da244e32d 100644 --- a/pkg/application/retriever/knowledgebaseretriever.go +++ b/pkg/appruntime/retriever/knowledgebaseretriever.go @@ -18,6 +18,7 @@ package retriever import ( "context" + "encoding/json" "fmt" "strconv" "strings" @@ -34,7 +35,7 @@ import ( apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" "github.com/kubeagi/arcadia/api/base/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/langchainwrap" pkgvectorstore "github.com/kubeagi/arcadia/pkg/vectorstore" ) @@ -52,6 +53,14 @@ type Reference struct { LineNumber int `json:"line_number" example:"7"` } +func (reference Reference) String() string { + bytes, err := json.Marshal(&reference) + if err != nil { + return "" + } + return string(bytes) +} + type KnowledgeBaseRetriever struct { langchaingoschema.Retriever base.BaseNode diff --git a/pkg/arctl/chat.go b/pkg/arctl/chat.go deleted file mode 100644 index c274306e6..000000000 --- a/pkg/arctl/chat.go +++ /dev/null @@ -1,221 +0,0 @@ -/* -Copyright 2023 KubeAGI. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package arctl - -import ( - "context" - "errors" - "fmt" - "path/filepath" - - "github.com/spf13/cobra" - "github.com/tmc/langchaingo/embeddings" - "github.com/tmc/langchaingo/llms/openai" - "github.com/tmc/langchaingo/schema" - "github.com/tmc/langchaingo/vectorstores" - "github.com/tmc/langchaingo/vectorstores/chroma" - - zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" - "github.com/kubeagi/arcadia/pkg/llms" - "github.com/kubeagi/arcadia/pkg/llms/zhipuai" -) - -var ( - question string - - // chat with LLM - model string - method string - temperature float32 - topP float32 - - // similarity search - scoreThreshold float32 - numDocs int - - // promptsWithSimilaritySearch = []zhipuai.Prompt{ - // { - // Role: zhipuai.User, - // Content: `Hi there, I am going to ask you a question, which I would like you to answer - // based only on the provided context, and not any other information. - // If there is not enough information in the context to answer the question,'say \"I am not sure\", then try to make a guess.' - // Break your answer up into nicely readable paragraphs.`, - // }, - // { - // Role: zhipuai.Assistant, - // Content: "Sure.Please provide your documents.", - // }, - // } - promptsWithSimilaritySearchCN = []zhipuai.Prompt{ - { - Role: zhipuai.User, - Content: ` - 我将要询问一些问题,希望你仅使用我提供的上下文信息回答。 - 请不要在回答中添加其他信息。 - 若我提供的上下文不足以回答问题, - 请回复"我不确定",再做出适当的猜测。 - 请将回答内容分割为适于阅读的段落。 - `, - }, - { - Role: zhipuai.Assistant, - Content: ` - 好的,我将尝试仅使用你提供的上下文信息回答,并在信息不足时提供一些合理推测。 - `, - }, - } -) - -func NewChatCmd(homePath string) *cobra.Command { - cmd := &cobra.Command{ - Use: "chat [usage]", - Short: "Do LLM chat with similarity search(optional)", - RunE: func(cmd *cobra.Command, args []string) error { - var docs []schema.Document - var err error - - if dataset != "" { - docs, err = SimilaritySearch(context.Background(), homePath) - if err != nil { - return err - } - fmt.Printf("similarDocs: %v \n", docs) - } - - return Chat(context.Background(), docs) - }, - } - - // Similarity search params - cmd.Flags().StringVar(&dataset, "dataset", "", "dataset(namespace/collection) to query from") - cmd.Flags().Float32Var(&scoreThreshold, "score-threshold", 0, "score threshold for similarity search(Higher is better)") - cmd.Flags().IntVar(&numDocs, "num-docs", 5, "number of documents to be returned with SimilarSearch") - - // For LLM chat - cmd.Flags().StringVar(&llmType, "llm-type", string(llms.ZhiPuAI), "llm type to use for embedding & chat(Only zhipuai,openai supported now)") - cmd.Flags().StringVar(&apiKey, "llm-apikey", "", "apiKey to access llm service.Must required when embedding similarity search is enabled") - cmd.Flags().StringVar(&question, "question", "", "question text to be asked") - if err := cmd.MarkFlagRequired("llm-apikey"); err != nil { - panic(err) - } - if err := cmd.MarkFlagRequired("question"); err != nil { - panic(err) - } - - // LLM Chat params - cmd.PersistentFlags().StringVar(&model, "model", string(llms.ZhiPuAILite), "which model to use: chatglm_lite/chatglm_std/chatglm_pro") - cmd.PersistentFlags().StringVar(&method, "method", "sse-invoke", "Invoke method used when access LLM service(invoke/sse-invoke)") - cmd.PersistentFlags().Float32Var(&temperature, "temperature", 0.95, "temperature for chat") - cmd.PersistentFlags().Float32Var(&topP, "top-p", 0.7, "top-p for chat") - - return cmd -} - -func SimilaritySearch(ctx context.Context, homePath string) ([]schema.Document, error) { - var embedder embeddings.Embedder - var err error - - ds, err := loadCachedDataset(filepath.Join(homePath, "dataset", dataset)) - if err != nil { - return nil, err - } - if ds.Name == "" { - return nil, fmt.Errorf("dataset %s does not exist", dataset) - } - - switch ds.LLMType { - case "zhipuai": - embedder, err = zhipuaiembeddings.NewZhiPuAI( - zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(ds.LLMApiKey)), - ) - if err != nil { - return nil, err - } - case "openai": - llm, err := openai.New() - if err != nil { - return nil, err - } - embedder, err = embeddings.NewEmbedder(llm) - if err != nil { - return nil, err - } - default: - return nil, errors.New("unsupported embedding type") - } - - chroma, err := chroma.New( - chroma.WithChromaURL(ds.VectorStore), - chroma.WithEmbedder(embedder), - chroma.WithNameSpace(dataset), - ) - if err != nil { - return nil, err - } - - return chroma.SimilaritySearch(ctx, question, numDocs, vectorstores.WithScoreThreshold(scoreThreshold)) -} - -func Chat(ctx context.Context, similarDocs []schema.Document) error { - // Only for zhipuai - client := zhipuai.NewZhiPuAI(apiKey) - - params := zhipuai.DefaultModelParams() - params.Model = model - params.Method = zhipuai.Method(method) - params.Temperature = temperature - params.TopP = topP - - var prompts []zhipuai.Prompt - if len(similarDocs) != 0 { - var docString string - for _, doc := range similarDocs { - docString += doc.PageContent - } - prompts = append(prompts, promptsWithSimilaritySearchCN...) - prompts = append(prompts, zhipuai.Prompt{ - Role: zhipuai.User, - Content: fmt.Sprintf("我的问题是: %s. 以下是我提供的上下文:%s", question, docString), - }) - } else { - prompts = append(prompts, zhipuai.Prompt{ - Role: zhipuai.User, - Content: question, - }) - } - - fmt.Printf("Prompts: %v \n", prompts) - params.Prompt = prompts - if params.Method == zhipuai.ZhiPuAIInvoke { - resp, err := client.Invoke(params) - if err != nil { - return err - } - if resp.Code != 200 { - return fmt.Errorf("chat failed: %s", resp.String()) - } - fmt.Println(resp.Data.Choices[0].Content) - return nil - } - - err := client.SSEInvoke(params, nil) - if err != nil { - return err - } - - return nil -} diff --git a/pkg/arctl/dataset.go b/pkg/arctl/dataset.go deleted file mode 100644 index 1b27ac7dc..000000000 --- a/pkg/arctl/dataset.go +++ /dev/null @@ -1,559 +0,0 @@ -/* -Copyright 2023 KubeAGI. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package arctl - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "strings" - "time" - - chromago "github.com/amikos-tech/chroma-go" - chromaopenapi "github.com/amikos-tech/chroma-go/swagger" - "github.com/spf13/cobra" - "github.com/tmc/langchaingo/documentloaders" - "github.com/tmc/langchaingo/embeddings" - "github.com/tmc/langchaingo/llms/openai" - "github.com/tmc/langchaingo/schema" - "github.com/tmc/langchaingo/textsplitter" - "github.com/tmc/langchaingo/vectorstores/chroma" - "k8s.io/klog/v2" - - zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" - "github.com/kubeagi/arcadia/pkg/llms" - "github.com/kubeagi/arcadia/pkg/llms/zhipuai" -) - -var ( - dataset string - - llmType string - apiKey string - - // path to documents separated by comma - inputDocuments string - - vectorStore string - documentLanguage string - textSplitter string - chunkSize int - chunkOverlap int - - // force remove - resetVectorStore bool -) - -func NewDatasetCmd(homePath string) *cobra.Command { - cmd := &cobra.Command{ - Use: "dataset [usage]", - Short: "Manage dataset locally", - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - datasetDir := filepath.Join(homePath, "dataset") - if _, err := os.Stat(datasetDir); os.IsNotExist(err) { - if err := os.MkdirAll(datasetDir, 0700); err != nil { - return err - } - } - return nil - }, - } - - cmd.AddCommand(DatasetListCmd(homePath)) - cmd.AddCommand(DatasetCreateCmd(homePath)) - cmd.AddCommand(DatasetShowCmd(homePath)) - cmd.AddCommand(DatasetExecuteCmd(homePath)) - cmd.AddCommand(DatasetDeleteCmd(homePath)) - - return cmd -} - -// DatasetListCmd returns a Cobra command for listing datasets. -func DatasetListCmd(homePath string) *cobra.Command { - cmd := &cobra.Command{ - Use: "list [usage]", - Short: "List dataset", - RunE: func(cmd *cobra.Command, args []string) error { - fmt.Printf("| DATASET | FILES |EMBEDDING MODEL | VECTOR STORE | DOCUMENT LANGUAGE | TEXT SPLITTER | CHUNK SIZE | CHUNK OVERLAP |\n") - err := filepath.Walk(filepath.Join(homePath, "dataset"), func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - // skip directory - if info.IsDir() { - return nil - } - ds, err := loadCachedDataset(path) - if err != nil { - return fmt.Errorf("failed to load cached dataset %s: %v", info.Name(), err) - } - // print item - fmt.Printf("| %s | %d | %s | %s | %s | %s | %d | %d |\n", ds.Name, len(ds.Files), ds.LLMType, ds.VectorStore, ds.DocumentLanguage, ds.TextSplitter, ds.ChunkSize, ds.ChunkOverlap) - return nil - }) - if err != nil { - return err - } - return nil - }, - } - - return cmd -} - -// DatasetCreateCmd returns a new instance of the cobra.Command that is used to create a dataset. -func DatasetCreateCmd(homePath string) *cobra.Command { - cmd := &cobra.Command{ - Use: "create [usage]", - Short: "Create dataset", - RunE: func(cmd *cobra.Command, args []string) error { - klog.Infof("Create dataset: %s \n", dataset) - ds, err := loadCachedDataset(filepath.Join(homePath, "dataset", dataset)) - if err != nil { - return err - } - if ds.Name != "" { - return fmt.Errorf("dataset %s already exists", dataset) - } - // set dataset - ds.Name = dataset - ds.CreateTime = time.Now().String() - ds.LLMApiKey = apiKey - ds.LLMType = llmType - ds.VectorStore = vectorStore - ds.DocumentLanguage = documentLanguage - ds.TextSplitter = textSplitter - ds.ChunkSize = chunkSize - ds.ChunkOverlap = chunkOverlap - - err = ds.execute(context.Background()) - if err != nil { - // only print error but do not exit - klog.Errorf("failed to execute dataset %s: %v", dataset, err) - } - - // cache the dataset to local - cache, err := json.Marshal(ds) - if err != nil { - return fmt.Errorf("failed to marshal dataset %s: %v", dataset, err) - } - err = os.WriteFile(filepath.Join(homePath, "dataset", dataset), cache, 0644) - if err != nil { - return err - } - klog.Infof("Successfully created dataset %s\n", dataset) - - return showDataset(homePath, dataset) - }, - } - cmd.Flags().StringVar(&dataset, "name", "", "dataset(namespace/collection) of the document to load into") - if err := cmd.MarkFlagRequired("name"); err != nil { - panic(err) - } - - cmd.Flags().StringVar(&inputDocuments, "documents", "", "path of the documents/document directories to load(separated by comma and directories supported)") - if err := cmd.MarkFlagRequired("documents"); err != nil { - panic(err) - } - - cmd.Flags().StringVar(&llmType, "llm-type", string(llms.ZhiPuAI), "llm type to use(Only zhipuai,openai supported now)") - cmd.Flags().StringVar(&apiKey, "llm-apikey", "", "apiKey to access embedding service") - if err := cmd.MarkFlagRequired("llm-apikey"); err != nil { - panic(err) - } - - cmd.Flags().StringVar(&vectorStore, "vector-store", "http://127.0.0.1:8000", "vector stores to use(Only chroma supported now)") - cmd.Flags().StringVar(&documentLanguage, "document-language", "text", "language of the document(Only text,html,csv supported now)") - cmd.Flags().StringVar(&textSplitter, "text-splitter", "character", "text splitter to use(Only character,token,markdown supported now)") - cmd.Flags().IntVar(&chunkSize, "chunk-size", 300, "chunk size for embedding") - cmd.Flags().IntVar(&chunkOverlap, "chunk-overlap", 30, "chunk overlap for embedding") - - return cmd -} - -func DatasetShowCmd(homePath string) *cobra.Command { - cmd := &cobra.Command{ - Use: "show [usage]", - Short: "Load more documents to dataset", - RunE: func(cmd *cobra.Command, args []string) error { - klog.Infof("Show dataset: %s \n", dataset) - return showDataset(homePath, dataset) - }, - } - - cmd.Flags().StringVar(&dataset, "name", "", "dataset(namespace/collection) of the document to load into") - if err := cmd.MarkFlagRequired("name"); err != nil { - panic(err) - } - - return cmd -} - -func showDataset(homePath, dataset string) error { - cachedDatasetFile, err := os.OpenFile(filepath.Join(homePath, "dataset", dataset), os.O_RDWR, 0644) - if err != nil { - if os.IsNotExist(err) { - klog.Errorf("dataset %s does not exist", dataset) - return nil - } else { - return fmt.Errorf("failed to open cached dataset file: %v", err) - } - } - defer cachedDatasetFile.Close() - - data, err := io.ReadAll(cachedDatasetFile) - if err != nil { - return fmt.Errorf("failed to read cached dataset file: %v", err) - } - // Create a buffer to store the formatted JSON - var formattedJSON bytes.Buffer - - // Indent and format the JSON - err = json.Indent(&formattedJSON, data, "", " ") - if err != nil { - return fmt.Errorf("failed to format cached dataset file: %v", err) - } - - // print dataset - klog.Infof("\n%s", formattedJSON.String()) - - return nil -} - -func DatasetExecuteCmd(homePath string) *cobra.Command { - cmd := &cobra.Command{ - Use: "execute [usage]", - Short: "Execute dataset to load documents to dataset", - RunE: func(cmd *cobra.Command, args []string) error { - klog.Infof("Execute dataset: %s \n", dataset) - ds, err := loadCachedDataset(filepath.Join(homePath, "dataset", dataset)) - if err != nil { - return err - } - if ds.Name == "" { - return fmt.Errorf("dataset %s does not exist", dataset) - } - - err = ds.execute(context.Background()) - if err != nil { - // only print error but do not exit - klog.Errorf("failed to execute dataset %s: %v", dataset, err) - } - - // cache the dataset to local - klog.Infof("Caching dataset %s", dataset) - cache, err := json.Marshal(ds) - if err != nil { - return fmt.Errorf("failed to marshal dataset %s: %v", dataset, err) - } - err = os.WriteFile(filepath.Join(homePath, "dataset", dataset), cache, 0644) - if err != nil { - return err - } - klog.Infof("Successfully execute dataset: %s \n", dataset) - return nil - }, - } - - cmd.Flags().StringVar(&dataset, "name", "", "dataset(namespace/collection) of the document to load into") - if err := cmd.MarkFlagRequired("name"); err != nil { - panic(err) - } - - cmd.Flags().StringVar(&inputDocuments, "documents", "", "path of the documents/document directories to load(separated by comma and directories supported)") - if err := cmd.MarkFlagRequired("documents"); err != nil { - panic(err) - } - - return cmd -} -func DatasetDeleteCmd(homePath string) *cobra.Command { - cmd := &cobra.Command{ - Use: "delete [usage]", - Short: "Delete dataset", - RunE: func(cmd *cobra.Command, args []string) error { - klog.Infof("Delete dataset: %s \n", dataset) - ds, err := loadCachedDataset(filepath.Join(homePath, "dataset", dataset)) - if err != nil { - return fmt.Errorf("failed to load cached dataset %s: %v", dataset, err) - } - if ds.Name == "" { - klog.Infof("Dataset %s does not exist", dataset) - return nil - } - - // remove dateset from remote vector store - if resetVectorStore { - configuration := chromaopenapi.NewConfiguration() - configuration.Servers = chromaopenapi.ServerConfigurations{ - { - URL: ds.VectorStore, - Description: "chroma server url for this store", - }, - } - client := &chromago.Client{ - ApiClient: chromaopenapi.NewAPIClient(configuration), - } - _, err := client.Reset() - if err != nil { - return err - } - } - // remove local cache - if err := os.Remove(filepath.Join(homePath, "dataset", dataset)); err != nil { - panic(err) - } - klog.Infof("Successfully delete dataset: %s \n", dataset) - return nil - }, - } - - cmd.Flags().StringVar(&dataset, "name", "arcadia", "dataset(namespace/collection) of the document to load into") - if err := cmd.MarkFlagRequired("name"); err != nil { - panic(err) - } - - cmd.Flags().BoolVar(&resetVectorStore, "reset-vector-store", false, "forcely reset dataset from remote vector store") - - return cmd -} - -type Dataset struct { - Name string `json:"name"` - CreateTime string `json:"create_time"` - - // Parameters for embedding service - LLMType string `json:"llm_type"` - LLMApiKey string `json:"llm_api_key"` - - // Parameters for vectorization - VectorStore string `json:"vector_store"` - DocumentLanguage string `json:"document_language"` - TextSplitter string `json:"text_splitter"` - ChunkSize int `json:"chunk_size"` - ChunkOverlap int `json:"chunk_overlap"` - - Files map[string]File `json:"files"` -} - -type File struct { - // basic info - Path string `json:"path"` - Size int64 `json:"size"` - - // embedding status - // Chunks is the number of split chunks - Chunks int `json:"chunks"` - // ChunksLoaded is the number of chunks loaded - ChunksLoaded int `json:"chunks_loaded"` - - TimeCost float64 `json:"time_cost"` -} - -func loadCachedDataset(cachedDatasetFilePath string) (*Dataset, error) { - cachedDatasetFile, err := os.OpenFile(cachedDatasetFilePath, os.O_RDWR, 0644) - if err != nil { - if os.IsNotExist(err) { - cachedDatasetFile, err = os.Create(cachedDatasetFilePath) - if err != nil { - return nil, fmt.Errorf("failed to create cached dataset file: %v", err) - } - } else { - return nil, fmt.Errorf("failed to open/create cached dataset file: %v", err) - } - } - defer cachedDatasetFile.Close() - - content, err := io.ReadAll(cachedDatasetFile) - if err != nil { - return nil, fmt.Errorf("failed to read cached dataset file: %v", err) - } - ds := &Dataset{ - Files: map[string]File{}, - } - if len(content) != 0 { - err = json.Unmarshal(content, ds) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal cached dataset file: %v", err) - } - } - return ds, nil -} - -// LoadDocuments loads documents to vector store. -func (cachedDS *Dataset) execute(ctx context.Context) error { - for _, document := range strings.Split(inputDocuments, ",") { - fileInfo, err := os.Stat(document) - if err != nil { - return err - } - // load documents under a document directory - if fileInfo.IsDir() { - if err = filepath.Walk(document, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - // skip if it is a directory - if info.IsDir() { - return nil - } - // process documents - klog.Infof("Loading document: %s \n", path) - return cachedDS.loadDocument(ctx, path) - }); err != nil { - return err - } - } else { - klog.Infof("Loading document: %s \n", document) - err := cachedDS.loadDocument(ctx, document) - if err != nil { - return err - } - } - } - return nil -} - -// LoadDocument loads a document from a file and splits it into multiple documents. -func (cachedDS *Dataset) loadDocument(ctx context.Context, document string) error { - start := time.Now() - // read document - file, err := os.Open(document) - if err != nil { - return err - } - defer file.Close() - data, err := io.ReadAll(file) - if err != nil { - return fmt.Errorf("failed to read file content: %v", err) - } - - // skip if all chunks has been loaded - hash := sha256.New() - hash.Write(data) - digest := hash.Sum(nil) - cachedFile, ok := cachedDS.Files[hex.EncodeToString(digest)] - // TODO: check if cached.Chunks == cached.ChunksLoaded - if ok && cachedFile.Chunks == cachedFile.ChunksLoaded { - klog.Infof("Document %s has been loaded.Skip loading", document) - return nil - } - - dataReader := bytes.NewReader(data) - var loader documentloaders.Loader - switch documentLanguage { - case "text": - loader = documentloaders.NewText(dataReader) - case "csv": - loader = documentloaders.NewCSV(dataReader) - case "html": - loader = documentloaders.NewHTML(dataReader) - default: - return errors.New("unsupported document language") - } - - // initialize text splitter - var split textsplitter.TextSplitter - switch cachedDS.TextSplitter { - case "token": - split = textsplitter.NewTokenSplitter( - textsplitter.WithChunkSize(chunkSize), - textsplitter.WithChunkOverlap(chunkOverlap), - ) - case "markdown": - split = textsplitter.NewMarkdownTextSplitter( - textsplitter.WithChunkSize(chunkSize), - textsplitter.WithChunkOverlap(chunkOverlap), - ) - default: - split = textsplitter.NewRecursiveCharacter( - textsplitter.WithChunkSize(chunkSize), - textsplitter.WithChunkOverlap(chunkOverlap), - ) - } - - documents, err := loader.LoadAndSplit(ctx, split) - if err != nil { - return err - } - - err = cachedDS.embedDocuments(context.Background(), documents) - if err != nil { - return err - } - - // cache the document - fileInfo, err := os.Stat(document) - if err != nil { - return fmt.Errorf("failed to get file info: %v", err) - } - cacheFile := File{ - Path: document, - Size: fileInfo.Size(), - Chunks: len(documents), - ChunksLoaded: len(documents), - TimeCost: time.Since(start).Seconds(), - } - cachedDS.Files[hex.EncodeToString(digest)] = cacheFile - - klog.Infof("Time cost %.2f seconds for loading document: %s \n", cacheFile.TimeCost, document) - return nil -} - -func (cachedDS *Dataset) embedDocuments(ctx context.Context, documents []schema.Document) error { - var embedder embeddings.Embedder - var err error - - switch llmType { - case "zhipuai": - embedder, err = zhipuaiembeddings.NewZhiPuAI( - zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(cachedDS.LLMApiKey)), - ) - if err != nil { - return err - } - case "openai": - llm, err := openai.New() - if err != nil { - return err - } - embedder, err = embeddings.NewEmbedder(llm) - if err != nil { - return err - } - default: - return errors.New("unsupported embedding type") - } - - chroma, err := chroma.New( - chroma.WithChromaURL(cachedDS.VectorStore), - chroma.WithEmbedder(embedder), - chroma.WithNameSpace(cachedDS.Name), - ) - if err != nil { - return err - } - _, err = chroma.AddDocuments(ctx, documents) - return err -} diff --git a/pkg/arctl/datasource.go b/pkg/arctl/datasource.go index e12f59fa4..009e0ded2 100644 --- a/pkg/arctl/datasource.go +++ b/pkg/arctl/datasource.go @@ -25,10 +25,10 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/client-go/dynamic" "k8s.io/klog/v2" "github.com/kubeagi/arcadia/apiserver/graph/generated" + "github.com/kubeagi/arcadia/apiserver/pkg/client" "github.com/kubeagi/arcadia/apiserver/pkg/datasource" "github.com/kubeagi/arcadia/pkg/arctl/printer" ) @@ -41,21 +41,21 @@ var ( description string ) -func NewDatasourceCmd(kubeClient dynamic.Interface, namespace string) *cobra.Command { +func NewDatasourceCmd(namespace *string) *cobra.Command { cmd := &cobra.Command{ Use: "datasource [usage]", Short: "Manage datasources", } - cmd.AddCommand(DatasourceCreateCmd(kubeClient, namespace)) - cmd.AddCommand(DatasourceGetCmd(kubeClient, namespace)) - cmd.AddCommand(DatasourceDeleteCmd(kubeClient, namespace)) - cmd.AddCommand(DatasourceListCmd(kubeClient, namespace)) + cmd.AddCommand(DatasourceCreateCmd(namespace)) + cmd.AddCommand(DatasourceGetCmd(namespace)) + cmd.AddCommand(DatasourceDeleteCmd(namespace)) + cmd.AddCommand(DatasourceListCmd(namespace)) return cmd } -func DatasourceCreateCmd(kubeClient dynamic.Interface, namespace string) *cobra.Command { +func DatasourceCreateCmd(namespace *string) *cobra.Command { var empytDatasource bool // endpoint flags var endpointURL, endpointAuthUser, endpointAuthPwd string @@ -79,7 +79,7 @@ func DatasourceCreateCmd(kubeClient dynamic.Interface, namespace string) *cobra. input := generated.CreateDatasourceInput{ Name: name, - Namespace: namespace, + Namespace: *namespace, DisplayName: &displayName, Description: &description, Endpointinput: generated.EndpointInput{ @@ -100,7 +100,11 @@ func DatasourceCreateCmd(kubeClient dynamic.Interface, namespace string) *cobra. input.Ossinput = &generated.OssInput{Bucket: ossBucket} } - _, err := datasource.CreateDatasource(cmd.Context(), kubeClient, input) + kubeClient, err := client.GetClient(nil) + if err != nil { + return err + } + _, err = datasource.CreateDatasource(cmd.Context(), kubeClient, input) if err != nil { return err } @@ -132,7 +136,7 @@ func DatasourceCreateCmd(kubeClient dynamic.Interface, namespace string) *cobra. return cmd } -func DatasourceGetCmd(kubeClient dynamic.Interface, namespace string) *cobra.Command { +func DatasourceGetCmd(namespace *string) *cobra.Command { cmd := &cobra.Command{ Use: "get [name]", Short: "Get datasource", @@ -142,7 +146,11 @@ func DatasourceGetCmd(kubeClient dynamic.Interface, namespace string) *cobra.Com } name := os.Args[3] - ds, err := datasource.ReadDatasource(cmd.Context(), kubeClient, name, namespace) + kubeClient, err := client.GetClient(nil) + if err != nil { + return err + } + ds, err := datasource.ReadDatasource(cmd.Context(), kubeClient, name, *namespace) if err != nil { return fmt.Errorf("failed to find datasource: %w", err) } @@ -154,13 +162,17 @@ func DatasourceGetCmd(kubeClient dynamic.Interface, namespace string) *cobra.Com return cmd } -func DatasourceListCmd(kubeClient dynamic.Interface, namespace string) *cobra.Command { +func DatasourceListCmd(namespace *string) *cobra.Command { cmd := &cobra.Command{ Use: "list [usage]", Short: "List datasources", RunE: func(cmd *cobra.Command, args []string) error { + kubeClient, err := client.GetClient(nil) + if err != nil { + return err + } list, err := datasource.ListDatasources(cmd.Context(), kubeClient, generated.ListCommonInput{ - Namespace: namespace, + Namespace: *namespace, }) if err != nil { return err @@ -177,7 +189,7 @@ func DatasourceListCmd(kubeClient dynamic.Interface, namespace string) *cobra.Co return cmd } -func DatasourceDeleteCmd(kubeClient dynamic.Interface, namespace string) *cobra.Command { +func DatasourceDeleteCmd(namespace *string) *cobra.Command { cmd := &cobra.Command{ Use: "delete [name]", Short: "Delete a datasource", @@ -186,7 +198,12 @@ func DatasourceDeleteCmd(kubeClient dynamic.Interface, namespace string) *cobra. return errors.New("missing datasource name") } name := os.Args[3] - ds, err := datasource.ReadDatasource(cmd.Context(), kubeClient, name, namespace) + + kubeClient, err := client.GetClient(nil) + if err != nil { + return err + } + ds, err := datasource.ReadDatasource(cmd.Context(), kubeClient, name, *namespace) if err != nil { return fmt.Errorf("failed to get datasource: %w", err) } @@ -196,7 +213,7 @@ func DatasourceDeleteCmd(kubeClient dynamic.Interface, namespace string) *cobra. Group: corev1.SchemeGroupVersion.Group, Version: corev1.SchemeGroupVersion.Version, Resource: "secrets", - }).Namespace(namespace).Delete(cmd.Context(), ds.Endpoint.AuthSecret.Name, metav1.DeleteOptions{}) + }).Namespace(*namespace).Delete(cmd.Context(), ds.Endpoint.AuthSecret.Name, metav1.DeleteOptions{}) if err != nil { return fmt.Errorf("failed to delete auth secret: %w", err) } @@ -204,7 +221,7 @@ func DatasourceDeleteCmd(kubeClient dynamic.Interface, namespace string) *cobra. } _, err = datasource.DeleteDatasources(cmd.Context(), kubeClient, &generated.DeleteCommonInput{ Name: &name, - Namespace: namespace, + Namespace: *namespace, }) if err != nil { return fmt.Errorf("failed to delete datasource: %w", err) diff --git a/pkg/arctl/eval.go b/pkg/arctl/eval.go new file mode 100644 index 000000000..08b1c7e05 --- /dev/null +++ b/pkg/arctl/eval.go @@ -0,0 +1,179 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package arctl + +import ( + "bytes" + "context" + "encoding/csv" + "fmt" + "io" + "log" + "os" + "path/filepath" + "strings" + + "github.com/spf13/cobra" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/dynamic" + + basev1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" + "github.com/kubeagi/arcadia/apiserver/graph/generated" + "github.com/kubeagi/arcadia/apiserver/pkg/client" + "github.com/kubeagi/arcadia/apiserver/pkg/common" + "github.com/kubeagi/arcadia/pkg/evaluation" +) + +func NewEvalCmd(home *string, namespace *string) *cobra.Command { + var appName string + + cmd := &cobra.Command{ + Use: "eval", + Short: "Manage evaluations", + } + + cmd.PersistentFlags().StringVar(&appName, "application", "", "The application to be evaluated") + if err := cmd.MarkPersistentFlagRequired("application"); err != nil { + panic(err) + } + + cmd.AddCommand(EvalGenTestDataset(home, namespace, &appName)) + + return cmd +} + +func EvalGenTestDataset(home *string, namespace *string, appName *string) *cobra.Command { + var inputDir string + var questionColumn string + var groundTruthsColumn string + var outputMethod string + var outputDir string + + cmd := &cobra.Command{ + Use: "gen_test_dataset", + Short: "Generate a test dataset for evaluation", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + if outputDir == "" { + outputDir = *home + } + + // init kubeclient + kubeClient, err := client.GetClient(nil) + if err != nil { + return err + } + + // read files + app := &basev1alpha1.Application{} + obj, err := common.ResouceGet(ctx, kubeClient, generated.TypedObjectReferenceInput{ + APIGroup: &common.ArcadiaAPIGroup, + Kind: "Application", + Namespace: namespace, + Name: *appName, + }, v1.GetOptions{}) + if err != nil { + return err + } + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), app) + if err != nil { + return err + } + + // read files from input directory + files, err := os.ReadDir(inputDir) + if err != nil { + log.Fatal(err) + } + for _, file := range files { + if file.IsDir() || filepath.Ext(file.Name()) != ".csv" || strings.HasPrefix(file.Name(), "ragas-") { + continue + } + var output evaluation.Output + switch outputMethod { + case "csv": + outputCSVFile, err := os.Create(filepath.Join(inputDir, fmt.Sprintf("ragas-%s", file.Name()))) + if err != nil { + return err + } + defer outputCSVFile.Close() + csvOutput := &evaluation.CSVOutput{ + W: csv.NewWriter(outputCSVFile), + } + defer csvOutput.W.Flush() + output = csvOutput + default: + output = &evaluation.PrintOutput{} + } + // read file from dataset + err = GenDatasetOnSingleFile(ctx, kubeClient, app, + filepath.Join(inputDir, file.Name()), + evaluation.WithQuestionColumn(questionColumn), + evaluation.WithGroundTruthsColumn(groundTruthsColumn), + evaluation.WithOutput(output), + ) + if err != nil { + return err + } + } + + return nil + }, + } + + cmd.Flags().StringVar(&inputDir, "input-dir", "", "The input directory where to load original dataset files") + if err := cmd.MarkFlagRequired("input-dir"); err != nil { + panic(err) + } + + cmd.Flags().StringVar(&questionColumn, "question-column", "q", "The column name which provides questions") + cmd.Flags().StringVar(&groundTruthsColumn, "ground-truths-column", "a", "The column name which provides the answers") + cmd.Flags().StringVar(&outputMethod, "output", "", "The way to output the generated dataset rows.We support two ways: \n - stdout: print row \n - csv: save row to csv file") + + return cmd +} + +func GenDatasetOnSingleFile(ctx context.Context, kubeClient dynamic.Interface, app *basev1alpha1.Application, file string, genOpts ...evaluation.GenOptions) error { + // read file content + f, err := os.Open(file) + if err != nil { + return err + } + data, err := io.ReadAll(f) + if err != nil { + return err + } + + // init evaluation dataset generator + generator, err := evaluation.NewRagasDatasetGenerator(ctx, kubeClient, app, genOpts...) + if err != nil { + return err + } + + // generate test dataset + err = generator.Generate( + ctx, + bytes.NewReader(data), + ) + if err != nil { + return err + } + + return nil +} diff --git a/pkg/evaluation/evaluation.go b/pkg/evaluation/evaluation.go index f209b535f..0e7a9d977 100644 --- a/pkg/evaluation/evaluation.go +++ b/pkg/evaluation/evaluation.go @@ -16,4 +16,140 @@ limitations under the License. package evaluation -// TO BE DEFINED +import ( + "context" + "fmt" + "io" + + "k8s.io/client-go/dynamic" + + "github.com/kubeagi/arcadia/api/base/v1alpha1" + "github.com/kubeagi/arcadia/pkg/appruntime" + pkgdocumentloaders "github.com/kubeagi/arcadia/pkg/documentloaders" +) + +// RagasDataRow which adapts to the Ragas evaluation framework +type RagasDataRow struct { + // Question by QAGeneration or manually input + Question string `json:"question"` + // GroundTruths by QAGeneration or manually input + GroundTruths []string `json:"ground_truths"` + // Contexts by similarity search to knowledgebase + Contexts []string `json:"contexts"` + // Answer by Application + Answer string `json:"answer"` +} + +// RagasDatasetGenerator generates datasets which adapts to the ragas framework +type RagasDatasetGenerator struct { + cli dynamic.Interface + + app appruntime.Application + + options *genOptions +} + +func NewRagasDatasetGenerator(ctx context.Context, cli dynamic.Interface, app *v1alpha1.Application, genOptions ...GenOptions) (*RagasDatasetGenerator, error) { + // set generation options + genOpts := defaultGenOptions() + for _, o := range genOptions { + o(genOpts) + } + + // output header + err := genOpts.output.Output(RagasDataRow{ + Question: "question", + GroundTruths: []string{"ground_truths"}, + Contexts: []string{"contexts"}, + Answer: "answer", + }) + if err != nil { + return nil, err + } + + runapp, err := appruntime.NewAppOrGetFromCache(ctx, cli, app) + if err != nil { + return nil, err + } + return &RagasDatasetGenerator{cli: cli, app: *runapp, options: genOpts}, nil +} + +type genOptions struct { + // questionColumn in csv file which has the question + questionColumn string + // groundTruthsColumn in csv file which has the correct answer + groundTruthsColumn string + + output Output +} + +func defaultGenOptions() *genOptions { + return &genOptions{ + questionColumn: "q", + groundTruthsColumn: "a", + output: &PrintOutput{}, + } +} + +func WithQuestionColumn(questionColumn string) GenOptions { + return func(genOpts *genOptions) { + genOpts.questionColumn = questionColumn + } +} + +func WithGroundTruthsColumn(groundTruthsColumn string) GenOptions { + return func(genOpts *genOptions) { + genOpts.groundTruthsColumn = groundTruthsColumn + } +} + +func WithOutput(output Output) GenOptions { + return func(genOpts *genOptions) { + genOpts.output = output + } +} + +type GenOptions func(*genOptions) + +// Generate a test dataset from a file(csv) +func (eval *RagasDatasetGenerator) Generate(ctx context.Context, csvData io.Reader, genOptions ...GenOptions) error { + // set or update options + for _, o := range genOptions { + o(eval.options) + } + + // load csv to langchain documents + loader := pkgdocumentloaders.NewQACSV(csvData, "", eval.options.questionColumn, eval.options.groundTruthsColumn) + langchainDocuments, err := loader.Load(ctx) + if err != nil { + return err + } + + // convert langchain documents to ragas dataset + for _, doc := range langchainDocuments { + ragasRow := RagasDataRow{ + Question: doc.PageContent, + GroundTruths: []string{doc.Metadata[eval.options.groundTruthsColumn].(string)}, + } + + // chat with application + out, err := eval.app.Run(ctx, eval.cli, nil, appruntime.Input{Question: ragasRow.Question, NeedStream: false, History: nil}) + if err != nil { + return err + } + ragasRow.Answer = out.Answer + + // handle context + contexts := make([]string, len(out.References)) + for refIndex, reference := range out.References { + contexts[refIndex] = reference.String() + } + ragasRow.Contexts = contexts + + if err = eval.options.output.Output(ragasRow); err != nil { + return fmt.Errorf("output: %v", err) + } + } + + return nil +} diff --git a/pkg/evaluation/output.go b/pkg/evaluation/output.go new file mode 100644 index 000000000..94c0e8078 --- /dev/null +++ b/pkg/evaluation/output.go @@ -0,0 +1,46 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evaluation + +import ( + "encoding/csv" + "fmt" + "strings" +) + +type Output interface { + Output(RagasDataRow) error +} + +// PrintOutput +type PrintOutput struct{} + +// Output this row to standard output +func (print *PrintOutput) Output(row RagasDataRow) error { + fmt.Printf("question:%s \nground_truths:%s \n answer:%s \n contexts:%v \n", row.Question, row.GroundTruths, row.Answer, row.Contexts) + return nil +} + +// CSVOutput writes row to csv +type CSVOutput struct { + W *csv.Writer +} + +// Output a row to csv +func (csv *CSVOutput) Output(row RagasDataRow) error { + return csv.W.Write([]string{row.Question, strings.Join(row.GroundTruths, ";"), row.Answer, strings.Join(row.Contexts, ";")}) +}