Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModernBERT export to onnx error #35545

Closed
2 of 4 tasks
wakaka6 opened this issue Jan 7, 2025 · 4 comments
Closed
2 of 4 tasks

ModernBERT export to onnx error #35545

wakaka6 opened this issue Jan 7, 2025 · 4 comments

Comments

@wakaka6
Copy link

wakaka6 commented Jan 7, 2025

System Info

  • transformers version: 4.48.0.dev0
  • Platform: Linux-5.15.0-84-generic-x86_64-with-glibc2.35
  • Python version: 3.11.11
  • Huggingface_hub version: 0.27.0
  • Safetensors version: 0.4.5
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 4090

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When I trained a classification model based on ModernBERT tried to export to onnx with the following script.

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification


def export():

    tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base", model_max_length=4096)
    model = AutoModelForSequenceClassification.from_pretrained(
            "./checkpoints",
            num_labels=3,
            # reference_compile=False,
            )

    model.eval()

    samples = ['examples']

    tokenized = tokenizer(samples,
            return_tensors='pt',
            max_length=tokenizer.model_max_length,
            padding='max_length',
            truncation=True)
    input_ids = tokenized['input_ids'].to('cuda')
    attention_mask = tokenized['attention_mask'].to('cuda')
    model = model.to('cuda')

    with torch.no_grad():
        torch.onnx.export(
                model,
                (input_ids, attention_mask),
                './model.onnx',
                input_names=["input_ids", "attention_mask"],
                output_names=["logits"],
                )

if __name__ == '__main__':
    export()

Got errors. May Be related pytorch/pytorch#104748

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in ModernBertForSequenceClassification is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py:711: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_seqlen_in_batch = int(seqlens_in_batch.max().item())
Traceback (most recent call last):
  File "/modernBERT/export_onnx.py", line 39, in <module>
    export()
  File "/modernBERT/export_onnx.py", line 28, in export
    torch.onnx.export(
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
    outs = ONNXTracedModule(
           ^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/jit/_trace.py", line 139, in forward
    graph, out = torch._C._create_graph_by_tracing(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/jit/_trace.py", line 130, in wrapper
    outs.append(self.inner(*trace_inputs))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 1160, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 895, in forward
    hidden_states = self.embeddings(input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 210, in forward
    self.compiled_embeddings(input_ids)
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 444, in _fn
    raise RuntimeError(
RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function. This is not supported at the moment.

https://huggingface.co/answerdotai/ModernBERT-base/discussions/10
When I read this post I modified part of the code as follows.

    model = AutoModelForSequenceClassification.from_pretrained(
            "./checkpoints",
            num_labels=3,
            reference_compile=False,
            )

I got another error.

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in ModernBertForSequenceClassification is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py:711: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_seqlen_in_batch = int(seqlens_in_batch.max().item())
/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py:166: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert sin.shape == cos.shape
/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py:168: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert headdim <= 256, "Only support headdim <= 256"
/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py:170: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py:185: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert seqlen_offsets + seqlen <= seqlen_ro
/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py:188: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if rotary_dim < headdim and not inplace:
/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py:193: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if rotary_dim <= 32
/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py:194: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py:197: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4)
Traceback (most recent call last):
  File "/modernBERT/export_onnx.py", line 39, in <module>
    export()
  File "/modernBERT/export_onnx.py", line 28, in export
    torch.onnx.export(
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
    outs = ONNXTracedModule(
           ^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/jit/_trace.py", line 139, in forward
    graph, out = torch._C._create_graph_by_tracing(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/jit/_trace.py", line 130, in wrapper
    outs.append(self.inner(*trace_inputs))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 1160, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 913, in forward
    layer_outputs = encoder_layer(
                    ^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 529, in forward
    attn_outputs = self.attn(
                   ^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 487, in forward
    attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 349, in flash_attention_forward
    qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 178, in forward
    qkv = apply_rotary_unpadded(
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 136, in apply_rotary_unpadded
    return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 75, in forward
    apply_rotary(
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/flash_attn/ops/triton/rotary.py", line 202, in apply_rotary
    rotary_kernel[grid](
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
             ^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/triton/compiler/compiler.py", line 276, in compile
    module = src.make_ir(options, codegen_fns, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/miniconda3/envs/bert/lib/python3.11/site-packages/triton/compiler/compiler.py", line 113, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 32:22:
    # Meta-parameters
    BLOCK_K: tl.constexpr,
    IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    INTERLEAVED: tl.constexpr,
    CONJUGATE: tl.constexpr,
    BLOCK_M: tl.constexpr,
):
    pid_m = tl.program_id(axis=0)
    pid_batch = tl.program_id(axis=1)
    pid_head = tl.program_id(axis=2)
    rotary_dim_half = rotary_dim // 2
                      ^
IncompatibleTypeErrorImpl('invalid operands of type pointer<int64> and triton.language.int32')

Expected behavior

export to model.onnx

@wakaka6 wakaka6 added the bug label Jan 7, 2025
@xenova
Copy link
Contributor

xenova commented Jan 7, 2025

Hi there 👋 Could you try using this branch of Optimum: huggingface/optimum#2131?

We used that to create these ONNX exports: https://huggingface.co/answerdotai/ModernBERT-base/tree/main/onnx

If you'd prefer to use it in your own custom scripts, you can adapt this context manager and use it when loading the model:

class DisableCompileContextManager:
    def __init__(self):
        self._original_compile = torch.compile

    def __enter__(self):
        # Turn torch.compile into a no-op
        torch.compile = lambda *args, **kwargs: lambda x: x

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.compile = self._original_compile

@wakaka6
Copy link
Author

wakaka6 commented Jan 8, 2025

I using this branch of Optimum: huggingface/optimum#2131

optimum-cli export onnx -m checkpoints/ --task text-classification classify_model

got same error. @xenova

triton.compiler.errors.CompilationError: at 32:22:
    # Meta-parameters
    BLOCK_K: tl.constexpr,
    IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    INTERLEAVED: tl.constexpr,
    CONJUGATE: tl.constexpr,
    BLOCK_M: tl.constexpr,
):
    pid_m = tl.program_id(axis=0)
    pid_batch = tl.program_id(axis=1)
    pid_head = tl.program_id(axis=2)
    rotary_dim_half = rotary_dim // 2
                      ^
IncompatibleTypeErrorImpl('invalid operands of type pointer<int64> and triton.language.int32')

my package

accelerate                1.2.1
aiohappyeyeballs          2.4.4
aiohttp                   3.11.11
aiosignal                 1.3.2
anyio                     4.7.0
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 3.0.0
async-lru                 2.0.4
attrs                     24.3.0
babel                     2.16.0
beautifulsoup4            4.12.3
bleach                    6.2.0
Brotli                    1.0.9
certifi                   2024.12.14
cffi                      1.17.1
charset-normalizer        3.4.1
coloredlogs               15.0.1
comm                      0.2.2
contourpy                 1.3.1
cycler                    0.12.1
datasets                  3.2.0
debugpy                   1.8.11
decorator                 5.1.1
defusedxml                0.7.1
dill                      0.3.8
einops                    0.8.0
evaluate                  0.4.3
executing                 2.1.0
fastjsonschema            2.21.1
filelock                  3.16.1
flash-attn                2.7.2.post1
flatbuffers               24.12.23
fonttools                 4.55.3
fqdn                      1.5.1
frozenlist                1.5.0
fsspec                    2024.9.0
gmpy2                     2.1.2
h11                       0.14.0
httpcore                  1.0.7
httpx                     0.28.1
huggingface-hub           0.27.1
humanfriendly             10.0
idna                      3.10
ipykernel                 6.29.5
ipython                   8.31.0
isoduration               20.11.0
jedi                      0.19.2
Jinja2                    3.1.5
joblib                    1.4.2
json5                     0.10.0
jsonpointer               3.0.0
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
jupyter_client            8.6.3
jupyter_core              5.7.2
jupyter-events            0.11.0
jupyter-lsp               2.2.5
jupyter_server            2.15.0
jupyter_server_terminals  0.5.3
jupyterlab                4.3.4
jupyterlab_pygments       0.3.0
jupyterlab_server         2.27.3
kiwisolver                1.4.8
MarkupSafe                3.0.2
matplotlib                3.10.0
matplotlib-inline         0.1.7
mistune                   3.1.0
mkl_fft                   1.3.11
mkl_random                1.2.8
mkl-service               2.4.0
mpmath                    1.3.0
multidict                 6.1.0
multiprocess              0.70.16
nbclient                  0.10.2
nbconvert                 7.16.4
nbformat                  5.10.4
nest-asyncio              1.6.0
networkx                  3.4.2
notebook_shim             0.2.4
numpy                     2.2.1
onnx                      1.17.0
onnxruntime               1.20.1
optimum                   1.24.0.dev0
overrides                 7.7.0
packaging                 24.2
pandas                    2.2.3
pandocfilters             1.5.1
parso                     0.8.4
pexpect                   4.9.0
pillow                    11.0.0
pip                       24.2
platformdirs              4.3.6
prometheus_client         0.21.1
prompt_toolkit            3.0.48
propcache                 0.2.1
protobuf                  5.29.2
psutil                    6.1.1
ptyprocess                0.7.0
pure_eval                 0.2.3
pyarrow                   18.1.0
pycparser                 2.22
Pygments                  2.18.0
pyparsing                 3.2.1
PySocks                   1.7.1
python-dateutil           2.9.0.post0
python-json-logger        3.2.1
pytz                      2024.2
PyYAML                    6.0.2
pyzmq                     26.2.0
referencing               0.35.1
regex                     2024.11.6
requests                  2.32.3
rfc3339-validator         0.1.4
rfc3986-validator         0.1.1
rpds-py                   0.22.3
safetensors               0.5.1
scikit-learn              1.6.0
scipy                     1.15.0
Send2Trash                1.8.3
setuptools                75.1.0
six                       1.17.0
sniffio                   1.3.1
soupsieve                 2.6
stack-data                0.6.3
sympy                     1.13.1
terminado                 0.18.1
threadpoolctl             3.5.0
tinycss2                  1.4.0
tokenizers                0.21.0
torch                     2.5.1
torchaudio                2.5.1
torchvision               0.20.1
tornado                   6.4.2
tqdm                      4.67.1
traitlets                 5.14.3
transformers              4.48.0.dev0
triton                    3.1.0
types-python-dateutil     2.9.0.20241206
typing_extensions         4.12.2
tzdata                    2024.2
uri-template              1.3.0
urllib3                   2.3.0
wcwidth                   0.2.13
webcolors                 24.11.1
webencodings              0.5.1
websocket-client          1.8.0
wheel                     0.44.0
xxhash                    3.5.0
yarl                      1.18.3

@DeepakSinghRawat
Copy link

I am also facing same error IncompatibleTypeErrorImpl('invalid operands of type pointer<int64> and triton.language.int32') when using huggingface/optimum#2131

@wakaka6
Copy link
Author

wakaka6 commented Jan 14, 2025

I found the problem and here is my final export script. Since Flash Attention 2.0's recalculated memory access patterns and partitioning policies caused onnx to report an error because it couldn't calculate the export mapping, it was able to export using standard flash attention! @DeepakSinghRawat @xenova

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

class DisableCompileContextManager:
    def __init__(self):
        self._original_compile = torch.compile

    def __enter__(self):
        # Turn torch.compile into a no-op
        torch.compile = lambda *args, **kwargs: lambda x: x

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.compile = self._original_compile

def export():

    with DisableCompileContextManager():
        tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base", model_max_length=4096)
        model = AutoModelForSequenceClassification.from_pretrained(
                "./checkpoints",
                num_labels=3,
                # Flash Attention 2.0's recalculated memory access pattern and partitioning strategy causes
                # onnx to report an error by not being able to compute the export map
                attn_implementation="eager"  # Use standard attention
                # reference_compile=False, # disable triton compile
                )

        model.eval()

        samples = ['example']

        tokenized = tokenizer(samples,
                return_tensors='pt',
                max_length=tokenizer.model_max_length,
                padding='max_length',
                truncation=True)
        input_ids = tokenized['input_ids'].to('cuda')
        attention_mask = tokenized['attention_mask'].to('cuda')
        model = model.to('cuda')

        with torch.no_grad():
            torch.onnx.export(
                    model,
                    (input_ids, attention_mask),
                    './model.onnx',
                    input_names=["input_ids", "attention_mask"],
                    output_names=["logits"],
                    )



if __name__ == '__main__':
    export()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants