-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
123 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
from typing import List | ||
|
||
|
||
class BaseLocalEmbeddings: | ||
""" | ||
Base class for local embeddings | ||
""" | ||
|
||
def __init__(self, path: str) -> None: | ||
self.path = path | ||
|
||
def get_embedding(self, text: str, model: str) -> List[float]: | ||
raise NotImplementedError | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,44 @@ | ||
from BaseLocal import BaseLocalEmbeddings | ||
from typing import List | ||
import torch | ||
from transformers import AutoModel, AutoTokenizer | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
|
||
class BgeEmbedding(BaseLocalEmbeddings): | ||
""" | ||
class for Bge embeddings | ||
""" | ||
#path:str = TIANJI_PATH / "embedding/BAAI/bge-small-zh" | ||
def __init__(self, path:str="BAAI/bge-small-zh") -> None: | ||
|
||
# path:str = TIANJI_PATH / "embedding/BAAI/bge-small-zh" | ||
def __init__(self, path: str = "BAAI/bge-small-zh") -> None: | ||
super().__init__(path) | ||
self._model,self._tokenizer,self._device = self.load_model() | ||
self._model, self._tokenizer, self._device = self.load_model() | ||
|
||
def get_embedding(self, text: str) -> List[float]: | ||
import torch | ||
encoded_input = self._tokenizer(text, padding=True, truncation=False, return_tensors='pt').to(self._device) | ||
encoded_input = self._tokenizer( | ||
text, padding=True, truncation=False, return_tensors="pt" | ||
).to(self._device) | ||
return self._model(**encoded_input)[0][:, 0][0] | ||
|
||
def load_model(self): | ||
import torch | ||
from transformers import AutoModel,AutoTokenizer | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
else: | ||
device = torch.device("cpu") | ||
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device) | ||
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5') | ||
return model,tokenizer,device | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-zh-v1.5") | ||
return model, tokenizer, device | ||
|
||
|
||
if __name__ == "__main__": | ||
Bge = BgeEmbedding() | ||
embedding_result = Bge.get_embedding("你好") | ||
print(f"Result of Bge Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}") | ||
|
||
print( | ||
f"Result of Bge Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,43 @@ | ||
from BaseLocal import BaseLocalEmbeddings | ||
from typing import List | ||
import os | ||
#os.environ["HF_TOKEN"]="" | ||
import torch | ||
from transformers import AutoModel | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
# os.environ["HF_TOKEN"]="" | ||
|
||
|
||
class JinaEmbedding(BaseLocalEmbeddings): | ||
""" | ||
class for Jina embeddings | ||
""" | ||
#path:str = TIANJI_PATH / "embedding/jinaai/jina-embeddings-v2-base-zh" | ||
def __init__(self, path:str='jinaai/jina-embeddings-v2-base-zh') -> None: | ||
|
||
# path:str = TIANJI_PATH / "embedding/jinaai/jina-embeddings-v2-base-zh" | ||
def __init__(self, path: str = "jinaai/jina-embeddings-v2-base-zh") -> None: | ||
super().__init__(path) | ||
self._model = self.load_model() | ||
def get_embedding(self, text: str)-> List[float]: | ||
|
||
def get_embedding(self, text: str) -> List[float]: | ||
return self._model.encode([text])[0] | ||
|
||
def load_model(self): | ||
import torch | ||
from transformers import AutoModel | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
else: | ||
device = torch.device("cpu") | ||
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device) | ||
return model | ||
|
||
|
||
if __name__ == "__main__": | ||
Jina = JinaEmbedding() | ||
embedding_result = Jina.get_embedding("你好") | ||
print(f"Result of Jina Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}") | ||
print( | ||
f"Result of Jina Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,13 @@ | ||
from typing import List | ||
|
||
|
||
class BaseOnlineEmbeddings: | ||
""" | ||
Base class for online embeddings | ||
""" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
def get_embedding(self, text: str, model: str) -> List[float]: | ||
raise NotImplementedError | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,38 @@ | ||
from BaseOnline import BaseOnlineEmbeddings | ||
from typing import List | ||
import os | ||
import erniebot | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
# os.environ["BAIDU_API_KEY"]="" | ||
|
||
|
||
class ErnieEmbedding(BaseOnlineEmbeddings): | ||
""" | ||
class for Ernie embeddings | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
import erniebot | ||
erniebot.api_type = "aistudio" | ||
erniebot.access_token = os.getenv("BAIDU_API_KEY") | ||
self.client = erniebot.Embedding() | ||
def get_embedding(self, text: str, model: str = "ernie-text-embedding") -> List[float]: | ||
response = self.client.create( | ||
model = model, | ||
input = [text]) | ||
|
||
def get_embedding( | ||
self, text: str, model: str = "ernie-text-embedding" | ||
) -> List[float]: | ||
response = self.client.create(model=model, input=[text]) | ||
return response.get_result()[0] | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
Ernie = ErnieEmbedding() | ||
embedding_result = Ernie.get_embedding("你好") | ||
print(f"Result of Ernie Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}") | ||
|
||
print( | ||
f"Result of Ernie Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,41 @@ | ||
from BaseOnline import BaseOnlineEmbeddings | ||
from typing import List | ||
import os | ||
# os.environ["OPENAI_BASE_URL"]="" | ||
from openai import OpenAI | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
# os.environ["OPENAI_API_BASE"]="" | ||
# os.environ["OPENAI_API_KEY"]="" | ||
|
||
|
||
class OpenAIEmbedding(BaseOnlineEmbeddings): | ||
""" | ||
class for OpenAI embeddings | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
from openai import OpenAI | ||
self.client = OpenAI() | ||
self.client.base_url = os.getenv("OPENAI_BASE_URL") | ||
self.client.base_url = os.getenv("OPENAI_API_BASE") | ||
self.client.api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
def get_embedding(self, text: str, model: str = "text-embedding-3-small") -> List[float]: | ||
|
||
def get_embedding( | ||
self, text: str, model: str = "text-embedding-3-small" | ||
) -> List[float]: | ||
text = text.replace("\n", " ") | ||
return self.client.embeddings.create(input=[text], model=model).data[0].embedding | ||
return ( | ||
self.client.embeddings.create(input=[text], model=model).data[0].embedding | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
OpenAI = OpenAIEmbedding() | ||
embedding_result = OpenAI.get_embedding("你好") | ||
print(f"Result of OpenAI Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}") | ||
print( | ||
f"Result of OpenAI Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,41 @@ | ||
from BaseOnline import BaseOnlineEmbeddings | ||
from typing import List | ||
import os | ||
from zhipuai import ZhipuAI | ||
from dotenv import load_dotenv | ||
|
||
# 加载.env文件 | ||
load_dotenv() | ||
|
||
# os.environ["ZHIPUAI_API_KEY"]="" | ||
|
||
|
||
class ZhipuEmbedding(BaseOnlineEmbeddings): | ||
""" | ||
class for Zhipu embeddings | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
from zhipuai import ZhipuAI | ||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY")) | ||
|
||
def get_embedding(self, text: str, model: str = "embedding-2",) -> List[float]: | ||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY")) | ||
|
||
def get_embedding( | ||
self, | ||
text: str, | ||
model: str = "embedding-2", | ||
) -> List[float]: | ||
response = self.client.embeddings.create( | ||
model=model, | ||
input=text, | ||
model=model, | ||
input=text, | ||
) | ||
return response.data[0].embedding | ||
|
||
|
||
if __name__ == "__main__": | ||
Zhipu = ZhipuEmbedding() | ||
embedding_result = Zhipu.get_embedding("你好") | ||
print(f"Result of Zhipu Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}") | ||
print( | ||
f"Result of Zhipu Embedding: \n" | ||
f"\t Type of output: {type(embedding_result)}\n" | ||
f"\t Shape of output: {len(embedding_result)}" | ||
) |