Skip to content

Commit

Permalink
add image_files_in_directory() function to image_tools.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bbean23 committed Jul 26, 2024
1 parent e087e96 commit 122abce
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 3 deletions.
42 changes: 41 additions & 1 deletion opencsp/common/lib/tool/image_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from PIL import Image

import opencsp.common.lib.tool.file_tools as ft
import opencsp.common.lib.tool.log_tools as lt

# disable auto formatting
Expand All @@ -20,7 +21,7 @@
"gd", "imt", "iptc", "naa", "mcidas", "mic", "mpo", "pcd", "pixar", "psd", "sun", "wal", "wmf", "emf", "xpm"]
""" A list of all image image formats that can be read by the Python Imaging Library (PIL). Note that not all of these formats can be written by PIL. """
pil_image_formats_writable = pil_image_formats_rw + ["palm", "pdf", "xv"]
""" A list of all image image formats that can be written by the Python Imaging Library (PIL). Note that not all of these formats can be ready by PIL. """
""" A list of all image image formats that can be written by the Python Imaging Library (PIL). Note that not all of these formats can be read by PIL. """
# fmt: on


Expand Down Expand Up @@ -216,3 +217,42 @@ def range_for_threshold(image: np.ndarray, threshold: int) -> tuple[int, int, in
ret.append(end)

return tuple(ret)


def image_files_in_directory(dir: str, allowable_extensions: list[str] = None) -> list[str]:
"""
Get a list of all image files in the given directory, as determined by the file extension.
Parameters
----------
dir : str
The directory to get files from.
allowable_extensions : list[str], optional
The allowed extensions, such as ["png"]. By default pil_image_formats_rw.
Returns
-------
image_file_names_exts: list[str]
A list of the name.ext for each image file in the given directory.
"""
# normalize input
if allowable_extensions is None:
allowable_extensions = pil_image_formats_rw
for i, ext in enumerate(allowable_extensions):
if ext.startswith("."):
continue
else:
allowable_extensions[i] = "." + ext

# get all matching files
files_per_ext = ft.files_in_directory_by_extension(dir, allowable_extensions)

# condense into a single list
files: list[str] = []
for ext in files_per_ext:
files += files_per_ext[ext]

# sort
files = sorted(files)

return files
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
29 changes: 27 additions & 2 deletions opencsp/common/lib/tool/test/test_image_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
class TestImageTools(unittest.TestCase):
def setUp(self) -> None:
path, _, _ = ft.path_components(__file__)
self.data_dir = os.path.join(path, "data", "input", "image_tools")
self.out_dir = os.path.join(path, "data", "output", "image_tools")
self.data_dir = ft.join(path, "data", "input", "image_tools")
self.out_dir = ft.join(path, "data", "output", "image_tools")

def test_numpy_to_image_truncate(self):
arr8i = np.array([[0, 125, 255]]).astype(np.int8)
Expand Down Expand Up @@ -63,6 +63,31 @@ def test_min_max(self):
nptest.assert_array_equal(it.min_max_colors(arr3)[0], np.array([0, 1, 2]))
nptest.assert_array_equal(it.min_max_colors(arr3)[1], np.array([5, 4, 3]))

def test_image_files_in_directory(self):
all_image_files = it.image_files_in_directory(self.data_dir)
self.assertIn("a.png", all_image_files)
self.assertIn("b.PNG", all_image_files)
self.assertIn("c.jpg", all_image_files)
self.assertNotIn("d.txt", all_image_files)

png_image_files = it.image_files_in_directory(self.data_dir, ["png"])
self.assertIn("a.png", png_image_files)
self.assertIn("b.PNG", png_image_files)
self.assertNotIn("c.jpg", png_image_files)
self.assertNotIn("d.txt", png_image_files)

jpg_image_files = it.image_files_in_directory(self.data_dir, ["jpg"])
self.assertNotIn("a.png", jpg_image_files)
self.assertNotIn("b.PNG", jpg_image_files)
self.assertIn("c.jpg", jpg_image_files)
self.assertNotIn("d.txt", jpg_image_files)

png_jpg_image_files = it.image_files_in_directory(self.data_dir, ["png", "jpg"])
self.assertIn("a.png", png_jpg_image_files)
self.assertIn("b.PNG", png_jpg_image_files)
self.assertIn("c.jpg", png_jpg_image_files)
self.assertNotIn("d.txt", png_jpg_image_files)


if __name__ == '__main__':
unittest.main()

0 comments on commit 122abce

Please sign in to comment.