-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update phi3 transformer code and add test
- Loading branch information
Hải Trường
authored and
Hải Trường
committed
Jan 20, 2025
1 parent
01baaca
commit b9ba243
Showing
3 changed files
with
230 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
230 changes: 230 additions & 0 deletions
230
tests/transformers_tests/models/phi3/test_modeling_phi3.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
# This module contains test cases that are defined in the `.test_cases.py` file, structured as lists or tuples like | ||
# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs, outputs_map]. | ||
# | ||
# Each defined case corresponds to a pair consisting of PyTorch and MindSpore modules, including their respective | ||
# initialization parameters and inputs for the forward. The testing framework adopted here is designed to generically | ||
# parse these parameters to assess and compare the precision of forward outcomes between the two frameworks. | ||
# | ||
# In cases where models have unique initialization procedures or require testing with specialized output formats, | ||
# it is necessary to develop distinct, dedicated test cases. | ||
|
||
import inspect | ||
import logging | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from transformers import Phi3Config | ||
|
||
import mindspore as ms | ||
|
||
from tests.modeling_test_utils import ( | ||
MS_DTYPE_MAPPING, | ||
PT_DTYPE_MAPPING, | ||
compute_diffs, | ||
generalized_parse_args, | ||
get_modules, | ||
) | ||
from tests.transformers_tests.models.modeling_common import ids_numpy | ||
|
||
DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-3} | ||
MODES = [0, 1] | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Phi3ModelTester: | ||
def __init__( | ||
self, | ||
batch_size=13, | ||
seq_length=7, | ||
is_training=False, | ||
use_input_mask=True, | ||
use_token_type_ids=False, | ||
use_labels=True, | ||
vocab_size=99, | ||
hidden_size=32, | ||
num_hidden_layers=2, | ||
num_attention_heads=4, | ||
intermediate_size=37, | ||
hidden_act="gelu", | ||
hidden_dropout_prob=0.1, | ||
attention_probs_dropout_prob=0.1, | ||
max_position_embeddings=512, | ||
type_vocab_size=16, | ||
type_sequence_label_size=2, | ||
initializer_range=0.02, | ||
num_labels=3, | ||
num_choices=4, | ||
pad_token_id=0, | ||
scope=None, | ||
): | ||
self.batch_size = batch_size | ||
self.seq_length = seq_length | ||
self.is_training = is_training | ||
self.use_input_mask = use_input_mask | ||
self.use_token_type_ids = use_token_type_ids | ||
self.use_labels = use_labels | ||
self.vocab_size = vocab_size | ||
self.hidden_size = hidden_size | ||
self.num_hidden_layers = num_hidden_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.intermediate_size = intermediate_size | ||
self.hidden_act = hidden_act | ||
self.hidden_dropout_prob = hidden_dropout_prob | ||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | ||
self.max_position_embeddings = max_position_embeddings | ||
self.type_vocab_size = type_vocab_size | ||
self.type_sequence_label_size = type_sequence_label_size | ||
self.initializer_range = initializer_range | ||
self.num_labels = num_labels | ||
self.num_choices = num_choices | ||
self.pad_token_id = pad_token_id | ||
self.scope = scope | ||
self.head_dim = self.hidden_size // self.num_attention_heads | ||
|
||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs | ||
def prepare_config_and_inputs(self): | ||
input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size) | ||
|
||
input_mask = None | ||
if self.use_input_mask: | ||
input_mask = np.tril(np.ones_like(input_ids)) | ||
|
||
token_type_ids = None | ||
if self.use_token_type_ids: | ||
token_type_ids = ids_numpy([self.batch_size, self.seq_length], self.type_vocab_size) | ||
|
||
sequence_labels = None | ||
token_labels = None | ||
choice_labels = None | ||
if self.use_labels: | ||
sequence_labels = ids_numpy([self.batch_size], self.type_sequence_label_size) | ||
token_labels = ids_numpy([self.batch_size, self.seq_length], self.num_labels) | ||
choice_labels = ids_numpy([self.batch_size], self.num_choices) | ||
|
||
config = self.get_config() | ||
# logger.info(f"config: {config}") | ||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels | ||
|
||
def get_config(self): | ||
return Phi3Config( | ||
vocab_size=self.vocab_size, | ||
hidden_size=self.hidden_size, | ||
num_hidden_layers=self.num_hidden_layers, | ||
num_attention_heads=self.num_attention_heads, | ||
intermediate_size=self.intermediate_size, | ||
hidden_activation=self.hidden_act, | ||
hidden_dropout_prob=self.hidden_dropout_prob, | ||
attention_probs_dropout_prob=self.attention_probs_dropout_prob, | ||
max_position_embeddings=self.max_position_embeddings, | ||
type_vocab_size=self.type_vocab_size, | ||
is_decoder=False, | ||
initializer_range=self.initializer_range, | ||
pad_token_id=self.pad_token_id, | ||
use_cache=False, | ||
) | ||
|
||
|
||
model_tester = Phi3ModelTester() | ||
( | ||
config, | ||
input_ids, | ||
token_type_ids, | ||
input_mask, | ||
sequence_labels, | ||
token_labels, | ||
choice_labels, | ||
) = model_tester.prepare_config_and_inputs() | ||
|
||
|
||
PHI3_CASES = [ | ||
[ | ||
"Phi3Model", | ||
"transformers.Phi3Model", | ||
"mindone.transformers.Phi3Model", | ||
(config,), | ||
{}, | ||
(input_ids,), | ||
{ | ||
"attention_mask": input_mask, | ||
}, | ||
{ | ||
"last_hidden_state": 0, | ||
}, | ||
], | ||
] | ||
|
||
|
||
# transformers need >= 4.41.2 | ||
@pytest.mark.parametrize( | ||
"name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", | ||
[ | ||
case | ||
+ [ | ||
dtype, | ||
] | ||
+ [ | ||
mode, | ||
] | ||
for case in PHI3_CASES | ||
for dtype in DTYPE_AND_THRESHOLDS.keys() | ||
for mode in MODES | ||
], | ||
) | ||
def test_named_modules( | ||
name, | ||
pt_module, | ||
ms_module, | ||
init_args, | ||
init_kwargs, | ||
inputs_args, | ||
inputs_kwargs, | ||
outputs_map, | ||
dtype, | ||
mode, | ||
): | ||
ms.set_context(mode=mode) | ||
|
||
( | ||
pt_model, | ||
ms_model, | ||
pt_dtype, | ||
ms_dtype, | ||
) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) | ||
pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( | ||
pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs | ||
) | ||
|
||
# set `hidden_dtype` if requiring, for some modules always compute in float | ||
# precision and require specific `hidden_dtype` to cast before return | ||
if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: | ||
pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) | ||
ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) | ||
with torch.no_grad(): | ||
pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) | ||
ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) | ||
# logger.info(f"ms:{ms_outputs}") | ||
# logger.info(f"pt:{pt_outputs}" ) | ||
if outputs_map: | ||
pt_outputs_n = [] | ||
ms_outputs_n = [] | ||
for pt_key, ms_idx in outputs_map.items(): | ||
pt_output = getattr(pt_outputs, pt_key) | ||
ms_output = ms_outputs[ms_idx] | ||
if isinstance(pt_output, (list, tuple)): | ||
pt_outputs_n += list(pt_output) | ||
ms_outputs_n += list(ms_output) | ||
else: | ||
pt_outputs_n.append(pt_output) | ||
ms_outputs_n.append(ms_output) | ||
diffs = compute_diffs(pt_outputs_n, ms_outputs_n) | ||
else: | ||
diffs = compute_diffs(pt_outputs, ms_outputs) | ||
logger.info(f"Differences: {diffs}") | ||
THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] | ||
assert (np.array(diffs) < THRESHOLD).all(), ( | ||
f"ms_dtype: {ms_dtype}, pt_type: {pt_dtype}, " | ||
f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" | ||
) |