forked from gusye1234/nano-graphrag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathusing_llm_api_as_llm+ollama_embedding.py
122 lines (99 loc) · 3.8 KB
/
using_llm_api_as_llm+ollama_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import logging
import ollama
import numpy as np
from openai import AsyncOpenAI
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
# Assumed llm model settings
LLM_BASE_URL = "https://your.api.url"
LLM_API_KEY = "your_api_key"
MODEL = "your_model_name"
# Assumed embedding model settings
EMBEDDING_MODEL = "nomic-embed-text"
EMBEDDING_MODEL_DIM = 768
EMBEDDING_MODEL_MAX_TOKENS = 8192
async def llm_model_if_cache(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
openai_async_client = AsyncOpenAI(
api_key=LLM_API_KEY, base_url=LLM_BASE_URL
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Get the cached response if having-------------------
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(MODEL, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# -----------------------------------------------------
response = await openai_async_client.chat.completions.create(
model=MODEL, messages=messages, **kwargs
)
# Cache the response if having-------------------
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
)
# -----------------------------------------------------
return response.choices[0].message.content
def remove_if_exist(file):
if os.path.exists(file):
os.remove(file)
WORKING_DIR = "./nano_graphrag_cache_llm_TEST"
def query():
rag = GraphRAG(
working_dir=WORKING_DIR,
best_model_func=llm_model_if_cache,
cheap_model_func=llm_model_if_cache,
embedding_func=ollama_embedding,
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
def insert():
from time import time
with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
FAKE_TEXT = f.read()
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
rag = GraphRAG(
working_dir=WORKING_DIR,
enable_llm_cache=True,
best_model_func=llm_model_if_cache,
cheap_model_func=llm_model_if_cache,
embedding_func=ollama_embedding,
)
start = time()
rag.insert(FAKE_TEXT)
print("indexing time:", time() - start)
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
# rag.insert(FAKE_TEXT[half_len:])
# We're using Ollama to generate embeddings for the BGE model
@wrap_embedding_func_with_attrs(
embedding_dim= EMBEDDING_MODEL_DIM,
max_token_size= EMBEDDING_MODEL_MAX_TOKENS,
)
async def ollama_embedding(texts :list[str]) -> np.ndarray:
embed_text = []
for text in texts:
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
embed_text.append(data["embedding"])
return embed_text
if __name__ == "__main__":
insert()
query()