Skip to content

Commit

Permalink
Showing 6 changed files with 210 additions and 129 deletions.
193 changes: 119 additions & 74 deletions keras_nlp/models/opt/opt_causal_lm.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,6 @@
from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tf_utils import tensor_to_string_list
from keras_nlp.utils.tf_utils import truncate_at_token


@keras_nlp_export("keras_nlp.models.OPTCausalLM")
@@ -49,7 +48,7 @@ class OPTCausalLM(Task):
default, `"top_k"` sampling will be used.
This model can optionally be configured with a `preprocessor` layer, in
which case it will automatically apply preprocessing to raw inputs during
which case it will automatically apply preprocessing to string inputs during
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
when creating the model with `from_preset()`.
@@ -301,28 +300,23 @@ def make_generate_function(self):

def generate_step(
self,
token_ids,
padding_mask,
inputs,
end_token_id=None,
):
"""A compilable generation function for a single batch of inputs.
This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. It takes in a dense `tf.Tensor` of token
ids, and return a dense `tf.Tensor` of token ids, and includes no
preprocessing. This function is wrapped by the `generate()` method.
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
Args:
token_ids: A dense int Tensor, with shape
`(batch_size, max_length)`. The user provided token ids
padded to `max_length`.
padding_mask: A dense boolean Tensor, with the same shape as
`token_ids`. Positions that are True in the `padding_mask`
are assumed to be user input and never updated.
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
@@ -347,7 +341,7 @@ def next(prompt, cache, index):
cache,
)

return self._sampler(
token_ids = self._sampler(
next=next,
prompt=token_ids,
cache=cache,
@@ -357,6 +351,78 @@ def next(prompt, cache, index):
hidden_states=hidden_states,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
# Build a mask of `end_token_id` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = (token_ids == end_token_id) & (~padding_mask)
end_locations = tf.cast(end_locations, tf.int32)
# Use cumsum to get ones in all locations after end_locations.
overflow = tf.math.cumsum(end_locations, exclusive=True)
# Our padding mask is the inverse of these overflow locations.
padding_mask = ~tf.cast(overflow, tf.bool)
else:
# Without early stopping, all locations will have been updated.
padding_mask = tf.ones_like(token_ids, dtype=tf.bool)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def _normalize_generate_inputs(
self,
inputs,
):
"""Normalize user input to the generate function.
This function coverts all inputs to tensors, adds a batch dimension if
necessary, and returns a iterable "dataset like" object (either an
actual `tf.data.Dataset` or a list with a single batch element).
"""
input_is_scalar = False

if isinstance(inputs, tf.data.Dataset):
return inputs, input_is_scalar

if isinstance(inputs, str) or isinstance(inputs, list):
inputs = tf.convert_to_tensor(inputs)

if isinstance(inputs, tf.Tensor) and inputs.shape.rank == 0:
input_is_scalar = True
inputs = inputs[tf.newaxis]

# We avoid coverting to a dataset purely for speed, for a single batch
# of input, creating a dataset would add significant overhead.
return [inputs], input_is_scalar

def _normalize_generate_outputs(
self,
outputs,
input_is_scalar,
):
"""Normalize user output from the generate function.
This function converts all output to numpy (for integer output), or
python strings (for string output). If a batch dimension was added to
the input, it is removed from the output (so generate can be string in,
string out).
"""

def normalize(x):
x = tf.concat(x, axis=0)
x = tf.squeeze(x, 0) if input_is_scalar else x
is_string = x.dtype == tf.string
# Convert outputs to a friendly pythonic type. For numerical outputs
# that is numpy, for string outputs that is `list` and `str`.
return tensor_to_string_list(x) if is_string else x.numpy()

if isinstance(outputs[0], dict):
return {
"token_ids": normalize([x["token_ids"] for x in outputs]),
"padding_mask": normalize([x["padding_mask"] for x in outputs]),
}
return normalize([x for x in outputs])

def generate(
self,
inputs,
@@ -367,14 +433,14 @@ def generate(
This method generates text based on given `inputs`. The sampling method
used for generation can be set in the `compile` method.
If `inputs` is a `tf.data.Dataset`, outputs will be generated
If `inputs` are a `tf.data.Dataset`, outputs will be generated
"batch-by-batch" and concatenated. Otherwise, all inputs will be handled
as a single batch.
If a `preprocessor` is attached to the model, `inputs` should be
strings and returned sequences will be strings. Otherwise, inputs should
be preprocessed into token ids before calling `generate()`, and returned
sequences will also be token ids.
be preprocessed before calling `generate()`, and returned sequences will
be token ids.
Args:
inputs: a string `tf.Tensor`, a `tf.data.Dataset` of strings, a
@@ -383,73 +449,52 @@ def generate(
`tf.Tensor` or `tf.data.Dataset` with keys `"token_ids"` and
`"padding_mask"`.
max_length: Optional. int. The max length of the generated sequence.
Will default to the configured `sequence_length` of the
Will default to the max configured `sequence_length` of the
`preprocessor`. If `preprocessor` is `None`, `inputs` should be
padded to the desired max length and this argument is ignored.
should be padded to the desired maximum length and this argument
will be ignored.
Returns:
A string or string list if `preprocessor` is set, and a integer
tensor of token ids if `preprocessor is None`.
tensor of token IDs if `preprocessor is None`.
"""
input_is_scalar = False

# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
# 2. Generate new tokens via a compiled function on dense tensors.
# 3. Optionally postprocess dense integer tensors back to string.
generate_function = self.make_generate_function()
end_token_id = None
if self.preprocessor is not None:
end_token_id = self.preprocessor.tokenizer.end_token_id

def preprocess(x):
return self.preprocessor(
x,
sequence_length=max_length,
return_labels=False,
# We do not append an end token by default during
# generation, as generating directly in the same sequence is
# the most common workflow. If an end token directly after
# a prompt is desired, it can be added to the prompt string.
add_end_token=False,
)

if not isinstance(inputs, tf.data.Dataset):
inputs = tf.convert_to_tensor(inputs)
input_is_scalar = inputs.shape.rank == 0
inputs = inputs[tf.newaxis] if input_is_scalar else inputs
# Wrap a list to avoid the overhead of converting to dataset.
inputs = [preprocess(inputs)]
else:
def preprocess(x):
return self.preprocessor.generate_preprocess(
x, sequence_length=max_length
)

def generate(x):
return generate_function(x, end_token_id=end_token_id)

def postprocess(x):
return self.preprocessor.generate_postprocess(x)

# Normalize inputs, apply our three passes, and normalize outputs.
inputs, input_is_scalar = self._normalize_generate_inputs(inputs)

if self.preprocessor is not None:
if isinstance(inputs, tf.data.Dataset):
inputs = inputs.map(preprocess, tf.data.AUTOTUNE)
inputs = inputs.prefetch(tf.data.AUTOTUNE)
else:
if not isinstance(inputs, tf.data.Dataset):
# Wrap a list to avoid the overhead of converting to dataset.
inputs = [inputs]
else:
# Fast path for non-dataset, single-batch input.
inputs = [preprocess(x) for x in inputs]

generate_function = self.make_generate_function()
outputs = []
for batch in inputs:
token_ids, padding_mask = batch["token_ids"], batch["padding_mask"]
# If `preprocessor` is attached, we can stop after `end_token_id``.
end_token_id = None
if self.preprocessor is not None:
end_token_id = self.preprocessor.tokenizer.end_token_id
# Run the compiled generate function.
output = generate_function(token_ids, padding_mask, end_token_id)

if self.preprocessor is not None:
# Truncate to ragged by removing tokens after the first
# generated `end_token_id`.
output = truncate_at_token(output, end_token_id, padding_mask)
# Strip start token if added.
if self.preprocessor.add_start_token:
output = output[:, 1:]
# Detokenize.
output = self.preprocessor.tokenizer.detokenize(output)
outputs.append(output)

outputs = tf.concat(outputs, axis=0)
outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs
# Convert outputs to a friendly pythonic type. For numerical outputs
# that is numpy, for string outputs that is `list` and `str`.
if outputs.dtype == tf.string:
return tensor_to_string_list(outputs)
return outputs.numpy()
outputs = [generate(x) for x in inputs]

if self.preprocessor is not None:
outputs = [postprocess(x) for x in outputs]

return self._normalize_generate_outputs(outputs, input_is_scalar)

@classmethod
def create_layout_map(cls, mesh):
89 changes: 65 additions & 24 deletions keras_nlp/models/opt/opt_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -14,10 +14,14 @@

"""OPT Causal LM preprocessor layer."""

import tensorflow as tf
from absl import logging

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.opt.opt_preprocessor import OPTPreprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@@ -95,36 +99,73 @@ def call(
y=None,
sample_weight=None,
sequence_length=None,
add_start_token=None,
add_end_token=None,
return_labels=True,
):
if y is not None or sample_weight is not None:
logging.warning(
"`OPTCausalLMPreprocessor` generates `y` and `sample_weight` "
"`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` "
"based on your input data, but your data already contains `y` "
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)
if return_labels:
# Tokenize with one extra token to account for the truncation below.
sequence_length = (sequence_length or self.sequence_length) + 1
x = super().call(
sequence_length = sequence_length or self.sequence_length

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
# Pad with one extra token to account for the truncation below.
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length,
add_start_token=add_start_token,
add_end_token=add_end_token,
sequence_length=sequence_length + 1,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)

def generate_preprocess(
self,
x,
sequence_length=None,
):
"""Covert strings to integer token input for generation.
Similar to calling the layer for training, this method takes in strings
or tensor strings, tokenizes and packs the input, and computes a padding
mask masking all inputs not filled in with a padded value.
Unlike calling the the layer for training, this method does not compute
labels and will never append a `tokenizer.end_token_id` to the end of
the sequence (as generation is expected to continue at the end of the
inputted prompt).
"""
x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
x, sequence_length=sequence_length, add_end_value=False
)
if return_labels:
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y = token_ids[..., 1:]
sample_weight = padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)
else:
return x
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def generate_postprocess(
self,
x,
):
"""Covert integer token output to strings for generation.
This method reverses `generate_preprocess()`, by first removing all
padding and start/end tokens, and then converting the interger sequence
back to a string.
"""
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
return self.tokenizer.detokenize(token_ids)
19 changes: 11 additions & 8 deletions keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -109,16 +109,19 @@ def test_dataset(self):
self.assertAllEqual(y, [[3, 4, 5, 3, 6, 1, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)

def test_call_overrides(self):
def test_generate_preprocess(self):
input_data = " airplane at airport"
x, _, _ = self.preprocessor(input_data, add_start_token=False)
self.assertAllEqual(x["token_ids"], [3, 4, 5, 3, 6, 1, 0, 0])
x, _, _ = self.preprocessor(input_data, add_end_token=False)
x = self.preprocessor.generate_preprocess(input_data)
self.assertAllEqual(x["token_ids"], [1, 3, 4, 5, 3, 6, 0, 0])
x, _, _ = self.preprocessor(input_data, sequence_length=4)
self.assertAllEqual(x["token_ids"], [1, 3, 4, 5])
x = self.preprocessor(input_data, return_labels=False)
self.assertAllEqual(x["token_ids"], [1, 3, 4, 5, 3, 6, 1, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])

def test_generate_postprocess(self):
input_data = {
"token_ids": tf.constant([1, 3, 4, 5, 3, 6, 0, 0]),
"padding_mask": tf.cast([1, 1, 1, 1, 1, 1, 0, 0], dtype="bool"),
}
x = self.preprocessor.generate_postprocess(input_data)
self.assertAllEqual(x, " airplane at airport")

def test_serialization(self):
config = keras.utils.serialize_keras_object(self.preprocessor)
15 changes: 11 additions & 4 deletions keras_nlp/models/opt/opt_causal_lm_test.py
Original file line number Diff line number Diff line change
@@ -104,17 +104,24 @@ def test_fit_no_xla(self):

def test_generate(self):
# String input.
prompt = " airplane"
output = self.causal_lm.generate(" airplane")
prompt = " airplane at airport"
output = self.causal_lm.generate(" airplane at airport")
self.assertTrue(prompt in output)
# String tensor input.
self.assertIsInstance(self.causal_lm.generate(self.raw_batch)[0], str)
# String dataset input.
self.assertIsInstance(self.causal_lm.generate(self.raw_dataset)[0], str)
# Int tensor input.
self.causal_lm.preprocessor = None
self.assertDTypeEqual(
self.causal_lm.generate(self.preprocessed_batch), tf.int32
outputs = self.causal_lm.generate(self.preprocessed_batch)
# Assert prompt is in output in token id space.
self.assertAllEqual(
outputs["token_ids"][:, :5],
self.preprocessed_batch["token_ids"][:, :5],
)
self.assertAllEqual(
outputs["padding_mask"][:, :5],
self.preprocessed_batch["padding_mask"][:, :5],
)

def test_generate_compilation(self):
17 changes: 3 additions & 14 deletions keras_nlp/models/opt/opt_preprocessor.py
Original file line number Diff line number Diff line change
@@ -68,10 +68,6 @@ class OPTPreprocessor(Preprocessor):
sample_weight: Any label weight data. Will be passed through unaltered.
sequence_length: Pass to override the configured `sequence_length` of
the layer.
add_start_token: Pass to override the configure value of
`add_start_token` on the layer.
add_end_token: Pass to override the configure value of
`add_end_token` on the layer.
Examples:
@@ -154,8 +150,6 @@ def call(
y=None,
sample_weight=None,
sequence_length=None,
add_start_token=None,
add_end_token=None,
):
x = convert_inputs_to_list_of_tensor_segments(x)
if len(x) != 1:
@@ -165,17 +159,12 @@ def call(
"for a multi-segment classification task, please refer to "
"classification models like BERT or RoBERTa."
)
if sequence_length is None:
sequence_length = self.sequence_length
if add_start_token is None:
add_start_token = self.add_start_token
if add_end_token is None:
add_end_token = self.add_end_token
sequence_length = sequence_length or self.sequence_length
token_ids, padding_mask = self.packer(
self.tokenizer(x[0]),
sequence_length=sequence_length,
add_start_value=add_start_token,
add_end_value=add_end_token,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
x = {
"token_ids": token_ids,
6 changes: 1 addition & 5 deletions keras_nlp/models/opt/opt_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -99,12 +99,8 @@ def test_tokenize_labeled_dataset(self):
self.assertAllEqual(x["token_ids"], [[1, 3, 4, 5, 3, 6, 1, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4)

def test_call_overrides(self):
def test_sequence_length_override(self):
input_data = " airplane at airport"
x = self.preprocessor(input_data, add_start_token=False)
self.assertAllEqual(x["token_ids"], [3, 4, 5, 3, 6, 1, 0, 0])
x = self.preprocessor(input_data, add_end_token=False)
self.assertAllEqual(x["token_ids"], [1, 3, 4, 5, 3, 6, 0, 0])
x = self.preprocessor(input_data, sequence_length=4)
self.assertAllEqual(x["token_ids"], [1, 3, 4, 1])

0 comments on commit ee2015a

Please sign in to comment.