Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for new model databricks/dbrx-base #82

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

quic-akuruvil
Copy link
Contributor

  • Model changes for databricks/dbrx-base KV range gather based model
  • Test case for the above model
  • Skip saving onnx file as external data for large model > 400GB
  • Changes in config function for multiple models with overlapping HF config.json file attributes

Copy link
Contributor

@ochougul ochougul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test case in tests/transformers/test_transformer_pytorch_transforms.py

@@ -114,6 +123,12 @@
GPT2Block: QEffGPT2Block,
GPT2Attention: QEffGPT2Attention,
GPT2LMHeadModel: QEffGPT2LMHeadModel,
# Dbrx model layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add in QEfficient/transformers/pytorch_transforms.py::KVCacheTransform too.
We will be deprecating this after 1.18 release.


from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask

DBRX_ATTENTION_CLASSES = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used anywhere?

@@ -219,7 +219,16 @@ def get_padding_shape_from_config(config, batch_size, seq_len):
): # Check for num_key_value_heads (Llama/Mistral)
n_heads = config.num_key_value_heads
d_head = config.hidden_size // config.num_attention_heads
elif hasattr(config, "n_heads"): # Check for n_heads and d_model in the config (MPT Model)
elif (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this condition same as line 231-233?
We can move those line here, and remove 223-226?

Copy link
Contributor Author

@quic-akuruvil quic-akuruvil Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually no, the MPT model needed more specific parameters. Because simply testing for n_heads causes dbrx also, to satisfy the condition. MPT and dbrx has similar config.json. Hence current check fails, and wrong config will be set for the model.

@ochougul ochougul added enhancement New feature or request model-enablement labels Aug 14, 2024
Comment on lines +92 to +115
# Save model to single weight file
params = sum(p.numel() for p in pt_model.parameters())
model_size = math.ceil((params * 4) / Constants.GB)
if model_size < 380:
info("ONNX model uses external data. Saving external data as single weight file.")
loaded_model = onnx.load(f"{gen_models_path}_tmp/{model_base_name}.onnx")
os.makedirs(f"{gen_models_path}", exist_ok=True)
shutil.rmtree(f"{gen_models_path}_tmp")
info("Clearing files .. ")
onnx.save_model(
loaded_model,
os.path.join(gen_models_path, f"{model_base_name}.onnx"),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{model_base_name}.onnxweights.data",
size_threshold=1024,
convert_attribute=False,
)
onnx.checker.check_model(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
else:
info("Skip saving external data as a single file.")
if os.path.exists(f"{gen_models_path}"):
shutil.rmtree(f"{gen_models_path}")
shutil.move(f"{gen_models_path}_tmp", f"{gen_models_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done with SplitTensorsTransform, now merged into main. Is this change tested with other models also?

@quic-akuruvil quic-akuruvil added the wip Work in progress label Nov 6, 2024
@quic-rishinr quic-rishinr self-requested a review as a code owner January 10, 2025 07:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request model-enablement wip Work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants