Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/gpt-hey/fine_tuned into d…
Browse files Browse the repository at this point in the history
…ocker
  • Loading branch information
guanw committed Dec 9, 2023
2 parents dd0fe55 + 1132a16 commit d280068
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 50 deletions.
18 changes: 13 additions & 5 deletions client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,31 @@
import websockets

async def send_message():
uri = "ws://127.0.0.1:8081" # Replace with the actual WebSocket server URI
uri = "ws://0.0.0.0:8081" # Replace with the actual WebSocket server URI
# uri = "ws://129.213.151.7:8081"

async with websockets.connect(uri) as websocket:
# message = "@remindme to take notes!"
async with websockets.connect(uri, ping_timeout=None) as websocket:
# Send your actual message
message = "@gguf tell me about yourself"
print(f"Sending message to server: {message}")
await websocket.send(message)

while True:
try:
# Receive and print the response from the server
response = await websocket.recv()
print(f"Received response from server: {response}")
if not response:
break
except websockets.exceptions.ConnectionClosed as e:
print(f"Connection closed by the client. Reason: {e.reason}")
break
except asyncio.TimeoutError:
print("Timeout waiting for response. Closing connection.")
break
except Exception as e:
print(f"An exception occurred: {e}")
break

# Run the WebSocket client
asyncio.run(send_message())
# asyncio.get_event_loop().run_until_complete(send_message())
asyncio.run(send_message())
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file removed sample.pdf
Binary file not shown.
13 changes: 2 additions & 11 deletions server/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
class WebsocketCallbackHandler(StreamingStdOutCallbackHandler):
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(WebsocketCallbackHandler, cls).__new__(cls)
# Initialize the singleton instance here
return cls._instance

def set_websocket(self, websocket):
def __init__(self, websocket):
self.websocket = websocket
async def on_llm_new_token(self, token, **kwargs):
"""Run on new LLM token. Only available when streaming is enabled."""
await self.websocket.send(token)

websocket_callback_singleton = WebsocketCallbackHandler()
await self.websocket.send(token)
43 changes: 25 additions & 18 deletions server/models/llama2_gguf_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from langchain.llms import LlamaCpp
from langchain.callbacks.manager import CallbackManager
from callbacks import websocket_callback_singleton
from text_loader import load_content
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
Expand All @@ -9,29 +8,37 @@
from text_color import TextColor
from qa_template import TEMPLATE

llm = LlamaCpp(
model_path="./server/models/llama-2-7b.gguf.q4_K_M.bin",
temperature=0.75,
callback_manager=CallbackManager([websocket_callback_singleton]),
verbose=True,
)
text_chunks = load_content()
embeddings=HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2', model_kwargs={'device':'cpu'})
vector_store=FAISS.from_documents(text_chunks, embeddings)
print(TextColor.BLUE + "convert text chunks into embeddings!" + TextColor.RESET)

# create chain
qa_prompt=PromptTemplate(template=TEMPLATE, input_variables=['context', 'question'])
print(TextColor.BLUE + "create q&a template!" + TextColor.RESET)

class Llama2GGUFModel:
def __init__(self):
# TODO refactor __init__ and swap out callbackManager to make init model faster
self.callback_manager = CallbackManager([])
self.llm = LlamaCpp(
model_path="./server/models/llama-2-7b.gguf.q4_K_M.bin",
temperature=0.7,
repeat_penalty=1.176,
top_p=0.1,
max_tokens=-1,
callback_manager=self.callback_manager,
verbose=False,
)
text_chunks = load_content()
embeddings=HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2', model_kwargs={'device':'cpu'})
self.vector_store=FAISS.from_documents(text_chunks, embeddings)
print(TextColor.BLUE + "convert text chunks into embeddings!" + TextColor.RESET)

# create chain
self.qa_prompt=PromptTemplate(template=TEMPLATE, input_variables=['context', 'question'])
print(TextColor.BLUE + "create q&a template!" + TextColor.RESET)
def update_callback_handler(self, callback_handler):
self.callback_manager.set_handler(callback_handler)
def is_matched(self, text):
return "@gguf" in text.lower()
def execute_action(self, text):
text = text.replace("@gguf", "")
llama2_chain = RetrievalQA.from_chain_type(llm=llm,
llama2_chain = RetrievalQA.from_chain_type(llm=self.llm,
chain_type='stuff',
retriever=vector_store.as_retriever(search_kwargs={'k': 2}),
retriever=self.vector_store.as_retriever(search_kwargs={'k': 2}),
return_source_documents=True,
chain_type_kwargs={'prompt': qa_prompt})
chain_type_kwargs={'prompt': self.qa_prompt})
llama2_chain({'query': text})
42 changes: 28 additions & 14 deletions server/websocket_server.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,55 @@
import asyncio
import websockets
from callbacks import websocket_callback_singleton
from callbacks import WebsocketCallbackHandler
from models.llama2_gguf_model import Llama2GGUFModel
from remindme_parser import Remindme

PORT = 8081
ADDRESS = "0.0.0.0"

async def handle_websocket_close(websocket):
try:
await websocket.close()

except websockets.exceptions.ConnectionClosedOK:
# Handle the case where the connection is already closed
pass
except Exception as e:
# Handle other exceptions as needed
print(f"Error: {e}")

r = Remindme()
gguf_model = Llama2GGUFModel()
async def message(websocket, path):
print(f"Client connected to path: {path}")
websocket_callback_singleton.set_websocket(websocket)

try:
async for message in websocket:
print(f"Received message from client: {message}")
r = Remindme()

if r.is_matched(message):
resp = r.execute_action(message)
resp = await asyncio.to_thread(r.execute_action, message)
await websocket.send(f'remindme executed: {resp}')
return

gguf_model = Llama2GGUFModel()
callback_handler = WebsocketCallbackHandler(websocket)
gguf_model.update_callback_handler(callback_handler)
if gguf_model.is_matched(message):
gguf_model.execute_action(message)
await asyncio.to_thread(gguf_model.execute_action, message)
return

await websocket.send('voided conversation with no match')

except websockets.exceptions.ConnectionClosed:
print("Connection closed by the client.")
finally:
await handle_websocket_close(websocket)

async def main():
# Start the WebSocket server on 127.0.0.1, port 8081
server = await websockets.serve(message, "127.0.0.1", PORT)

print(f"WebSocket server started on ws://127.0.0.1:{PORT}")

# Keep the server running until it's manually stopped
await server.wait_closed()
async with websockets.serve(message, ADDRESS, PORT):
await asyncio.Future()
print(f"WebSocket server started on ws://{ADDRESS}:{PORT}")

# Run the WebSocket server
asyncio.run(main())
if __name__ == "__main__":
asyncio.run(main())
4 changes: 2 additions & 2 deletions supervisord.conf
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ logfile=./log/supervisord.log
pidfile=./log/supervisord.pid

[program:gpt_server]
command=python websocket_server.py
directory=./server
command=python ./server/websocket_server.py
directory=./
autostart=true
autorestart=true
stdout_logfile=./log/access.log
Expand Down

0 comments on commit d280068

Please sign in to comment.