Skip to content

Commit

Permalink
solve import and .env issue
Browse files Browse the repository at this point in the history
  • Loading branch information
tackhwa committed Mar 3, 2024
1 parent 686ef4a commit 706f513
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 61 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,17 @@ pip install .

为确保项目正常运行,**请在项目内新建`.env`文件,并在其中设置你的API密钥**,你可以根据下列例子写入对应的 key,即可成功运行调用,目前默认使用 zhipuai,你可以仅写入`ZHIPUAI_API_KEY`即可使用。

如果在从Hugging Face下载模型时遇到速度极慢或无法下载的问题,请在.env文件中设置`HF_ENDPOINT`的值为'https://hf-mirror.com'。请注意,某些Hugging Face仓库可能需要访问权限(例如Jina Ai)。为此,请注册一个Hugging Face账号,并在.env文件中添加`HF_TOKEN`。你可以在[这里](https://huggingface.co/settings/tokens)找到并获取你的token。

```
OPENAI_API_KEY=
OPENAI_API_BASE=
ZHIPUAI_API_KEY=
BAIDU_API_KEY=
OPENAI_API_MODEL=
HF_HOME='./cache/'
HF_ENDPOINT = 'https://hf-mirror.com'
HF_TOKEN=
```

## 文件目录说明
Expand Down
5 changes: 3 additions & 2 deletions test/embedding/local/BaseLocal.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import List


class BaseLocalEmbeddings:
"""
Base class for local embeddings
"""

def __init__(self, path: str) -> None:
self.path = path

def get_embedding(self, text: str, model: str) -> List[float]:
raise NotImplementedError

41 changes: 25 additions & 16 deletions test/embedding/local/Bge.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,44 @@
from BaseLocal import BaseLocalEmbeddings
from typing import List
import torch
from transformers import AutoModel, AutoTokenizer
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()


class BgeEmbedding(BaseLocalEmbeddings):
"""
class for Bge embeddings
"""
#path:str = TIANJI_PATH / "embedding/BAAI/bge-small-zh"
def __init__(self, path:str="BAAI/bge-small-zh") -> None:

# path:str = TIANJI_PATH / "embedding/BAAI/bge-small-zh"
def __init__(self, path: str = "BAAI/bge-small-zh") -> None:
super().__init__(path)
self._model,self._tokenizer,self._device = self.load_model()
self._model, self._tokenizer, self._device = self.load_model()

def get_embedding(self, text: str) -> List[float]:
import torch
encoded_input = self._tokenizer(text, padding=True, truncation=False, return_tensors='pt').to(self._device)
encoded_input = self._tokenizer(
text, padding=True, truncation=False, return_tensors="pt"
).to(self._device)
return self._model(**encoded_input)[0][:, 0][0]

def load_model(self):
import torch
from transformers import AutoModel,AutoTokenizer
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5')
return model,tokenizer,device

tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-zh-v1.5")
return model, tokenizer, device


if __name__ == "__main__":
Bge = BgeEmbedding()
embedding_result = Bge.get_embedding("你好")
print(f"Result of Bge Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}")

print(
f"Result of Bge Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)
32 changes: 21 additions & 11 deletions test/embedding/local/Jina.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
from BaseLocal import BaseLocalEmbeddings
from typing import List
import os
#os.environ["HF_TOKEN"]=""
import torch
from transformers import AutoModel
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()

# os.environ["HF_TOKEN"]=""


class JinaEmbedding(BaseLocalEmbeddings):
"""
class for Jina embeddings
"""
#path:str = TIANJI_PATH / "embedding/jinaai/jina-embeddings-v2-base-zh"
def __init__(self, path:str='jinaai/jina-embeddings-v2-base-zh') -> None:

# path:str = TIANJI_PATH / "embedding/jinaai/jina-embeddings-v2-base-zh"
def __init__(self, path: str = "jinaai/jina-embeddings-v2-base-zh") -> None:
super().__init__(path)
self._model = self.load_model()
def get_embedding(self, text: str)-> List[float]:

def get_embedding(self, text: str) -> List[float]:
return self._model.encode([text])[0]

def load_model(self):
import torch
from transformers import AutoModel
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
return model


if __name__ == "__main__":
Jina = JinaEmbedding()
embedding_result = Jina.get_embedding("你好")
print(f"Result of Jina Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}")
print(
f"Result of Jina Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)
6 changes: 3 additions & 3 deletions test/embedding/online/BaseOnline.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import List


class BaseOnlineEmbeddings:
"""
Base class for online embeddings
"""

def __init__(self) -> None:
pass

def get_embedding(self, text: str, model: str) -> List[float]:
raise NotImplementedError


31 changes: 20 additions & 11 deletions test/embedding/online/Ernie.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,38 @@
from BaseOnline import BaseOnlineEmbeddings
from typing import List
import os
import erniebot
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()

# os.environ["BAIDU_API_KEY"]=""


class ErnieEmbedding(BaseOnlineEmbeddings):
"""
class for Ernie embeddings
"""

def __init__(self) -> None:
super().__init__()
import erniebot
erniebot.api_type = "aistudio"
erniebot.access_token = os.getenv("BAIDU_API_KEY")
self.client = erniebot.Embedding()
def get_embedding(self, text: str, model: str = "ernie-text-embedding") -> List[float]:
response = self.client.create(
model = model,
input = [text])

def get_embedding(
self, text: str, model: str = "ernie-text-embedding"
) -> List[float]:
response = self.client.create(model=model, input=[text])
return response.get_result()[0]



if __name__ == "__main__":
Ernie = ErnieEmbedding()
embedding_result = Ernie.get_embedding("你好")
print(f"Result of Ernie Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}")

print(
f"Result of Ernie Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)
32 changes: 23 additions & 9 deletions test/embedding/online/OpenAI.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
from BaseOnline import BaseOnlineEmbeddings
from typing import List
import os
# os.environ["OPENAI_BASE_URL"]=""
from openai import OpenAI
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()

# os.environ["OPENAI_API_BASE"]=""
# os.environ["OPENAI_API_KEY"]=""


class OpenAIEmbedding(BaseOnlineEmbeddings):
"""
class for OpenAI embeddings
"""

def __init__(self) -> None:
super().__init__()
from openai import OpenAI
self.client = OpenAI()
self.client.base_url = os.getenv("OPENAI_BASE_URL")
self.client.base_url = os.getenv("OPENAI_API_BASE")
self.client.api_key = os.getenv("OPENAI_API_KEY")

def get_embedding(self, text: str, model: str = "text-embedding-3-small") -> List[float]:

def get_embedding(
self, text: str, model: str = "text-embedding-3-small"
) -> List[float]:
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
return (
self.client.embeddings.create(input=[text], model=model).data[0].embedding
)


if __name__ == "__main__":
OpenAI = OpenAIEmbedding()
embedding_result = OpenAI.get_embedding("你好")
print(f"Result of OpenAI Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}")
print(
f"Result of OpenAI Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)
32 changes: 23 additions & 9 deletions test/embedding/online/Zhipu.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
from BaseOnline import BaseOnlineEmbeddings
from typing import List
import os
from zhipuai import ZhipuAI
from dotenv import load_dotenv

# 加载.env文件
load_dotenv()

# os.environ["ZHIPUAI_API_KEY"]=""


class ZhipuEmbedding(BaseOnlineEmbeddings):
"""
class for Zhipu embeddings
"""

def __init__(self) -> None:
super().__init__()
from zhipuai import ZhipuAI
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))

def get_embedding(self, text: str, model: str = "embedding-2",) -> List[float]:
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))

def get_embedding(
self,
text: str,
model: str = "embedding-2",
) -> List[float]:
response = self.client.embeddings.create(
model=model,
input=text,
model=model,
input=text,
)
return response.data[0].embedding


if __name__ == "__main__":
Zhipu = ZhipuEmbedding()
embedding_result = Zhipu.get_embedding("你好")
print(f"Result of Zhipu Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}")
print(
f"Result of Zhipu Embedding: \n"
f"\t Type of output: {type(embedding_result)}\n"
f"\t Shape of output: {len(embedding_result)}"
)

0 comments on commit 706f513

Please sign in to comment.