Skip to content

Commit

Permalink
Merge branch 'main' into readme-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
subramen authored Jul 29, 2024
2 parents f87991c + 8c1418e commit f7bdf4d
Show file tree
Hide file tree
Showing 18 changed files with 1,599 additions and 8 deletions.
13 changes: 13 additions & 0 deletions .github/scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,19 @@ sqlite
customerservice
fn
ExecuTorch
LLMScore
RecursiveCharacterTextSplitter
TPD
TPM
Tianjun
Zhang
distractor
distractors
frac
numRefusal
totalQA
DirectoryLoader
SitemapLoader
nf
quant
DLAI
Expand Down
97 changes: 97 additions & 0 deletions recipes/quickstart/finetuning/datasets/raft_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.


import copy
from datasets import load_dataset
import itertools

# check system prompt token seq or user prompt token seq is in the current token list
def check_header(targets,seq):
for i in range(len(seq)-3):
if seq[i:i+3] in targets:
return True
return False
def replace_target(target,seq):
for i in range(len(seq)-3):
if seq[i:i+3] == target:
seq[i],seq[i+1],seq[i+2] = -100,-100,-100
return seq
def tokenize_dialog(dialog, tokenizer):
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
if tokenizer.vocab_size >= 128000:
dialog_tokens = tokenizer.apply_chat_template(dialog)
eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
labels = copy.copy(dialog_tokens)
last_idx = 0
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
for n, idx in enumerate(eot_indices):
current_seq = labels[last_idx:idx+1]
if check_header(prompt_header_seqs,current_seq):
# found prompt header, indicating that this seq should be masked
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
else:
last_idx = idx
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
assistant_header_seq = [128006, 78191, 128007]
labels = replace_target(assistant_header_seq,labels)
dialog_tokens = [dialog_tokens]
labels_tokens = [labels]
else:
raise Exception("This raft_dataset only supports Llama 3 family models, please make sure the tokenizer is from Llama 3 family models.")

combined_tokens = {
"input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
"labels": list(itertools.chain(*(t for t in labels_tokens))),
}

return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
def raft_tokenize(q_a_pair, tokenizer):
end_tag = "</DOCUMENT>"
# find the last end_tag in the instruction, the rest is the question
try:
index =q_a_pair["instruction"].rindex(end_tag)+len(end_tag)
except ValueError:
print(q_a_pair["instruction"])
raise Exception("The instruction does not contain the end tag <\/DOCUMENT>")
# all the lines after end_tag are the question
question = q_a_pair["instruction"][index:].strip()
# all the lines before end_tag are the context
documents = q_a_pair["instruction"][:index].strip()
# output is the label
answer = q_a_pair["output"]
system_prompt = "You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context."
user_prompt = """
Question: {question}\nContext: {context}\n
Answer this question using the information given by multiple documents in the context above. Here are the things to pay attention to:
- The context contains many documents, each document starts with <DOCUMENT> and ends </DOCUMENT>.
- First provide step-by-step reasoning on how to answer the question.
- In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
- End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
You MUST begin your final answer with the tag "<ANSWER>:".
""".format(question=question, context=documents)

chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": answer}
]
return tokenize_dialog(chat, tokenizer)


def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.9):
# load_dataset will return DatasetDict that contains all the data in the train set
dataset_dict = load_dataset('json', data_files=dataset_config.data_path)
dataset = dataset_dict['train']
dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)

dataset = dataset[split].map(lambda sample: {
"instruction": sample["instruction"],
"output": sample["cot_answer"],
},
batched=True,
)
dataset = dataset.map(lambda x: raft_tokenize(x, tokenizer))
return dataset
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@
"metadata": {},
"source": [
"\n",
"One good way to quickly obtain labeled training data for a use case is to use the original, non-fine tuned model itself to highlight risky examples to label, while drawing random negatives from below a score threshold. This helps address the class imbalance (attacks and risky prompts can be a very small percentage of all prompts) and includes false positive examples (which tend to be very valuable to train on) in the dataset. The use of synthetic data for specific "
"One good way to quickly obtain labeled training data for a use case is to use the original, non-fine tuned model itself to highlight risky examples to label, while drawing random negatives from below a score threshold. This helps address the class imbalance (attacks and risky prompts can be a very small percentage of all prompts) and includes false positive examples (which tend to be very valuable to train on) in the dataset. Generating synthetic fine-tuning data for specific use cases can also be an effective strategy."
]
}
],
Expand Down
243 changes: 243 additions & 0 deletions recipes/use_cases/end2end-recipes/RAFT-Chatbot/README.md

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions recipes/use_cases/end2end-recipes/RAFT-Chatbot/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import yaml

def load_config(config_path: str = "./config.yaml"):
# Read the YAML configuration file
with open(config_path, "r") as file:
config = yaml.safe_load(file)
return config
Loading

0 comments on commit f7bdf4d

Please sign in to comment.