Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concept statistics #679

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions modules/ui/ConceptWindow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os
import random

Expand All @@ -6,6 +7,7 @@
from modules.util.enum.BalancingStrategy import BalancingStrategy
from modules.util.ui import components
from modules.util.ui.UIState import UIState
from scripts import concept_stats

from mgds.LoadingPipeline import LoadingPipeline
from mgds.OutputPipelineModule import OutputPipelineModule
Expand All @@ -24,6 +26,9 @@
from torchvision.transforms import functional

import customtkinter as ctk
from customtkinter import AppearanceModeTracker, ThemeManager
from matplotlib import pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from PIL import Image


Expand Down Expand Up @@ -83,6 +88,7 @@ def __init__(
self.general_tab = self.__general_tab(tabview.add("general"), concept)
self.image_augmentation_tab = self.__image_augmentation_tab(tabview.add("image augmentation"))
self.text_augmentation_tab = self.__text_augmentation_tab(tabview.add("text augmentation"))
self.concept_stats_tab = self.__concept_stats_tab(tabview.add("stats"))

components.button(self, 1, 0, "ok", self.__ok)

Expand Down Expand Up @@ -326,6 +332,137 @@ def __text_augmentation_tab(self, master):
frame.pack(fill="both", expand=1)
return frame

def __concept_stats_tab(self, master):
frame = ctk.CTkScrollableFrame(master, fg_color="transparent")
frame.grid_columnconfigure(0, weight=0, minsize=150)
frame.grid_columnconfigure(1, weight=0, minsize=150)
frame.grid_columnconfigure(2, weight=0, minsize=150)
frame.grid_columnconfigure(3, weight=0, minsize=150)

#file size
self.file_size_label = components.label(frame, 1, 0, "Total Size", pad=0,
tooltip="Total size of all image, mask, and caption files")
self.file_size_label.configure(font=ctk.CTkFont(underline=True))
self.file_size_preview = components.label(frame, 2, 0, pad=0, text="-")

#subdirectory count
self.dir_count_label = components.label(frame, 1, 1, "Directories", pad=0,
tooltip="Total number of directories including and under (if 'include subdirectories' is enabled) main concept directory")
self.dir_count_label.configure(font=ctk.CTkFont(underline=True))
self.dir_count_preview = components.label(frame, 2, 1, pad=0, text="-")

#basic img stats
self.image_count_label = components.label(frame, 3, 0, "\nTotal Images", pad=0,
tooltip="Total number of image files, any of the extensions " + str(path_util.SUPPORTED_IMAGE_EXTENSIONS) + ", excluding '-masklabel.png'")
self.image_count_label.configure(font=ctk.CTkFont(underline=True))
self.image_count_preview = components.label(frame, 4, 0, pad=0, text="-")
self.mask_count_label = components.label(frame, 3, 1, "\nTotal Masks", pad=0,
tooltip="Total number of mask files, any file ending in '-masklabel.png'")
self.mask_count_label.configure(font=ctk.CTkFont(underline=True))
self.mask_count_preview = components.label(frame, 4, 1, pad=0, text="-")
self.caption_count_label = components.label(frame, 3, 2, "\nTotal Captions", pad=0,
tooltip="Total number of caption files, any .txt file")
self.caption_count_label.configure(font=ctk.CTkFont(underline=True))
self.caption_count_preview = components.label(frame, 4, 2, pad=0, text="-")

#advanced img stats
self.image_count_mask_label = components.label(frame, 5, 0, "\nImages with Masks", pad=0,
tooltip="Total number of image files with an associated mask")
self.image_count_mask_label.configure(font=ctk.CTkFont(underline=True))
self.image_count_mask_preview = components.label(frame, 6, 0, pad=0, text="-")
self.mask_count_label_unpaired = components.label(frame, 5, 1, "\nUnpaired Masks", pad=0,
tooltip="Total number of mask files which lack a corresponding image file - if >0, check your data set!")
self.mask_count_label_unpaired.configure(font=ctk.CTkFont(underline=True))
self.mask_count_preview_unpaired = components.label(frame, 6, 1, pad=0, text="-")
self.image_count_caption_label = components.label(frame, 5, 2, "\nImages with Captions", pad=0,
tooltip="Total number of image files with an associated caption")
self.image_count_caption_label.configure(font=ctk.CTkFont(underline=True))
self.image_count_caption_preview = components.label(frame, 6, 2, pad=0, text="-")
self.caption_count_label_unpaired = components.label(frame, 5, 3, "\nUnpaired Captions", pad=0,
tooltip="Total number of caption files which lack a corresponding image file - if >0, check your data set!")
self.caption_count_label_unpaired.configure(font=ctk.CTkFont(underline=True))
self.caption_count_preview_unpaired = components.label(frame, 6, 3, pad=0, text="-")

#resolution info
self.pixel_max_label = components.label(frame, 7, 0, "\nMax Pixels", pad=0,
tooltip="Largest image in the concept by number of pixels (width * height)")
self.pixel_max_label.configure(font=ctk.CTkFont(underline=True))
self.pixel_max_preview = components.label(frame, 8, 0, pad=0, text="-", wraplength=150)
self.pixel_avg_label = components.label(frame, 7, 1, "\nAvg Pixels", pad=0,
tooltip="Average size of images in the concept by number of pixels (width * height)")
self.pixel_avg_label.configure(font=ctk.CTkFont(underline=True))
self.pixel_avg_preview = components.label(frame, 8, 1, pad=0, text="-", wraplength=150)
self.pixel_min_label = components.label(frame, 7, 2, "\nMin Pixels", pad=0,
tooltip="Smallest image in the concept by number of pixels (width * height)")
self.pixel_min_label.configure(font=ctk.CTkFont(underline=True))
self.pixel_min_preview = components.label(frame, 8, 2, pad=0, text="-", wraplength=150)

#caption info
self.caption_max_label = components.label(frame, 9, 0, "\nMax Caption Length", pad=0,
tooltip="Largest caption in concept by character count")
self.caption_max_label.configure(font=ctk.CTkFont(underline=True))
self.caption_max_preview = components.label(frame, 10, 0, pad=0, text="-", wraplength=150)
self.caption_avg_label = components.label(frame, 9, 1, "\nAvg Caption Length", pad=0,
tooltip="Average length of caption in concept by character count")
self.caption_avg_label.configure(font=ctk.CTkFont(underline=True))
self.caption_avg_preview = components.label(frame, 10, 1, pad=0, text="-", wraplength=150)
self.caption_min_label = components.label(frame, 9, 2, "\nMin Caption Length", pad=0,
tooltip="Smallest caption in concept by character count")
self.caption_min_label.configure(font=ctk.CTkFont(underline=True))
self.caption_min_preview = components.label(frame, 10, 2, pad=0, text="-", wraplength=150)

#bucket info
self.aspect_bucket_label = components.label(frame, 11, 0, "\nAspect Bucketing", pad=0,
tooltip="List of all possible buckets and the number of images in each one, defined as width/height. Buckets range from 0.25 (1:4 extremely tall) to 4 (4:1 extremely wide).")
self.aspect_bucket_label.configure(font=ctk.CTkFont(underline=True))
self.small_bucket_label = components.label(frame, 11, 1, "\nSmallest Buckets", pad=0,
tooltip="Image buckets with the least nonzero total images - if 'batch size' is larger than this, these images will be dropped during training!")
self.small_bucket_label.configure(font=ctk.CTkFont(underline=True))
self.small_bucket_preview = components.label(frame, 12, 1, pad=0, text="-")

# plot
appearance_mode = AppearanceModeTracker.get_mode()
background_color = self.winfo_rgb(ThemeManager.theme["CTkToplevel"]["fg_color"][appearance_mode])
text_color = self.winfo_rgb(ThemeManager.theme["CTkLabel"]["text_color"][appearance_mode])
background_color = f"#{int(background_color[0]/256):x}{int(background_color[1]/256):x}{int(background_color[2]/256):x}"
text_color = f"#{int(text_color[0]/256):x}{int(text_color[1]/256):x}{int(text_color[2]/256):x}"

plt.set_loglevel('WARNING') #suppress errors about data type in bar chart
self.bucket_fig, self.bucket_ax = plt.subplots(figsize=(7,2))
self.canvas = FigureCanvasTkAgg(self.bucket_fig, master=frame)
self.canvas.get_tk_widget().grid(row=13, column=0, columnspan=4, rowspan=2)
self.bucket_fig.tight_layout()

self.bucket_fig.set_facecolor(background_color)
self.bucket_ax.set_facecolor(background_color)
self.bucket_ax.spines['bottom'].set_color(text_color)
self.bucket_ax.spines['left'].set_color(text_color)
self.bucket_ax.spines['top'].set_visible(False)
self.bucket_ax.spines['right'].set_color(text_color)
self.bucket_ax.tick_params(axis='x', colors=text_color, which="both")
self.bucket_ax.tick_params(axis='y', colors=text_color, which="both")
self.bucket_ax.xaxis.label.set_color(text_color)
self.bucket_ax.yaxis.label.set_color(text_color)

#refresh stats - must be after all labels are defined or will give error
components.button(master=frame, row=0, column=0, text="Refresh Basic", command=lambda: self.__update_concept_stats(True, False),
tooltip="Reload basic statistics for the concept directory")
components.button(master=frame, row=0, column=1, text="Refresh Advanced", command=lambda: self.__update_concept_stats(True, True),
tooltip="Reload advanced statistics for the concept directory")
components.label(frame, 0, 2, text="Warning!", tooltip="Will be slow for large folders, particularly ones on HDDs!")
self.processing_time = components.label(frame, 0, 3, text="-", tooltip="Time taken to process concept directory")

#automatically get basic stats if available
try:
self.__update_concept_stats(False, False) #load stats from config if available
except KeyError:
self.__update_concept_stats(True, False) #force rescan if config is empty
except FileNotFoundError: #avoid error when loading concept window without config path defined
pass

frame.pack(fill="both", expand=1)
return frame

def __prev_image_preview(self):
self.image_preview_file_index = max(self.image_preview_file_index - 1, 0)
self.__update_image_preview()
Expand Down Expand Up @@ -439,5 +576,88 @@ def __get_preview_image(self):

return image

def __update_concept_stats(self, force_refresh : bool, advanced_checks : bool):
#only runs scan if specifically requested, otherwise loads from concept config
if force_refresh or len(self.concept.concept_stats) == 0:
self.__get_concept_stats(advanced_checks)
self.processing_time.configure(text=str(self.concept.concept_stats["processing_time"]) + " s")

#file size
self.file_size_preview.configure(text=str(int(self.concept.concept_stats["file_size"]/1048576)) + " MB")

#directory count
self.dir_count_preview.configure(text=self.concept.concept_stats["directory_count"])

#image count
self.image_count_preview.configure(text=self.concept.concept_stats["image_count"])
self.image_count_mask_preview.configure(text=self.concept.concept_stats["image_with_mask_count"])
self.image_count_caption_preview.configure(text=self.concept.concept_stats["image_with_caption_count"])

#mask count
self.mask_count_preview.configure(text=self.concept.concept_stats["mask_count"])
self.mask_count_preview_unpaired.configure(text=self.concept.concept_stats["unpaired_masks"])

#caption count
self.caption_count_preview.configure(text=self.concept.concept_stats["caption_count"])
self.caption_count_preview_unpaired.configure(text=self.concept.concept_stats["unpaired_captions"])

#resolution info
max_pixels = self.concept.concept_stats["max_pixels"]
avg_pixels = self.concept.concept_stats["avg_pixels"]
min_pixels = self.concept.concept_stats["min_pixels"]

if any(isinstance(x, str) for x in [max_pixels, avg_pixels, min_pixels]): #will be str if adv stats were not taken
self.pixel_max_preview.configure(text=max_pixels)
self.pixel_avg_preview.configure(text=avg_pixels)
self.pixel_min_preview.configure(text=min_pixels)
else:
#formatted as (#pixels/1000000) MP, widthxheight, \n filename
self.pixel_max_preview.configure(text=f'{str(round(max_pixels[0]/1000000, 2))} MP, {max_pixels[2]}\n{max_pixels[1]}')
self.pixel_avg_preview.configure(text=f'{str(round(avg_pixels/1000000, 2))} MP, ~{int(math.sqrt(avg_pixels))}x{int(math.sqrt(avg_pixels))}')
self.pixel_min_preview.configure(text=f'{str(round(min_pixels[0]/1000000, 2))} MP, {min_pixels[2]}\n{min_pixels[1]}')

#caption info
max_caption_length = self.concept.concept_stats["max_caption_length"]
avg_caption_length = self.concept.concept_stats["avg_caption_length"]
min_caption_length = self.concept.concept_stats["min_caption_length"]

if any(isinstance(x, str) for x in [max_caption_length, avg_caption_length, min_caption_length]): #will be str if adv stats were not taken
self.caption_max_preview.configure(text=max_caption_length)
self.caption_avg_preview.configure(text=avg_caption_length)
self.caption_min_preview.configure(text=min_caption_length)
else:
#formatted as (#pixels/1000000) MP, widthxheight, \n filename
self.caption_max_preview.configure(text=f'{max_caption_length[0]} chars, {max_caption_length[2]} words\n{max_caption_length[1]}')
self.caption_avg_preview.configure(text=f'{int(avg_caption_length[0])} chars, {int(avg_caption_length[1])} words')
self.caption_min_preview.configure(text=f'{min_caption_length[0]} chars, {min_caption_length[2]} words\n{min_caption_length[1]}')

#bucketing
aspect_buckets = self.concept.concept_stats["aspect_buckets"]
if len(aspect_buckets) != 0 and max(val for val in aspect_buckets.values()) > 0: #check aspect_bucket data exists and is not all zero
min_val = min(val for val in aspect_buckets.values() if val > 0) #smallest nonzero values
if max(val for val in aspect_buckets.values()) > min_val: #check if any buckets larger than min_val exist
min_val2 = min(val for val in aspect_buckets.values() if (val > 0 and val != min_val)) #second smallest bucket
else:
min_val2 = min_val #if no second smallest bucket exists set to min_val
min_aspect_buckets = {key: val for key,val in aspect_buckets.items() if val in (min_val, min_val2)}
min_bucket_str = ""
for key, val in min_aspect_buckets.items():
min_bucket_str += f'aspect {key}: {val} img\n'
min_bucket_str.strip()
self.small_bucket_preview.configure(text=min_bucket_str)

self.bucket_ax.cla()
aspects = [str(x) for x in list(aspect_buckets.keys())]
counts = list(aspect_buckets.values())
b = self.bucket_ax.bar(aspects, counts)
self.bucket_ax.bar_label(b)
self.canvas.draw()

def __get_concept_stats(self, advanced_checks : bool):
new_stats = concept_stats.get_concept_stats(self.concept, advanced_checks)
for key, val in new_stats.items():
if key not in self.concept.concept_stats or not (val == "-" or val == {}): #only update if key isn't in old config or new val isn't null
self.concept.concept_stats[key] = val

def __ok(self):
self.destroy()
2 changes: 2 additions & 0 deletions modules/util/config/ConceptConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class ConceptConfig(BaseConfig):
text_variations: int
repeats: float
loss_weight: float
concept_stats: dict

image: ConceptImageConfig
text: ConceptTextConfig
Expand Down Expand Up @@ -182,5 +183,6 @@ def default_values():
data.append(("balancing", 1.0, float, False))
data.append(("balancing_strategy", BalancingStrategy.REPEATS, BalancingStrategy, False))
data.append(("loss_weight", 1.0, float, False))
data.append(("concept_stats", {}, dict, False))

return ConceptConfig(data)
Loading