Skip to content

Commit

Permalink
Merge pull request #448 from Capsize-Games/develop
Browse files Browse the repository at this point in the history
Add more LLM options and improves chat
  • Loading branch information
w4ffl35 authored Jan 28, 2024
2 parents 8b806bc + 097d59b commit e468f93
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 65 deletions.
238 changes: 175 additions & 63 deletions src/airunner/aihandler/casual_lm_transfformer_base_handler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.llms.types import ChatResponse
from transformers import AutoModelForCausalLM, TextIteratorStreamer

from llama_index.llms import HuggingFaceLLM
from llama_index.llms import HuggingFaceLLM, ChatMessage
from llama_index.embeddings import HuggingFaceEmbedding
from llama_index import ServiceContext, StorageContext, load_index_from_storage
from llama_index import VectorStoreIndex, SimpleDirectoryReader

from airunner.aihandler.tokenizer_handler import TokenizerHandler
from airunner.enums import SignalCode
from airunner.enums import SignalCode, SelfReflectionCategory


class CasualLMTransformerBaseHandler(TokenizerHandler):
Expand Down Expand Up @@ -63,9 +64,16 @@ def load_streamer(self):

def load_llm(self):
self.logger.info("Loading RAG")
variables = {
"username": self.username,
"botname": self.botname,
"bos_token": self.tokenizer.bos_token,
"bot_mood": self.bot_mood,
"bot_personality": self.bot_personality,
}
self.llm = HuggingFaceLLM(
model=self.model,
tokenizer=self.tokenizer
tokenizer=self.tokenizer,
)

def load_embed_model(self):
Expand Down Expand Up @@ -107,7 +115,8 @@ def load_query_engine(self):

def do_generate(self):
#self.llm_stream()
self.rag_stream()
#self.rag_stream()
self.chat_stream()

def save_query_engine_to_disk(self):
self.index.storage_context.persist(
Expand All @@ -126,78 +135,181 @@ def load_query_engine_from_disk(self):
streaming=True
)

def prepare_template(self):
prompt = self.prompt
history = []
for message in self.history:
if message["role"] == "user":
# history.append("<s>[INST]" + self.username + ': "'+ message["content"] +'"[/INST]')
history.append(self.username + ': "' + message["content"] + '"')
else:
# history.append(self.botname + ': "'+ message["content"] +'"</s>')
history.append(self.botname + ': "' + message["content"])
history = "\n".join(history)
if history == "":
history = None
def prepare_messages(self):
optional_self_reflection = {}
optional_self_reflection[SelfReflectionCategory.ILLEGAL] = "illegal: Illegal activity."
optional_self_reflection[SelfReflectionCategory.HATE_VIOLENCE_HARASSMENT] = "hate violence harassment: Generation of hateful, harassing, or violent content: content that expresses, incites, or promotes hate based on identity, content that intends to harass, threaten, or bully an individual, content that promotes or glorifies violence or celebrates the suffering or humiliation of others."
optional_self_reflection[SelfReflectionCategory.MALWARE] = "malware: Generation of malware: content that attempts to generate code that is designed to disrupt, damage, or gain unauthorized access to a computer system."
optional_self_reflection[SelfReflectionCategory.PHYSICAL_HARM] = "physical harm: activity that has high risk of physical harm, including: weapons development, military and warfare, management or operation of critical infrastructure in energy, transportation, and water, content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders."
optional_self_reflection[SelfReflectionCategory.ECONOMIC_HARM] = "economic harm: activity that has high risk of economic harm, including: multi-level marketing, gambling, payday lending, automated determinations of eligibility for credit, employment, educational institutions, or public assistance services."
optional_self_reflection[SelfReflectionCategory.FRAUD] = "fraud: Fraudulent or deceptive activity, including: scams, coordinated inauthentic behavior, plagiarism, academic dishonesty, astroturfing, such as fake grassroots support or fake review generation, disinformation, spam, pseudo-pharmaceuticals."
optional_self_reflection[SelfReflectionCategory.ADULT] = "adult: Adult content, adult industries, and dating apps, including: content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness), erotic chat, pornography."
optional_self_reflection[SelfReflectionCategory.POLITICAL] = "political: Political campaigning or lobbying, by: generating high volumes of campaign materials, generating campaign materials personalized to or targeted at specific demographics, building conversational or interactive systems such as chatbots that provide information about campaigns or engage in political advocacy or lobbying, building products for political campaigning or lobbying purposes."
optional_self_reflection[SelfReflectionCategory.PRIVACY] = "privacy: Activity that violates people's privacy, including: tracking or monitoring an individual without their consent, facial recognition of private individuals, classifying individuals based on protected characteristics, using biometrics for identification or assessment, unlawful collection or disclosure of personal identifiable information or educational, financial, or other protected records."
optional_self_reflection[SelfReflectionCategory.UNQUALIFIED_LAW] = "unqualified law: Engaging in the unauthorized practice of law, or offering tailored legal advice without a qualified person reviewing the information."
optional_self_reflection[SelfReflectionCategory.UNQUALIFIED_FINANCIAL] = "unqualified financial: Offering tailored financial advice without a qualified person reviewing the information."
optional_self_reflection[SelfReflectionCategory.UNQUALIFIED_HEALTH] = "unqualified health: Telling someone that they have or do not have a certain health condition, or providing instructions on how to cure or treat a health condition."

# The guardrails prompt is optional and can be overriden.
guardrails_prompt = ""
if self.settings["llm_generator_settings"]["guardrails_active"]:
guardrails_prompt = (
"Always assist with care, respect, and truth. "
"Respond with utmost utility yet securely. "
"Avoid harmful, unethical, prejudiced, or negative content. "
"Ensure replies promote fairness and positivity."
)

# Create a dictionary with the variables
variables = {
"username": self.username,
"botname": self.botname,
"history": history or "",
"input": prompt,
"bos_token": self.tokenizer.bos_token,
"bot_mood": self.bot_mood,
"bot_personality": self.bot_personality,
}
# The self-reflection prompt is not optional, but the categories are.
def dec(m): return ''.join(chr(ord(c) ^ 0x55) for c in m)
self_reflection_prompt = [
"You're given a list of moderation categories as below:",
dec("xu6=<91u47 &0ou6=<91u&0- 49u47 &0u84!0'<49u:'u4;,u6:;!0;!u!=4!u0-%9:<!&u:'u=4'8&u6=<91'0;{"),
]

# All other categories in self-reflection are optional and can be controlled via the interface
self_reflection_categories = self.settings["llm_generator_settings"]["self_reflection_categories"]
for optional_self_reflection in self_reflection_categories:
if optional_self_reflection["active"]:
self_reflection_prompt.append(
f"- {optional_self_reflection['category']}"
)

self.history.append({
"role": "user",
"content": prompt
})
self_reflection_prompt.extend([
"Please classify the following text into one of these categories, and answer with that single word only.",
"",
"If the sentence does not fall within these categories, is safe and does not need to be moderated, please answer 'not moderated'."
])
self_reflection_prompt = "\n".join(self_reflection_prompt)

# TODO: override this
system_instructions = [
"You are a dungeon master for a roleplaying game. ",
"You will respond to the player's actions and questions. ",
]

system_instructions.append(self_reflection_prompt)

if self.settings["llm_generator_settings"]["assign_names"]:
system_instructions.append(
"Your name is " + self.botname + ". "
"The user's name is " + self.username + "."
)

# Render the template with the variables
# rendered_template = chat_template.render(variables)
if self.settings["llm_generator_settings"]["use_mood"]:
system_instructions.append(f"Your mood: {self.bot_mood}.")

# iterate over variables and replace again, this allows us to use variables
# in custom template variables (for example variables inside of botmood and bot_personality)
rendered_template = self.template
for n in range(2):
for key, value in variables.items():
rendered_template = rendered_template.replace("{{ " + key + " }}", value)
return rendered_template
if self.settings["llm_generator_settings"]["use_personality"]:
system_instructions.append(f"Your personality: {self.bot_personality}.")

system_prompt = "\n".join(system_instructions)

messages = [
ChatMessage(
role="system",
content=system_prompt
)
]
for message in self.history:
messages.append(
ChatMessage(
role=message["role"],
content=message["content"]
)
)
if self.prompt:
messages.append(
ChatMessage(
role="user",
content=self.prompt
)
)
return messages

def chat_stream(self):
self.logger.info("Generating chat response")
messages = self.prepare_messages()
streaming_response = self.llm.stream_chat(messages)
is_first_message = True
is_end_of_message = False
assistant_message = ""
for chat_response in streaming_response:
content, is_end_of_message = self.parse_chat_response(chat_response)
content = content.replace(assistant_message, "")
assistant_message += content
self.emit_streamed_text_signal(
message=content,
is_first_message=is_first_message,
is_end_of_message=is_end_of_message
)
is_first_message = False

if not is_end_of_message:
self.send_final_message()

print("assistant_message: " + assistant_message)
self.add_message_to_history(
assistant_message
)

def rag_stream(self):
self.logger.info("Generating RAG response")
streaming_response = self.query_engine.query(
self.prompt
)
streaming_response = self.query_engine.query(self.prompt)
is_first_message = True
is_end_of_message = False
assistant_message = ""
for new_text in streaming_response.response_gen:
if "</s>" in new_text:
new_text = new_text.replace("</s>", "")
is_end_of_message = True
self.emit(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
dict(
message=new_text,
is_first_message=is_first_message,
is_end_of_message=is_end_of_message,
name=self.botname,
)
content, is_end_of_message = self.parse_rag_response(new_text)
assistant_message += content
self.emit_streamed_text_signal(
message=content,
is_first_message=is_first_message,
is_end_of_message=is_end_of_message
)
is_first_message = False

if not is_end_of_message:
self.emit(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
dict(
message="",
is_first_message=False,
is_end_of_message=True,
name=self.botname,
)
)
self.send_final_message()

self.add_message_to_history(
assistant_message
)

def parse_rag_response(self, content):
is_end_of_message = False
if "</s>" in content:
content = content.replace("</s>", "")
is_end_of_message = True
return content, is_end_of_message

def parse_chat_response(self, chat_response):
message = chat_response.message
content = message.content
is_end_of_message = False
if "</s>" in content:
content = content.replace("</s>", "")
is_end_of_message = True
return content, is_end_of_message

def emit_streamed_text_signal(self, **kwargs):
kwargs["name"] = self.botname
self.emit(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
kwargs
)

def add_message_to_history(self, message):
self.history.append({
"role": "assistant",
"content": message
})

def send_final_message(self):
self.emit_streamed_text_signal(
message="",
is_first_message=False,
is_end_of_message=True
)

def llm_stream(self):
prompt = self.prompt
Expand Down
13 changes: 13 additions & 0 deletions src/airunner/aihandler/tokenizer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,25 @@ def post_load(self):
def load_tokenizer(self, local_files_only=None):
self.logger.info(f"Loading tokenizer from {self.current_model_path}")
local_files_only = self.local_files_only if local_files_only is None else local_files_only
chat_template = (
"{{ bos_token }}"
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{{ '[INST] <<SYS>>' + message['content'] + ' <</SYS>>[/INST]' }}"
"{% elif message['role'] == 'user' %}"
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ message['content'] + eos_token + ' ' }}"
"{% endif %}"
"{% endfor %}"
)
try:
self.tokenizer = self.tokenizer_class_.from_pretrained(
self.tokenizer_path,
local_files_only=local_files_only,
token=self.request_data.get("hf_api_key_read_key"),
device_map=self.device,
chat_template=chat_template,
)
self.logger.info("Tokenizer loaded")
except OSError as e:
Expand Down
15 changes: 15 additions & 0 deletions src/airunner/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ class ServiceCode(Enum):
GET_CALLBACK_FOR_SLIDER = "get_callback_for_slider"


class SelfReflectionCategory(Enum):
ILLEGAL = "illegal",
HATE_VIOLENCE_HARASSMENT = "hate violence harassment"
MALWARE = "malware"
PHYSICAL_HARM = "physical harm"
ECONOMIC_HARM = "economic harm"
FRAUD = "fraud"
ADULT = "adult"
POLITICAL = "political"
PRIVACY = "privacy"
UNQUALIFIED_LAW = "unqualified law"
UNQUALIFIED_FINANCIAL = "unqualified financial"
UNQUALIFIED_HEALTH = "unqualified health"


class SignalCode(Enum):
AI_MODELS_REFRESH_SIGNAL = "refresh_ai_models_signal"
AI_MODELS_SAVE_OR_UPDATE_SIGNAL = "ai_models_save_or_update_signal"
Expand Down
Loading

0 comments on commit e468f93

Please sign in to comment.