Skip to content

Commit

Permalink
[Hexaflake] 20240614 change config (#579)
Browse files Browse the repository at this point in the history
* test merge

* Hexaflake-20240607

* Hexaflake-20240614 change config
  • Loading branch information
wangxiaoyuvvv authored Jun 26, 2024
1 parent ee54245 commit 1f21be5
Show file tree
Hide file tree
Showing 10 changed files with 1,984 additions and 2 deletions.
20 changes: 20 additions & 0 deletions inference/benchmarks/llama2_7b_mmlu/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,23 @@
from .export import export_model
from .evaluator import evaluator
from .forward import model_forward, engine_forward
import os

env = os.environ['vendor']

if env == "hexaflake":
from .hexaflake import hx_dataloader
build_dataloader = hx_dataloader

from .hexaflake import hx_model
create_model = hx_model

from .hexaflake import hx_export_model
export_model = hx_export_model

from .hexaflake import hx_evaluator
evaluator = hx_evaluator

from .hexaflake import hx_model_forward, hx_engine_forward
model_forward = hx_model_forward
engine_forward = hx_engine_forward
206 changes: 206 additions & 0 deletions inference/benchmarks/llama2_7b_mmlu/pytorch/hexaflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import os
import pandas as pd
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader, Dataset
from loguru import logger
from tools import torch_sync
import time

TASKS = [
'abstract_algebra',
'anatomy',
'astronomy',
'business_ethics',
'clinical_knowledge',
'college_biology',
'college_chemistry',
'college_computer_science',
'college_mathematics',
'college_medicine',
'college_physics',
'computer_security',
'conceptual_physics',
'econometrics',
'electrical_engineering',
'elementary_mathematics',
'formal_logic',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_computer_science',
'high_school_european_history',
'high_school_geography',
'high_school_government_and_politics',
'high_school_macroeconomics',
'high_school_mathematics',
'high_school_microeconomics',
'high_school_physics',
'high_school_psychology',
'high_school_statistics',
'high_school_us_history',
'high_school_world_history',
'human_aging',
'human_sexuality',
'international_law',
'jurisprudence',
'logical_fallacies',
'machine_learning',
'management',
'marketing',
'medical_genetics',
'miscellaneous',
'moral_disputes',
'moral_scenarios',
'nutrition',
'philosophy',
'prehistory',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_studies',
'sociology',
'us_foreign_policy',
'virology',
'world_religions'
]
choices = ["A", "B", "C", "D"]

def format_subject(subject):
l = subject.split("_")
s = ""
for entry in l:
s += " " + entry
return s


def gen_prompt(train_df, subject, k=-1):
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject))
if k == -1:
k = train_df.shape[0]
for i in range(k):
prompt += format_example(train_df, i)
return prompt


def format_example(df, idx, include_answer=True):
prompt = df.iloc[idx, 0]
k = df.shape[1] - 2
for j in range(k):
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1])
prompt += "\nAnswer:"
if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt



def mmlu(config):
tokenizer = AutoTokenizer.from_pretrained(os.path.join(config.data_dir, config.weight_dir))
records = []
length = 0
labels = []

for task in TASKS:

logger.debug("Loading 5-shot " + str(task))

dev_df = pd.read_csv(os.path.join(config.data_dir, config.mmlu_dir, "dev", task + "_dev.csv"), header=None)[:config.few_shots]
test_df = pd.read_csv(os.path.join(config.data_dir, config.mmlu_dir, "test", task + "_test.csv"), header=None)

for i in range(test_df.shape[0]):
k = config.few_shots
label = test_df.iloc[i, test_df.shape[1]-1]
prompt_end = format_example(test_df, i, include_answer=False)
prompt = gen_prompt(dev_df, task, k)
prompt = prompt + prompt_end
while len(tokenizer.tokenize(prompt)) + 1> 2048:
prompt_split = prompt.split("\n\n")
prompt_split.pop(1)
prompt = "\n\n".join(prompt_split)
records.append(prompt)
labels.append(label)
return records, labels


def hx_dataloader(config):
dataset = mmlu(config)
assert config.batch_size == 1

return dataset


def hx_model_forward(model, dataloader, evaluator, config):
if config.no_validation:
return None, None, None
pass


def cal_perf(config, tokens, duration, core_time, str_prefix):
model_forward_perf = config.repeat * tokens / duration
logger.info(str_prefix + "(" + config.framework + ") Perf: " +
str(model_forward_perf) + " tps")
model_forward_core_perf = config.repeat * tokens / core_time
logger.info(str_prefix + "(" + config.framework + ") core Perf: " +
str(model_forward_core_perf) + " tps")
return round(model_forward_perf, 3), round(model_forward_core_perf, 3)


def hx_engine_forward(model, dataloader, evaluator, config):
start = time.time()
core_time = 0.0
foo_time = 0.0

token_cnt = 0
correct = 0
whole = 0

for times in range(config.repeat):

logger.debug("Repeat: " + str(times + 1))
data = dataloader[0]
label = dataloader[1]
for i in range(len(data)):
with torch.no_grad():
torch_sync(config)
core_time_start = time.time()

y = model(data[i])

torch_sync(config)
core_time += time.time() - core_time_start

token_cnt += y[2]
foo_time += y[1]
model_outputs = y[0]

r = evaluator(model_outputs, label[i])

correct += r
whole += 1

logger.info("MMLU" + str(config.few_shots) + "-shots Acc: " + str(correct / whole))

duration = time.time() - start
model_forward_perf, model_forward_core_perf = cal_perf(
config, token_cnt, duration, core_time - foo_time, "Inference")

return model_forward_perf, model_forward_core_perf, round(correct / whole, 3)


def hx_model(config):
pass


def hx_export_model(model,config):
return None


def hx_evaluator(pred, y):
if pred == y:
return 1
else:
return 0

18 changes: 18 additions & 0 deletions inference/docker_images/hexaflake/hexaflake_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import re

def analysis_log(logpath):
logfile = open(logpath)

max_usage = 0.0 ## usage_mem
max_mem = 0.0
for line in logfile.readlines():
'''
hxsmi pwr DTemp MUsed Mem
'''
if "/" in line:
line = line[:-1]
match = re.search(r'(\d+)MiB / (\d+)MiB', line)
max_mem = float(match.group(2))
usage = float(match.group(1))
max_usage = max(max_usage, usage)
return round(max_usage, 2), round(max_mem, 2), eval("30e12"), eval("120e12")
Loading

0 comments on commit 1f21be5

Please sign in to comment.