Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Flux model in diffusers #763

Open
JingyaHuang opened this issue Jan 15, 2025 · 3 comments
Open

Support for Flux model in diffusers #763

JingyaHuang opened this issue Jan 15, 2025 · 3 comments
Assignees

Comments

@JingyaHuang
Copy link
Collaborator

JingyaHuang commented Jan 15, 2025

Feature request

Hey @yahavb, to continue our discussion on supporting the Flux model here.

1st step: Export support

Add export support for each component

  • Text encoder 1

No need to add the wrapper, it should be already handled by CLIPTextNeuronConfig.

  • Text encoder 2

Can reuse the T5EncoderForDiffusersNeuronConfig class, in which the T5EncoderWrapper is equivalent to TracingT5TextEncoderWrapper.
For tp support, will need to enable tensor_parallel_size arg for many functions in optimum/exporters/neuron/__main__.py (eg. _get_submodels_and_neuron_configs_for_stable_diffusion). And we need to write a get_parallel_callable for the neuron config class to shard the model, we can take the one in T5EncoderForTransformersNeuronConfig as a reference.

  • Transformer

In the script, transformer_embedders / transformer_blocks / ingle_transformer_blocks / transformer_out_layers are traced separately, would it be possible to fit into neuron devices while tracing the whole transformer together? If it's not the case, in Optimum Neuron logic, each traced component means a NeuronConfig class, then you will need to create a class for each component and build a function to fetch each module (similar to _get_submodels_and_neuron_configs_for_stable_diffusion), otherwise, you would be able to create a single FluxTransformerNeuronConfig, and add the tp config there.

  • VAE decoder

Nothing to do, already supported

Test

Here is an example of testing the compilation:

optimum-cli export neuron --model black-forest-labs/FLUX.1-dev --tensor_parallel_size 8 --batch_size 1 --height 1024 --width 1024 --num_images_per_prompt 1 --sequence_length 512 --torch_dtype bfloat16 flux_neuron/

2nd step: Inference support

Depends on the 1st step to decide if we need to override some functions in FluxPipeline (if we traced transformer in multiple module), but whatever we need to put under optimum/neuron/pipelines/diffusers. We will need to add the following class under optimum/neuron/modeling_diffusion.py:

class NeuronFluxPipeline(NeuronDiffusionPipelineBase, FluxPipeline):
    main_input_name = "prompt"
    auto_model_class = FluxPipeline

Reference PRs:

Motivation

Add Flux support

Your contribution

Will collaborate with @yahavb on this.

@JingyaHuang JingyaHuang self-assigned this Jan 15, 2025
@yahavb
Copy link
Contributor

yahavb commented Jan 21, 2025

Started to work on the T5Encoder and need help with the expected call flow in export model (not inference yet). Changes are in https://github.com/yahavb/optimum-neuron/tree/main/optimum/exporters/neuron. I can't figure out how to add the sharding functions into the flow. Here are the steps (followed by changes I made):

  1. added tp size in the _get_submodels_and_neuron_configs_for_stable_diffusion and append it in the T5EncoderForDiffusersNeuronConfig
  2. In the T5EncoderForDiffusersNeuronConfig I added create_optimized_model that calls a new wrapper and get_parallel_callable that does the same thing.
  3. I added in model_wrappers.py the T5EncoderWrapperWithTP wrapper and text_encoder_sharding.py with all the sharding functions.

I can't figure out how to trigger the call create_optimized_model so will appreciate any hint if I'm in the right direction.

@yahavb
Copy link
Contributor

yahavb commented Jan 22, 2025

I also see that optimum/exporters/neuron/model_configs/traced_configs.py already included with tensor_parallel_size and calls to neuronx_distributed.parallel_layers.load when tp>1. Isn't this the flow you meant in Text encoder 2 above?

@JingyaHuang
Copy link
Collaborator Author

Hi @yahavb, as you mentioned, the attributes, functions related to the sharding of T5 encoder shall be within the NeuronConfig(eg tensor_parallel_size and create_optimized_model). The idea is that everything model-specific should be within the Neuron config dedicated to the model(I saw you put the config of T5 in optimum/exporters/neuron/__main__.py, there we only put general functions for all models).

To locate how the sharding is done, let's take T5 encoder as an example:

  • 1st step: load vanilla model and define its neuron config(model.neuron_config)

[In __main__.py] main_export -> load_models_and_neuron_configs -> get_submodels_and_neuron_configs -> _get_submodels_and_neuron_configs_for_stable_diffusion -> [In utils.py]get_diffusion_models_for_export(you shall add tensor_parallel_size as argument like in get_encoder_decoder_models_for_export) -> Instantiate the neuron config as text_encoder_config_constructor

untill here, the model is not yet sharded, but we have elements we need for the sharding within the neuron config of T5 encoder(tensor_parallel_size, the wrapper, patch_model_for_export: which returns the callable sent to neuronx_distributed.trace.parallel_model_trace and generate_io_aliases: we don't need it for flux).

  • 2nd step: send the callable defining how to shard the model to neuronx_distributed.trace.parallel_model_trace during the export

[In __main__.py] main_export -> [In convert.py] export_neuronx in which we call config.patch_model_for_export what we defined in the 1st step before sending it to trace.parallel_model_trace.

Besides, for sharding the model we actually call the ParallelizersManager.parallelizer_for_model which is shared with the training...

So,

  • You shall complete T5EncoderForDiffusersNeuronConfig (I built this class for pixart) with all we need for tp. It should be quite similar to T5EncoderForTransformersNeuronConfig, the one we have for t5 text generation, the difference would be, we don't need past key values as output, and with different encoder wrapper(no need to initialize the past key values with the "encoder")

Yeah it's what I meant in the description for Text encoder 2. Let me know if it's not clear enough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants