-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from tackhwa/main
Added test case for online and offline embedding model, correspond to 天机-任务看板 No.7
- Loading branch information
Showing
8 changed files
with
238 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import List | ||
|
||
|
||
class BaseLocalEmbeddings: | ||
""" | ||
Base class for local embeddings | ||
""" | ||
|
||
def __init__(self, path: str) -> None: | ||
self.path = path | ||
|
||
def get_embedding(self, text: str, model: str) -> List[float]: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from BaseLocal import BaseLocalEmbeddings | ||
from typing import List | ||
import torch | ||
from transformers import AutoModel, AutoTokenizer | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
|
||
class BgeEmbedding(BaseLocalEmbeddings): | ||
""" | ||
class for Bge embeddings | ||
""" | ||
|
||
# path:str = TIANJI_PATH / "embedding/BAAI/bge-small-zh" | ||
def __init__(self, path: str = "BAAI/bge-small-zh") -> None: | ||
super().__init__(path) | ||
self._model, self._tokenizer, self._device = self.load_model() | ||
|
||
def get_embedding(self, text: str) -> List[float]: | ||
encoded_input = self._tokenizer( | ||
text, padding=True, truncation=False, return_tensors="pt" | ||
).to(self._device) | ||
return self._model(**encoded_input)[0][:, 0][0] | ||
|
||
def load_model(self): | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
else: | ||
device = torch.device("cpu") | ||
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device) | ||
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-zh-v1.5") | ||
return model, tokenizer, device | ||
|
||
|
||
if __name__ == "__main__": | ||
Bge = BgeEmbedding() | ||
embedding_result = Bge.get_embedding("你好") | ||
print( | ||
f"Result of Bge Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from BaseLocal import BaseLocalEmbeddings | ||
from typing import List | ||
import os | ||
import torch | ||
from transformers import AutoModel | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
# os.environ["HF_TOKEN"]="" | ||
|
||
|
||
class JinaEmbedding(BaseLocalEmbeddings): | ||
""" | ||
class for Jina embeddings | ||
""" | ||
|
||
# path:str = TIANJI_PATH / "embedding/jinaai/jina-embeddings-v2-base-zh" | ||
def __init__(self, path: str = "jinaai/jina-embeddings-v2-base-zh") -> None: | ||
super().__init__(path) | ||
self._model = self.load_model() | ||
|
||
def get_embedding(self, text: str) -> List[float]: | ||
return self._model.encode([text])[0] | ||
|
||
def load_model(self): | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
else: | ||
device = torch.device("cpu") | ||
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device) | ||
return model | ||
|
||
|
||
if __name__ == "__main__": | ||
Jina = JinaEmbedding() | ||
embedding_result = Jina.get_embedding("你好") | ||
print( | ||
f"Result of Jina Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import List | ||
|
||
|
||
class BaseOnlineEmbeddings: | ||
""" | ||
Base class for online embeddings | ||
""" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
def get_embedding(self, text: str, model: str) -> List[float]: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from BaseOnline import BaseOnlineEmbeddings | ||
from typing import List | ||
import os | ||
import erniebot | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
# os.environ["BAIDU_API_KEY"]="" | ||
|
||
|
||
class ErnieEmbedding(BaseOnlineEmbeddings): | ||
""" | ||
class for Ernie embeddings | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
erniebot.api_type = "aistudio" | ||
erniebot.access_token = os.getenv("BAIDU_API_KEY") | ||
self.client = erniebot.Embedding() | ||
|
||
def get_embedding( | ||
self, text: str, model: str = "ernie-text-embedding" | ||
) -> List[float]: | ||
response = self.client.create(model=model, input=[text]) | ||
return response.get_result()[0] | ||
|
||
|
||
if __name__ == "__main__": | ||
Ernie = ErnieEmbedding() | ||
embedding_result = Ernie.get_embedding("你好") | ||
print( | ||
f"Result of Ernie Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from BaseOnline import BaseOnlineEmbeddings | ||
from typing import List | ||
import os | ||
from openai import OpenAI | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
# os.environ["OPENAI_API_BASE"]="" | ||
# os.environ["OPENAI_API_KEY"]="" | ||
|
||
|
||
class OpenAIEmbedding(BaseOnlineEmbeddings): | ||
""" | ||
class for OpenAI embeddings | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.client = OpenAI() | ||
self.client.base_url = os.getenv("OPENAI_API_BASE") | ||
self.client.api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
def get_embedding( | ||
self, text: str, model: str = "text-embedding-3-small" | ||
) -> List[float]: | ||
text = text.replace("\n", " ") | ||
return ( | ||
self.client.embeddings.create(input=[text], model=model).data[0].embedding | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
OpenAI = OpenAIEmbedding() | ||
embedding_result = OpenAI.get_embedding("你好") | ||
print( | ||
f"Result of OpenAI Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from BaseOnline import BaseOnlineEmbeddings | ||
from typing import List | ||
import os | ||
from zhipuai import ZhipuAI | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
# os.environ["ZHIPUAI_API_KEY"]="" | ||
|
||
|
||
class ZhipuEmbedding(BaseOnlineEmbeddings): | ||
""" | ||
class for Zhipu embeddings | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY")) | ||
|
||
def get_embedding( | ||
self, | ||
text: str, | ||
model: str = "embedding-2", | ||
) -> List[float]: | ||
response = self.client.embeddings.create( | ||
model=model, | ||
input=text, | ||
) | ||
return response.data[0].embedding | ||
|
||
|
||
if __name__ == "__main__": | ||
Zhipu = ZhipuEmbedding() | ||
embedding_result = Zhipu.get_embedding("你好") | ||
print( | ||
f"Result of Zhipu Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |