Skip to content

Commit

Permalink
use transformers output
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Mar 25, 2024
1 parent 4d3dc40 commit e4893c4
Showing 1 changed file with 75 additions and 72 deletions.
147 changes: 75 additions & 72 deletions rnabert/modeling_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutput

from .configuration_rnabert import RnaBertConfig

Expand Down Expand Up @@ -53,24 +54,17 @@ def forward(self, input_ids, token_type_ids=None):
class RnaBertLayer(nn.Module):
def __init__(self, config):
super().__init__()

self.attention = RnaBertAttention(config)

self.intermediate = RnaBertIntermediate(config)

self.output = RnaBertOutput(config)

def forward(self, hidden_states, attention_mask, attention_show_flg=False):
if attention_show_flg:
attention_output, attention_probs = self.attention(hidden_states, attention_mask, attention_show_flg)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output, attention_probs
else:
attention_output = self.attention(hidden_states, attention_mask, attention_show_flg)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output # [batch_size, seq_length, hidden_size]
def forward(self, hidden_states, attention_mask, output_attentions=False):
self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
attention_output, outputs = self_attention_outputs[0], self_attention_outputs[1:]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + outputs
return outputs


class RnaBertAttention(nn.Module):
Expand All @@ -79,15 +73,11 @@ def __init__(self, config):
self.selfattn = RnaBertSelfAttention(config)
self.output = RnaBertSelfOutput(config)

def forward(self, input_tensor, attention_mask, attention_show_flg=False):
if attention_show_flg:
self_output, attention_probs = self.selfattn(input_tensor, attention_mask, attention_show_flg)
attention_output = self.output(self_output, input_tensor)
return attention_output, attention_probs
else:
self_output = self.selfattn(input_tensor, attention_mask, attention_show_flg)
attention_output = self.output(self_output, input_tensor)
return attention_output
def forward(self, hidden_states, attention_mask, output_attentions=False):
self_outputs = self.selfattn(hidden_states, attention_mask, output_attentions=output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs


class RnaBertSelfAttention(nn.Module):
Expand All @@ -114,7 +104,7 @@ def transpose_for_scores(self, x):
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(self, hidden_states, attention_mask, attention_show_flg=False):
def forward(self, hidden_states, attention_mask, output_attentions=False):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
Expand All @@ -136,16 +126,13 @@ def forward(self, hidden_states, attention_mask, attention_show_flg=False):
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)

if attention_show_flg:
return context_layer, attention_probs
else:
return context_layer
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs


class RnaBertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()

self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
Expand Down Expand Up @@ -203,25 +190,40 @@ def forward(
self,
hidden_states,
attention_mask,
output_all_encoded_layers=True,
attention_show_flg=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_encoder_layers = []
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for layer in self.layer:
if attention_show_flg:
hidden_states, attention_probs = layer(hidden_states, attention_mask, attention_show_flg)
else:
hidden_states = layer(hidden_states, attention_mask, attention_show_flg)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)

if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)

if attention_show_flg:
return all_encoder_layers, attention_probs
else:
return all_encoder_layers
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer(hidden_states, attention_mask, output_attentions)
hidden_states = layer_outputs[0]

if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(
v
for v in [
hidden_states,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)


class RnaBertPooler(nn.Module):
Expand Down Expand Up @@ -283,8 +285,9 @@ def forward(
input_ids,
token_type_ids=None,
attention_mask=None,
output_all_encoded_layers=True,
attention_show_flg=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
Expand All @@ -296,32 +299,32 @@ def forward(
extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

embedding_output = self.embeddings(input_ids, token_type_ids)

if attention_show_flg:
encoded_layers, attention_probs = self.encoder(
embedding_output,
extended_attention_mask,
output_all_encoded_layers,
attention_show_flg,
)
else:
encoded_layers = self.encoder(
embedding_output,
extended_attention_mask,
output_all_encoded_layers,
attention_show_flg,
)

pooled_output = self.pooler(encoded_layers[-1])
if input_ids.max() > 5:
import ipdb; ipdb.set_trace()
embedding_output = self.embeddings(
input_ids=input_ids,
token_type_ids=token_type_ids,
# attention_mask=attention_mask,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

if attention_show_flg:
return encoded_layers, pooled_output, attention_probs
else:
return encoded_layers, pooled_output
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


class RnaBertPreTrainingHeads(nn.Module):
Expand Down

0 comments on commit e4893c4

Please sign in to comment.