-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 28773c1
Showing
4 changed files
with
425 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Building LLama Banker | ||
Doing RAG for Finance using LLama2. Highly recommend you run this in a GPU accelerated environment. I used a A100-80GB GPU on Runpod for the video! | ||
|
||
## See it live and in action 📺 | ||
[![Tutorial](https://i.imgur.com/lqMC3K7.png)](https://youtu.be/abc123 'Tutorial') | ||
|
||
# Startup 🚀 | ||
1. Clone this repo `git clone https://github.com/nicknochnack/Llama2RAG` | ||
2. Go into the directory `cd Llama2RAG` | ||
3. Startup jupyter by running `jupyter lab` in a terminal or command prompt | ||
4. Update the `auth_token` variable in the notebook. | ||
5. Hit `Ctrl + Enter` to run through the notebook! | ||
6. Go back to my YouTube channel and like and subscribe 😉...no seriously...please! lol | ||
7. If you want to start up the streamlit app run `streamlit run app.py` (make sure you update your auth token in there as well!) | ||
|
||
# Other References 🔗 | ||
<p>-<a href="https://huggingface.co/meta-llama/Llama-2-70b-chat-hf">Llama 2 70b Chat Model Card</a>:hugging face model card on the model used for the video.</p> | ||
<p>-<a href="https://www.llamaindex.ai/">Llama Index Doco</a>:sick library used for RAG.</p> | ||
|
||
# Who, When, Why? | ||
👨🏾💻 Author: Nick Renotte <br /> | ||
📅 Version: 1.x<br /> | ||
📜 License: This project is licensed under the MIT license. Feel free to use it, just don't do bad things with it. </br> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Import streamlit for app dev | ||
import streamlit as st | ||
|
||
# Import transformer classes for generaiton | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer | ||
# Import torch for datatype attributes | ||
import torch | ||
# Import the prompt wrapper...but for llama index | ||
from llama_index.prompts.prompts import SimpleInputPrompt | ||
# Import the llama index HF Wrapper | ||
from llama_index.llms import HuggingFaceLLM | ||
# Bring in embeddings wrapper | ||
from llama_index.embeddings import LangchainEmbedding | ||
# Bring in HF embeddings - need these to represent document chunks | ||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | ||
# Bring in stuff to change service context | ||
from llama_index import set_global_service_context | ||
from llama_index import ServiceContext | ||
# Import deps to load documents | ||
from llama_index import VectorStoreIndex, download_loader | ||
from pathlib import Path | ||
|
||
# Define variable to hold llama2 weights naming | ||
name = "meta-llama/Llama-2-70b-chat-hf" | ||
# Set auth token variable from hugging face | ||
auth_token = "YOUR HUGGING FACE AUTH TOKEN HERE" | ||
|
||
@st.cache_resource | ||
def get_tokenizer_model(): | ||
# Create tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir='./model/', use_auth_token=auth_token) | ||
|
||
# Create model | ||
model = AutoModelForCausalLM.from_pretrained(name, cache_dir='./model/' | ||
, use_auth_token=auth_token, torch_dtype=torch.float16, | ||
rope_scaling={"type": "dynamic", "factor": 2}, load_in_8bit=True) | ||
|
||
return tokenizer, model | ||
tokenizer, model = get_tokenizer_model() | ||
|
||
# Create a system prompt | ||
system_prompt = """<s>[INST] <<SYS>> | ||
You are a helpful, respectful and honest assistant. Always answer as | ||
helpfully as possible, while being safe. Your answers should not include | ||
any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. | ||
Please ensure that your responses are socially unbiased and positive in nature. | ||
If a question does not make any sense, or is not factually coherent, explain | ||
why instead of answering something not correct. If you don't know the answer | ||
to a question, please don't share false information. | ||
Your goal is to provide answers relating to the financial performance of | ||
the company.<</SYS>> | ||
""" | ||
# Throw together the query wrapper | ||
query_wrapper_prompt = SimpleInputPrompt("{query_str} [/INST]") | ||
|
||
# Create a HF LLM using the llama index wrapper | ||
llm = HuggingFaceLLM(context_window=4096, | ||
max_new_tokens=256, | ||
system_prompt=system_prompt, | ||
query_wrapper_prompt=query_wrapper_prompt, | ||
model=model, | ||
tokenizer=tokenizer) | ||
|
||
# Create and dl embeddings instance | ||
embeddings=LangchainEmbedding( | ||
HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | ||
) | ||
|
||
# Create new service context instance | ||
service_context = ServiceContext.from_defaults( | ||
chunk_size=1024, | ||
llm=llm, | ||
embed_model=embeddings | ||
) | ||
# And set the service context | ||
set_global_service_context(service_context) | ||
|
||
# Download PDF Loader | ||
PyMuPDFReader = download_loader("PyMuPDFReader") | ||
# Create PDF Loader | ||
loader = PyMuPDFReader() | ||
# Load documents | ||
documents = loader.load(file_path=Path('./data/annualreport.pdf'), metadata=True) | ||
|
||
# Create an index - we'll be able to query this in a sec | ||
index = VectorStoreIndex.from_documents(documents) | ||
# Setup index query engine using LLM | ||
query_engine = index.as_query_engine() | ||
|
||
# Create centered main title | ||
st.title('🦙 Llama Banker') | ||
# Create a text input box for the user | ||
prompt = st.text_input('Input your prompt here') | ||
|
||
# If the user hits enter | ||
if prompt: | ||
response = query_engine.query(prompt) | ||
# ...and write it out to the screen | ||
st.write(response) | ||
|
||
# Display raw response object | ||
with st.expander('Response Object'): | ||
st.write(response) | ||
# Display source text | ||
with st.expander('Source Text'): | ||
st.write(response.get_formatted_sources()) | ||
|
||
|
Binary file not shown.
Oops, something went wrong.