Skip to content

Commit

Permalink
Train using OpenLLama
Browse files Browse the repository at this point in the history
  • Loading branch information
robinhad committed Jun 14, 2023
1 parent 8b72ee9 commit 8803029
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 62 deletions.
98 changes: 98 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from peft import PeftModel
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
import gradio as gr
from torch.cuda import is_available

if is_available():
options = dict(
load_in_8bit=True,
device_map="auto",
)
else:
options = {}

tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_7b")
model = LlamaForCausalLM.from_pretrained(
"openlm-research/open_llama_7b",
**options
)
model = PeftModel.from_pretrained(model, "robinhad/open_llama_7b_uk")


def generate_prompt(instruction, input=None, output=""):
if input:
return f"""Унизу надається інструкція, яка описує завдання разом із вхідними даними, які надають додатковий контекст. Напиши відповідь, яка правильно доповнює запит.
### Інструкція:
{instruction}
### Вхідні дані:
{input}
### Відповідь:
{output}"""
else:
return f"""Унизу надається інструкція, яка описує завдання. Напиши відповідь, яка правильно доповнює запит.
### Інструкція:
{instruction}
### Відповідь:
{output}"""


generation_config = GenerationConfig(
temperature=0.2,
top_p=0.75,
num_beams=4,
)

def evaluate(instruction, input=None):
if input.strip() == "":
input = None
prompt = generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
if is_available():
input_ids = input_ids.cuda()
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=64
)
for s in generation_output.sequences:
output = tokenizer.decode(s, skip_special_tokens=True)
print("============")
print(output)
return output.split("### Відповідь:")[1].strip()


gr.Interface(
evaluate,
[
gr.inputs.Textbox(lines=5, label="Інструкція"),
gr.inputs.Textbox(lines=5, label="Вхідні дані (необов'язково)"),
],
gr.outputs.Textbox(label="Відповідь"),
title="Kruk",
description="Open Llama is a Ukrainian language model trained on the machine-translated Dolly dataset.",
examples=[
[
"Яка найвища гора в Україні?",
"",
],
[
"Розкажи історію про Івасика-Телесика.",
"",
],
[
"Яка з цих гір не знаходиться у Європі?",
"Говерла, Монблан, Гран-Парадізо, Еверест"
],
[
"Чому качки жовтоногі?",
"",
],
[
"Чому у качки жовті ноги?",
"",
],
]
).launch()
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
tqdm
openai
datasets
git+https://github.com/huggingface/transformers.git@c612628045822f909020f7eb6784c79700813eda # for LLaMa support, should be available in 4.28.0
transformers # for LLaMa support, should be available in 4.28.0
torch
peft
bitsandbytes
sentencepiece
tenacity
tenacity
scipy # for bitsandbytes
gradio
14 changes: 14 additions & 0 deletions scripts/alpaca/merge_m2m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import json
import glob

files = glob.glob("./data/*.json")
files = sorted(files, key=lambda x: int(x.replace(".json", "").split("_")[-1]))

items = []
for file in files:
item = ""
with open(file, "r") as f:
item = json.dumps(json.load(f), ensure_ascii=False)

with open("../../data/cc-by-sa-3.0/databricks-dolly-15k-translated.jsonl", "a") as output:
output.write(item + "\n")
13 changes: 0 additions & 13 deletions scripts/alpaca/merge_nllb.py

This file was deleted.

32 changes: 16 additions & 16 deletions scripts/alpaca/train_ualpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from datasets import load_dataset
import transformers
from transformers import AutoTokenizer, AutoConfig, LlamaForCausalLM
from transformers import LlamaForCausalLM
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model

Expand All @@ -13,19 +13,20 @@
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = 3 # we don't need 3 tbh
LEARNING_RATE = 3e-4 # the Karpathy constant
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
CUTOFF_LEN = 512 # 1024 accounts for about 99.5% of the data
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05

model_name = "openlm-research/open_llama_7b"

model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
model_name,
load_in_8bit=True,
device_map="auto",
)
tokenizer = LlamaTokenizer.from_pretrained(
"decapoda-research/llama-7b-hf", add_eos_token=True
model_name, add_eos_token=True
)

model = prepare_model_for_int8_training(model)
Expand All @@ -40,7 +41,7 @@
)
model = get_peft_model(model, config)
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
data = load_dataset("json", data_files="../../data/cc-by-nc/alpaca_data_translated.json")
data = load_dataset("json", data_files="../../data/cc-by-sa-3.0/databricks-dolly-15k-translated.jsonl")

# def generate_prompt(data_point):
# if data_point["input"]:
Expand All @@ -60,20 +61,20 @@

# TODO: take a look at translation
def generate_prompt(data_point):
if data_point["input"]:
if data_point["context"]:
return f"""Унизу надається інструкція, яка описує завдання разом із вхідними даними, які надають додатковий контекст. Напиши відповідь, яка правильно доповнює запит.
### Інструкція:
{data_point["instruction"]}
### Вхідні дані:
{data_point["input"]}
{data_point["context"]}
### Відповідь:
{data_point["output"]}"""
{data_point["response"]}"""
else:
return f"""Унизу надається інструкція, яка описує завдання. Напиши відповідь, яка правильно доповнює запит.
### Інструкція:
{data_point["instruction"]}
### Відповідь:
{data_point["output"]}"""
{data_point["response"]}"""



Expand All @@ -83,17 +84,16 @@ def tokenize(prompt):
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
max_length=CUTOFF_LEN,
#padding=True#"max_length",
)
return {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}
return result


data = data.shuffle().map(lambda x: tokenize(generate_prompt(x)))

original_size = len(data["train"])
print(f"Source data size: {original_size}")
#hub_token = os.environ["HUB_TOKEN"]
#print(f"Hub token: {hub_token}")

Expand All @@ -114,7 +114,7 @@ def tokenize(prompt):
#hub_token=hub_token,
save_strategy="epoch",
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=1),
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,56 +5,64 @@
import time
import glob
import json
from tqdm import tqdm

# Note: run this script inside the scripts/alpaca folder

#model_name = "facebook/m2m100-12B-avg-5-ckpt"
model_name = "facebook/m2m100_1.2B"
tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-3.3B", src_lang="eng_Latn"
model_name, src_lang="en"
)

device = "cuda"
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-3.3B").to(device)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

instructions = json.load(open('../../data/cc-by-nc/alpaca_data.json'))

instructions = []
with open('../../data/cc-by-sa-3.0/databricks-dolly-15k.jsonl') as f:
for line in f:
instructions.append(json.loads(line))


def translate(sentence):
with no_grad():
inputs = tokenizer(sentence, return_tensors="pt")

inputs = tokenizer(sentence, return_tensors="pt", padding=True)
translated_tokens = model.generate(

**inputs.to(device), forced_bos_token_id=tokenizer.lang_code_to_id["ukr_Cyrl"], max_length=512

**inputs.to(device), forced_bos_token_id=tokenizer.lang_code_to_id["uk"], max_length=1024
)

return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)#[0]


def translate_item(item):
new_item = {}
# for every key in the item
for key, value in item.items():
if len(value.strip()) == 0:
new_item[key] = value
elif key == "category":
new_item[key] = value
else:
# separate the text into paragraphs
parts = []
translated = []
for part in value.split("\n"):
parts.append(part)
if len(part.strip()) == 0:
translated.append(part)
else:
# separate the paragraphs into sentences
sentences = nltk.sent_tokenize(part)
sentence_parts = []
for sentence in sentences:
sentence_parts.append(translate(sentence))

translated.append(" ".join(sentence_parts))
sentence_parts = translate(sentences)
translated.append(" ".join(sentence_parts))
new_item[key] = "\n".join(translated)
return new_item

def translate_and_save(instruct):
idx, item = instruct
translated = translate_item(item)
with open(f"./data/dolly_data_translated_{idx}.json", "w") as f:
json.dump(translated, f)

if __name__ == '__main__':
start = time.perf_counter()
translated = []
Expand All @@ -66,19 +74,11 @@ def translate_item(item):
else:
last_json_id = max(jsons)

mean_times = []
for idx, instruction in enumerate(instructions[last_json_id:], start=last_json_id):
item_start = time.perf_counter()
with open(f"./data/alpaca_data_translated_{idx}.json", "w") as f:
json.dump(translate_item(instruction), f)
item_end = time.perf_counter() - item_start
# calculate time left based on average of 30
if len(mean_times) < 30:
mean_times.append(item_end)
else:
mean_times.pop(0)
mean_times.append(item_end)
print(f"Item {idx + 1}/{total_instructions} finished in {item_end:.2f} seconds. {(sum(mean_times)/len(mean_times)*(total_instructions-last_json_id)/60/60):.2f} hours left.", end='\r')
remaining = total_instructions - last_json_id

for i in tqdm( enumerate(instructions[last_json_id:], start=last_json_id), total=remaining):
translate_and_save(i)

end = time.perf_counter() - start
print(f"Finished in {end} seconds")
print(f"Item processing time is {end/idx} seconds")
print(f"Item processing time is {end/remaining} seconds")

0 comments on commit 8803029

Please sign in to comment.