Skip to content

Commit

Permalink
update text encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Feb 12, 2025
1 parent 71310a7 commit 463314d
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 58 deletions.
31 changes: 16 additions & 15 deletions examples/hunyuanvideo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,19 @@

Please download all checkpoints and convert them into MindSpore checkpoints following this [instruction](./ckpts/README.md).

### Run text encoder

```bash
cd hyvideo
python run_text_encoder.py
```


### Run VAE reconstruction

To run a video reconstruction using the CausalVAE, please use the following command:
```bash
python hyvideo/rec_video.py \
--video_path input_video.mp4 \
--rec_path rec.mp4 \
--height 360 \
--width 640 \
--num_frames 33 \
python scripts/run_vae.py \
--video-path "path/to/input_video.mp4" \
--output-path "path/to/output_directory" \
--rec-path "reconstructed_video.mp4" \
--height 336 \
--width 336 \
--num-frames 65 \
```
The reconstructed video is saved under `./samples/`. To run video reconstruction on a given folder of input videos, please see `hyvideo/rec_video_folder.py` for more information.
The reconstructed video is saved under `./save_samples/`. To run reconstruction on an input image or a input folder of videos, please refer to `scripts/vae/recon_image.sh` or `scripts/vae/recon_video_folder.sh`.


## 🔑 Training
Expand All @@ -37,6 +30,14 @@ The reconstructed video is saved under `./samples/`. To run video reconstruction

To prepare the dataset for training HuyuanVideo, please refer to the [dataset format](./hyvideo/dataset/README.md).

### Extract Text Embeddings

```bash
python scripts/run_text_encoder.py \
--data-file-path /path/to/caption.csv \
--output-path /path/to/text_embed_folder \
```
Please refer to `scripts/text_encoder/run_text_encoder.sh`. More details can be found by `python scripts/run_text_encoder.py --help`.

### Distributed Training

Expand Down
4 changes: 4 additions & 0 deletions examples/hunyuanvideo/hyvideo/dataset/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def read_captions(self, dataset):
def __getitem__(self, idx_text):
idx = self.caption_sample_indices[idx_text]
row = self.dataset[idx]
assert (
self.caption_column in row
), f"Expected caption column {self.caption_column} in dataset, but got {row.keys()}"
captions = row[self.caption_column]
if isinstance(captions, str):
captions = [captions]
assert self.file_column in row, f"Expected file path column {self.file_column} in dataset, but got {row.keys()}"
file_path = row[self.file_column]
# get the caption id
first_text_index = self.caption_sample_indices.index(idx)
Expand Down
91 changes: 54 additions & 37 deletions examples/hunyuanvideo/scripts/run_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from hyvideo.constants import PRECISIONS, PROMPT_TEMPLATE
from hyvideo.dataset.text_dataset import create_dataloader
from hyvideo.dataset.transform import text_preprocessing
from hyvideo.text_encoder import TextEncoder
from hyvideo.utils.message_utils import print_banner
from hyvideo.utils.ms_utils import init_env
Expand All @@ -28,7 +27,7 @@
def parse_args():
parser = argparse.ArgumentParser(description="HunyuanVideo text encoders")
parser.add_argument(
"--text_encoder_choices",
"--text-encoder-choices",
type=str,
nargs="+",
default=["llm", "clipL"],
Expand All @@ -38,20 +37,20 @@ def parse_args():

# text encoder llm
parser.add_argument(
"--text_encoder_name",
"--text-encoder-name",
type=str,
default="llm",
help="Name of the text encoder model.",
)
parser.add_argument(
"--text_encoder_precision",
"--text-encoder-precision",
type=str,
default="fp16",
choices=PRECISIONS,
help="Precision mode for the text encoder model.",
)
parser.add_argument(
"--text_encoder_path",
"--text-encoder-path",
type=str,
default="ckpts/text_encoder",
help="File path of the ckpt of the text encoder.",
Expand All @@ -63,77 +62,77 @@ def parse_args():
help="Name of the tokenizer model.",
)
parser.add_argument(
"--tokenizer_path",
"--tokenizer-path",
type=str,
default=None,
help="File path of the ckpt of the tokenizer.",
)
parser.add_argument(
"--text_len",
"--text-len",
type=int,
default=256,
help="Maximum length of the text input.",
)
parser.add_argument(
"--prompt_template",
"--prompt-template",
type=str,
default="dit-llm-encode",
choices=PROMPT_TEMPLATE,
help="Image prompt template for the decoder-only text encoder model.",
)
parser.add_argument(
"--prompt_template_video",
"--prompt-template-video",
type=str,
default="dit-llm-encode-video",
choices=PROMPT_TEMPLATE,
help="Video prompt template for the decoder-only text encoder model.",
)
parser.add_argument(
"--hidden_state_skip_layer",
"--hidden-state-skip-layer",
type=int,
default=2,
help="Skip layer for hidden states.",
)
parser.add_argument(
"--apply_final_norm",
"--apply-final-norm",
action="store_true",
help="Apply final normalization to the used text encoder hidden states.",
)

# text encoder clipL
parser.add_argument(
"--text_encoder_name_2",
"--text-encoder-name-2",
type=str,
default="clipL",
help="Name of the second text encoder model.",
)
parser.add_argument(
"--text_encoder_precision_2",
"--text-encoder-precision-2",
type=str,
default="fp16",
choices=PRECISIONS,
help="Precision mode for the second text encoder model.",
)
parser.add_argument(
"--text_encoder_path_2",
"--text-encoder-path-2",
type=str,
default="ckpts/text_encoder_2",
help="File path of the ckpt of the second text encoder.",
)
parser.add_argument(
"--tokenizer_2",
"--tokenizer-2",
type=str,
default="clipL",
help="Name of the second tokenizer model.",
)
parser.add_argument(
"--tokenizer_path_2",
"--tokenizer-path-2",
type=str,
default=None,
help="File path of the ckpt of the second tokenizer.",
)
parser.add_argument(
"--text_len_2",
"--text-len-2",
type=int,
default=77,
help="Maximum length of the second text input.",
Expand All @@ -147,7 +146,7 @@ def parse_args():
help="Specify the MindSpore mode: 0 for graph mode, 1 for pynative mode",
)
parser.add_argument(
"--jit_level",
"--jit-level",
default="O0",
type=str,
choices=["O0", "O1", "O2"],
Expand All @@ -157,42 +156,36 @@ def parse_args():
"O2: Ultimate performance optimization, adopt Sink execution mode.",
)
parser.add_argument(
"--use_parallel",
"--use-parallel",
default=False,
type=str2bool,
help="use parallel",
)
parser.add_argument(
"--device_target",
"--device-target",
type=str,
default="Ascend",
help="Ascend or GPU",
)
parser.add_argument(
"--jit_syntax_level",
"--jit-syntax-level",
default="strict",
choices=["strict", "lax"],
help="Set jit syntax level: strict or lax",
)
parser.add_argument(
"--batch_size",
"--batch-size",
default=1,
type=int,
help="batch size",
)

# others
parser.add_argument(
"--data_file_path",
"--data-file-path",
type=str,
default=None,
help="File path of prompts, must be a txt or csv file.",
)
parser.add_argument(
"--text_preprocessing",
type=str2bool,
default=False,
help="Whether do text preprocessing on input prompts.",
help="File path of prompts, must be a json or csv file.",
)
parser.add_argument(
"--prompt",
Expand All @@ -201,18 +194,18 @@ def parse_args():
help="text prompt",
)
parser.add_argument(
"--output_path",
"--output-path",
type=str,
default=None,
help="Output dir to save the embeddings, if None, will treat the parent dir of data_file_path as output dir.",
)
parser.add_argument(
"--file_column",
"--file-column",
default="path",
help="The column of file path in `data_file_path`. Defaults to `path`.",
)
parser.add_argument(
"--caption_column",
"--caption-column",
default="cap",
help="The column of caption file path in `data_file_path`. Defaults to `cap`.",
)
Expand All @@ -222,6 +215,12 @@ def parse_args():


def save_emb(output, output_2, output_dir, file_paths):
"""
save embedding to npz file, the saved .npz contains the following keys:
prompt_embeds: the embedding of the text encoder 1, size (text_len, 4096)
prompt_mask: the mask of the text encoder 1, size (text_len,)
prompt_embeds_2: the embedding of the text encoder 2 (CLIP), size (768, )
"""
num = output.hidden_state.shape[0] if output is not None else output_2.hidden_state.shape[0]
for i in range(num):
fn = Path(str(file_paths[i])).with_suffix(".npz")
Expand Down Expand Up @@ -268,6 +267,13 @@ def build_model(args, logger):
apply_final_norm=args.apply_final_norm,
logger=logger,
)
logger.debug(
"Loaded TextEncoder 1 from path: %s, TextEncoder name: %s, Tokenizer type: %s, Max length: %d",
args.text_encoder_path,
args.text_encoder_name,
args.tokenizer,
max_length,
)

# clipL
if args.text_encoder_name_2 in args.text_encoder_choices:
Expand All @@ -280,6 +286,13 @@ def build_model(args, logger):
tokenizer_path=args.tokenizer_path_2 if args.tokenizer_path_2 is not None else args.text_encoder_path_2,
logger=logger,
)
logger.debug(
"Loaded TextEncoder 2 from path: %s, TextEncoder name: %s, Tokenizer type: %s, Max length: %d",
args.text_encoder_path_2,
args.text_encoder_name_2,
args.tokenizer_2,
args.text_len_2,
)

return text_encoder, text_encoder_2

Expand All @@ -297,6 +310,11 @@ def main(args):

# build dataloader for large amount of captions
if args.data_file_path is not None:
assert isinstance(args.data_file_path, str), "Expect data_file_path to be a string!"
assert Path(args.data_file_path).exists(), "data_file_path does not exist!"
assert args.data_file_path.endswith(".csv") or args.data_file_path.endswith(
".json"
), "Expect data_file_path to be a csv or json file!"
ds_config = dict(
data_file_path=args.data_file_path,
file_column=args.file_column,
Expand All @@ -317,6 +335,7 @@ def main(args):
dataset_size = dataset.get_dataset_size()
logger.info(f"Num batches: {dataset_size}")
elif args.prompt is not None:
assert isinstance(args.prompt, str) and len(args.prompt) > 0, "Expect prompt to be a non-empty string!"
data = {}
prompt_fn = "-".join((args.prompt.replace("/", "").split(" ")[:16]))
data["file_path"] = ["./{}.npz".format(prompt_fn)]
Expand All @@ -343,18 +362,16 @@ def main(args):
ds_iter = dataset.create_dict_iterator(1, output_numpy=True)
else:
ds_iter = prompt_iter
for step, data in tqdm(enumerate(ds_iter), total=dataset_size):
for _, data in tqdm(enumerate(ds_iter), total=dataset_size):
file_paths = data["file_path"]
captions = data["caption"]
captions = [str(captions[i]) for i in range(len(captions))]
if args.text_preprocessing:
captions = [text_preprocessing(prompt) for prompt in captions]

output, output_2 = None, None
# llm
if text_encoder is not None:
output = text_encoder(captions, data_type="video")
print("D--: ", output.hidden_state)
# print("D--: ", output.hidden_state)
# clipL
if text_encoder_2 is not None:
output_2 = text_encoder_2(captions, data_type="video")
Expand Down
2 changes: 1 addition & 1 deletion examples/hunyuanvideo/scripts/run_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def get_parser():
help="Type of input data: image, video, or folder.",
)
parser.add_argument(
"--output-path", type=str, default="save_videos/", help="Path to save the reconstructed video or image path."
"--output-path", type=str, default="save_samples/", help="Path to save the reconstructed video or image path."
)
# Image Group
parser.add_argument("--image-path", type=str, default="", help="Path to the input image file")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
python hyvideo/run_text_encoder.py \
--text_encoder_path /path/to/ckpt \
--text_encoder_path_2 /path/to/ckpt \
--data_file_path /path/to/caption.json \
--output_path /path/to/text_embed_folder \
python scripts/run_text_encoder.py \
--data-file-path /path/to/caption.csv \
--output-path /path/to/text_embed_folder \

0 comments on commit 463314d

Please sign in to comment.