-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModernBERT export to onnx error #35545
Comments
Hi there 👋 Could you try using this branch of Optimum: huggingface/optimum#2131? We used that to create these ONNX exports: https://huggingface.co/answerdotai/ModernBERT-base/tree/main/onnx If you'd prefer to use it in your own custom scripts, you can adapt this context manager and use it when loading the model: class DisableCompileContextManager:
def __init__(self):
self._original_compile = torch.compile
def __enter__(self):
# Turn torch.compile into a no-op
torch.compile = lambda *args, **kwargs: lambda x: x
def __exit__(self, exc_type, exc_val, exc_tb):
torch.compile = self._original_compile |
I using this branch of Optimum: huggingface/optimum#2131 optimum-cli export onnx -m checkpoints/ --task text-classification classify_model got same error. @xenova
my package
|
I am also facing same error |
I found the problem and here is my final export script. Since Flash Attention 2.0's recalculated memory access patterns and partitioning policies caused onnx to report an error because it couldn't calculate the export mapping, it was able to export using standard flash attention! @DeepakSinghRawat @xenova
|
System Info
transformers
version: 4.48.0.dev0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
When I trained a classification model based on ModernBERT tried to export to onnx with the following script.
Got errors. May Be related pytorch/pytorch#104748
https://huggingface.co/answerdotai/ModernBERT-base/discussions/10
When I read this post I modified part of the code as follows.
I got another error.
Expected behavior
export to model.onnx
The text was updated successfully, but these errors were encountered: