diff --git a/BERT_explainability/modules/BERT/BertForSequenceClassification.py b/BERT_explainability/modules/BERT/BertForSequenceClassification.py index 71162e5..d9eb2da 100644 --- a/BERT_explainability/modules/BERT/BertForSequenceClassification.py +++ b/BERT_explainability/modules/BERT/BertForSequenceClassification.py @@ -7,6 +7,8 @@ from typing import List, Any import torch from BERT_rationale_benchmark.models.model_utils import PaddedSequence +from transformers.modeling_outputs import SequenceClassifierOutput + class BertForSequenceClassification(BertPreTrainedModel):