From a7f6a2d4d63fb5abcffa1e451d3408a13c8610e9 Mon Sep 17 00:00:00 2001 From: johnweck <83259326+jwmao1@users.noreply.github.com> Date: Sat, 18 Jan 2025 16:59:08 +0800 Subject: [PATCH] Update run.py --- run.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/run.py b/run.py index ffdf81e..33f37a5 100644 --- a/run.py +++ b/run.py @@ -488,6 +488,7 @@ parser.add_argument('--base_model_path', default=r"./RealVisXL_V4.0", type=str) parser.add_argument('--image_encoder_path', type=str, default=r"./IP-Adapter/sdxl_models/image_encoder") parser.add_argument('--ip_ckpt', default=r"./IP-Adapter/sdxl_models/ip-adapter_sdxl.bin", type=str) +parser.add_argument('--style', type=str, default='comic', choices=["comic","film","realistic"]) parser.add_argument('--device', default="cuda", type=str) parser.add_argument('--story', default=story1, nargs='+', type=str) @@ -497,6 +498,7 @@ image_encoder_path = args.image_encoder_path ip_ckpt = args.ip_ckpt device = args.device +style = args.style def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols @@ -553,7 +555,7 @@ def image_grid(imgs, rows, cols): for i, text in enumerate(prompts): images = storyadapter.generate(pil_image=None, num_samples=1, num_inference_steps=50, seed=seed, - prompt=text, scale=0.3, use_image=False) + prompt=text, scale=0.3, use_image=False, style=style) grid = image_grid(images, 1, 1) grid.save(f'./story/results_xl/img_{i}.png') @@ -573,7 +575,7 @@ def image_grid(imgs, rows, cols): print(f'epoch:{i+1}') for y, text in enumerate(prompts): image = storyadapter.generate(pil_image=images, num_samples=1, num_inference_steps=50, seed=seed, - prompt=text, scale=scale, use_image=True) + prompt=text, scale=scale, use_image=True, style=style) new_images.append(image[0].resize((256, 256))) grid = image_grid(image, 1, 1) grid.save(f'./story/results_xl{i+1}/img_{y}.png')