Skip to content

Commit

Permalink
convert-hf : add proper handling of shared tensors containing token e…
Browse files Browse the repository at this point in the history
…mbeddings in T5 models
  • Loading branch information
sszymczy committed Jun 24, 2024
1 parent 5df8988 commit 68b5116
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,10 @@ def write_tensors(self):
class T5Model(Model):
model_arch = gguf.MODEL_ARCH.T5

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.shared_token_embeddings_found = False

def set_vocab(self):
# to avoid TypeError: Descriptors cannot be created directly
# exception when importing sentencepiece_model_pb2
Expand Down Expand Up @@ -2841,12 +2845,17 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight":
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.")
return []
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
# and decoder and ignore the remaining ones.
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
if not self.shared_token_embeddings_found:
name = "shared.weight"
self.shared_token_embeddings_found = True
else:
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
return []

return [(self.map_tensor_name(name), data_torch)]

Expand Down

0 comments on commit 68b5116

Please sign in to comment.