From 190bd8e3c0c056667b1ed3e816a0ceb4c09d1b0a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 9 Jan 2025 03:24:18 -0800 Subject: [PATCH] only compare the recipe in AttentionParams.fp8_meta Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3f0267affb..529a4c57b4 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -303,6 +303,24 @@ class AttentionParams: fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None + def __eq__(self, other): + """ + Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared, + since all other entries of fp8_meta are unused in get_attention_backend. + """ + if not isinstance(other, self.__class__): + return NotImplemented + for field in fields(self): + fname = field.name + sf = getattr(self, fname) + of = getattr(other, fname) + if fname != "fp8_meta": + if sf != of: + return False + elif sf["recipe"] != of["recipe"]: + return False + return True + _alibi_cache = { "_num_heads": None,