From 654543a8cf90ae83b6e9ea078e8f38c00433e4cc Mon Sep 17 00:00:00 2001 From: Stefan+o Date: Mon, 6 Nov 2023 09:52:10 +0100 Subject: [PATCH] Added support for custom EncoderDecoder models (#911) --- trl/models/modeling_value_head.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py index 6ca057b970..1ce17832c1 100644 --- a/trl/models/modeling_value_head.py +++ b/trl/models/modeling_value_head.py @@ -33,10 +33,14 @@ def __init__(self, config, **kwargs): self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "hidden_size"): + hidden_size = config.hidden_size if hasattr(config, "word_embed_proj_dim"): hidden_size = config.word_embed_proj_dim - else: - hidden_size = config.hidden_size + elif hasattr(config, "is_encoder_decoder"): + if config.is_encoder_decoder and hasattr(config, "decoder"): + if hasattr(config.decoder, "hidden_size"): + hidden_size = config.decoder.hidden_size self.summary = nn.Linear(hidden_size, 1)