-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for new model databricks/dbrx-base #82
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ann <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test case in tests/transformers/test_transformer_pytorch_transforms.py
@@ -114,6 +123,12 @@ | |||
GPT2Block: QEffGPT2Block, | |||
GPT2Attention: QEffGPT2Attention, | |||
GPT2LMHeadModel: QEffGPT2LMHeadModel, | |||
# Dbrx model layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add in QEfficient/transformers/pytorch_transforms.py::KVCacheTransform
too.
We will be deprecating this after 1.18 release.
|
||
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask | ||
|
||
DBRX_ATTENTION_CLASSES = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used anywhere?
@@ -219,7 +219,16 @@ def get_padding_shape_from_config(config, batch_size, seq_len): | |||
): # Check for num_key_value_heads (Llama/Mistral) | |||
n_heads = config.num_key_value_heads | |||
d_head = config.hidden_size // config.num_attention_heads | |||
elif hasattr(config, "n_heads"): # Check for n_heads and d_model in the config (MPT Model) | |||
elif ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this condition same as line 231-233?
We can move those line here, and remove 223-226?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually no, the MPT model needed more specific parameters. Because simply testing for n_heads
causes dbrx also, to satisfy the condition. MPT and dbrx has similar config.json
. Hence current check fails, and wrong config will be set for the model.
# Save model to single weight file | ||
params = sum(p.numel() for p in pt_model.parameters()) | ||
model_size = math.ceil((params * 4) / Constants.GB) | ||
if model_size < 380: | ||
info("ONNX model uses external data. Saving external data as single weight file.") | ||
loaded_model = onnx.load(f"{gen_models_path}_tmp/{model_base_name}.onnx") | ||
os.makedirs(f"{gen_models_path}", exist_ok=True) | ||
shutil.rmtree(f"{gen_models_path}_tmp") | ||
info("Clearing files .. ") | ||
onnx.save_model( | ||
loaded_model, | ||
os.path.join(gen_models_path, f"{model_base_name}.onnx"), | ||
save_as_external_data=True, | ||
all_tensors_to_one_file=True, | ||
location=f"{model_base_name}.onnxweights.data", | ||
size_threshold=1024, | ||
convert_attribute=False, | ||
) | ||
onnx.checker.check_model(os.path.join(gen_models_path, f"{model_base_name}.onnx")) | ||
else: | ||
info("Skip saving external data as a single file.") | ||
if os.path.exists(f"{gen_models_path}"): | ||
shutil.rmtree(f"{gen_models_path}") | ||
shutil.move(f"{gen_models_path}_tmp", f"{gen_models_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done with SplitTensorsTransform, now merged into main. Is this change tested with other models also?
config.json
file attributes