Skip to content

Commit

Permalink
Merge pull request #451 from Capsize-Games/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
w4ffl35 authored Jan 28, 2024
2 parents e468f93 + 5ff85a3 commit c1989b2
Show file tree
Hide file tree
Showing 6 changed files with 444 additions and 358 deletions.
221 changes: 36 additions & 185 deletions src/airunner/aihandler/casual_lm_transfformer_base_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.llms.types import ChatResponse
from transformers import AutoModelForCausalLM, TextIteratorStreamer

from llama_index.llms import HuggingFaceLLM, ChatMessage
Expand All @@ -9,7 +8,7 @@
from llama_index import VectorStoreIndex, SimpleDirectoryReader

from airunner.aihandler.tokenizer_handler import TokenizerHandler
from airunner.enums import SignalCode, SelfReflectionCategory
from airunner.enums import SignalCode


class CasualLMTransformerBaseHandler(TokenizerHandler):
Expand All @@ -24,7 +23,6 @@ def __init__(self, *args, **kwargs):
self.documents = None
self.index = None
self.query_engine: BaseQueryEngine = None
self.embeddings_model_path = "BAAI/bge-small-en-v1.5"

def post_load(self):
super().post_load()
Expand Down Expand Up @@ -64,22 +62,15 @@ def load_streamer(self):

def load_llm(self):
self.logger.info("Loading RAG")
variables = {
"username": self.username,
"botname": self.botname,
"bos_token": self.tokenizer.bos_token,
"bot_mood": self.bot_mood,
"bot_personality": self.bot_personality,
}
self.llm = HuggingFaceLLM(
model=self.model,
tokenizer=self.tokenizer,
)

def load_embed_model(self):
self.logger.info("Loading embeddings")
self.logger.info("Loading embedding model")
self.embed_model = HuggingFaceEmbedding(
model_name=self.embeddings_model_path,
model_name=self.settings["llm_generator_settings"]["embeddings_model_path"],
)

def load_service_context(self):
Expand Down Expand Up @@ -114,9 +105,8 @@ def load_query_engine(self):
)

def do_generate(self):
#self.llm_stream()
#self.rag_stream()
self.chat_stream()
#self.rag_stream()

def save_query_engine_to_disk(self):
self.index.storage_context.persist(
Expand All @@ -135,80 +125,47 @@ def load_query_engine_from_disk(self):
streaming=True
)

def prepare_messages(self):
optional_self_reflection = {}
optional_self_reflection[SelfReflectionCategory.ILLEGAL] = "illegal: Illegal activity."
optional_self_reflection[SelfReflectionCategory.HATE_VIOLENCE_HARASSMENT] = "hate violence harassment: Generation of hateful, harassing, or violent content: content that expresses, incites, or promotes hate based on identity, content that intends to harass, threaten, or bully an individual, content that promotes or glorifies violence or celebrates the suffering or humiliation of others."
optional_self_reflection[SelfReflectionCategory.MALWARE] = "malware: Generation of malware: content that attempts to generate code that is designed to disrupt, damage, or gain unauthorized access to a computer system."
optional_self_reflection[SelfReflectionCategory.PHYSICAL_HARM] = "physical harm: activity that has high risk of physical harm, including: weapons development, military and warfare, management or operation of critical infrastructure in energy, transportation, and water, content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders."
optional_self_reflection[SelfReflectionCategory.ECONOMIC_HARM] = "economic harm: activity that has high risk of economic harm, including: multi-level marketing, gambling, payday lending, automated determinations of eligibility for credit, employment, educational institutions, or public assistance services."
optional_self_reflection[SelfReflectionCategory.FRAUD] = "fraud: Fraudulent or deceptive activity, including: scams, coordinated inauthentic behavior, plagiarism, academic dishonesty, astroturfing, such as fake grassroots support or fake review generation, disinformation, spam, pseudo-pharmaceuticals."
optional_self_reflection[SelfReflectionCategory.ADULT] = "adult: Adult content, adult industries, and dating apps, including: content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness), erotic chat, pornography."
optional_self_reflection[SelfReflectionCategory.POLITICAL] = "political: Political campaigning or lobbying, by: generating high volumes of campaign materials, generating campaign materials personalized to or targeted at specific demographics, building conversational or interactive systems such as chatbots that provide information about campaigns or engage in political advocacy or lobbying, building products for political campaigning or lobbying purposes."
optional_self_reflection[SelfReflectionCategory.PRIVACY] = "privacy: Activity that violates people's privacy, including: tracking or monitoring an individual without their consent, facial recognition of private individuals, classifying individuals based on protected characteristics, using biometrics for identification or assessment, unlawful collection or disclosure of personal identifiable information or educational, financial, or other protected records."
optional_self_reflection[SelfReflectionCategory.UNQUALIFIED_LAW] = "unqualified law: Engaging in the unauthorized practice of law, or offering tailored legal advice without a qualified person reviewing the information."
optional_self_reflection[SelfReflectionCategory.UNQUALIFIED_FINANCIAL] = "unqualified financial: Offering tailored financial advice without a qualified person reviewing the information."
optional_self_reflection[SelfReflectionCategory.UNQUALIFIED_HEALTH] = "unqualified health: Telling someone that they have or do not have a certain health condition, or providing instructions on how to cure or treat a health condition."

def build_system_prompt(self):
# The guardrails prompt is optional and can be overriden.
guardrails_prompt = ""
if self.settings["llm_generator_settings"]["guardrails_active"]:
guardrails_prompt = (
"Always assist with care, respect, and truth. "
"Respond with utmost utility yet securely. "
"Avoid harmful, unethical, prejudiced, or negative content. "
"Ensure replies promote fairness and positivity."
)
if self.settings["llm_generator_settings"]["use_guardrails"]:
guardrails_prompt = self.settings["llm_generator_settings"]["guardrails_prompt"]

# The self-reflection prompt is not optional, but the categories are.
def dec(m): return ''.join(chr(ord(c) ^ 0x55) for c in m)
self_reflection_prompt = [
"You're given a list of moderation categories as below:",
dec("xu6=<91u47 &0ou6=<91u&0- 49u47 &0u84!0'<49u:'u4;,u6:;!0;!u!=4!u0-%9:<!&u:'u=4'8&u6=<91'0;{"),
]
system_prompt = []

# All other categories in self-reflection are optional and can be controlled via the interface
self_reflection_categories = self.settings["llm_generator_settings"]["self_reflection_categories"]
for optional_self_reflection in self_reflection_categories:
if optional_self_reflection["active"]:
self_reflection_prompt.append(
f"- {optional_self_reflection['category']}"
)
if self.settings["llm_generator_settings"]["use_guardrails"]:
system_prompt.append(guardrails_prompt)

self_reflection_prompt.extend([
"Please classify the following text into one of these categories, and answer with that single word only.",
"",
"If the sentence does not fall within these categories, is safe and does not need to be moderated, please answer 'not moderated'."
])
self_reflection_prompt = "\n".join(self_reflection_prompt)

# TODO: override this
system_instructions = [
"You are a dungeon master for a roleplaying game. ",
"You will respond to the player's actions and questions. ",
]

system_instructions.append(self_reflection_prompt)
if self.settings["llm_generator_settings"]["use_system_instructions"]:
system_prompt.append(
self.settings["llm_generator_settings"]["system_instructions"]
)

if self.settings["llm_generator_settings"]["assign_names"]:
system_instructions.append(
system_prompt.append(
"Your name is " + self.botname + ". "
)
system_prompt.append(
"The user's name is " + self.username + "."
)

if self.settings["llm_generator_settings"]["use_mood"]:
system_instructions.append(f"Your mood: {self.bot_mood}.")
system_prompt.append(f"Your mood: {self.bot_mood}.")

if self.settings["llm_generator_settings"]["use_personality"]:
system_instructions.append(f"Your personality: {self.bot_personality}.")
system_prompt.append(f"Your personality: {self.bot_personality}.")

system_prompt = "\n".join(system_instructions)
system_prompt = "\n".join(system_prompt)
return system_prompt

messages = [
ChatMessage(
def prepare_messages(self, system_prompt=None):
if system_prompt is None:
system_prompt = ChatMessage(
role="system",
content=system_prompt
content=self.build_system_prompt()
)
messages = [
system_prompt
]
for message in self.history:
messages.append(
Expand All @@ -228,7 +185,13 @@ def dec(m): return ''.join(chr(ord(c) ^ 0x55) for c in m)

def chat_stream(self):
self.logger.info("Generating chat response")
self.add_message_to_history(
self.prompt,
role="user"
)

messages = self.prepare_messages()

streaming_response = self.llm.stream_chat(messages)
is_first_message = True
is_end_of_message = False
Expand All @@ -247,7 +210,6 @@ def chat_stream(self):
if not is_end_of_message:
self.send_final_message()

print("assistant_message: " + assistant_message)
self.add_message_to_history(
assistant_message
)
Expand Down Expand Up @@ -298,9 +260,9 @@ def emit_streamed_text_signal(self, **kwargs):
kwargs
)

def add_message_to_history(self, message):
def add_message_to_history(self, message, role="assistant"):
self.history.append({
"role": "assistant",
"role": role,
"content": message
})

Expand All @@ -309,115 +271,4 @@ def send_final_message(self):
message="",
is_first_message=False,
is_end_of_message=True
)

def llm_stream(self):
prompt = self.prompt
history = []
for message in self.history:
if message["role"] == "user":
# history.append("<s>[INST]" + self.username + ': "'+ message["content"] +'"[/INST]')
history.append(self.username + ': "' + message["content"] + '"')
else:
# history.append(self.botname + ': "'+ message["content"] +'"</s>')
history.append(self.botname + ': "' + message["content"])
history = "\n".join(history)
if history == "":
history = None

# Create a dictionary with the variables
variables = {
"username": self.username,
"botname": self.botname,
"history": history or "",
"input": prompt,
"bos_token": self.tokenizer.bos_token,
"bot_mood": self.bot_mood,
"bot_personality": self.bot_personality,
}

self.history.append({
"role": "user",
"content": prompt
})

# Render the template with the variables
# rendered_template = chat_template.render(variables)

# iterate over variables and replace again, this allows us to use variables
# in custom template variables (for example variables inside of botmood and bot_personality)
rendered_template = self.template
for n in range(2):
for key, value in variables.items():
rendered_template = rendered_template.replace("{{ " + key + " }}", value)

# Encode the rendered template
encoded = self.tokenizer(rendered_template, return_tensors="pt")
model_inputs = encoded.to("cuda" if torch.cuda.is_available() else "cpu")

# Generate the response
self.logger.info("Generating...")
import threading
self.thread = threading.Thread(target=self.model.generate, kwargs=dict(
model_inputs,
min_length=self.min_length,
max_length=self.max_length,
num_beams=self.num_beams,
do_sample=True,
top_k=self.top_k,
eta_cutoff=self.eta_cutoff,
top_p=self.top_p,
num_return_sequences=self.sequences,
eos_token_id=self.tokenizer.eos_token_id,
early_stopping=True,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
streamer=self.streamer
))
self.thread.start()
# strip all new lines from rendered_template:
rendered_template = rendered_template.replace("\n", " ")
rendered_template = "<s>" + rendered_template
skip = True
streamed_template = ""
replaced = False
is_end_of_message = False
is_first_message = True
for new_text in self.streamer:
# strip all newlines from new_text
parsed_new_text = new_text.replace("\n", " ")
streamed_template += parsed_new_text
streamed_template = streamed_template.replace("<s> [INST]", "<s>[INST]")
# iterate over every character in rendered_template and
# check if we have the same character in streamed_template
if not replaced:
for i, char in enumerate(rendered_template):
try:
if char == streamed_template[i]:
skip = False
else:
skip = True
break
except IndexError:
skip = True
break
if skip:
continue
elif not replaced:
replaced = True
streamed_template = streamed_template.replace(rendered_template, "")
else:
if "</s>" in new_text:
streamed_template = streamed_template.replace("</s>", "")
new_text = new_text.replace("</s>", "")
is_end_of_message = True
self.emit(
SignalCode.LLM_TEXT_STREAMED_SIGNAL,
dict(
message=new_text,
is_first_message=is_first_message,
is_end_of_message=is_end_of_message,
name=self.botname,
)
)
is_first_message = False
)
15 changes: 0 additions & 15 deletions src/airunner/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,6 @@ class ServiceCode(Enum):
GET_CALLBACK_FOR_SLIDER = "get_callback_for_slider"


class SelfReflectionCategory(Enum):
ILLEGAL = "illegal",
HATE_VIOLENCE_HARASSMENT = "hate violence harassment"
MALWARE = "malware"
PHYSICAL_HARM = "physical harm"
ECONOMIC_HARM = "economic harm"
FRAUD = "fraud"
ADULT = "adult"
POLITICAL = "political"
PRIVACY = "privacy"
UNQUALIFIED_LAW = "unqualified law"
UNQUALIFIED_FINANCIAL = "unqualified financial"
UNQUALIFIED_HEALTH = "unqualified health"


class SignalCode(Enum):
AI_MODELS_REFRESH_SIGNAL = "refresh_ai_models_signal"
AI_MODELS_SAVE_OR_UPDATE_SIGNAL = "ai_models_save_or_update_signal"
Expand Down
Loading

0 comments on commit c1989b2

Please sign in to comment.