diff --git a/README.md b/README.md index 442dde6..6e8375b 100644 --- a/README.md +++ b/README.md @@ -105,12 +105,17 @@ pip install . 为确保项目正常运行,**请在项目内新建`.env`文件,并在其中设置你的API密钥**,你可以根据下列例子写入对应的 key,即可成功运行调用,目前默认使用 zhipuai,你可以仅写入`ZHIPUAI_API_KEY`即可使用。 +如果在从Hugging Face下载模型时遇到速度极慢或无法下载的问题,请在.env文件中设置`HF_ENDPOINT`的值为'https://hf-mirror.com'。请注意,某些Hugging Face仓库可能需要访问权限(例如Jina Ai)。为此,请注册一个Hugging Face账号,并在.env文件中添加`HF_TOKEN`。你可以在[这里](https://huggingface.co/settings/tokens)找到并获取你的token。 + ``` OPENAI_API_KEY= OPENAI_API_BASE= ZHIPUAI_API_KEY= BAIDU_API_KEY= OPENAI_API_MODEL= +HF_HOME='./cache/' +HF_ENDPOINT = 'https://hf-mirror.com' +HF_TOKEN= ``` ## 文件目录说明 diff --git a/test/embedding/local/BaseLocal.py b/test/embedding/local/BaseLocal.py new file mode 100644 index 0000000..23f435b --- /dev/null +++ b/test/embedding/local/BaseLocal.py @@ -0,0 +1,13 @@ +from typing import List + + +class BaseLocalEmbeddings: + """ + Base class for local embeddings + """ + + def __init__(self, path: str) -> None: + self.path = path + + def get_embedding(self, text: str, model: str) -> List[float]: + raise NotImplementedError diff --git a/test/embedding/local/Bge.py b/test/embedding/local/Bge.py new file mode 100644 index 0000000..c4698ea --- /dev/null +++ b/test/embedding/local/Bge.py @@ -0,0 +1,44 @@ +from BaseLocal import BaseLocalEmbeddings +from typing import List +import torch +from transformers import AutoModel, AutoTokenizer +from dotenv import load_dotenv + +# 加载.env文件 +load_dotenv() + + +class BgeEmbedding(BaseLocalEmbeddings): + """ + class for Bge embeddings + """ + + # path:str = TIANJI_PATH / "embedding/BAAI/bge-small-zh" + def __init__(self, path: str = "BAAI/bge-small-zh") -> None: + super().__init__(path) + self._model, self._tokenizer, self._device = self.load_model() + + def get_embedding(self, text: str) -> List[float]: + encoded_input = self._tokenizer( + text, padding=True, truncation=False, return_tensors="pt" + ).to(self._device) + return self._model(**encoded_input)[0][:, 0][0] + + def load_model(self): + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device) + tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-zh-v1.5") + return model, tokenizer, device + + +if __name__ == "__main__": + Bge = BgeEmbedding() + embedding_result = Bge.get_embedding("你好") + print( + f"Result of Bge Embedding: \n" + f"\t Type of output: {type(embedding_result)}\n" + f"\t Shape of output: {len(embedding_result)}" + ) diff --git a/test/embedding/local/Jina.py b/test/embedding/local/Jina.py new file mode 100644 index 0000000..3d4b2e2 --- /dev/null +++ b/test/embedding/local/Jina.py @@ -0,0 +1,43 @@ +from BaseLocal import BaseLocalEmbeddings +from typing import List +import os +import torch +from transformers import AutoModel +from dotenv import load_dotenv + +# 加载.env文件 +load_dotenv() + +# os.environ["HF_TOKEN"]="" + + +class JinaEmbedding(BaseLocalEmbeddings): + """ + class for Jina embeddings + """ + + # path:str = TIANJI_PATH / "embedding/jinaai/jina-embeddings-v2-base-zh" + def __init__(self, path: str = "jinaai/jina-embeddings-v2-base-zh") -> None: + super().__init__(path) + self._model = self.load_model() + + def get_embedding(self, text: str) -> List[float]: + return self._model.encode([text])[0] + + def load_model(self): + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device) + return model + + +if __name__ == "__main__": + Jina = JinaEmbedding() + embedding_result = Jina.get_embedding("你好") + print( + f"Result of Jina Embedding: \n" + f"\t Type of output: {type(embedding_result)}\n" + f"\t Shape of output: {len(embedding_result)}" + ) diff --git a/test/embedding/online/BaseOnline.py b/test/embedding/online/BaseOnline.py new file mode 100644 index 0000000..7a1e8c1 --- /dev/null +++ b/test/embedding/online/BaseOnline.py @@ -0,0 +1,13 @@ +from typing import List + + +class BaseOnlineEmbeddings: + """ + Base class for online embeddings + """ + + def __init__(self) -> None: + pass + + def get_embedding(self, text: str, model: str) -> List[float]: + raise NotImplementedError diff --git a/test/embedding/online/Ernie.py b/test/embedding/online/Ernie.py new file mode 100644 index 0000000..0560329 --- /dev/null +++ b/test/embedding/online/Ernie.py @@ -0,0 +1,38 @@ +from BaseOnline import BaseOnlineEmbeddings +from typing import List +import os +import erniebot +from dotenv import load_dotenv + +# 加载.env文件 +load_dotenv() + +# os.environ["BAIDU_API_KEY"]="" + + +class ErnieEmbedding(BaseOnlineEmbeddings): + """ + class for Ernie embeddings + """ + + def __init__(self) -> None: + super().__init__() + erniebot.api_type = "aistudio" + erniebot.access_token = os.getenv("BAIDU_API_KEY") + self.client = erniebot.Embedding() + + def get_embedding( + self, text: str, model: str = "ernie-text-embedding" + ) -> List[float]: + response = self.client.create(model=model, input=[text]) + return response.get_result()[0] + + +if __name__ == "__main__": + Ernie = ErnieEmbedding() + embedding_result = Ernie.get_embedding("你好") + print( + f"Result of Ernie Embedding: \n" + f"\t Type of output: {type(embedding_result)}\n" + f"\t Shape of output: {len(embedding_result)}" + ) diff --git a/test/embedding/online/OpenAI.py b/test/embedding/online/OpenAI.py new file mode 100644 index 0000000..f433409 --- /dev/null +++ b/test/embedding/online/OpenAI.py @@ -0,0 +1,41 @@ +from BaseOnline import BaseOnlineEmbeddings +from typing import List +import os +from openai import OpenAI +from dotenv import load_dotenv + +# 加载.env文件 +load_dotenv() + +# os.environ["OPENAI_API_BASE"]="" +# os.environ["OPENAI_API_KEY"]="" + + +class OpenAIEmbedding(BaseOnlineEmbeddings): + """ + class for OpenAI embeddings + """ + + def __init__(self) -> None: + super().__init__() + self.client = OpenAI() + self.client.base_url = os.getenv("OPENAI_API_BASE") + self.client.api_key = os.getenv("OPENAI_API_KEY") + + def get_embedding( + self, text: str, model: str = "text-embedding-3-small" + ) -> List[float]: + text = text.replace("\n", " ") + return ( + self.client.embeddings.create(input=[text], model=model).data[0].embedding + ) + + +if __name__ == "__main__": + OpenAI = OpenAIEmbedding() + embedding_result = OpenAI.get_embedding("你好") + print( + f"Result of OpenAI Embedding: \n" + f"\t Type of output: {type(embedding_result)}\n" + f"\t Shape of output: {len(embedding_result)}" + ) diff --git a/test/embedding/online/Zhipu.py b/test/embedding/online/Zhipu.py new file mode 100644 index 0000000..32c82a3 --- /dev/null +++ b/test/embedding/online/Zhipu.py @@ -0,0 +1,41 @@ +from BaseOnline import BaseOnlineEmbeddings +from typing import List +import os +from zhipuai import ZhipuAI +from dotenv import load_dotenv + +# 加载.env文件 +load_dotenv() + +# os.environ["ZHIPUAI_API_KEY"]="" + + +class ZhipuEmbedding(BaseOnlineEmbeddings): + """ + class for Zhipu embeddings + """ + + def __init__(self) -> None: + super().__init__() + self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY")) + + def get_embedding( + self, + text: str, + model: str = "embedding-2", + ) -> List[float]: + response = self.client.embeddings.create( + model=model, + input=text, + ) + return response.data[0].embedding + + +if __name__ == "__main__": + Zhipu = ZhipuEmbedding() + embedding_result = Zhipu.get_embedding("你好") + print( + f"Result of Zhipu Embedding: \n" + f"\t Type of output: {type(embedding_result)}\n" + f"\t Shape of output: {len(embedding_result)}" + )