Skip to content

Commit

Permalink
Wav2Vec2 upgrade with Conv1D options (#1758)
Browse files Browse the repository at this point in the history
* Wav2Vec2 upgrade with Conv1D options

* refining scripts

* refining script again

* fix the formats

* fix the isort format

* refining the library

* update based on the suggestions

* update the variable name

* adding unk_token removal for the Python testing

* adding whitespace

* update Python format

* update variables

* update variables

* update variables

* update variables

---------

Co-authored-by: hkwon <[email protected]>
  • Loading branch information
homink and hkwon authored Aug 19, 2024
1 parent d202032 commit 8ba828c
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 117 deletions.
55 changes: 53 additions & 2 deletions include/ctranslate2/layers/wav2vec2.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,52 @@
namespace ctranslate2 {
namespace layers {

class Wav2Vec2LayerNormConvLayer : public Layer {
public:
Wav2Vec2LayerNormConvLayer(const models::Model& model,
const std::string& scope,
dim_t stride,
dim_t padding);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
dim_t _stride;
dim_t _padding;
const Conv1D _conv;
const LayerNorm _output_norm;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class Wav2Vec2PosConvLayer : public Layer {
public:
Wav2Vec2PosConvLayer(const models::Model& model, const std::string& scope);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
const Conv1D _conv;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class Wav2Vec2Encoder : public Layer {
public:
Wav2Vec2Encoder(const models::Model& model, const std::string& scope);
Expand Down Expand Up @@ -35,12 +81,17 @@ namespace ctranslate2 {
}

private:
const Wav2Vec2LayerNormConvLayer _feat_layer0;
const std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>> _feat_layers;
const LayerNorm _fp_norm;
const Dense _fp_ff;
const Wav2Vec2PosConvLayer _pos_conv_embed;
const ops::Transpose _transpose;
const ops::GELU _gelu;
// wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
//const ops::Transpose _transpose;
const dim_t _num_heads;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const LayerNorm _output_norm;
const Dense _lm_head;
};

}
Expand Down
45 changes: 38 additions & 7 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,8 @@ def architecture_name(self):
return "Wav2Vec2ForCTC"

def get_model_spec(self, model):
# Wav2Vec2 encoder Wav2Vec2PositionalConvEmbedding conv1d has groups 16
# that doesn't look available here so we make Wav2Vec2 encoder layers only
spec = wav2vec2_spec.Wav2Vec2Spec(
model.wav2vec2.config.num_feat_extract_layers,
model.wav2vec2.encoder.config.num_hidden_layers,
model.wav2vec2.encoder.config.num_attention_heads,
)
Expand All @@ -1007,9 +1006,7 @@ def get_model_spec(self, model):
layer.fc1 = layer.feed_forward.intermediate_dense
layer.fc2 = layer.feed_forward.output_dense

self.set_encoder(spec.encoder, model.wav2vec2.encoder)
self.set_linear(spec.lm_head, model.lm_head)
# only for Wav2Vec2Spec.get_vocabulary_size()
self.set_encoder(spec.encoder, model, model.wav2vec2.config)
return spec

def set_config(self, config, model, tokenizer):
Expand All @@ -1021,8 +1018,42 @@ def get_vocabulary(self, model, tokenizer):
def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_encoder(self, spec, encoder):
super().set_encoder(spec, encoder)
def set_feature_extractor(self, spec, feature_extractor):
spec.feat_layer0.conv.weight = feature_extractor.conv_layers[0].conv.weight
spec.feat_layer0.conv.bias = feature_extractor.conv_layers[0].conv.bias
self.set_layer_norm(
spec.feat_layer0.layer_norm, feature_extractor.conv_layers[0].layer_norm
)
for spec_layer, module_layer in zip(
spec.feat_layer, feature_extractor.conv_layers[1:]
):
spec_layer.conv.weight = module_layer.conv.weight
spec_layer.conv.bias = module_layer.conv.bias
self.set_layer_norm(spec_layer.layer_norm, module_layer.layer_norm)

def set_feature_projection(self, spec, feature_projection):
self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm)
self.set_linear(spec.fp_projection, feature_projection.projection)

def set_pos_conv_embed(self, spec, encoder, config):
# forcing parameters to be set because some transformers version initializes garbage numbers
# conv parameters are float16 so force float32 for the loading
encoder.pos_conv_embed.conv.weight.data = (
encoder.pos_conv_embed.conv.weight.data.float()
)
encoder.pos_conv_embed.conv.bias.data = encoder.pos_conv_embed.conv.bias.float()
for param in encoder.pos_conv_embed.parameters():
param.data = param.data.float()
encoder.pos_conv_embed(torch.randn((1, 1, config.hidden_size)))
spec.pos_conv_embed.conv.weight = encoder.pos_conv_embed.conv.weight
spec.pos_conv_embed.conv.bias = encoder.pos_conv_embed.conv.bias

def set_encoder(self, spec, model, config):
self.set_feature_extractor(spec, model.wav2vec2.feature_extractor)
self.set_feature_projection(spec, model.wav2vec2.feature_projection)
self.set_pos_conv_embed(spec, model.wav2vec2.encoder, config)
super().set_encoder(spec, model.wav2vec2.encoder)
self.set_linear(spec.lm_head, model.lm_head)

def set_common_layers(self, spec, module):
self.set_layer_norm(spec.layer_norm, module.layer_norm)
Expand Down
27 changes: 21 additions & 6 deletions python/ctranslate2/specs/wav2vec2_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ def __init__(self):


class Wav2Vec2Spec(model_spec.LanguageModelSpec):
def __init__(self, num_layers, num_heads):
def __init__(self, feat_layers, num_layers, num_heads):
super().__init__()
self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads)
self.lm_head = common_spec.LinearSpec()
self.encoder = Wav2Vec2EncoderSpec(feat_layers, num_layers, num_heads)

@property
def name(self):
Expand All @@ -30,14 +29,30 @@ def get_default_config(self):
return Wav2Vec2Config()

def get_vocabulary_size(self):
return self.lm_head.weight.shape[0]
return self.encoder.lm_head.weight.shape[0]


class Wav2Vec2LayerNormConvLayer(model_spec.LayerSpec):
def __init__(self):
self.conv = common_spec.Conv1DSpec()
self.layer_norm = common_spec.LayerNormSpec()


class Wav2Vec2PosEmbedConvLayer(model_spec.LayerSpec):
def __init__(self):
self.conv = common_spec.Conv1DSpec()


class Wav2Vec2EncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers, num_heads):
def __init__(self, feat_layers, num_layers, num_heads):
self.num_heads = np.dtype("int16").type(num_heads)
# wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
self.feat_layer0 = Wav2Vec2LayerNormConvLayer()
self.feat_layer = [Wav2Vec2LayerNormConvLayer() for i in range(feat_layers - 1)]
self.fp_layer_norm = common_spec.LayerNormSpec()
self.fp_projection = common_spec.LinearSpec()
self.pos_conv_embed = Wav2Vec2PosEmbedConvLayer()
self.layer_norm = common_spec.LayerNormSpec()
self.layer = [
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
]
self.lm_head = common_spec.LinearSpec()
89 changes: 17 additions & 72 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,24 +979,16 @@ def test_transformers_wav2vec2(
)
output_dir = str(tmp_dir.join("ctranslate2_model"))
output_dir = converter.convert(output_dir)
# 24 x Wav2Vec2EncoderLayerStableLayerNorm converted & saved

w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(model_name)
del w2v2_model.wav2vec2.encoder.layers
del w2v2_model.wav2vec2.encoder.layer_norm
w2v2_model.save_pretrained(output_dir + "/wav2vec2_partial.bin")
w2v2_processor = transformers.Wav2Vec2Processor.from_pretrained(model_name)
torch.save(w2v2_processor, output_dir + "/wav2vec2_processor.bin")
w2v2_processor.save_pretrained(output_dir + "/wav2vec2_processor")
processor = transformers.AutoProcessor.from_pretrained(
output_dir + "/wav2vec2_processor"
)

device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0))
w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(
output_dir + "/wav2vec2_partial.bin"
).to(device)
del w2v2_model.wav2vec2.encoder.layers
del w2v2_model.wav2vec2.encoder.layer_norm
w2v2_processor = torch.load(output_dir + "/wav2vec2_processor.bin")
ct2_w2v2_model = ctranslate2.models.Wav2Vec2(
model = ctranslate2.models.Wav2Vec2(
output_dir,
device=device,
device_index=[0],
Expand All @@ -1008,73 +1000,26 @@ def test_transformers_wav2vec2(
speech_array = np.load(
os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy")
)
input_values = w2v2_processor(
input_values = processor(
speech_array,
padding=True,
return_tensors="pt",
sampling_rate=16000,
).input_values

with torch.no_grad():
extract_features = w2v2_model.wav2vec2.feature_extractor(
input_values.to(w2v2_model.device)
).transpose(1, 2)
hidden_states, extract_features = w2v2_model.wav2vec2.feature_projection(
extract_features
)
position_embeddings = w2v2_model.wav2vec2.encoder.pos_conv_embed(
hidden_states
)
hidden_states = position_embeddings + hidden_states
# hidden_states = w2v2_model.encoder.dropout(hidden_states)
# Dropout(p=0.0, inplace=False) bypassed

if ct2_w2v2_model.device == "cuda":
hidden_states = hidden_states.cpu()
else:
hidden_states.numpy()

hidden_states = np.ascontiguousarray(hidden_states)
hidden_states = np.ascontiguousarray(input_values.unsqueeze(0))
hidden_states = ctranslate2.StorageView.from_array(hidden_states)
to_cpu = (
ct2_w2v2_model.device == "cuda" and len(ct2_w2v2_model.device_index) > 1
)
ct2_output = ct2_w2v2_model.encode(
hidden_states,
to_cpu=to_cpu,
) # 24 x Wav2Vec2EncoderLayerStableLayerNorm processed
if ct2_w2v2_model.device == "cuda":
hidden_states = torch.as_tensor(
ct2_output,
device=ct2_w2v2_model.device,
)
to_cpu = model.device == "cuda" and len(model.device_index) > 1
output = model.encode(hidden_states, to_cpu=to_cpu)
if model.device == "cuda":
logits = torch.as_tensor(output, device=model.device)[0]
else:
hidden_states = torch.as_tensor(
np.array(ct2_output),
dtype=torch.float32,
device=ct2_w2v2_model.device,
)

encoder_outputs = transformers.modeling_outputs.BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=None,
attentions=None,
)
hidden_states = encoder_outputs[0]
outputs = transformers.modeling_outputs.Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
hidden_states = outputs[0]
# hidden_states = w2v2_model.dropout(hidden_states)
# Dropout(p=0.0, inplace=False) bypassed

with torch.no_grad():
logits = w2v2_model.lm_head(hidden_states.to(torch.float32))[0]
logits = torch.as_tensor(
np.array(output), dtype=torch.float32, device=model.device
)[0]

predicted_ids = torch.argmax(logits, dim=-1)
transcription = w2v2_processor.decode(predicted_ids, output_word_offsets=True)
transcription = processor.decode(predicted_ids, output_word_offsets=True)
transcription = transcription[0].replace(processor.tokenizer.unk_token, "")

assert transcription[0] == expected_transcription[0]
assert transcription == expected_transcription[0]
Loading

0 comments on commit 8ba828c

Please sign in to comment.