From b2459cea896c8442999d1881d59d62277f6168ea Mon Sep 17 00:00:00 2001
From: Zhongkai Zhang <zhzhang@habana.ai>
Date: Thu, 9 Jan 2025 23:24:53 +0000
Subject: [PATCH 1/3] Added Mamba model using kernel to improve the performance

---
 README.md                                     |   1 +
 docs/source/index.mdx                         |   1 +
 examples/text-generation/README.md            |  12 +
 optimum/habana/transformers/modeling_utils.py |   4 +
 .../habana/transformers/models/__init__.py    |   2 +
 .../transformers/models/mamba/__init__.py     |   2 +
 .../models/mamba/modeling_mamba.py            | 231 +++++++++++++++++-
 .../transformers/models/mamba/util_mamba.py   |  29 +++
 tests/test_text_generation_example.py         |  30 ++-
 9 files changed, 305 insertions(+), 7 deletions(-)
 create mode 100644 optimum/habana/transformers/models/mamba/util_mamba.py

diff --git a/README.md b/README.md
index e44ca5430c..ccc7cb0433 100644
--- a/README.md
+++ b/README.md
@@ -258,6 +258,7 @@ The following model architectures, tasks and device distributions have been vali
 | Baichuan2 | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
 | DeepSeek-V2 |   | :heavy_check_mark: | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
 | ChatGLM | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
+| Mamba |   | <div style="text-align:left"><li>Single card</li></div> |  <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
 </div>
 
 - Diffusers:
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 51d6dadf0f..c7233e464f 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -109,6 +109,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
 | Baichuan2 | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
 | DeepSeek-V2 |   | ✅ | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
 | ChatGLM     | <div style="text-align:left"><li>DeepSpeed</li></div> |  <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
+| Mamba |   | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
 
 - Diffusers
 
diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md
index 7767443c6e..f6e03f5a6e 100755
--- a/examples/text-generation/README.md
+++ b/examples/text-generation/README.md
@@ -219,6 +219,18 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
 > --sdp_on_bf16
 > ```
 
+To run Mamba-130m inference on 1 Gaudi2 card, use the following command, for example if default custom kernel path is in /root/.cache/huggingface/hub/models--Habana--mamba/blobs/libcustom_tpc_perf_lib.so, if libcustom_tpc_perf_lib.so is in different folder, set accordingly,
+```bash
+GC_KERNEL_PATH=/root/.cache/huggingface/hub/models--Habana--mamba/blobs/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH python run_generation.py \
+--model_name_or_path state-spaces/mamba-130m-hf \
+--max_input_tokens 128 \
+--max_new_tokens 128 \
+--bf16 \
+--use_hpu_graphs \
+--use_kv_cache \
+--batch_size 1024
+```
+
 ### Use any dataset from the Hugging Face Hub
 
 You can also provide the name of a dataset from the Hugging Face Hub to perform generation on it with the argument `--dataset_name`.
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index 8fe0ba7b99..d95aba3ddf 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -209,8 +209,10 @@
     gaudi_gpt_neox_model_forward,
     gaudi_invert_attention_mask,
     gaudi_llama_rmsnorm_forward,
+    gaudi_MambaCache_update_conv_state,
     gaudi_MambaForCausalLM_prepare_inputs_for_generation,
     gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
+    gaudi_MambaMixer,
     gaudi_mistral_rmsnorm_forward,
     gaudi_mixtral_block_dynamic_moe_forward,
     gaudi_mixtral_block_moe_forward,
@@ -671,6 +673,8 @@ def adapt_transformers_to_gaudi():
     transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaForCausalLM.prepare_inputs_for_generation = (
         gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation
     )
+    transformers.models.mamba.modeling_mamba.MambaMixer = gaudi_MambaMixer
+    transformers.cache_utils.MambaCache.update_conv_state = gaudi_MambaCache_update_conv_state
     transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaModel.forward = gaudi_FalconMambaModel_forward
     transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaRMSNorm.forward = gaudi_llama_rmsnorm_forward
 
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py
index 13b84d48b1..2253ea1c7d 100644
--- a/optimum/habana/transformers/models/__init__.py
+++ b/optimum/habana/transformers/models/__init__.py
@@ -158,8 +158,10 @@
 from .llava import GaudiLlavaForConditionalGeneration
 from .llava_next import GaudiLlavaNextForConditionalGeneration
 from .mamba import (
+    gaudi_MambaCache_update_conv_state,
     gaudi_MambaForCausalLM_prepare_inputs_for_generation,
     gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
+    gaudi_MambaMixer,
 )
 from .minicpm import MiniCPM3Config, MiniCPM3ForCausalLM
 from .mistral import (
diff --git a/optimum/habana/transformers/models/mamba/__init__.py b/optimum/habana/transformers/models/mamba/__init__.py
index c22d12877c..6bd4566df3 100644
--- a/optimum/habana/transformers/models/mamba/__init__.py
+++ b/optimum/habana/transformers/models/mamba/__init__.py
@@ -1,4 +1,6 @@
 from .modeling_mamba import (
+    gaudi_MambaCache_update_conv_state,
     gaudi_MambaForCausalLM_prepare_inputs_for_generation,
     gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
+    gaudi_MambaMixer,
 )
diff --git a/optimum/habana/transformers/models/mamba/modeling_mamba.py b/optimum/habana/transformers/models/mamba/modeling_mamba.py
index e23ce65dd8..7bd10e919b 100644
--- a/optimum/habana/transformers/models/mamba/modeling_mamba.py
+++ b/optimum/habana/transformers/models/mamba/modeling_mamba.py
@@ -1,17 +1,86 @@
+import os
+from pathlib import Path
 from typing import Any, Dict, Optional
 
 import torch
-from transformers.models.mamba.modeling_mamba import (
-    MambaCache,
-)
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.cache_utils import MambaCache
+from transformers.models.mamba.configuration_mamba import MambaConfig
 from transformers.utils import (
     ModelOutput,
     logging,
 )
 
+from .util_mamba import set_mamba_lib
+
+env_variables = os.environ.copy()
+new_file_op, new_file_kernel = set_mamba_lib()
+realpath_kfn = os.path.realpath(new_file_kernel)
+kfn = os.path.basename(realpath_kfn)
+new_kfn = os.path.join(os.path.dirname(realpath_kfn), "libcustom_tpc_perf_lib.so")
+os.rename(realpath_kfn, new_kfn)
+
+env_variables["HABANA_CUSTOM_OP_DIR"] = os.path.dirname(new_file_op)
+default_path = env_variables["GC_KERNEL_PATH"]
+orig_path = '/usr/lib/habanalabs/libtpc_kernels.so'
+env_variables["GC_KERNEL_PATH"] = new_kfn + os.pathsep + default_path
+
+base_dir = env_variables["HABANA_CUSTOM_OP_DIR"]
+
+if os.path.exists(base_dir):
+    custom_op_lib_path = str(next(Path(base_dir).glob("hpu_custom_pscan_all.cpython-*-x86_64-linux-gnu.so")))
+    torch.ops.load_library(custom_op_lib_path)
 
 logger = logging.get_logger(__name__)
 
+is_fast_path_available = False
+
+use_pscan_kernel = False
+if os.path.exists(custom_op_lib_path) and default_path != orig_path:
+    use_pscan_kernel = True
+
+def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z):
+    in_state_h = in_state.unsqueeze(1).transpose(2, 3)
+    in_x_h = in_x.transpose(1, 2).unsqueeze(2)
+    in_dt_h = in_dt.unsqueeze(2)
+    in_A_h = in_A.unsqueeze(0).unsqueeze(1).transpose(2, 3)
+    in_B_h = in_B.unsqueeze(3)
+    in_C_h = in_C.unsqueeze(3)
+    in_D_h = in_D.unsqueeze(0).unsqueeze(1).unsqueeze(2)
+    in_z_h = in_z.transpose(1, 2).unsqueeze(2)
+
+    if in_state.dtype == torch.float:
+        state_out_h = torch.ops.custom_op.custom_pscan(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
+        output_h = torch.ops.custom_op.custom_pscan_update(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h)
+
+    else:
+        in_A_h = in_A_h.to(torch.bfloat16)
+        state_out_h = torch.ops.custom_op.custom_pscan_bf16(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
+        output_h = torch.ops.custom_op.custom_pscan_update_bf16(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h)
+
+    output_hpu = output_h.squeeze(2).transpose(1, 2)
+    state_hpu = state_out_h.transpose(2, 3)
+    state_out = torch.select(state_hpu, 1, output_hpu.shape[2] - 1)
+
+    return output_hpu, state_out
+
+
+def gaudi_MambaCache_update_conv_state(
+    self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
+) -> torch.Tensor:
+    conv_state = self.conv_states[layer_idx]
+    cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
+
+    conv_state = conv_state.roll(shifts=-1, dims=-1)
+    # conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
+    for c, i in enumerate(cache_position):
+        conv_state[:, :, i] = new_conv_state[:, :, c].to(conv_state.device)
+
+    self.conv_states[layer_idx].zero_()
+    self.conv_states[layer_idx] += conv_state
+    return self.conv_states[layer_idx]
+
 
 def gaudi_MambaForCausalLM_update_model_kwargs_for_generation(
     self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs
@@ -94,3 +163,159 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation(
         }
     )
     return model_inputs
+
+class gaudi_MambaMixer(nn.Module):
+    """
+    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
+    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
+    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
+    and is why Mamba is called **selective** state spaces)
+    We only replaced the slow path with custom op
+    """
+
+    def __init__(self, config: MambaConfig, layer_idx: int):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.ssm_state_size = config.state_size
+        self.conv_kernel_size = config.conv_kernel
+        self.intermediate_size = config.intermediate_size
+        self.time_step_rank = int(config.time_step_rank)
+        self.layer_idx = layer_idx
+        self.use_conv_bias = config.use_conv_bias
+        self.conv1d = nn.Conv1d(
+            in_channels=self.intermediate_size,
+            out_channels=self.intermediate_size,
+            bias=config.use_conv_bias,
+            kernel_size=config.conv_kernel,
+            groups=self.intermediate_size,
+            padding=config.conv_kernel - 1,
+        )
+
+        self.activation = config.hidden_act
+        self.act = ACT2FN[config.hidden_act]
+
+        self.use_mambapy = config.use_mambapy
+
+        # projection of the input hidden states
+        self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
+        # selective projection used to make dt, B and C input dependant
+        self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
+        # time step projection (discretization)
+        self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
+
+        # S4D real initialization. These are not discretized!
+        # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+        A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
+        A = A.expand(self.intermediate_size, -1).contiguous()
+
+        self.A_log = nn.Parameter(torch.log(A))
+        self.D = nn.Parameter(torch.ones(self.intermediate_size))
+        self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
+        self.use_bias = config.use_bias
+
+        if not is_fast_path_available:
+            logger.warning_once(
+                "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+                " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
+                " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
+            )
+
+    # fmt: off
+    def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
+        """
+        We replaced the 3c and 3d parts with custom op "Run_Mamba_Forward_Gaudi", which removed the sequence length loop and gain the performance.
+        """
+        batch_size, seq_len, _ = input_states.shape
+        dtype = input_states.dtype
+        # 1. Gated MLP's linear projection
+        projected_states = self.in_proj(input_states).transpose(1, 2)                   # [batch, 2 * intermediate_size, seq_len]
+        hidden_states, gate = projected_states.chunk(2, dim=1)
+
+        if attention_mask is not None:
+            hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+        # 2. Convolution sequence transformation
+        if cache_params is not None:
+            ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+            ssm_state = ssm_state.to(hidden_states.device)
+            # use `cache_position.shape[0]` to check whether we are in prefill
+            # stage, it's equivalent to check `cache_position[0] == 0`, which
+            # breaks dynamo fullgraph constraints
+            if cache_position.shape[0] == self.conv_kernel_size:
+                conv_state = nn.functional.pad(
+                    hidden_states,
+                    (self.conv_kernel_size - hidden_states.shape[-1], 0)
+                )
+
+                cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
+                hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])     # [batch, intermediate_size, seq_len]
+            else:
+                conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
+                hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
+                if self.use_conv_bias:
+                    hidden_states += self.conv1d.bias
+                hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)         # [batch, intermediate_size, 1] : decoding
+        else:
+            ssm_state = torch.zeros(
+                (batch_size, self.intermediate_size, self.ssm_state_size),
+                device=hidden_states.device, dtype=dtype
+            )
+            hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])         # [batch, intermediate_size, seq_len]
+
+        if attention_mask is not None:
+            hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+        # 3. State Space Model sequence transformation
+        # 3.a. Selection:  [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
+        ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+        time_step, B, C = torch.split(
+            ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+        )
+        discrete_time_step = self.dt_proj(time_step)                                    # [batch, seq_len, intermediate_size]
+
+        # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
+        A = -torch.exp(self.A_log.float())                                              # [intermediate_size, ssm_state_size]
+        if use_pscan_kernel:
+            scan_output, ssm_state = Run_Mamba_Forward_Gaudi(
+                    ssm_state,
+                    hidden_states,
+                    discrete_time_step,
+                    A,
+                    B,
+                    C,
+                    self.D,
+                    gate
+                    )
+        else:
+            discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
+            discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
+            discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()       # [batch, intermediate_size, seq_len, ssm_state_size]
+            deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
+
+            # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+            scan_outputs = []
+            for i in range(seq_len):
+                ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]      # [batch, intermediade_size, ssm_state]
+                scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
+                scan_outputs.append(scan_output[:, :, 0])
+            scan_output = torch.stack(scan_outputs, dim=-1)                                # [batch, seq_len, intermediade_size]
+            scan_output = scan_output + (hidden_states * self.D[None, :, None])
+            scan_output = (scan_output * self.act(gate))
+
+        if cache_params is not None:
+            cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
+
+        # 4. Final linear projection
+        contextualized_states = self.out_proj(scan_output.transpose(1, 2))  # [batch, seq_len, hidden_size]
+        return contextualized_states
+    # fmt: on
+
+    def forward(
+        self,
+        hidden_states,
+        cache_params: Optional[MambaCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)    
diff --git a/optimum/habana/transformers/models/mamba/util_mamba.py b/optimum/habana/transformers/models/mamba/util_mamba.py
new file mode 100644
index 0000000000..3b3e2dd9a7
--- /dev/null
+++ b/optimum/habana/transformers/models/mamba/util_mamba.py
@@ -0,0 +1,29 @@
+import os
+
+from huggingface_hub import hf_hub_download
+
+from ....utils import get_habana_frameworks_version
+
+
+def set_mamba_lib():
+    version_no = get_habana_frameworks_version()
+
+    name_op = "hpu_custom_pscan_all.cpython-310-x86_64-linux-gnu.so"
+    name_kernel = "libcustom_tpc_perf_lib.so"
+    if version_no.minor == 19:
+        name_op = "hpu_custom_pscan_all.cpython-310-x86_64-linux-gnu_119.so"
+        name_kernel = "libcustom_tpc_perf_lib_119.so"
+
+    file_op = hf_hub_download(repo_id="Habana/mamba", filename=name_op)
+    file_kernel = hf_hub_download(repo_id="Habana/mamba", filename=name_kernel)
+
+    new_file_op = file_op
+    new_file_kernel = file_kernel
+
+    if version_no.minor == 19:
+        new_file_op = file_op[:-7] + ".so"
+        new_file_kernel = file_kernel[:-7] + ".so"
+        os.rename(file_op, new_file_op)
+        os.rename(file_kernel, new_file_kernel)
+
+    return new_file_op, new_file_kernel
\ No newline at end of file
diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py
index ec1cc67475..9358c9f572 100644
--- a/tests/test_text_generation_example.py
+++ b/tests/test_text_generation_example.py
@@ -43,7 +43,7 @@
             ("Qwen/Qwen1.5-7B", 4, False, 490.8621617893209, False),
             ("google/gemma-7b", 1, False, 109.70751574382221, True),
             ("google/gemma-2-9b", 1, False, 92.302359446567, True),
-            ("state-spaces/mamba-130m-hf", 1536, False, 5385.511100161605, False),
+            ("state-spaces/mamba-130m-hf", 1536, False, 19283.0330042467, True),
             ("Deci/DeciLM-7B", 1, False, 115, False),
             ("Qwen/Qwen2-7B", 256, False, 8870.945160540245, True),
             ("Qwen/Qwen1.5-MoE-A2.7B", 1, True, 44.25834541569395, False),
@@ -113,6 +113,7 @@
         "mistralai/Mistral-7B-v0.1": "DeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system.\n\nDeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system",
         "mistralai/Mixtral-8x7B-v0.1": "DeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\n## Introduction\n\nDeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\n## What is DeepSpeed",
         "Qwen/Qwen2-7B": "DeepSpeed is a machine learning framework that provides a unified interface for training deep learning models. It is designed to be easy to use and to provide high performance. DeepSpeed is built on top of PyTorch and TensorFlow, and it supports a wide range of models, including transformers, convolutional neural networks, and recurrent neural networks.\nDeepSpeed is a machine learning framework that provides a unified interface for training deep learning models. It is designed to be easy to use and to provide high performance. DeepSpeed is built on top of Py",
+        "state-spaces/mamba-130m-hf": "DeepSpeed is a machine learning framework.\n\nThe authors declare no conflict of interest.\n\n![The structure of the *S. aureus* strain used in this study. The *S. aureus* strain was obtained from the National Center for Biotechnology Information (NCBI) database. The strain was isolated from a patient with a history of bacteremia and was identified as *S. aureus* by the presence of a *cpsA* gene. The strain was also isolated from a patient with a history of b",
     }
 else:
     # Gaudi1 CI baselines
@@ -135,7 +136,6 @@
             ("Qwen/Qwen1.5-7B", 1, False, 39.29068423087616, False),
             ("adept/persimmon-8b-base", 1, False, 34.53559807384106, False),
             ("bigcode/starcoder2-3b", 1, False, 82.09655684566117, False),
-            ("state-spaces/mamba-130m-hf", 224, False, 794.542, False),
         ],
         "fp8": [],
         "load_quantized_model_with_autogptq": [],
@@ -221,9 +221,31 @@ def _test_text_generation(
     if "decilm" in model_name.lower():
         command += ["--sdp_on_bf16"]
 
-    if "mamba-130m-hf" in model_name.lower():
-        command += ["--sdp_on_bf16"]
+    if "mamba" in model_name.lower():
+        from optimum.habana.utils import  get_habana_frameworks_version
+        from huggingface_hub import hf_hub_download
+        version_no = get_habana_frameworks_version()
+
+        name_kernel = "libcustom_tpc_perf_lib.so"
+        if version_no.minor == 19:
+            name_kernel = "libcustom_tpc_perf_lib_119.so"
+
+        file_kernel = hf_hub_download(repo_id="Habana/mamba", filename=name_kernel)
+
+        new_file_kernel = file_kernel
+
+        if version_no.minor == 19:
+            new_file_kernel = file_kernel[:-7] + ".so"
+            os.rename(file_kernel, new_file_kernel)
+
+        realpath_kfn = os.path.realpath(new_file_kernel)
+        kfn = os.path.basename(realpath_kfn)
+        new_kfn = os.path.join(os.path.dirname(realpath_kfn), "libcustom_tpc_perf_lib.so")
+        os.rename(realpath_kfn, new_kfn)
 
+        default_path = env_variables["GC_KERNEL_PATH"]
+        env_variables["GC_KERNEL_PATH"] = new_kfn + os.pathsep + default_path
+        
     if (reuse_cache or torch_compile) and not parallel_strategy == "tp" and not is_starcoder_first_gen_model:
         command += ["--reuse_cache"]
 

From 110fe3cefee5e92f9ddc47e306379532e6ea9b04 Mon Sep 17 00:00:00 2001
From: Zhongkai Zhang <zhzhang@habana.ai>
Date: Thu, 9 Jan 2025 23:24:53 +0000
Subject: [PATCH 2/3] Added Mamba model using kernel to improve the performance

---
 examples/image-to-text/run_pipeline.py        |  2 +-
 examples/language-modeling/run_clm.py         |  2 +-
 .../pytorch-image-models/train_hpu_graph.py   |  4 +-
 .../pytorch-image-models/train_hpu_lazy.py    |  4 +-
 .../run_speech_recognition_ctc.py             |  2 +-
 .../image_to_image_generation.py              |  6 +--
 .../text_to_image_generation.py               |  6 +--
 .../training/train_dreambooth_lora_flux.py    |  2 +-
 .../training/train_dreambooth_lora_sdxl.py    |  4 +-
 .../training/train_text_to_image_sdxl.py      |  8 ++--
 examples/summarization/run_summarization.py   |  6 +--
 examples/text-classification/run_glue.py      | 12 +++---
 examples/text-generation/run_generation.py    | 18 ++++----
 .../text-generation-pipeline/run_pipeline.py  |  6 +--
 .../run_pipeline_langchain.py                 |  4 +-
 examples/text-to-speech/run_pipeline.py       |  2 +-
 .../visual-question-answering/run_pipeline.py |  2 +-
 optimum/habana/accelerate/accelerator.py      |  6 +--
 .../pipeline_stable_diffusion_inpaint.py      |  2 +-
 ...eline_stable_diffusion_instruct_pix2pix.py |  2 +-
 .../pipeline_stable_diffusion_upscale.py      |  2 +-
 .../pipeline_stable_diffusion_xl_inpaint.py   |  2 +-
 optimum/habana/distributed/parallel_state.py  |  8 ++--
 optimum/habana/distributed/serialization.py   |  6 +--
 .../habana/transformers/generation/utils.py   | 41 +++++++++----------
 .../models/baichuan/modeling_baichuan.py      |  6 +--
 .../transformers/models/bart/modeling_bart.py |  3 +-
 .../models/chatglm/modeling_chatglm.py        |  6 +--
 .../models/falcon/modeling_falcon.py          |  4 +-
 .../models/gemma/modeling_gemma.py            |  6 +--
 .../models/gemma2/modeling_gemma2.py          |  6 +--
 .../gpt_bigcode/modeling_gpt_bigcode.py       | 12 +++---
 .../transformers/models/gptj/modeling_gptj.py |  6 +--
 .../models/llama/modeling_llama.py            |  6 +--
 .../models/mamba/modeling_mamba.py            |  5 +--
 .../transformers/models/mamba/util_mamba.py   |  2 +-
 .../models/modeling_all_models.py             |  6 +--
 .../transformers/models/opt/modeling_opt.py   |  3 +-
 .../models/qwen2_moe/modeling_qwen2_moe.py    |  6 +--
 .../seamless_m4t/modeling_seamless_m4t.py     |  2 +-
 .../models/speecht5/modeling_speecht5.py      |  3 +-
 .../transformers/models/t5/modeling_t5.py     |  2 +-
 .../transformers/models/xglm/modeling_xglm.py |  5 +--
 optimum/habana/transformers/trainer.py        |  6 +--
 optimum/habana/trl/trainer/dpo_trainer.py     |  3 +-
 optimum/habana/trl/trainer/sft_trainer.py     |  6 +--
 tests/test_diffusers.py                       | 36 ++++++++--------
 tests/test_encoder_decoder.py                 |  2 +-
 tests/test_text_generation_example.py         |  7 ++--
 tests/test_trainer.py                         |  8 ++--
 .../tests/models/gpt2/test_modeling_gpt2.py   |  6 +--
 .../models/gpt_neox/test_modeling_gpt_neox.py |  6 +--
 .../tests/test_modeling_common.py             |  6 +--
 53 files changed, 164 insertions(+), 170 deletions(-)

diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py
index 44eb8d575a..b218e81daf 100644
--- a/examples/image-to-text/run_pipeline.py
+++ b/examples/image-to-text/run_pipeline.py
@@ -355,7 +355,7 @@ def preprocess(self, image, prompt=None, timeout=None):
     throughput = total_new_tokens_generated / duration
     logger.info(f"result = {result}")
     logger.info(
-        f"time = {(end-start) * 1000 / args.n_iterations }ms, Throughput (including tokenization) = {throughput} tokens/second"
+        f"time = {(end - start) * 1000 / args.n_iterations}ms, Throughput (including tokenization) = {throughput} tokens/second"
     )
 
     # Store results if necessary
diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py
index feac065364..3ee73bc612 100644
--- a/examples/language-modeling/run_clm.py
+++ b/examples/language-modeling/run_clm.py
@@ -472,7 +472,7 @@ def main():
     else:
         model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
         n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
-        logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
+        logger.info(f"Training new model from scratch - Total size={n_params / 2**20:.2f}M params")
 
     # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
     # on a small vocab and want a smaller embedding size, remove this test.
diff --git a/examples/pytorch-image-models/train_hpu_graph.py b/examples/pytorch-image-models/train_hpu_graph.py
index 0bcfbe7295..01e11f8e88 100755
--- a/examples/pytorch-image-models/train_hpu_graph.py
+++ b/examples/pytorch-image-models/train_hpu_graph.py
@@ -1092,7 +1092,7 @@ def main():
 
     if utils.is_primary(args):
         _logger.info(
-            f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.'
+            f"Scheduled epochs: {num_epochs}. LR stepped per {'epoch' if lr_scheduler.t_in_epochs else 'update'}."
         )
 
     results = []
@@ -1324,7 +1324,7 @@ def _backward(_loss):
             if utils.is_primary(args):
                 _logger.info(
                     f"Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} "
-                    f"({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)]  "
+                    f"({100.0 * (update_idx + 1) / updates_per_epoch:>3.0f}%)]  "
                     f"Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g})  "
                     f"Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s  "
                     f"({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s)  "
diff --git a/examples/pytorch-image-models/train_hpu_lazy.py b/examples/pytorch-image-models/train_hpu_lazy.py
index bca523c9b4..f70ae7d7b6 100755
--- a/examples/pytorch-image-models/train_hpu_lazy.py
+++ b/examples/pytorch-image-models/train_hpu_lazy.py
@@ -1091,7 +1091,7 @@ def main():
 
     if utils.is_primary(args):
         _logger.info(
-            f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.'
+            f"Scheduled epochs: {num_epochs}. LR stepped per {'epoch' if lr_scheduler.t_in_epochs else 'update'}."
         )
 
     results = []
@@ -1325,7 +1325,7 @@ def _backward(_loss):
             if utils.is_primary(args):
                 _logger.info(
                     f"Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} "
-                    f"({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)]  "
+                    f"({100.0 * (update_idx + 1) / updates_per_epoch:>3.0f}%)]  "
                     f"Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g})  "
                     f"Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s  "
                     f"({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s)  "
diff --git a/examples/speech-recognition/run_speech_recognition_ctc.py b/examples/speech-recognition/run_speech_recognition_ctc.py
index 9d53e58519..f5da991dbf 100644
--- a/examples/speech-recognition/run_speech_recognition_ctc.py
+++ b/examples/speech-recognition/run_speech_recognition_ctc.py
@@ -504,7 +504,7 @@ def main():
     # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
     # that could be easily picked up by the model
     chars_to_ignore_regex = (
-        f'[{"".join(data_args.chars_to_ignore).replace(" ", "")}]' if data_args.chars_to_ignore is not None else None
+        f"[{''.join(data_args.chars_to_ignore).replace(' ', '')}]" if data_args.chars_to_ignore is not None else None
     )
     text_column_name = data_args.text_column_name
 
diff --git a/examples/stable-diffusion/image_to_image_generation.py b/examples/stable-diffusion/image_to_image_generation.py
index c76d3c0f5a..acc2536a26 100755
--- a/examples/stable-diffusion/image_to_image_generation.py
+++ b/examples/stable-diffusion/image_to_image_generation.py
@@ -370,12 +370,12 @@ def main():
             logger.info(f"Saving images in {image_save_dir.resolve()}...")
             if args.ldm3d:
                 for i, rgb in enumerate(outputs.rgb):
-                    rgb.save(image_save_dir / f"rgb_{i+1}.png")
+                    rgb.save(image_save_dir / f"rgb_{i + 1}.png")
                 for i, depth in enumerate(outputs.depth):
-                    depth.save(image_save_dir / f"depth_{i+1}.png")
+                    depth.save(image_save_dir / f"depth_{i + 1}.png")
             else:
                 for i, image in enumerate(outputs.images):
-                    image.save(image_save_dir / f"image_{i+1}.png")
+                    image.save(image_save_dir / f"image_{i + 1}.png")
         else:
             logger.warning("--output_type should be equal to 'pil' to save images in --image_save_dir.")
 
diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py
index 8fd48c99a8..b4668e7d99 100755
--- a/examples/stable-diffusion/text_to_image_generation.py
+++ b/examples/stable-diffusion/text_to_image_generation.py
@@ -687,12 +687,12 @@ def main():
             logger.info(f"Saving images in {image_save_dir.resolve()}...")
             if args.ldm3d:
                 for i, rgb in enumerate(outputs.rgb):
-                    rgb.save(image_save_dir / f"rgb_{i+1}.png")
+                    rgb.save(image_save_dir / f"rgb_{i + 1}.png")
                 for i, depth in enumerate(outputs.depth):
-                    depth.save(image_save_dir / f"depth_{i+1}.png")
+                    depth.save(image_save_dir / f"depth_{i + 1}.png")
             else:
                 for i, image in enumerate(outputs.images):
-                    image.save(image_save_dir / f"image_{i+1}.png")
+                    image.save(image_save_dir / f"image_{i + 1}.png")
         else:
             logger.warning("--output_type should be equal to 'pil' to save images in --image_save_dir.")
 
diff --git a/examples/stable-diffusion/training/train_dreambooth_lora_flux.py b/examples/stable-diffusion/training/train_dreambooth_lora_flux.py
index 68b5320d19..1117d0a43f 100755
--- a/examples/stable-diffusion/training/train_dreambooth_lora_flux.py
+++ b/examples/stable-diffusion/training/train_dreambooth_lora_flux.py
@@ -784,7 +784,7 @@ def load_model_hook(models, input_dir):
         lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
 
         transformer_state_dict = {
-            f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+            f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
         }
         transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
         incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py b/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py
index b177cf12e6..4e96ee8e0d 100755
--- a/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py
+++ b/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py
@@ -94,7 +94,7 @@ def save_model_card(
     for i, image in enumerate(images):
         image.save(os.path.join(repo_folder, f"image_{i}.png"))
         img_str += f"""
-        - text: '{validation_prompt if validation_prompt else ' ' }'
+        - text: '{validation_prompt if validation_prompt else " "}'
           output:
             url:
                 "image_{i}.png"
@@ -1083,7 +1083,7 @@ def load_model_hook(models, input_dir):
 
         lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
 
-        unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+        unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
         unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
         incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
         if incompatible_keys is not None:
diff --git a/examples/stable-diffusion/training/train_text_to_image_sdxl.py b/examples/stable-diffusion/training/train_text_to_image_sdxl.py
index b78c84bbe1..7bb96e51a1 100755
--- a/examples/stable-diffusion/training/train_text_to_image_sdxl.py
+++ b/examples/stable-diffusion/training/train_text_to_image_sdxl.py
@@ -884,9 +884,9 @@ def main(args):
     # download the dataset.
     if args.dataset_name is not None:
         if len(args.mediapipe) > 0:
-            assert (
-                args.resolution == args.crop_resolution
-            ), f"To use hardware pipe, --resolution ({args.resolution}) must equal --crop_resolution ({args.crop_resolution})"
+            assert args.resolution == args.crop_resolution, (
+                f"To use hardware pipe, --resolution ({args.resolution}) must equal --crop_resolution ({args.crop_resolution})"
+            )
             if args.local_rank == 0:
                 if not os.path.exists(args.mediapipe):
                     os.mkdir(args.mediapipe)
@@ -1532,7 +1532,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
                     image_save_dir.mkdir(parents=True, exist_ok=True)
                     logger.info(f"Saving images in {image_save_dir.resolve()}...")
                     for i, image in enumerate(images):
-                        image.save(image_save_dir / f"image_{epoch}_{i+1}.png")
+                        image.save(image_save_dir / f"image_{epoch}_{i + 1}.png")
                 else:
                     logger.warning("--output_type should be equal to 'pil' to save images in --image_save_dir.")
 
diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py
index 65755d24a2..a14e0e1dea 100755
--- a/examples/summarization/run_summarization.py
+++ b/examples/summarization/run_summarization.py
@@ -559,9 +559,9 @@ def main():
         return
 
     if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
-        assert (
-            data_args.lang is not None
-        ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
+        assert data_args.lang is not None, (
+            f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
+        )
 
         tokenizer.src_lang = data_args.lang
         tokenizer.tgt_lang = data_args.lang
diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py
index 68f5e9a2aa..6a78ecd91e 100755
--- a/examples/text-classification/run_glue.py
+++ b/examples/text-classification/run_glue.py
@@ -168,9 +168,9 @@ def __post_init__(self):
             train_extension = self.train_file.split(".")[-1]
             assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
             validation_extension = self.validation_file.split(".")[-1]
-            assert (
-                validation_extension == train_extension
-            ), "`validation_file` should have the same extension (csv or json) as `train_file`."
+            assert validation_extension == train_extension, (
+                "`validation_file` should have the same extension (csv or json) as `train_file`."
+            )
 
 
 @dataclass
@@ -338,9 +338,9 @@ def main():
             if data_args.test_file is not None:
                 train_extension = data_args.train_file.split(".")[-1]
                 test_extension = data_args.test_file.split(".")[-1]
-                assert (
-                    test_extension == train_extension
-                ), "`test_file` should have the same extension (csv or json) as `train_file`."
+                assert test_extension == train_extension, (
+                    "`test_file` should have the same extension (csv or json) as `train_file`."
+                )
                 data_files["test"] = data_args.test_file
             else:
                 raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py
index ef2252a989..e5df7f2c7c 100755
--- a/examples/text-generation/run_generation.py
+++ b/examples/text-generation/run_generation.py
@@ -526,7 +526,7 @@ def compute_valid_sequence_lengths_tensor(input_tokens):
                 profiling_record_shapes=args.profiling_record_shapes,
             ).cpu()
             first_token_time = iteration_times[0] + encode_duration
-            logger.info(f"Time to first token = {first_token_time*1000}ms")
+            logger.info(f"Time to first token = {first_token_time * 1000}ms")
             return tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
         from optimum.habana.utils import HabanaProfile
@@ -541,10 +541,10 @@ def compute_valid_sequence_lengths_tensor(input_tokens):
         if dyn_prompt_lens is None or len(set(dyn_prompt_lens)) == 1:
             for i in range(args.warmup):
                 if dyn_prompt_lens is None:
-                    print(f"Warming up iteration {i+1}/{args.warmup}", flush=True)
+                    print(f"Warming up iteration {i + 1}/{args.warmup}", flush=True)
                     generate(None, args.reduce_recompile)
                 else:
-                    print(f"Warming up for shape {dyn_prompt_lens[0]} iteration {i+1}/{args.warmup}", flush=True)
+                    print(f"Warming up for shape {dyn_prompt_lens[0]} iteration {i + 1}/{args.warmup}", flush=True)
                     generate(dyn_prompt_lens[0], args.reduce_recompile)
         else:
             if args.bucket_size > 0:
@@ -559,7 +559,7 @@ def rounder(x):
                 for i in range(args.warmup):
                     lst = list(range(min_prompt_len, max_sentence_len + 1, args.bucket_size))
                     for sz in lst:
-                        print(f"Warming up for shape {sz - 1} iteration {i+1}/{args.warmup}", flush=True)
+                        print(f"Warming up for shape {sz - 1} iteration {i + 1}/{args.warmup}", flush=True)
                         generate(sz - 1, args.reduce_recompile)
         torch_hpu.synchronize()
         compilation_duration = time.perf_counter() - t0
@@ -586,12 +586,12 @@ def rounder(x):
         all_inputs = []
         all_outputs = []
         for i, input_sentence in enumerate(zip(input_sentences)):
-            print(f"input {i+1}: {input_sentence}")
+            print(f"input {i + 1}: {input_sentence}")
             all_inputs.append(input_sentence)
             for j, output in enumerate(
                 zip(generated[args.num_return_sequences * i : args.num_return_sequences * (i + 1)])
             ):
-                print(f"output {i+1}.{j+1}: {output}")
+                print(f"output {i + 1}.{j + 1}: {output}")
                 all_outputs.append(output)
             print()
 
@@ -747,10 +747,10 @@ def generate_dataset(batch):
             duration += time.perf_counter() - t0
             total_new_tokens_generated += args.batch_size * args.max_new_tokens
             print(separator)
-            print(f"Batch n°{i+1}")
-            print(f"Input: {prompt[:args.batch_size]}")
+            print(f"Batch n°{i + 1}")
+            print(f"Input: {prompt[: args.batch_size]}")
             print(
-                f"Output: {tokenizer.batch_decode(outputs, skip_special_tokens=True)[:args.batch_size*args.num_return_sequences]}"
+                f"Output: {tokenizer.batch_decode(outputs, skip_special_tokens=True)[: args.batch_size * args.num_return_sequences]}"
             )
             print(separator)
             if args.run_partial_dataset and args.n_iterations == i + 1:
diff --git a/examples/text-generation/text-generation-pipeline/run_pipeline.py b/examples/text-generation/text-generation-pipeline/run_pipeline.py
index 43aea65cec..11e542d7a5 100644
--- a/examples/text-generation/text-generation-pipeline/run_pipeline.py
+++ b/examples/text-generation/text-generation-pipeline/run_pipeline.py
@@ -45,14 +45,14 @@ def main():
 
     duration = 0
     for iteration in range(args.n_iterations):
-        logger.info(f"Running inference iteration {iteration+1}...")
+        logger.info(f"Running inference iteration {iteration + 1}...")
         t0 = time.perf_counter()
         output = pipe(input_sentences)
         duration += time.perf_counter() - t0
 
         for i, (input_sentence, generated_text) in enumerate(zip(input_sentences, output)):
-            print(f"Prompt[{iteration+1}][{i+1}]: {input_sentence}")
-            print(f"Generated Text[{iteration+1}][{i+1}]: {repr(generated_text)}\n")
+            print(f"Prompt[{iteration + 1}][{i + 1}]: {input_sentence}")
+            print(f"Generated Text[{iteration + 1}][{i + 1}]: {repr(generated_text)}\n")
 
     throughput = args.n_iterations * args.batch_size * args.max_new_tokens / duration
     print(f"Inference Duration (for {args.n_iterations} iterations): {duration} seconds")
diff --git a/examples/text-generation/text-generation-pipeline/run_pipeline_langchain.py b/examples/text-generation/text-generation-pipeline/run_pipeline_langchain.py
index 556494cd37..6212e808aa 100644
--- a/examples/text-generation/text-generation-pipeline/run_pipeline_langchain.py
+++ b/examples/text-generation/text-generation-pipeline/run_pipeline_langchain.py
@@ -87,8 +87,8 @@ def main():
         duration += time.perf_counter() - t0
 
         for i, (question, answer) in enumerate(zip(input_questions, responses)):
-            print(f"Question[{iteration+1}][{i+1}]: {question['question']}")
-            print(f"Response[{iteration+1}][{i+1}]: {answer}\n")
+            print(f"Question[{iteration + 1}][{i + 1}]: {question['question']}")
+            print(f"Response[{iteration + 1}][{i + 1}]: {answer}\n")
 
     throughput = args.n_iterations * args.batch_size * args.max_new_tokens / duration
     print(f"Inference Duration (for {args.n_iterations} iterations): {duration} seconds")
diff --git a/examples/text-to-speech/run_pipeline.py b/examples/text-to-speech/run_pipeline.py
index 1d9b53de7d..81546b0cb9 100644
--- a/examples/text-to-speech/run_pipeline.py
+++ b/examples/text-to-speech/run_pipeline.py
@@ -129,7 +129,7 @@ def main():
                 text, batch_size=args.batch_size, forward_params=forward_params, generate_kwargs=generate_kwargs
             )
         end = time.time()
-        logger.info(f"speech = {speech} time = {(end-start) * 1000 / args.n_iterations }ms")
+        logger.info(f"speech = {speech} time = {(end - start) * 1000 / args.n_iterations}ms")
         sf.write("speech.wav", speech[0]["audio"].squeeze(), samplerate=speech[0]["sampling_rate"])
 
 
diff --git a/examples/visual-question-answering/run_pipeline.py b/examples/visual-question-answering/run_pipeline.py
index 7b4e817bb7..82b05933bc 100644
--- a/examples/visual-question-answering/run_pipeline.py
+++ b/examples/visual-question-answering/run_pipeline.py
@@ -135,7 +135,7 @@ def main():
         with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=autocast_enable):
             result = generator(model_input, batch_size=args.batch_size, topk=args.topk)
     end = time.time()
-    logger.info(f"result = {result}, time = {(end-start) * 1000/args.n_iterations}ms")
+    logger.info(f"result = {result}, time = {(end - start) * 1000 / args.n_iterations}ms")
 
 
 if __name__ == "__main__":
diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py
index b2d93730a4..f73769692d 100644
--- a/optimum/habana/accelerate/accelerator.py
+++ b/optimum/habana/accelerate/accelerator.py
@@ -197,9 +197,9 @@ def __init__(
 
         if kwargs_handlers is not None:
             for handler in kwargs_handlers:
-                assert isinstance(
-                    handler, KwargsHandler
-                ), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
+                assert isinstance(handler, KwargsHandler), (
+                    f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
+                )
                 if isinstance(handler, DistributedDataParallelKwargs):
                     if self.ddp_handler is not None:
                         raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.")
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 2884831732..f937423d13 100644
--- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -527,7 +527,7 @@ def __call__(
                         f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
                         f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
                         f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
-                        f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+                        f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
                         " `pipeline.unet` or your `mask_image` or `image` input."
                     )
             elif num_channels_unet != 4:
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index 0f8eb39f92..c4b0d0e742 100644
--- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -360,7 +360,7 @@ def __call__(
                     f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
                     f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
                     f" `num_channels_image`: {num_channels_image} "
-                    f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+                    f" = {num_channels_latents + num_channels_image}. Please verify the config of"
                     " `pipeline.unet` or your `image` input."
                 )
 
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index 58f2f977a9..136ff0dace 100644
--- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -413,7 +413,7 @@ def __call__(
                     f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
                     f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
                     f" `num_channels_image`: {num_channels_image} "
-                    f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+                    f" = {num_channels_latents + num_channels_image}. Please verify the config of"
                     " `pipeline.unet` or your `image` input."
                 )
 
diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 8d94596e3b..dab18e82e2 100644
--- a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -639,7 +639,7 @@ def denoising_value_valid(dnv):
                         f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
                         f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
                         f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
-                        f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+                        f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
                         " `pipeline.unet` or your `mask_image` or `image` input."
                     )
             elif num_channels_unet != 4:
diff --git a/optimum/habana/distributed/parallel_state.py b/optimum/habana/distributed/parallel_state.py
index c370d88229..3d5c5d9a74 100644
--- a/optimum/habana/distributed/parallel_state.py
+++ b/optimum/habana/distributed/parallel_state.py
@@ -146,9 +146,9 @@ def initialize_model_parallel(
 
     enable_ds_sequence_parallel = sequence_parallel_size > 1
     if enable_ds_sequence_parallel:
-        assert (
-            tensor_model_parallel_size == 1 and pipeline_model_parallel_size == 1
-        ), "DeepSpeed's sequence parallel does not work with tensor parallel or pipeline parallel"
+        assert tensor_model_parallel_size == 1 and pipeline_model_parallel_size == 1, (
+            "DeepSpeed's sequence parallel does not work with tensor parallel or pipeline parallel"
+        )
 
         if world_size % sequence_parallel_size != 0:
             raise RuntimeError(
@@ -168,7 +168,7 @@ def initialize_model_parallel(
 
     if virtual_pipeline_model_parallel_size is not None:
         if not pipeline_model_parallel_size > 2:
-            raise RuntimeError("pipeline-model-parallel size should be greater than 2 with " "interleaved schedule")
+            raise RuntimeError("pipeline-model-parallel size should be greater than 2 with interleaved schedule")
         global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
         global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
         _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
diff --git a/optimum/habana/distributed/serialization.py b/optimum/habana/distributed/serialization.py
index bf59fb2445..14842d24ca 100644
--- a/optimum/habana/distributed/serialization.py
+++ b/optimum/habana/distributed/serialization.py
@@ -191,9 +191,9 @@ def load_state_dict(
     assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}"
 
     if checkpoint_sharding is not None and checkpoint_sharding != "layer":
-        assert (
-            world_size == len(checkpoints)
-        ), f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}"
+        assert world_size == len(checkpoints), (
+            f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}"
+        )
 
         checkpoints = [checkpoints[rank]]
 
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index 68b445c1b2..fe198f24d9 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -1078,28 +1078,27 @@ def generate(
             assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal"
             assert generation_config.use_cache, "please set use_cache flag to use bucket_internal"
         if generation_config.reuse_cache:
-            assert (
-                self.config.model_type
-                in [
-                    "llama",
-                    "mistral",
-                    "falcon",
-                    "mixtral",
-                    "phi",
-                    "qwen2",
-                    "gptj",
-                    "starcoder2",
-                    "qwen2_moe",
-                    "gemma",
-                    "gemma2",
-                    "baichuan",
-                    "chatglm",
-                ]
-            ), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan and chatglm at the moment"
+            assert self.config.model_type in [
+                "llama",
+                "mistral",
+                "falcon",
+                "mixtral",
+                "phi",
+                "qwen2",
+                "gptj",
+                "starcoder2",
+                "qwen2_moe",
+                "gemma",
+                "gemma2",
+                "baichuan",
+                "chatglm",
+            ], (
+                "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan and chatglm at the moment"
+            )
             if not generation_config.bucket_internal:
-                assert (
-                    generation_config.bucket_size <= 0
-                ), "please set bucket_internal along with reuse_cache and bucket_size"
+                assert generation_config.bucket_size <= 0, (
+                    "please set bucket_internal along with reuse_cache and bucket_size"
+                )
             else:
                 assert generation_config.bucket_size >= 0, "please set valid bucket_size to use bucket_internal"
 
diff --git a/optimum/habana/transformers/models/baichuan/modeling_baichuan.py b/optimum/habana/transformers/models/baichuan/modeling_baichuan.py
index b733712fbb..ca9498e0f1 100644
--- a/optimum/habana/transformers/models/baichuan/modeling_baichuan.py
+++ b/optimum/habana/transformers/models/baichuan/modeling_baichuan.py
@@ -133,9 +133,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):
             self.inp_seq_len = inp_seq_len
             self.cache = torch.zeros(shape, dtype=dtype, device=device)
         else:
-            assert (
-                self.inp_seq_len == inp_seq_len
-            ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            assert self.inp_seq_len == inp_seq_len, (
+                f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            )
             self.cache.fill_(0)
 
     def update(self, prev, cur, dim, idx, inp_seq_len):
diff --git a/optimum/habana/transformers/models/bart/modeling_bart.py b/optimum/habana/transformers/models/bart/modeling_bart.py
index 3e5f822cb1..2fdfbcc6d0 100644
--- a/optimum/habana/transformers/models/bart/modeling_bart.py
+++ b/optimum/habana/transformers/models/bart/modeling_bart.py
@@ -158,8 +158,7 @@ def gaudi_BartAttention_forward(
     if layer_head_mask is not None:
         if layer_head_mask.size() != (self.num_heads,):
             raise ValueError(
-                f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
-                f" {layer_head_mask.size()}"
+                f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
             )
         attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
         attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
diff --git a/optimum/habana/transformers/models/chatglm/modeling_chatglm.py b/optimum/habana/transformers/models/chatglm/modeling_chatglm.py
index 01c508aa5d..3afa86c4a9 100644
--- a/optimum/habana/transformers/models/chatglm/modeling_chatglm.py
+++ b/optimum/habana/transformers/models/chatglm/modeling_chatglm.py
@@ -148,9 +148,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):
             # self.cache = torch.zeros(shape, dtype=dtype, device=device)
             self.cache = torch.zeros(shape, dtype=torch.bfloat16, device=device)
         else:
-            assert (
-                self.inp_seq_len == inp_seq_len
-            ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            assert self.inp_seq_len == inp_seq_len, (
+                f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            )
             self.cache.fill_(0)
 
     def update(self, prev, cur, dim, idx, inp_seq_len):
diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py
index 8895f32459..5afa728c4b 100644
--- a/optimum/habana/transformers/models/falcon/modeling_falcon.py
+++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py
@@ -1048,7 +1048,9 @@ def forward(
         """
         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         if use_flash_attention:
-            assert FusedSDPA, "`use_flash_attention` is True, but cannot find FusedSDPA. Please import it as `from habana_frameworks.torch.hpex.kernels import FusedSDPA` or set use_flash_attention to False (at the expense of a possible performance degradation)."
+            assert FusedSDPA, (
+                "`use_flash_attention` is True, but cannot find FusedSDPA. Please import it as `from habana_frameworks.torch.hpex.kernels import FusedSDPA` or set use_flash_attention to False (at the expense of a possible performance degradation)."
+            )
         if flash_attention_recompute:
             assert use_flash_attention, "flash_attention_recompute is set, but use_flash_attention is not"
         if flash_attention_causal_mask:
diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py
index 532539065d..cee3796aff 100755
--- a/optimum/habana/transformers/models/gemma/modeling_gemma.py
+++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py
@@ -132,9 +132,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):
             self.inp_seq_len = inp_seq_len
             self.cache = torch.zeros(shape, dtype=dtype, device=device)
         else:
-            assert (
-                self.inp_seq_len == inp_seq_len
-            ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            assert self.inp_seq_len == inp_seq_len, (
+                f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            )
             self.cache.fill_(0)
 
     def update(self, prev, cur, dim, idx, inp_seq_len):
diff --git a/optimum/habana/transformers/models/gemma2/modeling_gemma2.py b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py
index 4196775c19..5905e8bf3a 100755
--- a/optimum/habana/transformers/models/gemma2/modeling_gemma2.py
+++ b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py
@@ -214,9 +214,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):
             self.inp_seq_len = inp_seq_len
             self.cache = torch.zeros(shape, dtype=dtype, device=device)
         else:
-            assert (
-                self.inp_seq_len == inp_seq_len
-            ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            assert self.inp_seq_len == inp_seq_len, (
+                f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            )
             self.cache.fill_(0)
 
     def update(self, prev, cur, dim, idx, inp_seq_len):
diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index 9f451256c9..f01255624f 100644
--- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -306,9 +306,9 @@ def forward(
         - optimize KV cache
         """
         if use_flash_attention:
-            assert (
-                self.fused_scaled_dot_product_attention is not None
-            ), "Can't load HPU fused scaled dot-product attention kernel. Please retry without flash attention"
+            assert self.fused_scaled_dot_product_attention is not None, (
+                "Can't load HPU fused scaled dot-product attention kernel. Please retry without flash attention"
+            )
 
         if encoder_hidden_states is not None:
             if not hasattr(self, "q_attn") or not self.is_cross_attention:
@@ -353,9 +353,9 @@ def forward(
             present = torch.cat((key, value), dim=-1) if use_cache else None
         else:
             assert token_idx is not None, "Invalid parameters: token_idx is None at decode stage with bucket_internal"
-            assert (
-                layer_past is not None
-            ), "Invalid parameters: layer_past is None at decode stage with bucket_internal"
+            assert layer_past is not None, (
+                "Invalid parameters: layer_past is None at decode stage with bucket_internal"
+            )
 
             past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
             key = past_key.index_copy_(1, token_idx - 1, key)
diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py
index 3927e1feb9..2617a8e66a 100644
--- a/optimum/habana/transformers/models/gptj/modeling_gptj.py
+++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py
@@ -38,9 +38,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):
             self.inp_seq_len = inp_seq_len
             self.cache = torch.zeros(shape, dtype=dtype, device=device)
         else:
-            assert (
-                self.inp_seq_len == inp_seq_len
-            ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            assert self.inp_seq_len == inp_seq_len, (
+                f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            )
             self.cache.fill_(0)
 
     def update(self, prev, cur, dim, idx, inp_seq_len):
diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py
index 67f07437a1..0afcfbe05a 100755
--- a/optimum/habana/transformers/models/llama/modeling_llama.py
+++ b/optimum/habana/transformers/models/llama/modeling_llama.py
@@ -397,9 +397,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):
             self.inp_seq_len = inp_seq_len
             self.cache = torch.zeros(shape, dtype=dtype, device=device)
         else:
-            assert (
-                self.inp_seq_len == inp_seq_len
-            ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            assert self.inp_seq_len == inp_seq_len, (
+                f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            )
             self.cache.fill_(0)
 
     @staticmethod
diff --git a/optimum/habana/transformers/models/mamba/modeling_mamba.py b/optimum/habana/transformers/models/mamba/modeling_mamba.py
index 7bd10e919b..82ad159566 100644
--- a/optimum/habana/transformers/models/mamba/modeling_mamba.py
+++ b/optimum/habana/transformers/models/mamba/modeling_mamba.py
@@ -17,13 +17,12 @@
 env_variables = os.environ.copy()
 new_file_op, new_file_kernel = set_mamba_lib()
 realpath_kfn = os.path.realpath(new_file_kernel)
-kfn = os.path.basename(realpath_kfn)
 new_kfn = os.path.join(os.path.dirname(realpath_kfn), "libcustom_tpc_perf_lib.so")
 os.rename(realpath_kfn, new_kfn)
 
 env_variables["HABANA_CUSTOM_OP_DIR"] = os.path.dirname(new_file_op)
 default_path = env_variables["GC_KERNEL_PATH"]
-orig_path = '/usr/lib/habanalabs/libtpc_kernels.so'
+orig_path = "/usr/lib/habanalabs/libtpc_kernels.so"
 env_variables["GC_KERNEL_PATH"] = new_kfn + os.pathsep + default_path
 
 base_dir = env_variables["HABANA_CUSTOM_OP_DIR"]
@@ -318,4 +317,4 @@ def forward(
         cache_position: Optional[torch.LongTensor] = None,
         attention_mask: Optional[torch.LongTensor] = None,
     ):
-        return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)    
+        return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
diff --git a/optimum/habana/transformers/models/mamba/util_mamba.py b/optimum/habana/transformers/models/mamba/util_mamba.py
index 3b3e2dd9a7..a08131b3df 100644
--- a/optimum/habana/transformers/models/mamba/util_mamba.py
+++ b/optimum/habana/transformers/models/mamba/util_mamba.py
@@ -26,4 +26,4 @@ def set_mamba_lib():
         os.rename(file_op, new_file_op)
         os.rename(file_kernel, new_file_kernel)
 
-    return new_file_op, new_file_kernel
\ No newline at end of file
+    return new_file_op, new_file_kernel
diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py
index 5a78359e3a..3f9304db74 100644
--- a/optimum/habana/transformers/models/modeling_all_models.py
+++ b/optimum/habana/transformers/models/modeling_all_models.py
@@ -48,9 +48,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):
             self.inp_seq_len = inp_seq_len
             self.cache = torch.zeros(shape, dtype=dtype, device=device)
         else:
-            assert (
-                self.inp_seq_len == inp_seq_len
-            ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            assert self.inp_seq_len == inp_seq_len, (
+                f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            )
             self.cache.fill_(0)
 
     @staticmethod
diff --git a/optimum/habana/transformers/models/opt/modeling_opt.py b/optimum/habana/transformers/models/opt/modeling_opt.py
index dda2a6c204..a622eb2a7a 100644
--- a/optimum/habana/transformers/models/opt/modeling_opt.py
+++ b/optimum/habana/transformers/models/opt/modeling_opt.py
@@ -124,8 +124,7 @@ def gaudi_opt_attention_forward(
     if layer_head_mask is not None:
         if layer_head_mask.size() != (self.num_heads,):
             raise ValueError(
-                f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
-                f" {layer_head_mask.size()}"
+                f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
             )
         attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
         attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
diff --git a/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py
index 721abfa8ff..11039f9636 100755
--- a/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py
+++ b/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py
@@ -189,9 +189,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):
             self.inp_seq_len = inp_seq_len
             self.cache = torch.zeros(shape, dtype=dtype, device=device)
         else:
-            assert (
-                self.inp_seq_len == inp_seq_len
-            ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            assert self.inp_seq_len == inp_seq_len, (
+                f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+            )
             self.cache.fill_(0)
 
     @staticmethod
diff --git a/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py
index 53cea37255..061aebb3c6 100644
--- a/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py
+++ b/optimum/habana/transformers/models/seamless_m4t/modeling_seamless_m4t.py
@@ -732,7 +732,7 @@ def gaudi_SeamlessM4TForTextToSpeech_generate(
             elif tgt_lang not in lang_code_to_id:
                 raise ValueError(
                     f"""`tgt_lang={tgt_lang}` is not supported by this model.
-                Please specify a `tgt_lang` in {','.join(lang_code_to_id.keys())}. Note that SeamlessM4T supports
+                Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4T supports
                 more languages for text translation than for speech synthesis."""
                 )
     if kwargs.get("hpu_graphs", True):
diff --git a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py
index 07c4fa8a14..8ce05607a2 100644
--- a/optimum/habana/transformers/models/speecht5/modeling_speecht5.py
+++ b/optimum/habana/transformers/models/speecht5/modeling_speecht5.py
@@ -114,8 +114,7 @@ def gaudi_SpeechT5Attention_forward(
     if layer_head_mask is not None:
         if layer_head_mask.size() != (self.num_heads,):
             raise ValueError(
-                f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
-                f" {layer_head_mask.size()}"
+                f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
             )
         attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
         attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
diff --git a/optimum/habana/transformers/models/t5/modeling_t5.py b/optimum/habana/transformers/models/t5/modeling_t5.py
index b7d7f9957e..c498916b18 100644
--- a/optimum/habana/transformers/models/t5/modeling_t5.py
+++ b/optimum/habana/transformers/models/t5/modeling_t5.py
@@ -69,7 +69,7 @@ def gaudi_T5Attention_forward(
     if past_key_value is not None:
         if len(past_key_value) != 2:
             raise ValueError(
-                f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+                f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
             )
         if token_idx is None:
             real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
diff --git a/optimum/habana/transformers/models/xglm/modeling_xglm.py b/optimum/habana/transformers/models/xglm/modeling_xglm.py
index ef5a16801a..f69eb3b990 100644
--- a/optimum/habana/transformers/models/xglm/modeling_xglm.py
+++ b/optimum/habana/transformers/models/xglm/modeling_xglm.py
@@ -109,8 +109,7 @@ def gaudi_xglm_attention_forward(
     if layer_head_mask is not None:
         if layer_head_mask.size() != (self.num_heads,):
             raise ValueError(
-                f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
-                f" {layer_head_mask.size()}"
+                f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
             )
         attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
         attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -300,7 +299,7 @@ def gaudi_xglm_model_forward(
     if self.gradient_checkpointing and self.training:
         if use_cache:
             logger.warning_once(
-                "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache =" " False`..."
+                "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..."
             )
             use_cache = False
 
diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py
index ec7d31e3a6..9c44426559 100644
--- a/optimum/habana/transformers/trainer.py
+++ b/optimum/habana/transformers/trainer.py
@@ -1586,9 +1586,9 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te
             self.htcore.mark_step()
 
         if _is_peft_model(self.model) and self.model.peft_type == PeftType.ADALORA:
-            assert not (
-                self.accelerator.state.is_fp8_enabled and self.args.gradient_checkpointing
-            ), "FP8 precision with gradient_checkpointing is currently not supported with PeftType.ADALORA"
+            assert not (self.accelerator.state.is_fp8_enabled and self.args.gradient_checkpointing), (
+                "FP8 precision with gradient_checkpointing is currently not supported with PeftType.ADALORA"
+            )
             if self.is_deepspeed_enabled and not is_deepspeed_zero3_enabled():
                 self.accelerator.deepspeed_engine_wrapped.engine.backward(loss)
                 self.model.base_model.update_and_allocate(self.state.global_step)
diff --git a/optimum/habana/trl/trainer/dpo_trainer.py b/optimum/habana/trl/trainer/dpo_trainer.py
index bd07a981bb..84c48f1782 100644
--- a/optimum/habana/trl/trainer/dpo_trainer.py
+++ b/optimum/habana/trl/trainer/dpo_trainer.py
@@ -167,8 +167,7 @@ def __init__(
 
         if isinstance(ref_model, str):
             warnings.warn(
-                "You passed a ref model_id to the DPOTrainer. This will automatically create an "
-                "`AutoModelForCausalLM`"
+                "You passed a ref model_id to the DPOTrainer. This will automatically create an `AutoModelForCausalLM`"
             )
             ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
 
diff --git a/optimum/habana/trl/trainer/sft_trainer.py b/optimum/habana/trl/trainer/sft_trainer.py
index 04e648a161..6fb6365655 100644
--- a/optimum/habana/trl/trainer/sft_trainer.py
+++ b/optimum/habana/trl/trainer/sft_trainer.py
@@ -133,9 +133,9 @@ def __init__(
         - num_buckets: Number of buckets. > 0 means apply bucketing, <= 0  means no bucketing
         """
         if num_buckets > 0:
-            assert (
-                data_collator is None
-            ), "For bucketing (num_buckets > 0), we only support data_collator=None (later it becomes DataCollatorForLanguageModeling)"
+            assert data_collator is None, (
+                "For bucketing (num_buckets > 0), we only support data_collator=None (later it becomes DataCollatorForLanguageModeling)"
+            )
         if args is None:
             output_dir = "tmp_trainer"
             warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.")
diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py
index 03663b7fc8..b26878551a 100755
--- a/tests/test_diffusers.py
+++ b/tests/test_diffusers.py
@@ -1616,15 +1616,15 @@ def test_fused_qkv_projections(self):
         image = pipe(**inputs).images
         image_slice_disabled = image[0, -3:, -3:, -1]
 
-        assert np.allclose(
-            original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
-        ), "Fusion of QKV projections shouldn't affect the outputs."
-        assert np.allclose(
-            image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
-        ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
-        assert np.allclose(
-            original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
-        ), "Original outputs should match when fused QKV projections are disabled."
+        assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+            "Fusion of QKV projections shouldn't affect the outputs."
+        )
+        assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+            "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+        )
+        assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+            "Original outputs should match when fused QKV projections are disabled."
+        )
 
 
 class GaudiStableDiffusionControlNetPipelineTester(TestCase):
@@ -2536,7 +2536,7 @@ def test_train_controlnet(self):
 
             cmd_line = f"""
                     python3
-                    {path_to_script.parent.parent.parent / 'gaudi_spawn.py'}
+                    {path_to_script.parent.parent.parent / "gaudi_spawn.py"}
                     --use_mpi
                     --world_size 8
                     {path_to_script}
@@ -2624,7 +2624,7 @@ def _test_dreambooth(self, extra_config, train_text_encoder=False):
                 python3
                 {path_to_script}
                 --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
-                --instance_data_dir {Path(os.path.dirname(__file__))/'resource/img'}
+                --instance_data_dir {Path(os.path.dirname(__file__)) / "resource/img"}
                 --resolution 64
                 --train_batch_size 1
                 --gradient_accumulation_steps 1
@@ -2720,7 +2720,7 @@ def _test_dreambooth_lora_sdxl(self, train_text_encoder=False):
                 python3
                 {path_to_script}
                 --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
-                --instance_data_dir {Path(os.path.dirname(__file__))/'resource/img'}
+                --instance_data_dir {Path(os.path.dirname(__file__)) / "resource/img"}
                 --resolution 64
                 --train_batch_size 1
                 --gradient_accumulation_steps 1
@@ -5939,9 +5939,9 @@ def new_step(self, *args, **kwargs):
             inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
             latents = pipe_1(**inputs_1).images[0]
 
-            assert (
-                expected_steps_1 == done_steps
-            ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+            assert expected_steps_1 == done_steps, (
+                f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+            )
 
             inputs_2 = {
                 **inputs,
@@ -5955,9 +5955,9 @@ def new_step(self, *args, **kwargs):
             pipe_3(**inputs_3).images[0]
 
             assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
-            assert (
-                expected_steps == done_steps
-            ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+            assert expected_steps == done_steps, (
+                f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+            )
 
         for steps in [7, 11, 20]:
             for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
diff --git a/tests/test_encoder_decoder.py b/tests/test_encoder_decoder.py
index 20d808b69f..723739eb5b 100644
--- a/tests/test_encoder_decoder.py
+++ b/tests/test_encoder_decoder.py
@@ -189,7 +189,7 @@ def _test_text_translation(
             "--do_predict",
             "--source_lang en",
             "--target_lang ro",
-            '--source_prefix "translate English to Romanian: "' "--dataset_name wmt16",
+            '--source_prefix "translate English to Romanian: "--dataset_name wmt16',
             "--dataset_config_name ro-en",
             f"--per_device_eval_batch_size {batch_size}",
             f"--generation_num_beams {num_beams}",
diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py
index 9358c9f572..3e4c9fb7c5 100644
--- a/tests/test_text_generation_example.py
+++ b/tests/test_text_generation_example.py
@@ -239,7 +239,6 @@ def _test_text_generation(
             os.rename(file_kernel, new_file_kernel)
 
         realpath_kfn = os.path.realpath(new_file_kernel)
-        kfn = os.path.basename(realpath_kfn)
         new_kfn = os.path.join(os.path.dirname(realpath_kfn), "libcustom_tpc_perf_lib.so")
         os.rename(realpath_kfn, new_kfn)
 
@@ -391,9 +390,9 @@ def _test_text_generation(
 
         # Verify output for 1 HPU, BF16
         if check_output:
-            assert (
-                model_name in MODEL_OUTPUTS
-            ), f"Failed functional testing, missing expected output in MODEL_OUTPUTS for model {model_name}"
+            assert model_name in MODEL_OUTPUTS, (
+                f"Failed functional testing, missing expected output in MODEL_OUTPUTS for model {model_name}"
+            )
             expected_output = MODEL_OUTPUTS[model_name]
             assert results["output"][0][0] == expected_output
 
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
index eddb82b500..61ff958477 100644
--- a/tests/test_trainer.py
+++ b/tests/test_trainer.py
@@ -540,7 +540,7 @@ def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True):
         keys = list(state_dict.keys())
 
         shard_files = [
-            shard_name.replace(f".{extension}", f"-{idx+1:05d}-of-{len(keys):05d}.{extension}")
+            shard_name.replace(f".{extension}", f"-{idx + 1:05d}-of-{len(keys):05d}.{extension}")
             for idx in range(len(keys))
         ]
         index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}
@@ -1698,9 +1698,9 @@ def test_load_best_model_with_save(self):
             )
             trainer.train()
             # Check that we have the last known step:
-            assert os.path.exists(
-                os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}")
-            ), f"Could not find checkpoint-{trainer.state.max_steps}"
+            assert os.path.exists(os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}")), (
+                f"Could not find checkpoint-{trainer.state.max_steps}"
+            )
             # And then check the last step
             assert os.path.exists(os.path.join(tmpdir, "checkpoint-9")), "Could not find checkpoint-9"
 
diff --git a/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py b/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py
index eae4e5571a..b479f2b237 100644
--- a/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py
+++ b/tests/transformers/tests/models/gpt2/test_modeling_gpt2.py
@@ -392,9 +392,9 @@ def create_and_check_cached_forward_with_and_without_attention_mask(self, config
         model.eval()
 
         # We want this for SDPA, eager works with a `None` attention mask
-        assert (
-            model.config._attn_implementation == "sdpa"
-        ), "This test assumes the model to have the SDPA implementation for its attention calculations."
+        assert model.config._attn_implementation == "sdpa", (
+            "This test assumes the model to have the SDPA implementation for its attention calculations."
+        )
 
         # Prepare cache and non_cache input, needs a full attention mask
         cached_len = input_ids.shape[-1] // 2
diff --git a/tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py
index 14561c2080..5026ff87d8 100644
--- a/tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py
+++ b/tests/transformers/tests/models/gpt_neox/test_modeling_gpt_neox.py
@@ -213,9 +213,9 @@ def create_and_check_cached_forward_with_and_without_attention_mask(self, config
         model.to(torch_device)
         model.eval()
         # We want this for SDPA, eager works with a `None` attention mask
-        assert (
-            model.config._attn_implementation == "sdpa"
-        ), "This test assumes the model to have the SDPA implementation for its attention calculations."
+        assert model.config._attn_implementation == "sdpa", (
+            "This test assumes the model to have the SDPA implementation for its attention calculations."
+        )
         # Prepare cache and non_cache input, needs a full attention mask
         cached_len = input_ids.shape[-1] // 2
         input_mask = torch.ones(size=input_ids.size()).to(torch_device)
diff --git a/tests/transformers/tests/test_modeling_common.py b/tests/transformers/tests/test_modeling_common.py
index e08860278b..55c7aa8dae 100755
--- a/tests/transformers/tests/test_modeling_common.py
+++ b/tests/transformers/tests/test_modeling_common.py
@@ -2261,9 +2261,9 @@ def test_model_is_small(self):
         for model_class in self.all_model_classes:
             model = model_class(config)
             num_params = model.num_parameters()
-            assert (
-                num_params < 1000000
-            ), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
+            assert num_params < 1000000, (
+                f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
+            )
 
 
 global_rng = random.Random()

From 1f481abd689c43f5aea91e1bafa6a6b4acebf072 Mon Sep 17 00:00:00 2001
From: Zhongkai Zhang <zhzhang@habana.ai>
Date: Thu, 9 Jan 2025 23:24:53 +0000
Subject: [PATCH 3/3] Added Mamba model using kernel to improve the performance

---
 optimum/habana/transformers/models/mamba/modeling_mamba.py | 3 +++
 tests/test_text_generation_example.py                      | 6 ++++--
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/optimum/habana/transformers/models/mamba/modeling_mamba.py b/optimum/habana/transformers/models/mamba/modeling_mamba.py
index 82ad159566..945c49dcff 100644
--- a/optimum/habana/transformers/models/mamba/modeling_mamba.py
+++ b/optimum/habana/transformers/models/mamba/modeling_mamba.py
@@ -14,6 +14,7 @@
 
 from .util_mamba import set_mamba_lib
 
+
 env_variables = os.environ.copy()
 new_file_op, new_file_kernel = set_mamba_lib()
 realpath_kfn = os.path.realpath(new_file_kernel)
@@ -39,6 +40,7 @@
 if os.path.exists(custom_op_lib_path) and default_path != orig_path:
     use_pscan_kernel = True
 
+
 def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z):
     in_state_h = in_state.unsqueeze(1).transpose(2, 3)
     in_x_h = in_x.transpose(1, 2).unsqueeze(2)
@@ -163,6 +165,7 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation(
     )
     return model_inputs
 
+
 class gaudi_MambaMixer(nn.Module):
     """
     Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py
index 3e4c9fb7c5..14df7655d3 100644
--- a/tests/test_text_generation_example.py
+++ b/tests/test_text_generation_example.py
@@ -222,8 +222,10 @@ def _test_text_generation(
         command += ["--sdp_on_bf16"]
 
     if "mamba" in model_name.lower():
-        from optimum.habana.utils import  get_habana_frameworks_version
         from huggingface_hub import hf_hub_download
+
+        from optimum.habana.utils import get_habana_frameworks_version
+
         version_no = get_habana_frameworks_version()
 
         name_kernel = "libcustom_tpc_perf_lib.so"
@@ -244,7 +246,7 @@ def _test_text_generation(
 
         default_path = env_variables["GC_KERNEL_PATH"]
         env_variables["GC_KERNEL_PATH"] = new_kfn + os.pathsep + default_path
-        
+
     if (reuse_cache or torch_compile) and not parallel_strategy == "tp" and not is_starcoder_first_gen_model:
         command += ["--reuse_cache"]