diff --git a/examples/hunyuanvideo/README.md b/examples/hunyuanvideo/README.md index 6a2d502454..97d2ff3c93 100644 --- a/examples/hunyuanvideo/README.md +++ b/examples/hunyuanvideo/README.md @@ -9,26 +9,19 @@ Please download all checkpoints and convert them into MindSpore checkpoints following this [instruction](./ckpts/README.md). -### Run text encoder - -```bash -cd hyvideo -python run_text_encoder.py -``` - - ### Run VAE reconstruction To run a video reconstruction using the CausalVAE, please use the following command: ```bash -python hyvideo/rec_video.py \ - --video_path input_video.mp4 \ - --rec_path rec.mp4 \ - --height 360 \ - --width 640 \ - --num_frames 33 \ +python scripts/run_vae.py \ + --video-path "path/to/input_video.mp4" \ + --output-path "path/to/output_directory" \ + --rec-path "reconstructed_video.mp4" \ + --height 336 \ + --width 336 \ + --num-frames 65 \ ``` -The reconstructed video is saved under `./samples/`. To run video reconstruction on a given folder of input videos, please see `hyvideo/rec_video_folder.py` for more information. +The reconstructed video is saved under `./save_samples/`. To run reconstruction on an input image or a input folder of videos, please refer to `scripts/vae/recon_image.sh` or `scripts/vae/recon_video_folder.sh`. ## 🔑 Training @@ -37,6 +30,14 @@ The reconstructed video is saved under `./samples/`. To run video reconstruction To prepare the dataset for training HuyuanVideo, please refer to the [dataset format](./hyvideo/dataset/README.md). +### Extract Text Embeddings + +```bash +python scripts/run_text_encoder.py \ + --data-file-path /path/to/caption.csv \ + --output-path /path/to/text_embed_folder \ +``` +Please refer to `scripts/text_encoder/run_text_encoder.sh`. More details can be found by `python scripts/run_text_encoder.py --help`. ### Distributed Training diff --git a/examples/hunyuanvideo/hyvideo/dataset/text_dataset.py b/examples/hunyuanvideo/hyvideo/dataset/text_dataset.py index cb433ca0c1..120000419f 100644 --- a/examples/hunyuanvideo/hyvideo/dataset/text_dataset.py +++ b/examples/hunyuanvideo/hyvideo/dataset/text_dataset.py @@ -63,9 +63,13 @@ def read_captions(self, dataset): def __getitem__(self, idx_text): idx = self.caption_sample_indices[idx_text] row = self.dataset[idx] + assert ( + self.caption_column in row + ), f"Expected caption column {self.caption_column} in dataset, but got {row.keys()}" captions = row[self.caption_column] if isinstance(captions, str): captions = [captions] + assert self.file_column in row, f"Expected file path column {self.file_column} in dataset, but got {row.keys()}" file_path = row[self.file_column] # get the caption id first_text_index = self.caption_sample_indices.index(idx) diff --git a/examples/hunyuanvideo/scripts/run_text_encoder.py b/examples/hunyuanvideo/scripts/run_text_encoder.py index bd4b68729f..1dbef9342a 100644 --- a/examples/hunyuanvideo/scripts/run_text_encoder.py +++ b/examples/hunyuanvideo/scripts/run_text_encoder.py @@ -14,7 +14,6 @@ from hyvideo.constants import PRECISIONS, PROMPT_TEMPLATE from hyvideo.dataset.text_dataset import create_dataloader -from hyvideo.dataset.transform import text_preprocessing from hyvideo.text_encoder import TextEncoder from hyvideo.utils.message_utils import print_banner from hyvideo.utils.ms_utils import init_env @@ -28,7 +27,7 @@ def parse_args(): parser = argparse.ArgumentParser(description="HunyuanVideo text encoders") parser.add_argument( - "--text_encoder_choices", + "--text-encoder-choices", type=str, nargs="+", default=["llm", "clipL"], @@ -38,20 +37,20 @@ def parse_args(): # text encoder llm parser.add_argument( - "--text_encoder_name", + "--text-encoder-name", type=str, default="llm", help="Name of the text encoder model.", ) parser.add_argument( - "--text_encoder_precision", + "--text-encoder-precision", type=str, default="fp16", choices=PRECISIONS, help="Precision mode for the text encoder model.", ) parser.add_argument( - "--text_encoder_path", + "--text-encoder-path", type=str, default="ckpts/text_encoder", help="File path of the ckpt of the text encoder.", @@ -63,77 +62,77 @@ def parse_args(): help="Name of the tokenizer model.", ) parser.add_argument( - "--tokenizer_path", + "--tokenizer-path", type=str, default=None, help="File path of the ckpt of the tokenizer.", ) parser.add_argument( - "--text_len", + "--text-len", type=int, default=256, help="Maximum length of the text input.", ) parser.add_argument( - "--prompt_template", + "--prompt-template", type=str, default="dit-llm-encode", choices=PROMPT_TEMPLATE, help="Image prompt template for the decoder-only text encoder model.", ) parser.add_argument( - "--prompt_template_video", + "--prompt-template-video", type=str, default="dit-llm-encode-video", choices=PROMPT_TEMPLATE, help="Video prompt template for the decoder-only text encoder model.", ) parser.add_argument( - "--hidden_state_skip_layer", + "--hidden-state-skip-layer", type=int, default=2, help="Skip layer for hidden states.", ) parser.add_argument( - "--apply_final_norm", + "--apply-final-norm", action="store_true", help="Apply final normalization to the used text encoder hidden states.", ) # text encoder clipL parser.add_argument( - "--text_encoder_name_2", + "--text-encoder-name-2", type=str, default="clipL", help="Name of the second text encoder model.", ) parser.add_argument( - "--text_encoder_precision_2", + "--text-encoder-precision-2", type=str, default="fp16", choices=PRECISIONS, help="Precision mode for the second text encoder model.", ) parser.add_argument( - "--text_encoder_path_2", + "--text-encoder-path-2", type=str, default="ckpts/text_encoder_2", help="File path of the ckpt of the second text encoder.", ) parser.add_argument( - "--tokenizer_2", + "--tokenizer-2", type=str, default="clipL", help="Name of the second tokenizer model.", ) parser.add_argument( - "--tokenizer_path_2", + "--tokenizer-path-2", type=str, default=None, help="File path of the ckpt of the second tokenizer.", ) parser.add_argument( - "--text_len_2", + "--text-len-2", type=int, default=77, help="Maximum length of the second text input.", @@ -147,7 +146,7 @@ def parse_args(): help="Specify the MindSpore mode: 0 for graph mode, 1 for pynative mode", ) parser.add_argument( - "--jit_level", + "--jit-level", default="O0", type=str, choices=["O0", "O1", "O2"], @@ -157,25 +156,25 @@ def parse_args(): "O2: Ultimate performance optimization, adopt Sink execution mode.", ) parser.add_argument( - "--use_parallel", + "--use-parallel", default=False, type=str2bool, help="use parallel", ) parser.add_argument( - "--device_target", + "--device-target", type=str, default="Ascend", help="Ascend or GPU", ) parser.add_argument( - "--jit_syntax_level", + "--jit-syntax-level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax", ) parser.add_argument( - "--batch_size", + "--batch-size", default=1, type=int, help="batch size", @@ -183,16 +182,10 @@ def parse_args(): # others parser.add_argument( - "--data_file_path", + "--data-file-path", type=str, default=None, - help="File path of prompts, must be a txt or csv file.", - ) - parser.add_argument( - "--text_preprocessing", - type=str2bool, - default=False, - help="Whether do text preprocessing on input prompts.", + help="File path of prompts, must be a json or csv file.", ) parser.add_argument( "--prompt", @@ -201,18 +194,18 @@ def parse_args(): help="text prompt", ) parser.add_argument( - "--output_path", + "--output-path", type=str, default=None, help="Output dir to save the embeddings, if None, will treat the parent dir of data_file_path as output dir.", ) parser.add_argument( - "--file_column", + "--file-column", default="path", help="The column of file path in `data_file_path`. Defaults to `path`.", ) parser.add_argument( - "--caption_column", + "--caption-column", default="cap", help="The column of caption file path in `data_file_path`. Defaults to `cap`.", ) @@ -222,6 +215,12 @@ def parse_args(): def save_emb(output, output_2, output_dir, file_paths): + """ + save embedding to npz file, the saved .npz contains the following keys: + prompt_embeds: the embedding of the text encoder 1, size (text_len, 4096) + prompt_mask: the mask of the text encoder 1, size (text_len,) + prompt_embeds_2: the embedding of the text encoder 2 (CLIP), size (768, ) + """ num = output.hidden_state.shape[0] if output is not None else output_2.hidden_state.shape[0] for i in range(num): fn = Path(str(file_paths[i])).with_suffix(".npz") @@ -268,6 +267,13 @@ def build_model(args, logger): apply_final_norm=args.apply_final_norm, logger=logger, ) + logger.debug( + "Loaded TextEncoder 1 from path: %s, TextEncoder name: %s, Tokenizer type: %s, Max length: %d", + args.text_encoder_path, + args.text_encoder_name, + args.tokenizer, + max_length, + ) # clipL if args.text_encoder_name_2 in args.text_encoder_choices: @@ -280,6 +286,13 @@ def build_model(args, logger): tokenizer_path=args.tokenizer_path_2 if args.tokenizer_path_2 is not None else args.text_encoder_path_2, logger=logger, ) + logger.debug( + "Loaded TextEncoder 2 from path: %s, TextEncoder name: %s, Tokenizer type: %s, Max length: %d", + args.text_encoder_path_2, + args.text_encoder_name_2, + args.tokenizer_2, + args.text_len_2, + ) return text_encoder, text_encoder_2 @@ -297,6 +310,11 @@ def main(args): # build dataloader for large amount of captions if args.data_file_path is not None: + assert isinstance(args.data_file_path, str), "Expect data_file_path to be a string!" + assert Path(args.data_file_path).exists(), "data_file_path does not exist!" + assert args.data_file_path.endswith(".csv") or args.data_file_path.endswith( + ".json" + ), "Expect data_file_path to be a csv or json file!" ds_config = dict( data_file_path=args.data_file_path, file_column=args.file_column, @@ -317,6 +335,7 @@ def main(args): dataset_size = dataset.get_dataset_size() logger.info(f"Num batches: {dataset_size}") elif args.prompt is not None: + assert isinstance(args.prompt, str) and len(args.prompt) > 0, "Expect prompt to be a non-empty string!" data = {} prompt_fn = "-".join((args.prompt.replace("/", "").split(" ")[:16])) data["file_path"] = ["./{}.npz".format(prompt_fn)] @@ -343,18 +362,16 @@ def main(args): ds_iter = dataset.create_dict_iterator(1, output_numpy=True) else: ds_iter = prompt_iter - for step, data in tqdm(enumerate(ds_iter), total=dataset_size): + for _, data in tqdm(enumerate(ds_iter), total=dataset_size): file_paths = data["file_path"] captions = data["caption"] captions = [str(captions[i]) for i in range(len(captions))] - if args.text_preprocessing: - captions = [text_preprocessing(prompt) for prompt in captions] output, output_2 = None, None # llm if text_encoder is not None: output = text_encoder(captions, data_type="video") - print("D--: ", output.hidden_state) + # print("D--: ", output.hidden_state) # clipL if text_encoder_2 is not None: output_2 = text_encoder_2(captions, data_type="video") diff --git a/examples/hunyuanvideo/scripts/run_vae.py b/examples/hunyuanvideo/scripts/run_vae.py index f039415f52..0574ccb7f6 100644 --- a/examples/hunyuanvideo/scripts/run_vae.py +++ b/examples/hunyuanvideo/scripts/run_vae.py @@ -231,7 +231,7 @@ def get_parser(): help="Type of input data: image, video, or folder.", ) parser.add_argument( - "--output-path", type=str, default="save_videos/", help="Path to save the reconstructed video or image path." + "--output-path", type=str, default="save_samples/", help="Path to save the reconstructed video or image path." ) # Image Group parser.add_argument("--image-path", type=str, default="", help="Path to the input image file") diff --git a/examples/hunyuanvideo/scripts/text_encoder/run_text_encoder.sh b/examples/hunyuanvideo/scripts/text_encoder/run_text_encoder.sh index 74fc69abb8..5cefb0751f 100644 --- a/examples/hunyuanvideo/scripts/text_encoder/run_text_encoder.sh +++ b/examples/hunyuanvideo/scripts/text_encoder/run_text_encoder.sh @@ -1,5 +1,3 @@ -python hyvideo/run_text_encoder.py \ - --text_encoder_path /path/to/ckpt \ - --text_encoder_path_2 /path/to/ckpt \ - --data_file_path /path/to/caption.json \ - --output_path /path/to/text_embed_folder \ +python scripts/run_text_encoder.py \ + --data-file-path /path/to/caption.csv \ + --output-path /path/to/text_embed_folder \