Skip to content

Commit

Permalink
feat(runtime): add node preview display and callbacks (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaoses-Ib committed Apr 21, 2024
1 parent c555758 commit 6e7cd33
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 8 deletions.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "comfy-script"
version = "0.4.6"
version = "0.5.0a1"
description = "A Python front end and library for ComfyUI"
readme = "README.md"
# ComfyUI: >=3.8
Expand Down Expand Up @@ -28,6 +28,11 @@ client = [

# 1.5.9: https://github.com/erdewit/nest_asyncio/issues/87
"nest_asyncio ~= 1.0, >= 1.5.9",

# Already required by ComfyUI
"Pillow",

"aenum ~= 3.1"
]

# Transpiler
Expand Down
58 changes: 58 additions & 0 deletions src/comfy_script/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from io import BytesIO
import json
import os
from pathlib import PurePath
import struct
import sys
import traceback
from typing import Callable

import asyncio
from warnings import warn
from PIL import Image
import nest_asyncio
import aiohttp
from yarl import URL
Expand Down Expand Up @@ -136,6 +142,58 @@ def default(self, o):
return str(o)
return super().default(o)

class BinaryEventTypes(IntEnum):
# See ComfyUI::server.BinaryEventTypes
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
'''Only used internally in ComfyUI.'''

@dataclass
class BinaryEvent:
type: BinaryEventTypes | int
data: bytes

@staticmethod
def from_bytes(data: bytes) -> BinaryEvent:
# See ComfyUI::server.encode_bytes()
type_int = struct.unpack('>I', data[:4])[0]
try:
type = BinaryEventTypes(type_int)
except ValueError:
warn(f'Unknown binary event type: {data[:4]}')
type = type_int
data = data[4:]
return BinaryEvent(type, data)

def to_object(self) -> Image.Image | bytes:
if self.type == BinaryEventTypes.PREVIEW_IMAGE:
return _PreviewImage.from_bytes(self.data).image
return self

class _PreviewImageFormat(IntEnum):
'''`format.name` is compatible with PIL.'''
JPEG = 1
PNG = 2

@dataclass
class _PreviewImage:
format: _PreviewImageFormat
image: Image.Image

@staticmethod
def from_bytes(data: bytes) -> _PreviewImage:
# See ComfyUI::LatentPreviewer
format_int = struct.unpack('>I', data[:4])[0]
format = None
try:
format = _PreviewImageFormat(format_int).name
except ValueError:
warn(f'Unknown image format: {data[:4]}')

image = Image.open(BytesIO(data[4:]), formats=(format,) if format is not None else None)

return _PreviewImage(format, image)

__all__ = [
'client',
'Client',
Expand Down
54 changes: 47 additions & 7 deletions src/comfy_script/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from pathlib import Path
import sys
import threading
import traceback
from typing import Callable, Iterable
import uuid
from warnings import warn

import asyncio
import nest_asyncio
import aiohttp
from PIL import Image

nest_asyncio.apply()

Expand Down Expand Up @@ -443,6 +446,7 @@ def __init__(self):
self._queue_empty_callback = None
self._queue_remaining_callbacks = [self._when_empty_callback]
self._watch_display_node = None
self._watch_display_node_preview = None
self._watch_display_task = None
self.queue_remaining = 0

Expand All @@ -463,6 +467,7 @@ async def _watch(self):
async with session.ws_connect(f'{client.client.base_url}ws', params={'clientId': _client_id}) as ws:
self.queue_remaining = 0
executing = False
progress_data = None
async for msg in ws:
# print(msg.type)
if msg.type == aiohttp.WSMsgType.TEXT:
Expand Down Expand Up @@ -513,15 +518,25 @@ async def _watch(self):
if self._watch_display_node:
print(f'Queue remaining: {self.queue_remaining}')
elif msg['type'] == 'progress':
# TODO: https://github.com/comfyanonymous/ComfyUI/issues/2425
data = msg['data']
_print_progress(data['value'], data['max'])
# See ComfyUI::main.hijack_progress
# 'prompt_id', 'node': https://github.com/comfyanonymous/ComfyUI/issues/2425
progress_data = msg['data']
# TODO: Node
_print_progress(progress_data['value'], progress_data['max'])
elif msg.type == aiohttp.WSMsgType.BINARY:
pass
event = client.BinaryEvent.from_bytes(msg.data)
if event.type == client.BinaryEventTypes.PREVIEW_IMAGE:
prompt_id = progress_data.get('prompt_id')
if prompt_id is not None:
task: Task = self._tasks.get(prompt_id)
task._set_node_preview(progress_data['node'], event.to_object(), self._watch_display_node_preview)
else:
warn(f'Cannot get preview node, please update the ComfyUI server to at least 66831eb6e96cd974fb2d0fc4f299b23c6af16685 (2024-01-02)')
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
break
except Exception as e:
print(f'ComfyScript: Failed to watch, will retry in 5 seconds: {e}')
traceback.print_exc()
await asyncio.sleep(5)
'''
{'type': 'status', 'data': {'status': {'exec_info': {'queue_remaining': 0}}, 'sid': 'adc24049-b013-4a58-956b-edbc591dc6e2'}}
Expand All @@ -539,19 +554,27 @@ async def _watch(self):
{'type': 'executing', 'data': {'node': None, 'prompt_id': '3328f0c8-9368-4070-90e7-087e854fe315'}}
'''

def start_watch(self, display_node: bool = True, display_task: bool = True):
def start_watch(self, display_node: bool = True, display_task: bool = True, display_node_preview: bool = True):
'''
- `display_node`: When an output node is finished, display its result.
- `display_task`: When a task is finished (all output nodes are finished), display all the results.
`load()` will `start_watch()` by default.
## Previewing
Previewing is disabled by default. Pass `--preview-method auto` to ComfyUI to enable previewing.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
The default maximum preview resolution is 512x512. The only way to change it is to modify ComfyUI::MAX_PREVIEW_RESOLUTION.
'''

if display_node or display_task:
if display_node or display_task or display_node_preview:
try:
import IPython
self._watch_display_node = display_node
self._watch_display_task = display_task
self._watch_display_node_preview = display_node_preview
except ImportError:
print('ComfyScript: IPython is not available, cannot display task results')

Expand All @@ -567,13 +590,14 @@ def remove_queue_remaining_callback(self, callback: Callable[[int], None]):
if callback in self._queue_remaining_callbacks:
self._queue_remaining_callbacks.remove(callback)

def watch_display(self, display_node: bool = True, display_task: bool = True):
def watch_display(self, display_node: bool = True, display_task: bool = True, display_node_preview: bool = True):
'''
- `display_node`: When an output node is finished, display its result.
- `display_task`: When a task is finished (all output nodes are finished), display all the results.
'''
self._watch_display_node = display_node
self._watch_display_task = display_task
self._watch_display_node_preview = display_node_preview

async def _put(self, workflow: data.NodeOutput | Iterable[data.NodeOutput] | Workflow, source = None) -> Task | None:
global _client_id
Expand Down Expand Up @@ -685,13 +709,23 @@ def __init__(self, prompt_id: str, number: int, id: data.IdManager):
self._id = id
self._new_outputs = {}
self._fut = asyncio.Future()
self._node_preview_callbacks: list[Callable[[Task, str, Image.Image]]] = []

def __str__(self):
return f'Task {self.number} ({self.prompt_id})'

def __repr__(self):
return f'Task(n={self.number}, id={self.prompt_id})'

def _set_node_preview(self, node_id: str, preview: Image.Image, display: bool):
for callback in self._node_preview_callbacks:
callback(self, node_id, preview)

if display:
from IPython.display import display

display(preview, clear=True)

async def _set_result_threadsafe(self, node_id: str | None, output: dict, display_result: bool = False) -> None:
if node_id is not None:
self._new_outputs[node_id] = output
Expand Down Expand Up @@ -781,6 +815,12 @@ def wait_result(self, output: data.NodeOutput) -> data.Result | None:
# def __await__(self):
# return self._wait().__await__()

def add_preview_callback(self, callback: Callable[[Task, str, Image.Image], None]):
self._node_preview_callbacks.append(callback)

def remove_preview_callback(self, callback: Callable[[Task, str, Image.Image], None]):
self._node_preview_callbacks.remove(callback)

def done(self) -> bool:
"""Return True if the task is done.
Expand Down

0 comments on commit 6e7cd33

Please sign in to comment.