Skip to content

Commit

Permalink
Added support for custom EncoderDecoder models (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
ribesstefano authored Nov 6, 2023
1 parent c273b18 commit 654543a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ def __init__(self, config, **kwargs):
self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()

# some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m
if hasattr(config, "hidden_size"):
hidden_size = config.hidden_size
if hasattr(config, "word_embed_proj_dim"):
hidden_size = config.word_embed_proj_dim
else:
hidden_size = config.hidden_size
elif hasattr(config, "is_encoder_decoder"):
if config.is_encoder_decoder and hasattr(config, "decoder"):
if hasattr(config.decoder, "hidden_size"):
hidden_size = config.decoder.hidden_size

self.summary = nn.Linear(hidden_size, 1)

Expand Down

0 comments on commit 654543a

Please sign in to comment.