Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context_logits for eval accuracy calculation in case of multi token prediction tasks #11753

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
13 changes: 7 additions & 6 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def deploy(
max_input_len: int = 256,
max_output_len: int = 256,
max_batch_size: int = 8,
output_context_logits: bool = True,
output_generation_logits: bool = True,
):
"""
Expand Down Expand Up @@ -364,8 +365,11 @@ def deploy(
Needs to be True to be able to run evaluation. Default: True.
openai_format_response (bool): Return the response from PyTriton server in OpenAI compatible format. Needs to
be True while running evaluation. Default: True.
output_context_logits (bool): If True builds trtllm engine with gather_context_logits set to True. Default: True.
context_logits are used to compute the logProb of the output token in case of multi token prediction benchmarks.
output_generation_logits (bool): If True builds trtllm engine with gather_generation_logits set to True.
generation_logits are used to compute the logProb of the output token. Default: True.
generation_logits are used to compute the logProb of the output token in case of single token prediction
benchmarks (like MMLU, lambada). Default: True.
"""
from nemo.collections.llm.deploy.base import get_trtllm_deployable, unset_environment_variables
from nemo.deploy import DeployPyTriton
Expand All @@ -383,6 +387,7 @@ def deploy(
max_output_len,
max_batch_size,
dtype,
output_context_logits,
output_generation_logits,
)

Expand Down Expand Up @@ -425,7 +430,6 @@ def evaluate(
limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000,
# inference params
max_tokens_to_generate: Optional[int] = 256,
temperature: Optional[float] = 0.000000001,
top_p: Optional[float] = 0.0,
top_k: Optional[int] = 1,
Expand Down Expand Up @@ -454,7 +458,6 @@ def evaluate(
bootstrap_iters (int): Number of iterations for bootstrap statistics, used when calculating stderrs. Set to 0
for no stderr calculations to be performed. Default: 100000.
# inference params
max_tokens_to_generate (int): max tokens to generate. Default: 256.
temperature: Optional[float]: float value between 0 and 1. temp of 0 indicates greedy decoding, where the token
with highest prob is chosen. Temperature can't be set to 0.0 currently, due to a bug with TRTLLM
(# TODO to be investigated). Hence using a very samll value as the default. Default: 0.000000001.
Expand All @@ -480,9 +483,7 @@ def evaluate(
# Wait for server to be ready before starting evaluation
evaluation.wait_for_server_ready(url=url, triton_http_port=triton_http_port, model_name=model_name)
# Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate
model = evaluation.NeMoFWLMEval(
model_name, url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos
)
model = evaluation.NeMoFWLMEval(model_name, url, tokenizer, temperature, top_p, top_k, add_bos)
results = evaluator.simple_evaluate(
model=model,
tasks=eval_task,
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/deploy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def get_trtllm_deployable(
max_output_len,
max_batch_size,
dtype,
output_context_logits,
output_generation_logits,
):
"""
Expand Down Expand Up @@ -109,6 +110,7 @@ def get_trtllm_deployable(
max_output_len=max_output_len,
max_batch_size=max_batch_size,
dtype=dtype,
gather_context_logits=output_context_logits,
gather_generation_logits=output_generation_logits,
)
except Exception as error:
Expand Down
61 changes: 46 additions & 15 deletions nemo/collections/llm/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,51 @@
Created based on: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.4/docs/model_guide.md
"""

def __init__(self, model_name, api_url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos):
def __init__(self, model_name, api_url, tokenizer, temperature, top_p, top_k, add_bos):
self.model_name = model_name
self.api_url = api_url
self.tokenizer = tokenizer
self.max_tokens_to_generate = max_tokens_to_generate
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.add_bos = add_bos
super().__init__()

def _generate_tokens_logits(self, payload, return_text: bool = False, return_logits: bool = False):
def _generate_tokens_logits(
self, payload, single_prediction_token, return_text: bool = False, return_logits: bool = False
):
Comment on lines +46 to +48

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
"""
A private method that sends post request to the model on PyTriton server and returns either generated text or
logits.
"""
nq = NemoQueryLLM(url=self.api_url, model_name=payload['model'])

output_context_logits = False
output_generation_logits = False
if single_prediction_token:
# In case of single token prediction return the generation logits
output_generation_logits = True
else:
# In case of multiple token prediction return the context logits
output_context_logits = True
response = nq.query_llm(
prompts=payload['prompt'] if isinstance(payload['prompt'], list) else [payload['prompt']],
max_output_len=payload['max_tokens'],
top_k=payload['top_k'],
top_p=payload['top_p'],
temperature=payload['temperature'],
output_generation_logits=True,
output_context_logits=output_context_logits,
output_generation_logits=output_generation_logits,
openai_format_response=True,
)

if return_text:
return response["choices"][0]["text"] # shape[batch_size, 1]
if return_logits:
return response["choices"][0]["generation_logits"] # shape[batch_size, 1, num_tokens, vocab_size]
elif return_logits:
if output_context_logits:
return response["choices"][0]["context_logits"]
else:
return response["choices"][0]["generation_logits"]

def tokenizer_type(self, tokenizer):
"""
Expand Down Expand Up @@ -93,6 +106,16 @@
elif tokenizer_type == "AutoTokenizer":
special_tokens_kwargs['add_special_tokens'] = self.add_bos

single_prediction_token = False
# Assuming evaluating on only one benchmark/task at a time, hence all instances in requests are of the same
# task.
mmlu_regex_pattern = r"^mmlu_"
lambada_regex_pattern = r"^lambada_"
if re.match(mmlu_regex_pattern, requests[0].task_name) or re.match(
lambada_regex_pattern, requests[0].task_name
):
single_prediction_token = True

results = []
for request in tqdm(requests):
# get the input prompt from the request
Expand All @@ -105,31 +128,39 @@
if self.tokenizer_type(self.tokenizer) == "SentencePieceTokenizer":
continuation_enc = continuation_enc[1:]
num_cont_tokens = len(continuation_enc)
# Update self.max_tokens_to_generate with number of continuation tokens (or output tokens) in the request
self.max_tokens_to_generate = num_cont_tokens
# Hard code max_tokens_to_generate to 1 to always generate just 1 token
self.max_tokens_to_generate = 1
# Delete the last token from continuation before passing it to the ip prompt by replacing with empty string
prompt = context + continuation.replace(self.tokenizer.tokenizer.decode(continuation_enc[-1]), "")
# Create payload to query the model deployed on PyTriton server
payload = {
"model": self.model_name,
"prompt": context,
"prompt": prompt,
"max_tokens": self.max_tokens_to_generate,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
}
# Get the logits from the model
generation_logits = self._generate_tokens_logits(payload, return_logits=True)
# Convert generation_logits to torch tensor to easily get logprobs wo manual implementation of log_softmax
multi_logits = F.log_softmax(torch.tensor(generation_logits[0]), dim=-1)
logits = self._generate_tokens_logits(payload, single_prediction_token, return_logits=True)
# In case of multiple token prediction where full context logits are returned, get only logits
# corresponding to the continuation tokens from the context logits tensor.context_logits contains logits
# for all tokens in the ip prompt along with the logit for the next token prediction after the final token
# in the prompt. Shape of context_logits: [1, #tokens_in_prompt+1, vocab_size]
if not single_prediction_token:
logits = logits[:, -num_cont_tokens:, :]
# Convert logits to torch tensor to easily get logprobs wo manual implementation of log_softmax
logProbs = F.log_softmax(torch.tensor(logits), dim=-1)
# Convert encoded continuation tokens to torch tensor
cont_toks = torch.tensor(continuation_enc, dtype=torch.long).unsqueeze(0)
# Get the greedy token from the logits (i.e token with the highest prob)
greedy_tokens = multi_logits.argmax(dim=-1)
greedy_tokens = logProbs.argmax(dim=-1)
# Check if all greedy_tokens match the the actual continuation tokens
is_greedy = (greedy_tokens == cont_toks).all()
# Get the logits corresponding to the actual continuation tokens
logits = torch.gather(multi_logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
logProbs_actual = torch.gather(logProbs, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
# result is tuple of logProb of generating the continuation token and is_greedy
result = (float(logits.sum()), bool(is_greedy))
result = (float(logProbs_actual.sum()), bool(is_greedy))

results.append(result)

Expand Down
6 changes: 6 additions & 0 deletions nemo/deploy/nlp/query_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def query_llm(
end_strings=None,
init_timeout=60.0,
openai_format_response: bool = False,
output_context_logits: bool = False,
output_generation_logits: bool = False,
):
"""
Expand Down Expand Up @@ -275,6 +276,9 @@ def query_llm(
if end_strings is not None:
inputs["end_strings"] = str_list2numpy(end_strings)

if output_context_logits is not None:
inputs["output_context_logits"] = np.full(prompts.shape, output_context_logits, dtype=np.bool_)

if output_generation_logits is not None:
inputs["output_generation_logits"] = np.full(prompts.shape, output_generation_logits, dtype=np.bool_)

Expand All @@ -301,6 +305,8 @@ def query_llm(
}
if output_generation_logits:
openai_response["choices"][0]["generation_logits"] = result_dict["generation_logits"]
if output_context_logits:
openai_response["choices"][0]["context_logits"] = result_dict["context_logits"]
return openai_response
else:
return sentences
Expand Down
18 changes: 17 additions & 1 deletion nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@ def forward(
prompt_embeddings_checkpoint_path: str = None,
streaming: bool = False,
output_log_probs: bool = False,
output_context_logits: bool = False,
output_generation_logits: bool = False,
**sampling_kwargs,
):
Expand Down Expand Up @@ -1049,6 +1050,7 @@ def forward(
no_repeat_ngram_size=no_repeat_ngram_size,
output_log_probs=output_log_probs,
multiprocessed_env=multiprocessed_env,
output_context_logits=output_context_logits,
output_generation_logits=output_generation_logits,
**sampling_kwargs,
)
Expand Down Expand Up @@ -1133,6 +1135,7 @@ def get_triton_input(self):
Tensor(name="no_repeat_ngram_size", shape=(-1,), dtype=np.single, optional=True),
Tensor(name="task_id", shape=(-1,), dtype=bytes, optional=True),
Tensor(name="lora_uids", shape=(-1,), dtype=bytes, optional=True),
Tensor(name="output_context_logits", shape=(-1,), dtype=np.bool_, optional=False),
Tensor(name="output_generation_logits", shape=(-1,), dtype=np.bool_, optional=False),
)
return inputs
Expand All @@ -1142,13 +1145,15 @@ def get_triton_output(self):
outputs = (
Tensor(name="outputs", shape=(-1,), dtype=bytes),
Tensor(name="generation_logits", shape=(-1,), dtype=np.single),
Tensor(name="context_logits", shape=(-1,), dtype=np.single),
)
return outputs

@batch
def triton_infer_fn(self, **inputs: np.ndarray):
"""Triton infer function for streaming"""
output_dict = {}
context_logits_available = False
generation_logits_available = False
try:
infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))}
Expand Down Expand Up @@ -1179,10 +1184,21 @@ def triton_infer_fn(self, **inputs: np.ndarray):
if "output_generation_logits" in inputs:
generation_logits_available = inputs["output_generation_logits"]
infer_input["output_generation_logits"] = inputs.pop("output_generation_logits")[0][0]
if "output_context_logits" in inputs:
context_logits_available = inputs["output_context_logits"]
Fixed Show fixed Hide fixed
infer_input["output_context_logits"] = inputs.pop("output_context_logits")[0][0]

if generation_logits_available:
output_texts, generation_logits = self.forward(**infer_input)
output_dict["generation_logits"] = np.array(generation_logits.cpu().numpy())
# generation_logits is a 4d tensor of dim [1,1,#generated_tokens, vocab_size], return just the 3d tensor
# in output dict.
output_dict["generation_logits"] = np.array(generation_logits[0].cpu().numpy())
elif context_logits_available:
output_texts, context_logits = self.forward(**infer_input)
# convert context logits to 3d tensor from list since its avaiable as a list of tensor shaped
# [#tokens, vocab_size]
context_logits = context_logits[0].unsqueeze(0)
output_dict["context_logits"] = np.array(context_logits.cpu().numpy())
else:
output_texts = self.forward(**infer_input)
output_dict["outputs"] = cast_output(output_texts, np.bytes_)
Expand Down
3 changes: 3 additions & 0 deletions nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ def generate(
streaming: bool = False,
output_log_probs=False,
multiprocessed_env=False,
output_context_logits=False,
output_generation_logits=False,
**sampling_kwargs,
) -> Optional[List[List[str]]]:
Expand Down Expand Up @@ -709,6 +710,8 @@ def generate(

if output_generation_logits:
return output_lines_list, outputs['generation_logits']
elif output_context_logits:
return output_lines_list, outputs['context_logits']
return output_lines_list


Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements_eval.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Installs EleutherAI's lm-evaluation-harness https://github.com/EleutherAI/lm-evaluation-harness/tree/main
lm-eval
Loading