Skip to content

Commit

Permalink
add callback for creating a tab in train UI
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Nov 8, 2022
1 parent 8011be3 commit 1610b32
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
27 changes: 25 additions & 2 deletions modules/script_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi import FastAPI
from gradio import Blocks


def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
Expand Down Expand Up @@ -45,22 +46,29 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
"""Total number of sampling steps planned"""


class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
self.txt2img_preview_params = txt2img_preview_params


ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
callbacks_ui_tabs=[],
callbacks_ui_train_tabs=[],
callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
callbacks_cfg_denoiser=[]
callbacks_cfg_denoiser=[],
)


def clear_callbacks():
for callback_list in callback_map.values():
callback_list.clear()


def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']:
try:
Expand All @@ -79,7 +87,7 @@ def model_loaded_callback(sd_model):

def ui_tabs_callback():
res = []

for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
Expand All @@ -89,6 +97,14 @@ def ui_tabs_callback():
return res


def ui_train_tabs_callback(params: UiTrainTabParams):
for c in callback_map['callbacks_ui_train_tabs']:
try:
c.callback(params)
except Exception:
report_exception(c, 'callbacks_ui_train_tabs')


def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']:
try:
Expand Down Expand Up @@ -169,6 +185,13 @@ def on_ui_tabs(callback):
add_callback(callback_map['callbacks_ui_tabs'], callback)


def on_ui_train_tabs(callback):
"""register a function to be called when the UI is creating new tabs for the train tab.
Create your new tabs with gr.Tab.
"""
add_callback(callback_map['callbacks_ui_train_tabs'], callback)


def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
Expand Down
4 changes: 4 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
train_embedding = gr.Button(value="Train Embedding", variant='primary')

params = script_callbacks.UiTrainTabParams(txt2img_preview_params)

script_callbacks.ui_train_tabs_callback(params)

with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
Expand Down

0 comments on commit 1610b32

Please sign in to comment.