diff --git a/src/airunner/aihandler/runner.py b/src/airunner/aihandler/runner.py index b5f16601c..4ef0107d7 100644 --- a/src/airunner/aihandler/runner.py +++ b/src/airunner/aihandler/runner.py @@ -970,17 +970,28 @@ def call_pipe(self, **kwargs): try: args.update({ "prompt_embeds": self.prompt_embeds, - "negative_prompt_embeds": self.negative_prompt_embeds, }) except Exception as _e: Logger.warning("Compel failed: " + str(_e)) args.update({ "prompt": self.prompt, - "negative_prompt": self.negative_prompt, }) else: args.update({ "prompt": self.prompt, + }) + if self.use_compel: + try: + args.update({ + "negative_prompt_embeds": self.negative_prompt_embeds, + }) + except Exception as _e: + Logger.warning("Compel failed: " + str(_e)) + args.update({ + "negative_prompt": self.negative_prompt, + }) + else: + args.update({ "negative_prompt": self.negative_prompt, }) args["callback_steps"] = 1 diff --git a/src/airunner/aihandler/settings_manager.py b/src/airunner/aihandler/settings_manager.py index 6add3bd4d..5a5a311ea 100644 --- a/src/airunner/aihandler/settings_manager.py +++ b/src/airunner/aihandler/settings_manager.py @@ -1,11 +1,11 @@ from PyQt6.QtCore import QObject, pyqtSignal from airunner.aihandler.qtvar import StringVar, IntVar, BooleanVar, FloatVar, DictVar -from airunner.data.db import session -from airunner.data.models import LLMGenerator, Settings, GeneratorSetting, AIModel, Pipeline, ControlnetModel, ImageFilter, Prompt, \ - SavedPrompt, PromptCategory, PromptVariable, PromptVariableCategory, PromptVariableCategoryWeight, StandardImageWidgetSettings -from airunner.utils import save_session +from airunner.data.models import LLMGenerator, Settings, GeneratorSetting, AIModel, Pipeline, ControlnetModel, ImageFilter, \ + SavedPrompt, StandardImageWidgetSettings +from airunner.utils import save_session, get_session from airunner.aihandler.logger import Logger as logger +from airunner.data.models import Document document = None _app = None @@ -17,6 +17,7 @@ "FLOAT": FloatVar, "JSON": DictVar, } +session = get_session() class SettingsSignal(QObject): @@ -236,8 +237,7 @@ def __init__(self, app=None, *args, **kwargs): _app = app document = _app.document else: - from airunner.data.db import session - from airunner.data.models import Document + session = get_session() document = session.query(Document).first() super().__init__(*args, **kwargs) diff --git a/src/airunner/aihandler/transformer_runner.py b/src/airunner/aihandler/transformer_runner.py index 4c668c4ce..7cc21e808 100644 --- a/src/airunner/aihandler/transformer_runner.py +++ b/src/airunner/aihandler/transformer_runner.py @@ -12,7 +12,7 @@ from airunner.aihandler.settings_manager import SettingsManager from airunner.data.models import LLMGenerator -from airunner.data.db import session +from airunner.utils import get_session from airunner.aihandler.logger import Logger @@ -66,6 +66,7 @@ class TransformerRunner(QObject): @property def generator(self): try: + session = get_session() if not self._generator or self.current_generator_name != self.requested_generator_name: self.current_generator_name = self.requested_generator_name self._generator = session.query(LLMGenerator).filter_by(name=self.current_generator_name).first() diff --git a/src/airunner/data/db.py b/src/airunner/data/db.py index a790d344f..5b07d1609 100644 --- a/src/airunner/data/db.py +++ b/src/airunner/data/db.py @@ -1,5 +1,4 @@ from airunner.data.bootstrap.controlnet_bootstrap_data import controlnet_bootstrap_data -from airunner.data.bootstrap.generator_bootstrap_data import sections_bootstrap_data from airunner.data.bootstrap.imagefilter_bootstrap_data import imagefilter_bootstrap_data from airunner.data.bootstrap.llm import seed_data from airunner.data.bootstrap.model_bootstrap_data import model_bootstrap_data @@ -14,502 +13,440 @@ LLMGeneratorSetting, LLMGenerator, LLMModelVersion from airunner.utils import get_session from alembic.config import Config -from airunner.aihandler.logger import Logger from alembic import command import os import configparser -session = get_session() -do_stamp_alembic = False -# check if database is blank: -if not session.query(Prompt).first(): - do_stamp_alembic = True - # Add Prompt objects - for prompt_option, data in prompt_bootstrap_data.items(): - category = PromptCategory(name=prompt_option, negative_prompt=data["negative_prompt"]) - prompt = Prompt( - name=f"Standard {prompt_option} prompt", - category=category - ) - session.add(prompt) - session.commit() - prompt_id = prompt.id - - prompt_variables = [] - for category_name, variable_values in data["variables"].items(): - # add prompt category - cat = session.query(PromptVariableCategory).filter_by(name=category_name).first() - if not cat: - cat = PromptVariableCategory(name=category_name) - session.add(cat) - session.commit() - - # add prompt variable category weight - weight = session.query(PromptVariableCategoryWeight).filter_by( - prompt_category=category, - variable_category=cat - ).first() - if not weight: - try: - weight_value = data["weights"][category_name] - except KeyError: - weight_value = 1.0 - weight = PromptVariableCategoryWeight( - prompt_category=category, - variable_category=cat, - weight=weight_value - ) - session.add(weight) - session.commit() - # add prompt variables - for var in variable_values: - session.add(PromptVariable( - value=var, +def prepare_database(): + my_session = get_session() + do_stamp_alembic = False + + # check if database is blank: + if not my_session.query(Prompt).first(): + do_stamp_alembic = True + + # Add Prompt objects + for prompt_option, data in prompt_bootstrap_data.items(): + category = PromptCategory(name=prompt_option, negative_prompt=data["negative_prompt"]) + prompt = Prompt( + name=f"Standard {prompt_option} prompt", + category=category + ) + my_session.add(prompt) + my_session.commit() + prompt_id = prompt.id + + prompt_variables = [] + for category_name, variable_values in data["variables"].items(): + # add prompt category + cat = my_session.query(PromptVariableCategory).filter_by(name=category_name).first() + if not cat: + cat = PromptVariableCategory(name=category_name) + my_session.add(cat) + my_session.commit() + + # add prompt variable category weight + weight = my_session.query(PromptVariableCategoryWeight).filter_by( prompt_category=category, variable_category=cat - )) - session.commit() - - def insert_variables(variables, prev_object=None): - for option in variables: - text = option.get("text", None) - cond = option.get("cond", "") - else_cond = option.get("else", "") - next_cond = option.get("next", None) - or_cond = option.get("or_cond", None) - prompt_option = PromptOption( - text=text, - cond=cond, - else_cond=else_cond, - or_cond=or_cond, - prompt_id=prompt_id - ) - if prev_object: - session.add(prompt_option) - session.commit() - prev_object.next_cond_id = prompt_option.id - session.add(prev_object) - session.commit() - prev_object = prompt_option - else: - session.add(prompt_option) - session.commit() - prev_object = prompt_option - if next_cond: - prev_object = insert_variables( - variables=next_cond, - prev_object=prev_object, + ).first() + if not weight: + try: + weight_value = data["weights"][category_name] + except KeyError: + weight_value = 1.0 + weight = PromptVariableCategoryWeight( + prompt_category=category, + variable_category=cat, + weight=weight_value ) - return prev_object - - insert_variables(data["builder"]) - - session.commit() + my_session.add(weight) + my_session.commit() + + # add prompt variables + for var in variable_values: + my_session.add(PromptVariable( + value=var, + prompt_category=category, + variable_category=cat + )) + my_session.commit() + + def insert_variables(variables, prev_object=None): + for option in variables: + text = option.get("text", None) + cond = option.get("cond", "") + else_cond = option.get("else", "") + next_cond = option.get("next", None) + or_cond = option.get("or_cond", None) + prompt_option = PromptOption( + text=text, + cond=cond, + else_cond=else_cond, + or_cond=or_cond, + prompt_id=prompt_id + ) + if prev_object: + my_session.add(prompt_option) + my_session.commit() + prev_object.next_cond_id = prompt_option.id + my_session.add(prev_object) + my_session.commit() + prev_object = prompt_option + else: + my_session.add(prompt_option) + my_session.commit() + prev_object = prompt_option + if next_cond: + prev_object = insert_variables( + variables=next_cond, + prev_object=prev_object, + ) + return prev_object + + insert_variables(data["builder"]) + + my_session.commit() + + for variable_category, data in variable_bootstrap_data.items(): + category = my_session.query(PromptVariableCategory).filter_by(name=variable_category).first() + if not category: + category = PromptVariableCategory(name=variable_category) + my_session.add(category) + my_session.commit() + for variable in data: + my_session.add(PromptVariable( + value=variable, + variable_category=category + )) + my_session.commit() + + # Add PromptStyle objects + for style_category, data in style_bootstrap_data.items(): + category = PromptStyleCategory(name=style_category, negative_prompt=data["negative_prompt"]) + my_session.add(category) + my_session.commit() + for style in data["styles"]: + my_session.add(PromptStyle( + name=style, + style_category=category + )) + my_session.commit() - for variable_category, data in variable_bootstrap_data.items(): - category = session.query(PromptVariableCategory).filter_by(name=variable_category).first() - if not category: - category = PromptVariableCategory(name=variable_category) - session.add(category) - session.commit() - for variable in data: - session.add(PromptVariable( - value=variable, - variable_category=category - )) - session.commit() + # Add ControlnetModel objects + for name, path in controlnet_bootstrap_data.items(): + my_session.add(ControlnetModel(name=name, path=path)) + my_session.commit() - # Add PromptStyle objects - for style_category, data in style_bootstrap_data.items(): - category = PromptStyleCategory(name=style_category, negative_prompt=data["negative_prompt"]) - session.add(category) - session.commit() - for style in data["styles"]: - session.add(PromptStyle( - name=style, - style_category=category - )) - session.commit() - # Add ControlnetModel objects - for name, path in controlnet_bootstrap_data.items(): - session.add(ControlnetModel(name=name, path=path)) - session.commit() + # Add AIModel objects + for model_data in model_bootstrap_data: + my_session.add(AIModel(**model_data)) + my_session.commit() - # Add AIModel objects - for model_data in model_bootstrap_data: - session.add(AIModel(**model_data)) - session.commit() + # Add Pipeline objects + for pipeline_data in pipeline_bootstrap_data: + my_session.add(Pipeline(**pipeline_data)) + my_session.commit() - # Add Pipeline objects - for pipeline_data in pipeline_bootstrap_data: - session.add(Pipeline(**pipeline_data)) - session.commit() + # Add PathSettings objects + my_session.add(PathSettings()) + my_session.commit() - # Add PathSettings objects - session.add(PathSettings()) - session.commit() + # Add BrushSettings objects + my_session.add(BrushSettings()) + my_session.commit() - # Add BrushSettings objects - session.add(BrushSettings()) - session.commit() + # Add GridSettings objects + my_session.add(GridSettings()) + my_session.commit() + my_session.add(DeterministicSettings()) + my_session.commit() - # Add GridSettings objects - session.add(GridSettings()) - session.commit() - session.add(DeterministicSettings()) - session.commit() + # Add MetadataSettings objects + my_session.add(MetadataSettings()) + my_session.commit() - # Add MetadataSettings objects - session.add(MetadataSettings()) - session.commit() + # Add MemorySettings objects + my_session.add(MemorySettings()) + my_session.commit() + # Add ActiveGridSettings object + my_session.add(ActiveGridSettings()) + my_session.commit() - # Add MemorySettings objects - session.add(MemorySettings()) - session.commit() + # Add ImageFilter objects + for filter in imagefilter_bootstrap_data: + image_filter = ImageFilter( + display_name=filter[0], + name=filter[1], + filter_class=filter[2] + ) + for filter_value in filter[3]: + image_filter.image_filter_values.append(ImageFilterValue( + name=filter_value[0], + value=filter_value[1], + value_type=filter_value[2], + min_value=filter_value[3] if len(filter_value) > 3 else None, + max_value=filter_value[4] if len(filter_value) > 4 else None + )) + my_session.add(image_filter) + my_session.commit() - # Add ActiveGridSettings object - session.add(ActiveGridSettings()) - session.commit() + image_filter = my_session.query(ImageFilter).filter_by(name='color_balance').first() + # Access its image_filter_values + filter_values = image_filter.image_filter_values - # Add ImageFilter objects - for filter in imagefilter_bootstrap_data: - image_filter = ImageFilter( - display_name=filter[0], - name=filter[1], - filter_class=filter[2] - ) - for filter_value in filter[3]: - image_filter.image_filter_values.append(ImageFilterValue( - name=filter_value[0], - value=filter_value[1], - value_type=filter_value[2], - min_value=filter_value[3] if len(filter_value) > 3 else None, - max_value=filter_value[4] if len(filter_value) > 4 else None - )) - session.add(image_filter) - session.commit() - - image_filter = session.query(ImageFilter).filter_by(name='color_balance').first() - - # Access its image_filter_values - filter_values = image_filter.image_filter_values - - # Add Document object - settings = Settings(nsfw_filter=True) - settings.prompt_generator_settings.append( - PromptGeneratorSetting( - name="Prompt A", - active=True, - settings_id=settings.id + # Add Document object + settings = Settings(nsfw_filter=True) + settings.prompt_generator_settings.append( + PromptGeneratorSetting( + name="Prompt A", + active=True, + settings_id=settings.id + ) ) - ) - settings.prompt_generator_settings.append( - PromptGeneratorSetting( - name="Prompt B", - settings_id=settings.id + settings.prompt_generator_settings.append( + PromptGeneratorSetting( + name="Prompt B", + settings_id=settings.id + ) ) - ) - settings.splitter_sizes.append(SplitterSection( - name="content_splitter", - order=0, - size=390 - )) - settings.splitter_sizes.append(SplitterSection( - name="content_splitter", - order=1, - size=512 - )) - settings.splitter_sizes.append(SplitterSection( - name="content_splitter", - order=2, - size=200 - )) - settings.splitter_sizes.append(SplitterSection( - name="content_splitter", - order=3, - size=64 - )) - settings.splitter_sizes.append(SplitterSection( - name="main_splitter", - order=0, - size=520 - )) - settings.splitter_sizes.append(SplitterSection( - name="main_splitter", - order=1, - size=-1 - )) - settings.splitter_sizes.append(SplitterSection( - name="canvas_splitter", - order=0, - size=520 - )) - settings.splitter_sizes.append(SplitterSection( - name="canvas_splitter", - order=1, - size=-1 - )) - session.add(settings) - - settings.brush_settings = session.query(BrushSettings).first() - settings.path_settings = session.query(PathSettings).first() - settings.grid_settings = session.query(GridSettings).first() - settings.deterministic_settings = session.query(DeterministicSettings).first() - settings.metadata_settings = session.query(MetadataSettings).first() - settings.memory_settings = session.query(MemorySettings).first() - settings.active_grid_settings = session.query(ActiveGridSettings).first() - - active_grid_colors = { - "stablediffusion": { - "border": { - "txt2img": "#00FF00", - "outpaint": "#00FFFF", - "depth2img": "#0000FF", - "pix2pix": "#FFFF00", - "upscale": "#00FFFF", - "superresolution": "#FF00FF", - "txt2vid": "#999999", + settings.splitter_sizes.append(SplitterSection( + name="content_splitter", + order=0, + size=390 + )) + settings.splitter_sizes.append(SplitterSection( + name="content_splitter", + order=1, + size=512 + )) + settings.splitter_sizes.append(SplitterSection( + name="content_splitter", + order=2, + size=200 + )) + settings.splitter_sizes.append(SplitterSection( + name="content_splitter", + order=3, + size=64 + )) + settings.splitter_sizes.append(SplitterSection( + name="main_splitter", + order=0, + size=520 + )) + settings.splitter_sizes.append(SplitterSection( + name="main_splitter", + order=1, + size=-1 + )) + settings.splitter_sizes.append(SplitterSection( + name="canvas_splitter", + order=0, + size=520 + )) + settings.splitter_sizes.append(SplitterSection( + name="canvas_splitter", + order=1, + size=-1 + )) + my_session.add(settings) + + settings.brush_settings = my_session.query(BrushSettings).first() + settings.path_settings = my_session.query(PathSettings).first() + settings.grid_settings = my_session.query(GridSettings).first() + settings.deterministic_settings = my_session.query(DeterministicSettings).first() + settings.metadata_settings = my_session.query(MetadataSettings).first() + settings.memory_settings = my_session.query(MemorySettings).first() + settings.active_grid_settings = my_session.query(ActiveGridSettings).first() + + active_grid_colors = { + "stablediffusion": { + "border": { + "txt2img": "#00FF00", + "outpaint": "#00FFFF", + "depth2img": "#0000FF", + "pix2pix": "#FFFF00", + "upscale": "#00FFFF", + "superresolution": "#FF00FF", + "txt2vid": "#999999", + }, + # choose complimentary colors for the fill + "fill": { + "txt2img": "#FF0000", + "outpaint": "#FF00FF", + "depth2img": "#FF8000", + "pix2pix": "#8000FF", + "upscale": "#00FF80", + "superresolution": "#00FF00", + "txt2vid": "#000000", + + } }, - # choose complimentary colors for the fill - "fill": { - "txt2img": "#FF0000", - "outpaint": "#FF00FF", - "depth2img": "#FF8000", - "pix2pix": "#8000FF", - "upscale": "#00FF80", - "superresolution": "#00FF00", - "txt2vid": "#000000", - - } - }, - } - - generator_section = "txt2img" - generator_name = "stablediffusion" - session.add(GeneratorSetting( - section=generator_section, - generator_name=generator_name, - active_grid_border_color=active_grid_colors[generator_name]["border"][generator_section], - active_grid_fill_color=active_grid_colors[generator_name]["fill"][generator_section] - )) - - session.add(Document( - name="Untitled", - settings=settings, - active=True - )) - session.commit() - - - available_schedulers = {} - for scheduler_data in [ - ("Euler a", "EULER_ANCESTRAL"), - ("Euler", "EULER"), - ("LMS", "LMS"), - ("Heun", "HEUN"), - ("DPM2", "DPM2"), - ("DPM++ 2M", "DPM_PP_2M"), - ("DPM2 Karras", "DPM2_K"), - ("DPM2 a Karras", "DPM2_A_K"), - ("DPM++ 2M Karras", "DPM_PP_2M_K"), - ("DPM++ 2M SDE Karras", "DPM_PP_2M_SDE_K"), - ("DDIM", "DDIM"), - ("UniPC", "UNIPC"), - ("DDPM", "DDPM"), - ("DEIS", "DEIS"), - ("DPM 2M SDE Karras", "DPM_2M_SDE_K"), - ("PLMS", "PLMS"), - ]: - obj = Scheduler( - name=scheduler_data[1], - display_name=scheduler_data[0] - ) - session.add(obj) - available_schedulers[scheduler_data[1]] = obj - session.commit() - - generator_sections = { - "stablediffusion": { - "upscale": ["EULER"], - "superresolution": ["DDIM", "LMS", "PLMS"], - }, - } - - # add all of the schedulers for the defined generator sections - for generator, sections in generator_sections.items(): - for section, schedulers in sections.items(): - for scheduler in schedulers: - session.add(ActionScheduler( - section=section, - generator_name=generator, - scheduler_id=session.query(Scheduler).filter_by(name=scheduler).first().id - )) - session.commit() - - # add the rest of the stable diffusion schedulers - for k, v in available_schedulers.items(): - for section in [ - "txt2img", "depth2img", "pix2pix", "vid2vid", - "outpaint", "controlnet", "txt2vid" + } + + generator_section = "txt2img" + generator_name = "stablediffusion" + my_session.add(GeneratorSetting( + section=generator_section, + generator_name=generator_name, + active_grid_border_color=active_grid_colors[generator_name]["border"][generator_section], + active_grid_fill_color=active_grid_colors[generator_name]["fill"][generator_section] + )) + + my_session.add(Document( + name="Untitled", + settings=settings, + active=True + )) + my_session.commit() + + + available_schedulers = {} + for scheduler_data in [ + ("Euler a", "EULER_ANCESTRAL"), + ("Euler", "EULER"), + ("LMS", "LMS"), + ("Heun", "HEUN"), + ("DPM2", "DPM2"), + ("DPM++ 2M", "DPM_PP_2M"), + ("DPM2 Karras", "DPM2_K"), + ("DPM2 a Karras", "DPM2_A_K"), + ("DPM++ 2M Karras", "DPM_PP_2M_K"), + ("DPM++ 2M SDE Karras", "DPM_PP_2M_SDE_K"), + ("DDIM", "DDIM"), + ("UniPC", "UNIPC"), + ("DDPM", "DDPM"), + ("DEIS", "DEIS"), + ("DPM 2M SDE Karras", "DPM_2M_SDE_K"), + ("PLMS", "PLMS"), ]: - obj = ActionScheduler( - section=section, - generator_name="stablediffusion", - scheduler_id=v.id + obj = Scheduler( + name=scheduler_data[1], + display_name=scheduler_data[0] ) - session.add(obj) - session.commit() - - # create tab sections - session.add(TabSection( - panel="center_tab", - active_tab="Canvas" - )) - session.add(TabSection( - panel="tool_tab_widget", - active_tab="Embeddings" - )) - session.add(TabSection( - panel="prompt_builder.ui.tabs", - active_tab="0" - )) - session.commit() - - session.add(PromptBuilder( - name="Prompt A", - active=True - )) - session.add(PromptBuilder( - name="Prompt B", - active=True - )) - session.commit() - - session.add(CanvasSettings()) - session.commit() - - - for generator_name, generator_data in seed_data.items(): - generator = LLMGenerator(name=generator_name) - session.add(generator) - - # create GeneratorSetting with property, value and property_type based on value type - setting = LLMGeneratorSetting() - setting.generator = generator - for k, v in generator_data["generator_settings"].items(): - setting.__setattr__(k, v) - session.add(setting) - - if "model_versions" in generator_data: - model_versions = [] - for name in generator_data["model_versions"]: - print("Name", name) - model_versions.append(LLMModelVersion(name=name)) - - for version in model_versions: - generator.model_versions.append(version) - - session.add(generator) - session.commit() - - from airunner.data.bootstrap.prompt_templates import prompt_template_seed_data - for data in prompt_template_seed_data: - prompt_template = LLMPromptTemplate( - name=data["name"], - template=data["template"] - ) - session.add(prompt_template) - session.commit() - - - default_models = [ - { - "name": "Stable Diffusion 2.1 512", - "pipeline": "txt2img", - "toolname": "txt2img" - }, - { - "name": "Stable Diffusion Inpaint 2", - "pipeline": "outpaint", - "toolname": "outpaint" - }, - { - "name": "Stable Diffusion Depth2Img", - "pipeline": "depth2img", - "toolname": "depth2img" - }, - { - "name": "Stable Diffusion 1.5", - "pipeline": "controlnet", - "toolname": "controlnet" - }, - { - "name": "Stability AI 4x resolution", - "pipeline": "superresolution", - "toolname": "superresolution" - }, - { - "name": "Instruct pix2pix", - "pipeline": "pix2pix", - "toolname": "pix2pix" - }, - { - "name": "SD Image Variations", - "pipeline": "vid2vid", - "toolname": "vid2vid" - }, - { - "name": "sd-x2-latent-upscaler", - "pipeline": "upscale", - "toolname": "upscale" - }, - { - "name": "Inpaint vae", - "pipeline": "inpaint_vae", - "toolname": "inpaint_vae" - }, - { - "name": "Salesforce InstructBlip Flan T5 XL", - "pipeline": "visualqa", - "toolname": "visualqa" - }, - { - "name": "Llama 2 7b Chat", - "pipeline": "casuallm", - "toolname": "casuallm" - }, - { - "name": "Flan T5 XL", - "pipeline": "seq2seq", - "toolname": "prompt_generation" - }, - ] - -HERE = os.path.abspath(os.path.dirname(__file__)) -alembic_ini_path = os.path.join(HERE, "../alembic.ini") - -config = configparser.ConfigParser() -config.read(f"{alembic_ini_path}.config") - -home_dir = os.path.expanduser("~") -db_path = f'sqlite:///{home_dir}/.airunner/airunner.db' - -config.set('alembic', 'sqlalchemy.url', db_path) -with open(alembic_ini_path, 'w') as configfile: - config.write(configfile) -alembic_cfg = Config(alembic_ini_path) -if not do_stamp_alembic: - command.upgrade(alembic_cfg, "head") -else: - command.stamp(alembic_cfg, "head") + my_session.add(obj) + available_schedulers[scheduler_data[1]] = obj + my_session.commit() + + generator_sections = { + "stablediffusion": { + "upscale": ["EULER"], + "superresolution": ["DDIM", "LMS", "PLMS"], + }, + } + + # add all of the schedulers for the defined generator sections + for generator, sections in generator_sections.items(): + for section, schedulers in sections.items(): + for scheduler in schedulers: + my_session.add(ActionScheduler( + section=section, + generator_name=generator, + scheduler_id=my_session.query(Scheduler).filter_by(name=scheduler).first().id + )) + my_session.commit() + + # add the rest of the stable diffusion schedulers + for k, v in available_schedulers.items(): + for section in [ + "txt2img", "depth2img", "pix2pix", "vid2vid", + "outpaint", "controlnet", "txt2vid" + ]: + obj = ActionScheduler( + section=section, + generator_name="stablediffusion", + scheduler_id=v.id + ) + my_session.add(obj) + my_session.commit() + + # create tab sections + my_session.add(TabSection( + panel="center_tab", + active_tab="Canvas" + )) + my_session.add(TabSection( + panel="tool_tab_widget", + active_tab="Embeddings" + )) + my_session.add(TabSection( + panel="prompt_builder.ui.tabs", + active_tab="0" + )) + my_session.commit() + + my_session.add(PromptBuilder( + name="Prompt A", + active=True + )) + my_session.add(PromptBuilder( + name="Prompt B", + active=True + )) + my_session.commit() + + my_session.add(CanvasSettings()) + my_session.commit() + + + for generator_name, generator_data in seed_data.items(): + generator = LLMGenerator(name=generator_name) + my_session.add(generator) + + # create GeneratorSetting with property, value and property_type based on value type + setting = LLMGeneratorSetting() + setting.generator = generator + for k, v in generator_data["generator_settings"].items(): + setting.__setattr__(k, v) + my_session.add(setting) + + if "model_versions" in generator_data: + model_versions = [] + for name in generator_data["model_versions"]: + print("Name", name) + model_versions.append(LLMModelVersion(name=name)) + + for version in model_versions: + generator.model_versions.append(version) + + my_session.add(generator) + my_session.commit() + + from airunner.data.bootstrap.prompt_templates import prompt_template_seed_data + for data in prompt_template_seed_data: + prompt_template = LLMPromptTemplate( + name=data["name"], + template=data["template"] + ) + my_session.add(prompt_template) + my_session.commit() + + HERE = os.path.abspath(os.path.dirname(__file__)) + alembic_ini_path = os.path.join(HERE, "../alembic.ini") + + config = configparser.ConfigParser() + config.read(f"{alembic_ini_path}.config") + + home_dir = os.path.expanduser("~") + db_path = f'sqlite:///{home_dir}/.airunner/airunner.db' + + config.set('alembic', 'sqlalchemy.url', db_path) + with open(alembic_ini_path, 'w') as configfile: + config.write(configfile) + alembic_cfg = Config(alembic_ini_path) + if not do_stamp_alembic: + command.upgrade(alembic_cfg, "head") + else: + command.stamp(alembic_cfg, "head") diff --git a/src/airunner/main.py b/src/airunner/main.py index 0b3b73b55..e3bf7f1bc 100644 --- a/src/airunner/main.py +++ b/src/airunner/main.py @@ -6,7 +6,8 @@ ******************************************************************************* """ import os - +from airunner.data.db import prepare_database +prepare_database() from airunner.aihandler.settings_manager import SettingsManager settings_manager = SettingsManager() hf_cache_path = settings_manager.path_settings.hf_cache_path @@ -80,6 +81,7 @@ def show_main_application(app, splash, watch_files=False): try: window = MainWindow() if watch_files: + print("Watching style files for changes...") # get existing app watcher = watch_frontend_files() watcher.emitter.file_changed.connect(window.redraw) diff --git a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py index 521d9e9f6..128dbd6a1 100644 --- a/src/airunner/widgets/canvas_plus/canvas_plus_widget.py +++ b/src/airunner/widgets/canvas_plus/canvas_plus_widget.py @@ -9,18 +9,43 @@ from PyQt6.QtGui import QBrush, QColor, QPen, QPixmap, QPainter, QCursor from PyQt6.QtWidgets import QGraphicsScene, QGraphicsItem, QGraphicsPixmapItem, QGraphicsLineItem from PyQt6 import QtWidgets, QtCore +from PyQt6.QtCore import QThread, pyqtSignal from airunner.aihandler.logger import Logger from airunner.aihandler.settings_manager import SettingsManager from airunner.cursors.circle_brush import CircleCursor -from airunner.data.db import session from airunner.data.models import Layer, CanvasSettings, ActiveGridSettings -from airunner.utils import save_session +from airunner.utils import save_session, get_session from airunner.widgets.canvas_plus.canvas_base_widget import CanvasBaseWidget from airunner.widgets.canvas_plus.templates.canvas_plus_ui import Ui_canvas from airunner.utils import apply_opacity_to_image +class ImageAdder(QThread): + finished = pyqtSignal() + + def __init__(self, widget, image, is_outpaint, image_root_point): + super().__init__() + self.widget = widget + self.image = image + self.is_outpaint = is_outpaint + self.image_root_point = image_root_point + + def run(self): + session = get_session() + self.widget.current_active_image = self.image + if self.image_root_point is not None: + self.widget.current_layer.pos_x = self.image_root_point.x() + self.widget.current_layer.pos_y = self.image_root_point.y() + elif not self.is_outpaint: + self.widget.current_layer.pos_x = self.widget.active_grid_area_rect.x() + self.widget.current_layer.pos_y = self.widget.active_grid_area_rect.y() + session.add(self.widget.current_layer) + save_session() + self.widget.do_draw() + self.finished.emit() + + class DraggablePixmap(QGraphicsPixmapItem): def __init__(self, parent, pixmap): self.parent = parent @@ -309,6 +334,7 @@ def active_grid_settings(self): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + session = get_session() self.canvas_settings = session.query(CanvasSettings).first() self.ui.central_widget.resizeEvent = self.resizeEvent self.app.add_image_to_canvas_signal.connect(self.handle_add_image_to_canvas) @@ -736,16 +762,22 @@ def switch_to_layer(self, layer_index): self.current_layer_index = layer_index def add_image_to_scene(self, image, is_outpaint=False, image_root_point=None): - self.current_active_image = image - if image_root_point is not None: - self.current_layer.pos_x = image_root_point.x() - self.current_layer.pos_y = image_root_point.y() - elif not is_outpaint: - self.current_layer.pos_x = self.active_grid_area_rect.x() - self.current_layer.pos_y = self.active_grid_area_rect.y() - session.add(self.current_layer) - save_session() - self.do_draw() + # self.current_active_image = image + # if image_root_point is not None: + # self.current_layer.pos_x = image_root_point.x() + # self.current_layer.pos_y = image_root_point.y() + # elif not is_outpaint: + # self.current_layer.pos_x = self.active_grid_area_rect.x() + # self.current_layer.pos_y = self.active_grid_area_rect.y() + # session.add(self.current_layer) + # save_session() + # self.do_draw() + self.image_adder = ImageAdder(self, image, is_outpaint, image_root_point) + self.image_adder.finished.connect(self.on_image_adder_finished) + self.image_adder.start() + + def on_image_adder_finished(self): + pass def image_to_system_clipboard_windows(self, pixmap): if not pixmap: diff --git a/src/airunner/widgets/canvas_plus/standard_image_widget.py b/src/airunner/widgets/canvas_plus/standard_image_widget.py index e576708a4..0182f1c5a 100644 --- a/src/airunner/widgets/canvas_plus/standard_image_widget.py +++ b/src/airunner/widgets/canvas_plus/standard_image_widget.py @@ -348,8 +348,7 @@ def load_models(self): current_model = self.settings_manager.generator.model if current_model != "": self.ui.model.setCurrentText(current_model) - else: - self.settings_manager.set_value("generator.model", self.ui.model.currentText()) + self.settings_manager.set_value("generator.model", self.ui.model.currentText()) self.ui.model.blockSignals(False) def load_schedulers(self): diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py index c5275e0be..3a85adb89 100644 --- a/src/airunner/widgets/generator_form/generator_form_widget.py +++ b/src/airunner/widgets/generator_form/generator_form_widget.py @@ -4,10 +4,10 @@ from PyQt6.QtCore import pyqtSignal, QRect from airunner.aihandler.settings import MAX_SEED -from airunner.data.db import session from airunner.data.models import ActiveGridSettings, CanvasSettings from airunner.widgets.base_widget import BaseWidget from airunner.widgets.generator_form.templates.generatorform_ui import Ui_generator_form +from airunner.utils import get_session class GeneratorForm(BaseWidget): @@ -118,6 +118,7 @@ def controlnet_image(self): return self.app.standard_image_panel.ui.controlnet_settings.current_controlnet_image def __init__(self, *args, **kwargs): + session = get_session() super().__init__(*args, **kwargs) self.ui.generator_form_tabs.tabBar().hide() self.active_grid_settings = session.query(ActiveGridSettings).first() @@ -594,7 +595,7 @@ def set_progress_bar_value(self, tab_section, section, value): progressbar.setRange(0, 100) progressbar.setValue(value) - def stop_progress_bar(self, tab_section, section): + def stop_progress_bar(self): progressbar = self.ui.progress_bar if not progressbar: return diff --git a/src/airunner/widgets/generator_form/generator_tab_widget.py b/src/airunner/widgets/generator_form/generator_tab_widget.py index a949f7a2c..6fb82ac9e 100644 --- a/src/airunner/widgets/generator_form/generator_tab_widget.py +++ b/src/airunner/widgets/generator_form/generator_tab_widget.py @@ -168,12 +168,9 @@ def set_progress_bar_value(self, tab_section, section, value): progressbar.setRange(0, 100) progressbar.setValue(value) - def stop_progress_bar(self, tab_section, section): - progressbar = self.find_widget("progress_bar", tab_section, section) - if not progressbar: - return - progressbar.setRange(0, 100) - progressbar.setValue(100) + def stop_progress_bar(self): + self.generate_form.progress_bar.setRange(0, 100) + self.generate_form.progress_bar.setValue(100) def add_widget_to_grid(self, widget, row=None, col=0): if row is None: diff --git a/src/airunner/widgets/image/image_widget.py b/src/airunner/widgets/image/image_widget.py index 786cfc78b..5aaf18ca2 100644 --- a/src/airunner/widgets/image/image_widget.py +++ b/src/airunner/widgets/image/image_widget.py @@ -51,8 +51,14 @@ def set_image(self, image_path): path = self.image_path + ".thumbnail.png" if not os.path.exists(path): image = Image.open(self.image_path) - image.thumbnail((size, size)) - image.save(path) + try: + image.thumbnail((size, size)) + except OSError: + pass + try: + image.save(path) + except OSError: + pass if self.is_thumbnail: pixmap = QPixmap(path) else: diff --git a/src/airunner/widgets/llm/chat_prompt_widget.py b/src/airunner/widgets/llm/chat_prompt_widget.py index e02cfa613..96a908902 100644 --- a/src/airunner/widgets/llm/chat_prompt_widget.py +++ b/src/airunner/widgets/llm/chat_prompt_widget.py @@ -2,12 +2,11 @@ from PyQt6.QtWidgets import QSpacerItem, QSizePolicy from airunner.aihandler.enums import MessageCode -from airunner.data.db import session from airunner.data.models import Conversation, LLMPromptTemplate, Message from airunner.widgets.base_widget import BaseWidget from airunner.widgets.llm.templates.chat_prompt_ui import Ui_chat_prompt from airunner.widgets.llm.message_widget import MessageWidget -from airunner.utils import save_session +from airunner.utils import save_session, get_session from airunner.aihandler.logger import Logger @@ -51,6 +50,7 @@ def instructions(self): return f"{self.generator.botname} loves {self.generator.username}. {self.generator.botname} is very nice. {self.generator.botname} uses compliments, kind responses, and nice words. Everything {self.generator.botname} says is nice. {self.generator.botname} is kind." def load_data(self): + session = get_session() self.conversation = session.query(Conversation).first() if self.conversation is None: self.conversation = Conversation() @@ -122,6 +122,7 @@ def handle_text_generated(self, message): message=message, conversation=self.conversation ) + session = get_session() session.add(message_object) session.commit() @@ -208,7 +209,7 @@ def action_button_clicked_send(self, image_override=None, prompt_override=None, Logger.warning("Prompt is empty") return - print(self.generator.prompt_template) + session = get_session() prompt_template = session.query(LLMPromptTemplate).filter( LLMPromptTemplate.name == self.generator.prompt_template ).first() diff --git a/src/airunner/widgets/prompt_builder/prompt_builder_widget.py b/src/airunner/widgets/prompt_builder/prompt_builder_widget.py index 3bbf0c3db..58326ad8f 100644 --- a/src/airunner/widgets/prompt_builder/prompt_builder_widget.py +++ b/src/airunner/widgets/prompt_builder/prompt_builder_widget.py @@ -1,9 +1,8 @@ import random from airunner.aihandler.settings import MAX_SEED -from airunner.data.db import session from airunner.data.models import TabSection, PromptBuilder -from airunner.utils import save_session +from airunner.utils import save_session, get_session from airunner.widgets.base_widget import BaseWidget from airunner.widgets.prompt_builder.prompt_builder_form_widget import PromptBuilderForm from airunner.widgets.prompt_builder.templates.prompt_builder_ui import Ui_prompt_builder @@ -142,6 +141,7 @@ def negative_prompt_generator_suffix(self, value): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + session = get_session() self.prompt_generator_settings = session.query(PromptBuilder).all() ts = session.query(TabSection).filter(TabSection.panel == "prompt_builder.ui.tabs").first() self.ui.tabs.blockSignals(True) @@ -321,6 +321,7 @@ def inject_prompt(self, options): def tab_changed(self, val): print("tab_changed", val) + session = get_session() ts = session.query(TabSection).filter(TabSection.panel == "prompt_builder.ui.tabs").first() ts.active_tab = str(val) save_session() diff --git a/src/airunner/windows/main/main_window.py b/src/airunner/windows/main/main_window.py index f9837267a..064a8a923 100644 --- a/src/airunner/windows/main/main_window.py +++ b/src/airunner/windows/main/main_window.py @@ -1,4 +1,5 @@ import os +import queue import pickle import platform import subprocess @@ -7,12 +8,10 @@ from functools import partial from PyQt6 import uic, QtCore -from PyQt6.QtCore import pyqtSlot, Qt, pyqtSignal, QTimer +from PyQt6.QtCore import pyqtSlot, Qt, pyqtSignal, QTimer, QObject, QThread from PyQt6.QtGui import QGuiApplication -from PyQt6.QtWidgets import QApplication, QFileDialog, QMainWindow, QWidget, QSpacerItem, QSizePolicy -from PyQt6.QtCore import Qt +from PyQt6.QtWidgets import QApplication, QFileDialog, QMainWindow, QWidget from PyQt6 import QtGui -from PyQt6.QtCore import QTimer from airunner.resources_light_rc import * from airunner.resources_dark_rc import * @@ -23,13 +22,12 @@ from airunner.aihandler.settings import LOG_LEVEL from airunner.aihandler.settings_manager import SettingsManager from airunner.airunner_api import AIRunnerAPI -from airunner.data.db import session from airunner.data.models import SplitterSection, Prompt, TabSection, LLMGenerator from airunner.filters.windows.filter_base import FilterBase from airunner.input_event_manager import InputEventManager from airunner.mixins.history_mixin import HistoryMixin from airunner.settings import BASE_PATH -from airunner.utils import get_version, auto_export_image, save_session, \ +from airunner.utils import get_version, auto_export_image, save_session, get_session, \ create_airunner_paths, default_hf_cache_dir from airunner.widgets.status.status_widget import StatusWidget from airunner.windows.about.about import AboutWindow @@ -42,7 +40,63 @@ from airunner.data.models import TabSection from airunner.widgets.brushes.brushes_container import BrushesContainer from airunner.data.models import Document -from airunner.utils import get_session + + +class ImageDataWorker(QObject): + finished = pyqtSignal() + stop_progress_bar = pyqtSignal() + + def __init__(self, parent): + super().__init__() + self.parent = parent + + @pyqtSlot() + def process(self): + while True: + item = self.parent.image_data_queue.get() + self.process_image_data(item) + self.finished.emit() + + def process_image_data(self, message): + images = message["images"] + data = message["data"] + nsfw_content_detected = message["nsfw_content_detected"] + self.parent.clear_status_message() + self.parent.data = data + print("process_image_data 3") + if data["action"] == "txt2vid": + return self.parent.video_handler(data) + self.stop_progress_bar.emit() + print("process_image_data 4") + path = "" + if self.parent.settings_manager.auto_export_images: + procesed_images = [] + for image in images: + path, image = auto_export_image( + image=image, + data=data, + seed=data["options"]["seed"], + latents_seed=data["options"]["latents_seed"] + ) + if path is not None: + self.parent.set_status_label(f"Image exported to {path}") + procesed_images.append(image) + images = procesed_images + if nsfw_content_detected and self.parent.settings_manager.nsfw_filter: + self.parent.message_handler({ + "message": "Explicit content detected, try again.", + "code": MessageCode.ERROR + }) + + images = self.parent.post_process_images(images) + self.parent.image_data.emit({ + "images": images, + "path": path, + "data": data + }) + self.parent.message_handler("") + self.parent.ui.layer_widget.show_layers() + self.parent.image_generated.emit(True) class MainWindow( @@ -219,6 +273,9 @@ def current_layer(self): @property def current_canvas(self): return self.standard_image_panel + + def stop_progress_bar(self): + self.generator_tab_widget.stop_progress_bar() def describe_image(self, image, callback): self.generator_tab_widget.ui.ai_tab_widget.describe_image( @@ -314,9 +371,30 @@ def __init__(self, *args, **kwargs): # This is used to check the state of the window and save splitter sizes if they have changed self.start_splitter_timer() + + self.initialize_image_worker() self.loaded.emit() + def initialize_image_worker(self): + self.image_data_queue = queue.Queue() + + self.worker_thread = QThread() + self.worker = ImageDataWorker(self) + self.worker.stop_progress_bar.connect(self.stop_progress_bar) + + self.worker.moveToThread(self.worker_thread) + + self.worker.finished.connect(self.worker_thread.quit) + self.worker.finished.connect(self.worker.deleteLater) + self.worker_thread.finished.connect(self.worker_thread.deleteLater) + + self.worker_thread.started.connect(self.worker.process) + self.worker_thread.start() + + def handle_image_generated(self, message): + self.image_data_queue.put(message) + def initialize_panel_tabs(self): """ Iterate over each TabSection entry from database and set the active tab @@ -324,6 +402,7 @@ def initialize_panel_tabs(self): :return: """ self.ui.mode_tab_widget.currentChanged.connect(self.mode_tab_index_changed) + session = get_session() tabsections = session.query(TabSection).filter( TabSection.panel != "generator_tabs" ).all() @@ -564,6 +643,7 @@ def action_open_discord(self): webbrowser.open("https://discord.gg/ukcgjEpc5f") def tool_tab_index_changed(self, index): + session = get_session() tab_section = session.query(TabSection).filter_by( panel="tool_tab_widget" ).first() @@ -571,6 +651,7 @@ def tool_tab_index_changed(self, index): session.commit() def center_panel_tab_index_changed(self, val): + session = get_session() tab_section = session.query(TabSection).filter_by( panel="center_tab" ).first() @@ -579,6 +660,7 @@ def center_panel_tab_index_changed(self, val): session.commit() def bottom_panel_tab_index_changed(self, index): + session = get_session() tab_section = session.query(TabSection).filter_by( panel="bottom_panel_tab_widget" ).first() @@ -813,6 +895,7 @@ def initialize_splitter_sizes(self): ) ) splitter_names = ["main", "content", "canvas"] + session = get_session() for name in splitter_names: self.splitters[name]["sizes"] = session.query(SplitterSection).filter( SplitterSection.name == f"{name}_splitter" @@ -1103,63 +1186,6 @@ def video_handler(self, data): filename = data["video_filename"] VideoPopup(settings_manager=self.settings_manager, file_path=filename) - def handle_image_generated(self, message): - images = message["images"] - data = message["data"] - nsfw_content_detected = message["nsfw_content_detected"] - self.clear_status_message() - self.data = data - if data["action"] == "txt2vid": - return self.video_handler(data) - - self.generator_tab_widget.stop_progress_bar( - data["tab_section"], data["action"] - ) - path = "" - if self.settings_manager.auto_export_images: - procesed_images = [] - for image in images: - path, image = auto_export_image( - image=image, - data=data, - seed=data["options"]["seed"], - latents_seed=data["options"]["latents_seed"] - ) - if path is not None: - self.set_status_label(f"Image exported to {path}") - procesed_images.append(image) - images = procesed_images - - self.generator_tab_widget.stop_progress_bar( - data["tab_section"], data["action"] - ) - # get max progressbar value - if nsfw_content_detected and self.settings_manager.nsfw_filter: - self.message_handler({ - "message": "Explicit content detected, try again.", - "code": MessageCode.ERROR - }) - - images = self.post_process_images(images) - - if data["options"][f"deterministic_generation"]: - self.deterministic_images = images - DeterministicGenerationWindow( - self.settings_manager, - app=self, - images=images, - data=data) - else: - self.image_data.emit({ - "images": images, - "path": path, - "data": data - }) - self.message_handler("") - self.ui.layer_widget.show_layers() - - self.image_generated.emit(True) - def post_process_images(self, images): #return self.automatic_filter_manager.apply_filters(images) return images @@ -1257,6 +1283,7 @@ def insert_into_prompt(self, text, negative_prompt=False): prompt_widget.setPlainText(text) def change_content_widget(self): + session = get_session() active_tab_obj = session.query(TabSection).filter( TabSection.panel == "center_tab" ).first() @@ -1382,6 +1409,7 @@ def set_all_image_generator_buttons(self): is_prompt_builder = self.settings_manager.generator_section == GeneratorSection.PROMPT_BUILDER.value def image_generators_toggled(self): + session = get_session() self.image_generation_toggled() self.settings_manager.set_value("mode", Mode.IMAGE.value) self.settings_manager.set_value("generator_section", GeneratorSection.TXT2IMG.value) @@ -1392,6 +1420,7 @@ def image_generators_toggled(self): self.change_content_widget() def text_to_video_toggled(self): + session = get_session() self.image_generation_toggled() self.settings_manager.set_value("mode", Mode.IMAGE.value) self.settings_manager.set_value("generator_section", GeneratorSection.TXT2VID.value) @@ -1402,6 +1431,7 @@ def text_to_video_toggled(self): self.change_content_widget() def toggle_prompt_builder(self, val): + session = get_session() if not val: self.image_generators_toggled() else: