Skip to content

Commit

Permalink
Update run.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jwmao1 authored Jan 18, 2025
1 parent 5158727 commit a7f6a2d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@
parser.add_argument('--base_model_path', default=r"./RealVisXL_V4.0", type=str)
parser.add_argument('--image_encoder_path', type=str, default=r"./IP-Adapter/sdxl_models/image_encoder")
parser.add_argument('--ip_ckpt', default=r"./IP-Adapter/sdxl_models/ip-adapter_sdxl.bin", type=str)
parser.add_argument('--style', type=str, default='comic', choices=["comic","film","realistic"])
parser.add_argument('--device', default="cuda", type=str)
parser.add_argument('--story', default=story1, nargs='+', type=str)

Expand All @@ -497,6 +498,7 @@
image_encoder_path = args.image_encoder_path
ip_ckpt = args.ip_ckpt
device = args.device
style = args.style

def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
Expand Down Expand Up @@ -553,7 +555,7 @@ def image_grid(imgs, rows, cols):

for i, text in enumerate(prompts):
images = storyadapter.generate(pil_image=None, num_samples=1, num_inference_steps=50, seed=seed,
prompt=text, scale=0.3, use_image=False)
prompt=text, scale=0.3, use_image=False, style=style)
grid = image_grid(images, 1, 1)
grid.save(f'./story/results_xl/img_{i}.png')

Expand All @@ -573,7 +575,7 @@ def image_grid(imgs, rows, cols):
print(f'epoch:{i+1}')
for y, text in enumerate(prompts):
image = storyadapter.generate(pil_image=images, num_samples=1, num_inference_steps=50, seed=seed,
prompt=text, scale=scale, use_image=True)
prompt=text, scale=scale, use_image=True, style=style)
new_images.append(image[0].resize((256, 256)))
grid = image_grid(image, 1, 1)
grid.save(f'./story/results_xl{i+1}/img_{y}.png')
Expand Down

0 comments on commit a7f6a2d

Please sign in to comment.