diff --git a/QEfficient/exporter/export_utils.py b/QEfficient/exporter/export_utils.py
index 8c33bc6ca..15b6df2f8 100644
--- a/QEfficient/exporter/export_utils.py
+++ b/QEfficient/exporter/export_utils.py
@@ -5,6 +5,7 @@
 #
 # -----------------------------------------------------------------------------
 
+import math
 import os
 import shutil
 import sys
@@ -18,6 +19,7 @@
 from onnx import external_data_helper
 
 from QEfficient.base.onnx_transforms import FP16Clip
+from QEfficient.utils.constants import Constants
 
 
 def export_onnx(
@@ -86,27 +88,31 @@ def export_onnx(
         raise RuntimeError("Exporting to ONNX failed. {}".format(e))
 
     onnx.checker.check_model(f"{gen_models_path}_tmp/{model_base_name}.onnx")
-    loaded_model = onnx.load(f"{gen_models_path}_tmp/{model_base_name}.onnx")
-    shutil.rmtree(f"{gen_models_path}_tmp")
-    os.makedirs(f"{gen_models_path}", exist_ok=True)
-    info("Clearing files .. ")
-
-    # Check if model uses external data format to save the weight tensors
-    # model_uses_external_data = check_model_uses_external_data(loaded_model)
-    # if model_uses_external_data:
-    # Save model to single weight file
-    info("ONNX model uses external data. Saving as external data.")
-    onnx.save_model(
-        loaded_model,
-        os.path.join(gen_models_path, f"{model_base_name}.onnx"),
-        save_as_external_data=True,
-        all_tensors_to_one_file=True,
-        location=f"{model_base_name}.onnxweights.data",
-        size_threshold=1024,
-        convert_attribute=False,
-    )
-    onnx.checker.check_model(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
 
+    # Save model to single weight file
+    params = sum(p.numel() for p in pt_model.parameters())
+    model_size = math.ceil((params * 4) / Constants.GB)
+    if model_size < 380:
+        info("ONNX model uses external data. Saving external data as single weight file.")
+        loaded_model = onnx.load(f"{gen_models_path}_tmp/{model_base_name}.onnx")
+        os.makedirs(f"{gen_models_path}", exist_ok=True)
+        shutil.rmtree(f"{gen_models_path}_tmp")
+        info("Clearing files .. ")
+        onnx.save_model(
+            loaded_model,
+            os.path.join(gen_models_path, f"{model_base_name}.onnx"),
+            save_as_external_data=True,
+            all_tensors_to_one_file=True,
+            location=f"{model_base_name}.onnxweights.data",
+            size_threshold=1024,
+            convert_attribute=False,
+        )
+        onnx.checker.check_model(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
+    else:
+        info("Skip saving external data as a single file.")
+        if os.path.exists(f"{gen_models_path}"):
+            shutil.rmtree(f"{gen_models_path}")
+        shutil.move(f"{gen_models_path}_tmp", f"{gen_models_path}")
     # Run shape inference in intial model itself
     onnx.shape_inference.infer_shapes_path(
         os.path.join(gen_models_path, f"{model_base_name}.onnx"),
diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py
index 2806fd452..94ded803b 100644
--- a/QEfficient/transformers/modeling_utils.py
+++ b/QEfficient/transformers/modeling_utils.py
@@ -14,6 +14,7 @@
     CodeGenForCausalLM,
     CodeGenModel,
 )
+from transformers.models.dbrx.modeling_dbrx import DbrxAttention, DbrxExperts, DbrxForCausalLM, DbrxModel, DbrxRouter
 from transformers.models.falcon.modeling_falcon import (
     FalconAttention,
     FalconForCausalLM,
@@ -52,6 +53,13 @@
     QEffCodeGenForCausalLM,
     QEffCodeGenModel,
 )
+from .models.dbrx.modeling_dbrx import (
+    QEffDbrxAttention,
+    QEffDbrxExperts,
+    QEffDbrxForCausalLM,
+    QEffDbrxModel,
+    QEffDbrxRouter,
+)
 from .models.falcon.modeling_falcon import (
     QEffFalconAttention,
     QEffFalconForCausalLM,
@@ -103,6 +111,7 @@
         FalconForCausalLM.__name__,
         Qwen2ForCausalLM.__name__,
         Starcoder2ForCausalLM.__name__,
+        DbrxForCausalLM.__name__,
     ]
 )
 
@@ -114,6 +123,12 @@
     GPT2Block: QEffGPT2Block,
     GPT2Attention: QEffGPT2Attention,
     GPT2LMHeadModel: QEffGPT2LMHeadModel,
+    # Dbrx model layers
+    DbrxAttention: QEffDbrxAttention,
+    DbrxRouter: QEffDbrxRouter,
+    DbrxExperts: QEffDbrxExperts,
+    DbrxModel: QEffDbrxModel,
+    DbrxForCausalLM: QEffDbrxForCausalLM,
     # GPTJ model layers
     GPTJModel: QEffGPTJModel,
     GPTJAttention: QEffGPTJAttention,
diff --git a/QEfficient/transformers/models/dbrx/__init__.py b/QEfficient/transformers/models/dbrx/__init__.py
new file mode 100755
index 000000000..91fee0a49
--- /dev/null
+++ b/QEfficient/transformers/models/dbrx/__init__.py
@@ -0,0 +1,7 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c)  2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
diff --git a/QEfficient/transformers/models/dbrx/modeling_dbrx.py b/QEfficient/transformers/models/dbrx/modeling_dbrx.py
new file mode 100755
index 000000000..ad7b87636
--- /dev/null
+++ b/QEfficient/transformers/models/dbrx/modeling_dbrx.py
@@ -0,0 +1,434 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""PyTorch Dbrx model."""
+
+import math
+from typing import Any, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import (
+    _prepare_4d_causal_attention_mask,
+)
+from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
+from transformers.models.dbrx.modeling_dbrx import (
+    DbrxAttention,
+    DbrxExperts,
+    DbrxForCausalLM,
+    DbrxModel,
+    DbrxRouter,
+    apply_rotary_pos_emb,
+    load_balancing_loss_func,
+    logger,
+    repeat_kv,
+)
+
+from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
+
+DBRX_ATTENTION_CLASSES = {
+    "eager": DbrxAttention,
+}
+
+
+class QEffDbrxAttention(DbrxAttention):
+    """Multi-head self attention."""
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.LongTensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Any,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        qkv_states = self.Wqkv(hidden_states)
+        min_val = -self.clip_qkv if self.clip_qkv is not None else None
+        max_val = self.clip_qkv
+        qkv_states = qkv_states.clamp(min=min_val, max=max_val)
+
+        query_states, key_states, value_states = qkv_states.split(
+            [
+                self.hidden_size,
+                self.num_key_value_heads * self.head_dim,
+                self.num_key_value_heads * self.head_dim,
+            ],
+            dim=2,
+        )
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        kv_seq_len = key_states.shape[-2]
+        past_key_value = getattr(self, "past_key_value", past_key_value)
+        cos, sin = self.rotary_emb(value_states, position_ids)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+        if past_key_value is not None:
+            kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.block_idx)
+        if past_key_value is not None:
+            # sin and cos are specific to RoPE models; position_ids needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids}
+            key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:
+            attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_weights = nn.functional.dropout(attn_weights, p=self.attn_pdrop, training=self.training)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                + f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+        attn_output = self.out_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+class QEffDbrxRouter(DbrxRouter):
+    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
+        if self.training and self.moe_jitter_eps is not None:
+            hidden_states *= torch.empty_like(hidden_states).uniform_(
+                1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps
+            )
+        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+        weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32)
+        top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
+
+        # top_weights_sca
+        top_weights_scale = torch.sum(torch.abs(top_weights), dim=-1, keepdim=True)
+        top_weights = top_weights / top_weights_scale
+
+        weights = weights.to(hidden_states.dtype)
+        top_weights = top_weights.to(hidden_states.dtype)
+        return weights, top_weights, top_experts
+
+
+class QEffDbrxExperts(DbrxExperts):
+    def forward(
+        self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
+    ) -> torch.Tensor:
+        bsz, q_len, hidden_size = x.shape
+        x = x.view(-1, hidden_size)
+        out = torch.zeros_like(x)
+
+        expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
+        # Chunk experts at once to avoid storing full parameter multiple times in autograd
+        w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
+            self.moe_num_experts, dim=0
+        )
+        v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
+            self.moe_num_experts, dim=0
+        )
+        w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
+            self.moe_num_experts, dim=0
+        )
+        w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
+        v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
+        w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
+        for expert_idx in range(0, self.moe_num_experts):
+            expert_mask_tr = expert_mask[expert_idx].transpose(0, 1)
+            expert_out = (
+                self.mlp(x, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx])
+                * (top_weights * expert_mask_tr).sum(1)[:, None]
+            )
+            expert_out = torch.where(
+                (top_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None], expert_out, torch.tensor(0.0)
+            )
+            out = out + expert_out
+        out = out.reshape(bsz, q_len, hidden_size)
+        return out
+
+
+class QEffDbrxModel(DbrxModel):
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        output_router_logits: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+    ) -> Union[Tuple, MoeModelOutputWithPast]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        output_router_logits = (
+            output_router_logits if output_router_logits is not None else self.config.output_router_logits
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError(
+                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+            )
+
+        if self.gradient_checkpointing and self.training and use_cache:
+            logger.warning_once(
+                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+            )
+            use_cache = False
+        if input_ids is not None:
+            batch_size, seq_length = input_ids.shape
+        if inputs_embeds is None:
+            inputs_embeds = self.wte(input_ids)
+        past_key_values_length = 0
+
+        inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
+
+        past_seen_tokens = 0
+        if use_cache:  # kept for BC (cache positions)
+            if not isinstance(past_key_values, StaticCache):
+                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+                past_seen_tokens = past_key_values.get_seq_length()
+                past_key_values_length = past_key_values.get_usable_length(seq_length)
+        if cache_position is None:
+            if isinstance(past_key_values, StaticCache):
+                raise ValueError("cache_position is a required argument when using StaticCache.")
+            cache_position = torch.arange(
+                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+            )
+
+        if position_ids is None:
+            device = input_ids.device if input_ids is not None else inputs_embeds.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+        if attention_mask is None:
+            # Causal mask with # --- Rolling buffer --- and # Sliding window mask
+            # Change for Cloud AI 100 (vbaddi)
+            attention_mask = _create_causal_mask(
+                position_ids=position_ids,
+                target_length=past_key_values_length,
+                sliding_window=None,
+            )
+        else:
+            # 4d mask is passed through the layers
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask,
+                (batch_size, seq_length),
+                inputs_embeds,
+                past_key_values_length,
+                sliding_window=None,
+            )
+
+        # embed positions
+        hidden_states = inputs_embeds
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_router_logits = () if output_router_logits else None
+        next_decoder_cache = None
+
+        for block in self.blocks:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+                block_outputs = self._gradient_checkpointing_func(
+                    block.__call__,
+                    hidden_states,
+                    attention_mask,
+                    position_ids,
+                    past_key_values,
+                    output_attentions,
+                    output_router_logits,
+                    use_cache,
+                    cache_position,
+                )
+            else:
+                block_outputs = block(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_values,
+                    output_attentions=output_attentions,
+                    output_router_logits=output_router_logits,
+                    use_cache=use_cache,
+                    cache_position=cache_position,
+                )
+
+            hidden_states = block_outputs[0]
+
+            if use_cache:
+                next_decoder_cache = block_outputs[2 if output_attentions else 1]
+
+            if output_attentions:
+                all_self_attns += (block_outputs[1],)
+
+            if output_router_logits:
+                all_router_logits += (block_outputs[-1],)
+
+        hidden_states = self.norm_f(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = None
+        if use_cache:
+            next_cache = (
+                next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
+            )
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+                if v is not None
+            )
+        return MoeModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            router_logits=all_router_logits,
+        )
+
+
+class QEffDbrxForCausalLM(DbrxForCausalLM):
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        output_router_logits: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+        r"""Forward function for causal language modeling.
+
+        Args:
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >> from transformers import AutoTokenizer, DbrxForCausalLM
+
+        >> model = DbrxForCausalLM.from_pretrained("databricks/dbrx-instruct")
+        >> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct")
+
+        >> prompt = "Hey, are you conscious? Can you talk to me?"
+        >> inputs = tokenizer(prompt, return_tensors="pt")
+
+        >> # Generate
+        >> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+        ```
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        output_router_logits = (
+            output_router_logits if output_router_logits is not None else self.config.output_router_logits
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.transformer(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            output_router_logits=output_router_logits,
+            return_dict=return_dict,
+            cache_position=cache_position,
+        )
+
+        # Cast to int32 to avoid ONNXRT issue
+        logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True)
+        hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx]
+        logits = self.lm_head(hidden_states)
+        logits = logits.float()
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = nn.CrossEntropyLoss()
+            shift_logits = shift_logits.view(-1, self.config.vocab_size)
+            shift_labels = shift_labels.view(-1)
+            # Enable model parallelism
+            shift_labels = shift_labels.to(shift_logits.device)
+            loss = loss_fct(shift_logits, shift_labels)
+
+        aux_loss = None
+        if output_router_logits:
+            aux_loss = load_balancing_loss_func(
+                outputs.router_logits if return_dict else outputs[-1],
+                self.num_experts,
+                self.num_experts_per_tok,
+                attention_mask,
+            )
+            if labels is not None and loss is not None:
+                loss += self.moe_loss_weight * aux_loss.to(loss.device)  # make sure to reside in the same device
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            if output_router_logits:
+                output = (aux_loss,) + output
+            return (loss,) + output if loss is not None else output
+
+        return MoeCausalLMOutputWithPast(
+            loss=loss,
+            aux_loss=aux_loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            router_logits=outputs.router_logits,
+        )
diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py
index d3283ece1..b17ef769a 100644
--- a/QEfficient/utils/_utils.py
+++ b/QEfficient/utils/_utils.py
@@ -219,7 +219,16 @@ def get_padding_shape_from_config(config, batch_size, seq_len):
     ):  # Check for num_key_value_heads (Llama/Mistral)
         n_heads = config.num_key_value_heads
         d_head = config.hidden_size // config.num_attention_heads
-    elif hasattr(config, "n_heads"):  # Check for n_heads and d_model in the config (MPT Model)
+    elif (
+        hasattr(config, "auto_map") and config.auto_map["AutoModelForCausalLM"] == "modeling_mpt.MPTForCausalLM"
+    ):  # check for MPT
+        n_heads = config.n_heads
+        d_head = config.d_model // config.n_heads
+    elif hasattr(config, "ffn_config") and config.ffn_config.moe_top_k:  # Check for Dbrx
+        if config.attn_config.kv_n_heads is not None:
+            n_heads = config.attn_config.kv_n_heads
+            d_head = config.d_model // config.n_heads
+    elif hasattr(config, "n_heads"):  # Check for n_heads and d_model in the config
         n_heads = config.n_heads
         d_head = config.d_model // config.n_heads
     elif hasattr(config, "multi_query"):  # Check for Falcon
diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py
index 313e666f4..8df50ca77 100644
--- a/tests/transformers/models/test_causal_lm_models.py
+++ b/tests/transformers/models/test_causal_lm_models.py
@@ -23,6 +23,7 @@
     "wtang06/mpt-125m-c4",
     "hakurei/gpt-j-random-tinier",
     "mistralai/Mixtral-8x7B-Instruct-v0.1",
+    "databricks/dbrx-base",
 ]