Skip to content

Commit

Permalink
Merge pull request #1303 from bghira/feature/minmax-aspect-bucket-bounds
Browse files Browse the repository at this point in the history
(#1113) add minimum and maximum aspect ratio bucket parameter and associated tests
  • Loading branch information
bghira authored Jan 27, 2025
2 parents 9165da4 + a105d55 commit 4bc85b8
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 0 deletions.
16 changes: 16 additions & 0 deletions documentation/DATALOADER.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Here is the most basic example of a dataloader configuration file, as `multidata
"resolution": 1024,
"minimum_image_size": 768,
"maximum_image_size": 2048,
"minimum_aspect_ratio": 0.50,
"maximum_aspect_ratio": 3.00,
"target_downsample_size": 1024,
"resolution_type": "pixel_area",
"prepend_instance_prompt": false,
Expand Down Expand Up @@ -119,6 +121,20 @@ Both `textfile` and `parquet` support multi-captions:
- When `resolution` is measured in pixels, you should use the same unit here (eg. `1024` to exclude images under 1024px **shorter edge length**)
- **Recommendation**: Keep `minimum_image_size` equal to `resolution` unless you want to risk training on poorly-upsized images.

### `minimum_aspect_ratio`

- **Description:** The minimum aspect ratio of the image. If the image's aspect ratio is less than this value, it will be excluded from training.
- **Note**: If the number of images qualifying for exclusion is excessive, this might waste time at startup as the trainer will try to scan them and bucket if they are missing from the bucket lists.

> **Note**: Once the aspect and metadata lists are built for your dataset, using `skip_file_discovery="vae aspect metadata"` will prevent the trainer from scanning the dataset on startup, saving a lot of time.
### `maximum_aspect_ratio`

- **Description:** The maximum aspect ratio of the image. If the image's aspect ratio is greater than this value, it will be excluded from training.
- **Note**: If the number of images qualifying for exclusion is excessive, this might waste time at startup as the trainer will try to scan them and bucket if they are missing from the bucket lists.

> **Note**: Once the aspect and metadata lists are built for your dataset, using `skip_file_discovery="vae aspect metadata"` will prevent the trainer from scanning the dataset on startup, saving a lot of time.
#### Examples

##### Configuration
Expand Down
6 changes: 6 additions & 0 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,12 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
minimum_image_size=backend.get(
"minimum_image_size", args.minimum_image_size
),
minimum_aspect_ratio=backend.get(
"minimum_aspect_ratio", None
),
maximum_aspect_ratio=backend.get(
"maximum_aspect_ratio", None
),
resolution_type=backend.get("resolution_type", args.resolution_type),
batch_size=args.train_batch_size,
metadata_update_interval=backend.get(
Expand Down
55 changes: 55 additions & 0 deletions helpers/metadata/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(
delete_unwanted_images: bool = False,
metadata_update_interval: int = 3600,
minimum_image_size: int = None,
minimum_aspect_ratio: int = None,
maximum_aspect_ratio: int = None,
cache_file_suffix: str = None,
repeats: int = 0,
):
Expand Down Expand Up @@ -74,6 +76,12 @@ def __init__(
self.minimum_image_size = (
float(minimum_image_size) if minimum_image_size else None
)
self.minimum_aspect_ratio = (
float(minimum_aspect_ratio) if minimum_aspect_ratio else None
)
self.maximum_aspect_ratio = (
float(maximum_aspect_ratio) if maximum_aspect_ratio else None
)
self.image_metadata_loaded = False
self.vae_output_scaling_factor = 8
self.metadata_semaphor = Semaphore()
Expand Down Expand Up @@ -448,6 +456,9 @@ def _enforce_min_bucket_size(self):
"""
Remove buckets that have fewer samples than batch_size and enforce minimum image size constraints.
"""
if self.minimum_image_size is None:
return

logger.info(
f"Enforcing minimum image size of {self.minimum_image_size}."
" This could take a while for very-large datasets."
Expand All @@ -464,6 +475,50 @@ def _enforce_min_bucket_size(self):
# We do this twice in case there were any new contenders for being too small.
self._prune_small_buckets(bucket)

def _enforce_min_aspect_ratio(self):
"""
Remove buckets that have an aspect ratio outside the specified range.
"""
if self.minimum_aspect_ratio is None or self.minimum_aspect_ratio == 0.0:
return

logger.info(
f"Enforcing minimum aspect ratio of {self.minimum_aspect_ratio}."
" This could take a while for very-large datasets."
)
for bucket in tqdm(
list(self.aspect_ratio_bucket_indices.keys()),
leave=False,
desc="Enforcing minimum aspect ratio",
): # Safe iteration over keys
if float(bucket) < self.minimum_aspect_ratio:
logger.info(
f"Removing bucket {bucket} due to aspect ratio being less than {self.minimum_aspect_ratio}."
)
del self.aspect_ratio_bucket_indices[bucket]

def _enforce_max_aspect_ratio(self):
"""
Remove buckets that have an aspect ratio outside the specified range.
"""
if self.maximum_aspect_ratio is None or self.maximum_aspect_ratio == 0.0:
return

logger.info(
f"Enforcing maximum aspect ratio of {self.maximum_aspect_ratio}."
" This could take a while for very-large datasets."
)
for bucket in tqdm(
list(self.aspect_ratio_bucket_indices.keys()),
leave=False,
desc="Enforcing maximum aspect ratio",
): # Safe iteration over keys
if float(bucket) > self.maximum_aspect_ratio:
logger.info(
f"Removing bucket {bucket} due to aspect ratio being greater than {self.maximum_aspect_ratio}."
)
del self.aspect_ratio_bucket_indices[bucket]

def _prune_small_buckets(self, bucket):
"""
Remove buckets with fewer images than the batch size.
Expand Down
6 changes: 6 additions & 0 deletions helpers/metadata/backends/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
delete_unwanted_images: bool = False,
metadata_update_interval: int = 3600,
minimum_image_size: int = None,
minimum_aspect_ratio: int = None,
maximum_aspect_ratio: int = None,
cache_file_suffix: str = None,
repeats: int = 0,
):
Expand All @@ -53,6 +55,8 @@ def __init__(
delete_unwanted_images=delete_unwanted_images,
metadata_update_interval=metadata_update_interval,
minimum_image_size=minimum_image_size,
minimum_aspect_ratio=minimum_aspect_ratio,
maximum_aspect_ratio=maximum_aspect_ratio,
cache_file_suffix=cache_file_suffix,
repeats=repeats,
)
Expand Down Expand Up @@ -159,6 +163,8 @@ def save_cache(self, enforce_constraints: bool = False):
# Prune any buckets that have fewer samples than batch_size
if enforce_constraints:
self._enforce_min_bucket_size()
self._enforce_min_aspect_ratio()
self._enforce_max_aspect_ratio()
if self.read_only:
logger.debug("Skipping cache update on storage backend, read-only mode.")
return
Expand Down
7 changes: 7 additions & 0 deletions helpers/metadata/backends/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(
delete_unwanted_images: bool = False,
metadata_update_interval: int = 3600,
minimum_image_size: int = None,
minimum_aspect_ratio: int = None,
maximum_aspect_ratio: int = None,
cache_file_suffix: str = None,
repeats: int = 0,
):
Expand All @@ -60,6 +62,8 @@ def __init__(
delete_unwanted_images=delete_unwanted_images,
metadata_update_interval=metadata_update_interval,
minimum_image_size=minimum_image_size,
minimum_aspect_ratio=minimum_aspect_ratio,
maximum_aspect_ratio=maximum_aspect_ratio,
cache_file_suffix=cache_file_suffix,
repeats=repeats,
)
Expand Down Expand Up @@ -295,6 +299,9 @@ def save_cache(self, enforce_constraints: bool = False):
# Prune any buckets that have fewer samples than batch_size
if enforce_constraints:
self._enforce_min_bucket_size()
self._enforce_min_aspect_ratio()
self._enforce_max_aspect_ratio()

if self.read_only:
logger.debug("Metadata backend is read-only, skipping cache save.")
return
Expand Down
40 changes: 40 additions & 0 deletions tests/test_metadata_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,46 @@ def test_save_cache(self):
self.metadata_backend.save_cache()
mock_write.assert_called_once()

def test_minimum_aspect_size(self):
# when metadata_backend.minimum_aspect_ratio is not None and > 0.0 it will remove buckets from the list.
# this test ensures that the bucket is removed when the value is set correctly.
self.metadata_backend.aspect_ratio_bucket_indices = {
"1.0": ["image1", "image2"],
"1.5": ["image3"],
}
self.metadata_backend.minimum_aspect_ratio = 1.25
self.metadata_backend._enforce_min_aspect_ratio()
self.assertEqual(
self.metadata_backend.aspect_ratio_bucket_indices, {"1.5": ["image3"]}
)

def test_maximum_aspect_size(self):
# when metadata_backend.maximum_aspect_ratio is not None and > 0.0 it will remove buckets from the list.
# this test ensures that the bucket is removed when the value is set correctly.
self.metadata_backend.aspect_ratio_bucket_indices = {
"1.0": ["image1", "image2"],
"1.5": ["image3"],
}
self.metadata_backend.maximum_aspect_ratio = 1.25
self.metadata_backend._enforce_max_aspect_ratio()
self.assertEqual(
self.metadata_backend.aspect_ratio_bucket_indices, {"1.0": ["image1", "image2"]}
)

def test_unbound_aspect_list(self):
# when metadata_backend.maximum_aspect_ratio is None and metadata_backend.minimum_aspect_ratio is None
# the aspect_ratio_bucket_indices should not be modified.
self.metadata_backend.aspect_ratio_bucket_indices = {
"1.0": ["image1", "image2"],
"1.5": ["image3"],
}
self.metadata_backend._enforce_min_aspect_ratio()
self.metadata_backend._enforce_max_aspect_ratio()
self.assertEqual(
self.metadata_backend.aspect_ratio_bucket_indices,
{"1.0": ["image1", "image2"], "1.5": ["image3"]},
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4bc85b8

Please sign in to comment.