diff --git a/documentation/DATALOADER.md b/documentation/DATALOADER.md index 1c3230d4..b80d462e 100644 --- a/documentation/DATALOADER.md +++ b/documentation/DATALOADER.md @@ -14,6 +14,8 @@ Here is the most basic example of a dataloader configuration file, as `multidata "resolution": 1024, "minimum_image_size": 768, "maximum_image_size": 2048, + "minimum_aspect_ratio": 0.50, + "maximum_aspect_ratio": 3.00, "target_downsample_size": 1024, "resolution_type": "pixel_area", "prepend_instance_prompt": false, @@ -119,6 +121,20 @@ Both `textfile` and `parquet` support multi-captions: - When `resolution` is measured in pixels, you should use the same unit here (eg. `1024` to exclude images under 1024px **shorter edge length**) - **Recommendation**: Keep `minimum_image_size` equal to `resolution` unless you want to risk training on poorly-upsized images. +### `minimum_aspect_ratio` + +- **Description:** The minimum aspect ratio of the image. If the image's aspect ratio is less than this value, it will be excluded from training. +- **Note**: If the number of images qualifying for exclusion is excessive, this might waste time at startup as the trainer will try to scan them and bucket if they are missing from the bucket lists. + +> **Note**: Once the aspect and metadata lists are built for your dataset, using `skip_file_discovery="vae aspect metadata"` will prevent the trainer from scanning the dataset on startup, saving a lot of time. + +### `maximum_aspect_ratio` + +- **Description:** The maximum aspect ratio of the image. If the image's aspect ratio is greater than this value, it will be excluded from training. +- **Note**: If the number of images qualifying for exclusion is excessive, this might waste time at startup as the trainer will try to scan them and bucket if they are missing from the bucket lists. + +> **Note**: Once the aspect and metadata lists are built for your dataset, using `skip_file_discovery="vae aspect metadata"` will prevent the trainer from scanning the dataset on startup, saving a lot of time. + #### Examples ##### Configuration diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index d08d976f..127d2304 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -804,6 +804,12 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize minimum_image_size=backend.get( "minimum_image_size", args.minimum_image_size ), + minimum_aspect_ratio=backend.get( + "minimum_aspect_ratio", None + ), + maximum_aspect_ratio=backend.get( + "maximum_aspect_ratio", None + ), resolution_type=backend.get("resolution_type", args.resolution_type), batch_size=args.train_batch_size, metadata_update_interval=backend.get( diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index fd0443d3..1e6c2d0f 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -42,6 +42,8 @@ def __init__( delete_unwanted_images: bool = False, metadata_update_interval: int = 3600, minimum_image_size: int = None, + minimum_aspect_ratio: int = None, + maximum_aspect_ratio: int = None, cache_file_suffix: str = None, repeats: int = 0, ): @@ -74,6 +76,12 @@ def __init__( self.minimum_image_size = ( float(minimum_image_size) if minimum_image_size else None ) + self.minimum_aspect_ratio = ( + float(minimum_aspect_ratio) if minimum_aspect_ratio else None + ) + self.maximum_aspect_ratio = ( + float(maximum_aspect_ratio) if maximum_aspect_ratio else None + ) self.image_metadata_loaded = False self.vae_output_scaling_factor = 8 self.metadata_semaphor = Semaphore() @@ -448,6 +456,9 @@ def _enforce_min_bucket_size(self): """ Remove buckets that have fewer samples than batch_size and enforce minimum image size constraints. """ + if self.minimum_image_size is None: + return + logger.info( f"Enforcing minimum image size of {self.minimum_image_size}." " This could take a while for very-large datasets." @@ -464,6 +475,50 @@ def _enforce_min_bucket_size(self): # We do this twice in case there were any new contenders for being too small. self._prune_small_buckets(bucket) + def _enforce_min_aspect_ratio(self): + """ + Remove buckets that have an aspect ratio outside the specified range. + """ + if self.minimum_aspect_ratio is None or self.minimum_aspect_ratio == 0.0: + return + + logger.info( + f"Enforcing minimum aspect ratio of {self.minimum_aspect_ratio}." + " This could take a while for very-large datasets." + ) + for bucket in tqdm( + list(self.aspect_ratio_bucket_indices.keys()), + leave=False, + desc="Enforcing minimum aspect ratio", + ): # Safe iteration over keys + if float(bucket) < self.minimum_aspect_ratio: + logger.info( + f"Removing bucket {bucket} due to aspect ratio being less than {self.minimum_aspect_ratio}." + ) + del self.aspect_ratio_bucket_indices[bucket] + + def _enforce_max_aspect_ratio(self): + """ + Remove buckets that have an aspect ratio outside the specified range. + """ + if self.maximum_aspect_ratio is None or self.maximum_aspect_ratio == 0.0: + return + + logger.info( + f"Enforcing maximum aspect ratio of {self.maximum_aspect_ratio}." + " This could take a while for very-large datasets." + ) + for bucket in tqdm( + list(self.aspect_ratio_bucket_indices.keys()), + leave=False, + desc="Enforcing maximum aspect ratio", + ): # Safe iteration over keys + if float(bucket) > self.maximum_aspect_ratio: + logger.info( + f"Removing bucket {bucket} due to aspect ratio being greater than {self.maximum_aspect_ratio}." + ) + del self.aspect_ratio_bucket_indices[bucket] + def _prune_small_buckets(self, bucket): """ Remove buckets with fewer images than the batch size. diff --git a/helpers/metadata/backends/discovery.py b/helpers/metadata/backends/discovery.py index 7e3959bf..823149b4 100644 --- a/helpers/metadata/backends/discovery.py +++ b/helpers/metadata/backends/discovery.py @@ -36,6 +36,8 @@ def __init__( delete_unwanted_images: bool = False, metadata_update_interval: int = 3600, minimum_image_size: int = None, + minimum_aspect_ratio: int = None, + maximum_aspect_ratio: int = None, cache_file_suffix: str = None, repeats: int = 0, ): @@ -53,6 +55,8 @@ def __init__( delete_unwanted_images=delete_unwanted_images, metadata_update_interval=metadata_update_interval, minimum_image_size=minimum_image_size, + minimum_aspect_ratio=minimum_aspect_ratio, + maximum_aspect_ratio=maximum_aspect_ratio, cache_file_suffix=cache_file_suffix, repeats=repeats, ) @@ -159,6 +163,8 @@ def save_cache(self, enforce_constraints: bool = False): # Prune any buckets that have fewer samples than batch_size if enforce_constraints: self._enforce_min_bucket_size() + self._enforce_min_aspect_ratio() + self._enforce_max_aspect_ratio() if self.read_only: logger.debug("Skipping cache update on storage backend, read-only mode.") return diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 9293f961..beb58113 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -39,6 +39,8 @@ def __init__( delete_unwanted_images: bool = False, metadata_update_interval: int = 3600, minimum_image_size: int = None, + minimum_aspect_ratio: int = None, + maximum_aspect_ratio: int = None, cache_file_suffix: str = None, repeats: int = 0, ): @@ -60,6 +62,8 @@ def __init__( delete_unwanted_images=delete_unwanted_images, metadata_update_interval=metadata_update_interval, minimum_image_size=minimum_image_size, + minimum_aspect_ratio=minimum_aspect_ratio, + maximum_aspect_ratio=maximum_aspect_ratio, cache_file_suffix=cache_file_suffix, repeats=repeats, ) @@ -295,6 +299,9 @@ def save_cache(self, enforce_constraints: bool = False): # Prune any buckets that have fewer samples than batch_size if enforce_constraints: self._enforce_min_bucket_size() + self._enforce_min_aspect_ratio() + self._enforce_max_aspect_ratio() + if self.read_only: logger.debug("Metadata backend is read-only, skipping cache save.") return diff --git a/tests/test_metadata_backend.py b/tests/test_metadata_backend.py index 16584329..02f8ebaa 100644 --- a/tests/test_metadata_backend.py +++ b/tests/test_metadata_backend.py @@ -101,6 +101,46 @@ def test_save_cache(self): self.metadata_backend.save_cache() mock_write.assert_called_once() + def test_minimum_aspect_size(self): + # when metadata_backend.minimum_aspect_ratio is not None and > 0.0 it will remove buckets from the list. + # this test ensures that the bucket is removed when the value is set correctly. + self.metadata_backend.aspect_ratio_bucket_indices = { + "1.0": ["image1", "image2"], + "1.5": ["image3"], + } + self.metadata_backend.minimum_aspect_ratio = 1.25 + self.metadata_backend._enforce_min_aspect_ratio() + self.assertEqual( + self.metadata_backend.aspect_ratio_bucket_indices, {"1.5": ["image3"]} + ) + + def test_maximum_aspect_size(self): + # when metadata_backend.maximum_aspect_ratio is not None and > 0.0 it will remove buckets from the list. + # this test ensures that the bucket is removed when the value is set correctly. + self.metadata_backend.aspect_ratio_bucket_indices = { + "1.0": ["image1", "image2"], + "1.5": ["image3"], + } + self.metadata_backend.maximum_aspect_ratio = 1.25 + self.metadata_backend._enforce_max_aspect_ratio() + self.assertEqual( + self.metadata_backend.aspect_ratio_bucket_indices, {"1.0": ["image1", "image2"]} + ) + + def test_unbound_aspect_list(self): + # when metadata_backend.maximum_aspect_ratio is None and metadata_backend.minimum_aspect_ratio is None + # the aspect_ratio_bucket_indices should not be modified. + self.metadata_backend.aspect_ratio_bucket_indices = { + "1.0": ["image1", "image2"], + "1.5": ["image3"], + } + self.metadata_backend._enforce_min_aspect_ratio() + self.metadata_backend._enforce_max_aspect_ratio() + self.assertEqual( + self.metadata_backend.aspect_ratio_bucket_indices, + {"1.0": ["image1", "image2"], "1.5": ["image3"]}, + ) + if __name__ == "__main__": unittest.main()