Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wav2Vec2BertForSequenceClassification. return_attention_mask work wrong #35495

Closed
2 of 4 tasks
HERIUN opened this issue Jan 3, 2025 · 3 comments
Closed
2 of 4 tasks
Labels

Comments

@HERIUN
Copy link
Contributor

HERIUN commented Jan 3, 2025

System Info

  • transformers version: 4.47.1
  • Platform: Linux-6.8.0-47-generic-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • Huggingface_hub version: 0.27.0
  • Safetensors version: 0.4.5
  • Accelerate version: 0.26.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA RTX A5000

Who can help?

@ylacombe

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'am using https://github.com/huggingface/transformers/blob/main/examples/pytorch/audio-classification/run_audio_classification.py

from transformers import (
    AutoConfig,
    AutoFeatureExtractor,
    AutoModelForAudioClassification
)


dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate


labels = raw_datasets["train"].features[data_args.label_column_name].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
        label2id[label] = str(i)
        id2label[str(i)] = label


model_name_or_path = "facebook/w2v-bert-2.0"

feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)
config = AutoConfig.from_pretrained(
        model_args.config_name or model_args.model_name_or_path,
        num_labels=len(label_list),
        label2id=label2id,
        id2label=id2label,
        finetuning_task="audio-classification",
)
model = AutoModelForAudioClassification.from_pretrained(
        model_name_or_path,
        config=config,
)

def train_transforms(batch):
        """Apply train_transforms across a batch."""
        subsampled_wavs = []
        for audio in batch[data_args.audio_column_name]:
            wav = random_subsample(
                audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
            )
            subsampled_wavs.append(wav)
        inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
        output_batch = {model_input_name: inputs.get(model_input_name)}
        output_batch["labels"] = list(batch[data_args.label_column_name])

        return output_batch

def val_transforms(batch):
        """Apply val_transforms across a batch."""
        wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
        inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
        output_batch = {model_input_name: inputs.get(model_input_name)}
        output_batch["labels"] = list(batch[data_args.label_column_name])

raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
raw_datasets["eval"].set_transform(val_transforms, output_all_columns=False)

...
## I don't know..

Expected behavior

When padding in batch, attention_mask will always [1,1,1,...,1]
but, I expect [0,0,0,...,1,1]

@HERIUN HERIUN added the bug label Jan 3, 2025
@HERIUN HERIUN changed the title Wav2Vec2BertForSequenceClassification. return_attention_mask work weried. Wav2Vec2BertForSequenceClassification. return_attention_mask work wrong Jan 3, 2025
@LysandreJik
Copy link
Member

cc @eustlb

@sambhavnoobcoder
Copy link

hey @HERIUN , could you verify if you strictly expected a left padded output , or if even a right padded output like [1,1,1..... 0,0,0] would be equally accurate for you ? I was giving this issue a look and used the following script for replication :

import torch
from datasets import load_dataset
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, AutoConfig

# Load a small test dataset
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
dataset = dataset.select(range(4))  # Just use 4 examples for testing

# Setup labels
label_list = ["label1", "label2"]  # Dummy labels for testing
label2id = {label: str(i) for i, label in enumerate(label_list)}
id2label = {str(i): label for i, label in enumerate(label_list)}

# Load model and feature extractor
model_name = "facebook/w2v-bert-2.0"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
config = AutoConfig.from_pretrained(
    model_name,
    num_labels=len(label_list),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="audio-classification",
)
model = AutoModelForAudioClassification.from_pretrained(model_name, config=config)

# Process audio inputs
audio_inputs = [x["audio"]["array"] for x in dataset]
sampling_rate = dataset.features["audio"].sampling_rate

# Create batch with different length inputs to force padding
inputs = feature_extractor(
    audio_inputs, 
    sampling_rate=sampling_rate,
    padding=True,
    return_tensors="pt"
)

print("Attention mask shape:", inputs.attention_mask.shape)
print("Sample attention mask:\n", inputs.attention_mask[0])

and the output I got was

Attention mask shape: torch.Size([4, 624])
Sample attention mask:
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       dtype=torch.int32)

which is different from your finding of all 1's . I was just checking to verify if this is in line with your expected outputs or not . if not , I will give a look into what is the cause if this issue .

@HERIUN
Copy link
Contributor Author

HERIUN commented Jan 13, 2025

you're right. I was confused..

@HERIUN HERIUN closed this as completed Jan 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants