Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question About register_cross_attention_hook and replace_call_method_for_sd3 in Attention Map Visualization #14

Open
Passenger12138 opened this issue Dec 27, 2024 · 0 comments

Comments

@Passenger12138
Copy link

Hello,

First of all, thank you for your excellent work on visualizing attention maps for DiT (Diffusion Transformer). I am currently extending your approach to visualize attention maps for video-based DiT models.

While going through the source code, I encountered the following snippet:

pipeline.transformer = register_cross_attention_hook(pipeline.transformer, hook_function, 'attn')  
pipeline.transformer = replace_call_method_for_sd3(pipeline.transformer)  

image

I understand that the register_cross_attention_hook function is used to define a hook to capture the attention map during the forward pass. However, I am confused about the necessity of the second line, replace_call_method_for_sd3.

From my understanding, the second line replaces the forward method for SD3Transformer2DModel and its submodules. However, I noticed that the code does not seem to define a custom forward process for SD3Transformer2DModel, and it appears that the original attention computation is already sufficient.

Could you please explain:

  1. Why is replace_call_method_for_sd3 necessary in this context?
  2. If the forward process is not altered, what specific purpose does this replacement serve?

Any clarification or suggestions on this would be greatly appreciated. Thank you again for your work and support!

Best regards,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant