Skip to content

Commit

Permalink
fix the formats
Browse files Browse the repository at this point in the history
  • Loading branch information
hkwon committed Aug 13, 2024
1 parent 2d4670f commit 85654ec
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
11 changes: 7 additions & 4 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,10 +1022,11 @@ def set_feature_extractor(self, spec, feature_extractor):
spec.feat_layer0.conv.weight = feature_extractor.conv_layers[0].conv.weight
spec.feat_layer0.conv.bias = feature_extractor.conv_layers[0].conv.bias
self.set_layer_norm(
spec.feat_layer0.layer_norm,
feature_extractor.conv_layers[0].layer_norm
spec.feat_layer0.layer_norm, feature_extractor.conv_layers[0].layer_norm
)
for spec_layer, module_layer in zip(spec.feat_layer, feature_extractor.conv_layers[1:]):
for spec_layer, module_layer in zip(
spec.feat_layer, feature_extractor.conv_layers[1:]
):
spec_layer.conv.weight = module_layer.conv.weight
spec_layer.conv.bias = module_layer.conv.bias
self.set_layer_norm(spec_layer.layer_norm, module_layer.layer_norm)
Expand All @@ -1037,7 +1038,9 @@ def set_feature_projection(self, spec, feature_projection):
def set_pos_conv_embed(self, spec, encoder, config):
# forcing parameters to be set because some transformers version initializes garbage numbers
# conv parameters are float16 so force float32 for the loading
encoder.pos_conv_embed.conv.weight.data = encoder.pos_conv_embed.conv.weight.data.float()
encoder.pos_conv_embed.conv.weight.data = (
encoder.pos_conv_embed.conv.weight.data.float()
)
encoder.pos_conv_embed.conv.bias.data = encoder.pos_conv_embed.conv.bias.float()
for param in encoder.pos_conv_embed.parameters():
param.data = param.data.float()
Expand Down
10 changes: 2 additions & 8 deletions python/ctranslate2/specs/wav2vec2_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ def __init__(self):
class Wav2Vec2Spec(model_spec.LanguageModelSpec):
def __init__(self, feat_layers, num_layers, num_heads):
super().__init__()
self.encoder = Wav2Vec2EncoderSpec(
feat_layers,
num_layers,
num_heads
)
self.encoder = Wav2Vec2EncoderSpec(feat_layers, num_layers, num_heads)

@property
def name(self):
Expand Down Expand Up @@ -52,9 +48,7 @@ class Wav2Vec2EncoderSpec(model_spec.LayerSpec):
def __init__(self, feat_layers, num_layers, num_heads):
self.num_heads = np.dtype("int16").type(num_heads)
self.feat_layer0 = Wav2Vec2LayerNormConvLayer()
self.feat_layer = [
Wav2Vec2LayerNormConvLayer() for i in range(feat_layers - 1)
]
self.feat_layer = [Wav2Vec2LayerNormConvLayer() for i in range(feat_layers - 1)]
self.fp_layer_norm = common_spec.LayerNormSpec()
self.fp_projection = common_spec.LinearSpec()
self.pos_conv_embed = Wav2Vec2PosEmbedConvLayer()
Expand Down
10 changes: 7 additions & 3 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,9 @@ def test_transformers_wav2vec2(

w2v2_processor = transformers.Wav2Vec2Processor.from_pretrained(model_name)
w2v2_processor.save_pretrained(output_dir + "/wav2vec2_processor")
processor = transformers.AutoProcessor.from_pretrained(output_dir + "/wav2vec2_processor")
processor = transformers.AutoProcessor.from_pretrained(
output_dir + "/wav2vec2_processor"
)

device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0))
Expand All @@ -1007,12 +1009,14 @@ def test_transformers_wav2vec2(

hidden_states = np.ascontiguousarray(input_values.unsqueeze(0))
hidden_states = ctranslate2.StorageView.from_array(hidden_states)
to_cpu = (model.device == "cuda" and len(model.device_index) > 1)
to_cpu = model.device == "cuda" and len(model.device_index) > 1
output = model.encode(hidden_states, to_cpu=to_cpu)
if model.device == "cuda":
logits = torch.as_tensor(output, device=model.device)[0]
else:
logits = torch.as_tensor(np.array(output), dtype=torch.float32, device=model.device)[0]
logits = torch.as_tensor(
np.array(output), dtype=torch.float32, device=model.device
)[0]

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids, output_word_offsets=True)
Expand Down

0 comments on commit 85654ec

Please sign in to comment.