diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index c963959..e39e04d 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -8,11 +8,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Full unit test coverage - Type annotations on all functions -## [0.4.0] - 2020-01-19 +## [0.4.0] - 2020-02-08 +### Added +- `plate_labels` command line argument +- `base` module with base classes to provide core functionality throughout the package +- `geometry` module with base shape classes +- `plate` module with `Plate` and `PlateCollection` classes +- `image_file` module with `ImageFile` and `ImageFileCollection` classes ### Changed - Cached data is now not used by default - `use_saved` command line argument renamed to `use_cached_data` - Compressed serialised data filename changed to `cached_data` +- `save_plots` command line argument renamed to `plots` +- Refactored most of the functions from `main` as static methods in the `plate` or `image_file` modules +### Fixed +- A rare error when opening images using skimage.io.imread ## [0.3.4] - 2020-01-18 ### Added diff --git a/docs/command_line_arguments.md b/docs/command_line_arguments.md index 7171892..23f4ad6 100644 --- a/docs/command_line_arguments.md +++ b/docs/command_line_arguments.md @@ -15,14 +15,6 @@ A full list of available arguments, along with their default values -h --help ``` -### Information output -The level of information output to the command line. Default level is `1`, increase to see more information. Output can be silenced with `0` - -- input: integer -``` --v ---verbose -``` ### Image density The image density your scanner uses, this can usually be found in your scanner settings. It is important to set this correctly as it enables the program to acurately convert the plate size in millimeters to pixels. @@ -31,12 +23,23 @@ The image density your scanner uses, this can usually be found in your scanner s -dpi --dots_per_inch ``` -### Plate size -The diameter of the plates used, in millimeters. It is important to set this correctly otherwise the plates may be located incorrectly in the images. +### Multiprocessing +This technique utilises all of the available processors that your computer has to analyse images in parallel. Since most computers now have at least 2 or 4 processors, this can greatly reduce the time needed to process a set of images. -- input: integer +This technique is however quite resource intensive for your computer so you may wish to disable it. + +- input: boolean ``` ---plate_size +-mp +--multiprocessing +``` +### Plot images output +The level of detail required when saving plot images after analysis. At the default level (`1`), a few summary plots are saved to give a quick overview of the data. If the output level is increased, individual plots for each plate will be saved. + +- input: boolean +``` +-p +--plots ``` ### Plate edge cut The radius, in pixels, to remove from the edge of the plate image. This ensures that the image is clear of reflections, shadows and writing that are typically present near the edge of the plate image. @@ -45,6 +48,19 @@ The radius, in pixels, to remove from the edge of the plate image. This ensures ``` --plate_edge_cut ``` +### Plate labels +A list of labels to identify each plate. The label is used in file names and the plate map. + +Plates are ordered from top left, in rows, and labels must be provided in that order. + +Labels are separated with spaces. To use a space within a label, wrap that label in quotes + +Example: `--plate_labels first second third "label with spaces" fifth sixth` + +- input: list +``` +--plate_labels +``` ### Plate holder shape The layout of the plates in the image in rows and columns. The default is `3` rows and `2` columns. @@ -54,14 +70,12 @@ A square grid of 9 plates would be entered as `--plate_lattice 3 3` ``` --plate_lattice ``` -### Plot images output -The level of detail required when saving plot images after analysis. At the default level (`1`), a few summary plots are saved to give a quick overview of the data. If the output level is increased, individual plots for each plate will be saved. - -Warning: increasing the number of plots can greatly increase the time taken for the image analysis +### Plate size +The diameter of the plates used, in millimeters. It is important to set this correctly otherwise the plates may be located incorrectly in the images. -- input: boolean +- input: integer ``` ---save_plots +--plate_size ``` ### Cached data The package saves a compressed serialised version of its output, along with the uncompressed CSV data. This allows it to quickly generate the CSV files and plot images again, without the need for reanalysing the original images. This is disabled by default to prevent confusing situation where outdated information is output from new or altered image sets. @@ -70,14 +84,12 @@ The package saves a compressed serialised version of its output, along with the ``` --use_cached_data ``` -### Multiprocessing -This technique utilises all of the available processors that your computer has to analyse images in parallel. Since most computers now have at least 2 or 4 processors, this can greatly reduce the time needed to process a set of images. - -This technique is however quite resource intensive for your computer so you may wish to disable it. +### Information output +The level of information output to the command line. Default level is `1`, increase to see more information. Output can be silenced with `0` -- input: boolean +- input: integer ``` --mp ---multiprocessing +-v +--verbose ``` diff --git a/docs/index.md b/docs/index.md index 1cc1e7a..5a10537 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,7 @@ ColonyScanalyser can provide information on: * Colony appearance time * Colony growth over time * Growth and appearance time distribution -* Colony colour (e.g. staining or other visual +* Colony colour (e.g. staining or other visual indicator) ## Install ``` diff --git a/src/colonyscanalyser/base.py b/src/colonyscanalyser/base.py new file mode 100644 index 0000000..2102ec8 --- /dev/null +++ b/src/colonyscanalyser/base.py @@ -0,0 +1,242 @@ +from typing import Type, TypeVar, Optional, List +from collections.abc import Collection +from datetime import datetime, timedelta + + +class Identified: + """ + An object with a integer ID number + """ + def __init__(self, id: int): + self.id = id + + @property + def id(self) -> int: + return self.__id + + @id.setter + def id(self, val: int): + if self.__id_is_valid(val): + self.__id = val + else: + raise ValueError(f"'{val}' is not a valid id. An id must be a non-negative integer'") + + @staticmethod + def __id_exists(collection: Collection, id: int) -> bool: + """ + Verifies if an object in a collection matches the specified ID number + + :param collection: a collection of objects (List, Dict etc) + :param id: an ID number to locate + :returns: True if an object with id exists in the collection + """ + return any(id == existing.id for existing in collection) + + @staticmethod + def __id_is_valid(id: int) -> bool: + """ + Verifies if a value conforms to the requirements for an ID number + + An ID number is an integer with a value greater than zero + + :param id: an ID number to verify + :returns: True if the value conforms to the requirements for an ID number + """ + return isinstance(id, int) and id > 0 + + +class IdentifiedCollection: + """ + An collection of Identified objects with generic methods for modifying the collection + """ + T = TypeVar("T", bound = Identified) + + def __init__(self, items: Collection = None): + self.items = items + + @property + def count(self) -> int: + return len(self.items) + + @property + def items(self) -> List["T"]: + """ + Returns a sorted list of items from the collection + + A copy is returned, preventing direct changes to the collection + """ + return sorted(self.__items, key = lambda item: item.id) + + @items.setter + def items(self, val: Collection): + if isinstance(val, dict): + val = list(val.values()) + + if val is None: + self.__items = list() + elif isinstance(val, Collection) and not isinstance(val, str): + self.__items = val.copy() + else: + raise ValueError(f"Items must be supplied as a valid Collection, not {type(val)}") + + def add(self, id: int) -> "T": + """ + Create a new instance of T and append it to the collection + + :param id: a valid Identified ID number + :returns: a new instance of T + """ + item = Identified(id = id) + + self.append(item) + + return item + + def append(self, item: Type[T]): + """ + Append an item to the collection + + :param item: the object to append to the collection + """ + if not self.exists(item): + self.__items.append(item) + else: + raise ValueError(f"An item with ID #{item.id} already exists") + + def exists(self, item: Type[T]) -> bool: + """ + Check if an item exists in the item collection + + :param item: an instance of T + :returns: True if an item is found with matching ID + """ + return self.id_exists(item.id) + + def id_exists(self, id: int) -> bool: + """ + Check if an item with the specified ID number exists in the item collection + + :param id: a valid Identified id number + :returns: True if an item is found with matching ID + """ + return Identified._Identified__id_exists(self.items, id) + + def get_item(self, id: int) -> Optional["T"]: + """ + Returns an item with the specified ID number from the item collection + + :param id: a valid Identified ID number + :returns: an item from the collection, if found + """ + for item in self.items: + if item.id == id: + return item + + return None + + def remove(self, id: int): + """ + Remove an item from the collection + + :param id: a valid Identified ID number + """ + if self.id_exists(id): + for item in self.items: + if item.id == id: + self.__items.remove(item) + else: + raise KeyError(f"No item with ID #{id} could be found") + + +class Named: + """ + An object with a string identifier + """ + def __init__(self, name: str): + self.name = name + + @property + def name(self) -> str: + return self.__name + + @name.setter + def name(self, val: str): + self.__name = str(val) + + +class Unique(Identified): + """ + An object with a auto incremented integer ID number + """ + id_count = 0 + + def __init__(self): + self._Identified__id = self.id_increment() + + @Identified.id.setter + def id(self, val: int): + """ + Overrides base method to make id read-only + """ + pass + + def id_increment(self) -> int: + """ + Increments the built-in ID counter + + :returns: the auto incremented ID number + """ + Unique.id_count += 1 + + return Unique.id_count + + +class TimeStamped: + def __init__(self, timestamp: datetime = None): + if timestamp is None: + timestamp = datetime.now() + + self.timestamp = timestamp + + @property + def timestamp(self) -> datetime: + return self.__timestamp + + @timestamp.setter + def timestamp(self, val: datetime): + self.__timestamp = val + + +class TimeStampElapsed(TimeStamped): + def __init__(self, timestamp: datetime = None, timestamp_initial: datetime = None): + if timestamp is None: + timestamp = datetime.now() + if timestamp_initial is None: + timestamp_initial = timestamp + + self._TimeStamped__timestamp = timestamp + self.timestamp_initial = timestamp_initial + + @property + def timestamp_elapsed(self) -> timedelta: + return self.timestamp - self.timestamp_initial + + @property + def timestamp_elapsed_hours(self) -> float: + return (self.timestamp_elapsed_seconds / 60) / 60 + + @property + def timestamp_elapsed_minutes(self) -> int: + return int(self.timestamp_elapsed_seconds / 60) + + @property + def timestamp_elapsed_seconds(self) -> int: + return self.timestamp_elapsed.total_seconds() + + @property + def timestamp_initial(self) -> datetime: + return self.__timestamp_initial + + @timestamp_initial.setter + def timestamp_initial(self, val: datetime): + self.__timestamp_initial = val \ No newline at end of file diff --git a/src/colonyscanalyser/colony.py b/src/colonyscanalyser/colony.py index 966f2e3..529a160 100644 --- a/src/colonyscanalyser/colony.py +++ b/src/colonyscanalyser/colony.py @@ -1,12 +1,15 @@ from datetime import datetime, timedelta from math import pi, log from dataclasses import dataclass -from collections.abc import Iterable +from typing import Union, List, Tuple +from collections.abc import Collection +from numpy import ndarray +from .base import Identified, Named from .utilities import round_tuple_floats -from.imaging import rgb_to_name +from .imaging import rgb_to_name -class Colony: +class Colony(Identified, Named): """ An object to hold information on a single colony over time """ @@ -29,9 +32,9 @@ def __iter__(self): round(self.diameter, 2), round(self.perimeter, 2), round_tuple_floats(self.color_average, 2), - ]) + ]) - def __init__(self, id, timepoints = None): + def __init__(self, id: int, timepoints: Collection = None): self.id = id # Can't set argument default otherwise it is shared across all class instances if timepoints is None: @@ -58,86 +61,92 @@ def __iter__(self): round_tuple_floats(self.timepoint_last.center, 2), self.timepoint_last.area, round(self.timepoint_last.diameter, 2) - ]) + ]) @property - def timepoints(self): - if len(self.__timepoints) > 0: - return self.__timepoints - else: - raise ValueError("No time points are stored for this colony") - - @timepoints.setter - def timepoints(self, val): - if isinstance(val, dict): - self.__timepoints = val - elif isinstance(val, Iterable) and not isinstance(val, str): - self.__timepoints = {timepoint.date_time: timepoint for timepoint in val} - else: - raise ValueError("Timepoints must be supplied as a Dict or other iterable") - - @property - def timepoint_first(self): - return self.get_timepoint(min(self.timepoints.keys())) - - @property - def timepoint_last(self): - return self.get_timepoint(max(self.timepoints.keys())) - - @property - def center(self): + def center(self) -> Union[Tuple[float, float], Tuple[float, float, float]]: centers = [x.center for x in self.timepoints.values()] return tuple(sum(x) / len(self.timepoints) for x in zip(*centers)) @property - def color(self): + def color(self) -> Tuple[float, float, float]: color_averages = [x.color_average for x in self.timepoints.values()] return tuple(sum(x) / len(self.timepoints) for x in zip(*color_averages)) @property - def color_name(self): + def color_name(self) -> str: return rgb_to_name(self.color, color_spec = "css3") @property - def growth_rate(self): + def growth_rate(self) -> float: try: return (self.timepoint_last.area - self.timepoint_first.area) / self.timepoint_first.area except ZeroDivisionError: return 0 @property - def growth_rate_average(self): + def growth_rate_average(self) -> float: if self.growth_rate == 0: return 0 else: return ((self.timepoint_last.area - self.timepoint_first.area) ** (1 / len(self.timepoints))) - 1 @property - def time_of_appearance(self): - return self.timepoint_first.date_time + def timepoints(self): + if len(self.__timepoints) > 0: + return self.__timepoints + else: + raise ValueError("No time points are stored for this colony") - def get_timepoint(self, date_time): - if date_time in self.__timepoints: - return self.timepoints[date_time] + @timepoints.setter + def timepoints(self, val: Collection): + if isinstance(val, dict): + self.__timepoints = val + elif isinstance(val, Collection) and not isinstance(val, str): + self.__timepoints = {timepoint.date_time: timepoint for timepoint in val} else: - raise ValueError(f"The requested time point ({date_time}) does not exist") + raise ValueError("Timepoints must be supplied as a Dict or other Collection") - def append_timepoint(self, timepoint): + @property + def timepoint_first(self) -> "Timepoint": + return self.get_timepoint(min(self.timepoints.keys())) + + @property + def timepoint_last(self) -> "Timepoint": + return self.get_timepoint(max(self.timepoints.keys())) + + @property + def time_of_appearance(self) -> datetime: + return self.timepoint_first.date_time + + def append_timepoint(self, timepoint: Timepoint): + """ + Add a Timepoint to the Colony timepoints collection + + :param timepoint: a Timepoint object + """ if timepoint.date_time not in self.__timepoints: self.__timepoints[timepoint.date_time] = timepoint else: raise ValueError(f"This time point ({timepoint.date_time}) already exists") - def update_timepoint(self, timepoint_original, timepoint_new): - self.timepoints[timepoint_original.date_time] = timepoint_new - - def remove_timepoint(self, date_time): - del self.timepoints[date_time] + def get_circularity_at_timepoint(self, date_time: datetime) -> float: + """ + Calculate the circularity of the colony at a specified timepoint - def circularity_at_timepoint(self, date_time): + :param date_time: the datetime key for specific Timepoint in the Colony timepoints collection + :returns: the circularity of the colony as a float + """ return self.__circularity(self.get_timepoint(date_time).area, self.get_timepoint(date_time).perimeter) - def get_doubling_times(self, window = 10, elapsed_minutes = False): + def get_doubling_times(self, window: int = 10, elapsed_minutes: bool = False) -> Union[List[datetime], List[float]]: + """ + Calculate the colony area doubling times over the specified number of time points + + :param window: the number of time points to calculate the doubling times over + :param elapsed_minutes: return the timestamps in minutes since starting, instead of absolute DateTimes + :returns: a list of doubling times local to the specified timepoint window + """ timepoint_count = len(self.timepoints) if timepoint_count <= 1: return list() @@ -151,10 +160,17 @@ def get_doubling_times(self, window = 10, elapsed_minutes = False): x_pts = [value.date_time for key, value in self.timepoints.items()] y_pts = [value.area for key, value in self.timepoints.items()] - return [self.__local_doubling_time(i, x_pts, y_pts, window) for i in range(len(x_pts) - window)] + return [self.__local_doubling_time(i, x_pts, y_pts, window = window) for i in range(len(x_pts) - window)] - def get_doubling_time_average(self, window = 10, elapsed_minutes = False): - doubling_times = self.get_doubling_times(window, elapsed_minutes) + def get_doubling_time_average(self, window: int = 10, elapsed_minutes: bool = False) -> Union[timedelta, float]: + """ + Calculate an average of the colony area doubling times over the specified number of time points + + :param window: the number of time points to calculate the doubling times over + :param elapsed_minutes: return the timestamps in minutes since starting, instead of absolute DateTimes + :returns: the mean of doubling times local to the specified timepoint window + """ + doubling_times = self.get_doubling_times(window = window, elapsed_minutes = elapsed_minutes) if not len(doubling_times) > 0: return 0 @@ -166,11 +182,64 @@ def get_doubling_time_average(self, window = 10, elapsed_minutes = False): return time_sum / len(doubling_times) - def __circularity(self, area, perimeter): - return (4 * pi * area) / (perimeter * perimeter) + def get_timepoint(self, date_time: datetime) -> "Timepoint": + """ + Returns a Timepoint object from the Colony timepoints collection - def __local_doubling_time(self, index, x_pts, y_pts, window = 10): + :param date_time: the datetime key for specific Timepoint in the Colony timepoints collection + :returns: a Timepoint object from the Colony timepoints collection + """ + if date_time in self.__timepoints: + return self.timepoints[date_time] + else: + raise ValueError(f"The requested time point ({date_time}) does not exist") + + def remove_timepoint(self, date_time: datetime): + """ + Remove a specified Timepoint from the Colony timepoints collection + + :param date_time: the datetime key for specific Timepoint in the Colony timepoints collection + """ + del self.timepoints[date_time] + + def update_timepoint(self, timepoint_original: Timepoint, timepoint_new: Timepoint): + """ + Replace a Timepoint from the Colony timepoints collection with a new Timepoint + :param timepoint_original: a Timepoint that exists in the Colony timepoints collection + :param timepoint_new: a Timepoint object to replace the existing Timepoint + """ + self.timepoints[timepoint_original.date_time] = timepoint_new + + @staticmethod + def __circularity(area: float, perimeter: float) -> float: + """ + Calculate how closely the shape of an object approaches that of a mathematically perfect circle + + A mathematically perfect circle has a circularity of 1 + + :param area: the size of the region enclosed by the perimeter + :param perimeter: the total distance along the edge of a shape + :returns: a ratio of area to perimiter as a float + """ + return (4 * pi * area) / (perimeter * perimeter) + + @staticmethod + def __local_doubling_time( + index: int, + x_pts: list, + y_pts: List[float], + window: int = 10 + ) -> Union[List[datetime], List[float]]: + """ + Calculate the doubling times over the specified number of sequence points + + :param index: the index key for the starting point in both x_pts and y_pts + :param x_pts: a list of x-axis data points, usually DateTimes + :param y_pts: a list of y-axis data points + :param window: the number of time points to calculate the doubling times over + :returns: a list of doubling times local to the specified timepoint window + """ x1 = x_pts[index] y1 = y_pts[index] x2 = x_pts[index + window] @@ -182,7 +251,12 @@ def __local_doubling_time(self, index, x_pts, y_pts, window = 10): return 0 -def timepoints_from_image(image_segmented, time_point, elapsed_minutes, image = None): +def timepoints_from_image( + image_segmented: ndarray, + time_point: datetime, + elapsed_minutes: int, + image: ndarray = None +) -> List[Colony.Timepoint]: """ Create Timepoint objects from a segemented image @@ -234,7 +308,10 @@ def timepoints_from_image(image_segmented, time_point, elapsed_minutes, image = return colonies -def colonies_from_timepoints(timepoints, distance_tolerance = 1): +def colonies_from_timepoints( + timepoints: List["Timepoint"], + distance_tolerance: float = 1 +) -> List[Colony]: """ Create a dictionary of Colony objects from Timepoint data @@ -254,7 +331,7 @@ def colonies_from_timepoints(timepoints, distance_tolerance = 1): timepoints, max_distance = distance_tolerance, axis = 0 - ) + ) # Then split the groups further by column values for timepoint_group in center_groups: @@ -262,7 +339,7 @@ def colonies_from_timepoints(timepoints, distance_tolerance = 1): timepoint_group, max_distance = distance_tolerance, axis = 1 - ) + ) colony_centers.extend(group) # Create a colony object for each group of centres @@ -275,7 +352,11 @@ def colonies_from_timepoints(timepoints, distance_tolerance = 1): return colonies -def group_timepoints_by_center(timepoints, max_distance = 1, axis = 0): +def group_timepoints_by_center( + timepoints: List[Colony.Timepoint], + max_distance: float = 1, + axis: int = 0 +) -> List[List[Colony]]: """ Split a list of Timepoint objects into sub groups Compares difference in values along a specified axis diff --git a/src/colonyscanalyser/file_access.py b/src/colonyscanalyser/file_access.py index ce1dafb..507347d 100644 --- a/src/colonyscanalyser/file_access.py +++ b/src/colonyscanalyser/file_access.py @@ -1,4 +1,5 @@ -from enum import Enum +from typing import List +from enum import Enum def file_exists(file_path): @@ -19,6 +20,19 @@ def file_exists(file_path): return False +def file_safe_name(file_name: List[str], separator: str = "_") -> str: + """ + Converts a list of string to a safe file name + + :param file_name: a list of strings that make up the complete filename + :param separator: a character to place in between the items of file_name + :returns: a new filename string + """ + safe_names = [val.replace(" ", separator) for val in file_name] + + return separator.join(filter(None, safe_names)) + + def get_files_by_type(path, file_extensions = ["*"]): """ Get a list of path objects of a given filetype(s) @@ -207,7 +221,7 @@ def save_to_csv(data, headers, save_path, delimiter = ","): outfile, delimiter = delimiter, fieldnames = headers - ) + ) writer.writeheader() data = [data] else: diff --git a/src/colonyscanalyser/geometry.py b/src/colonyscanalyser/geometry.py new file mode 100644 index 0000000..a73ffa3 --- /dev/null +++ b/src/colonyscanalyser/geometry.py @@ -0,0 +1,102 @@ +from typing import Union, Tuple +from math import pi + + +class Shape: + """ + An abstract class to provide the fundamental properties of a surface + """ + @property + def area(self) -> float: + raise NotImplementedError("This property must be implemented in a derived class") + + @property + def center(self) -> Union[Tuple[float, float], Tuple[float, float, float]]: + try: + return self.__center + except AttributeError: + return None + + @center.setter + def center(self, val: Union[Tuple[float, float], Tuple[float, float, float]]): + self.__center = val + + @property + def depth(self) -> float: + try: + return self.__depth + except AttributeError: + return 0 + + @depth.setter + def depth(self, val: float): + self.__depth = val + + @property + def height(self) -> float: + try: + return self.__height + except AttributeError: + return 0 + + @height.setter + def height(self, val: float): + self.__height = val + + @property + def perimeter(self) -> float: + raise NotImplementedError("This property must be implemented in a derived class") + + @property + def width(self) -> float: + try: + return self.__width + except AttributeError: + return 0 + + @width.setter + def width(self, val: float): + self.__width = val + + +class Circle(Shape): + """ + An object to generate the properties of a circle + """ + def __init__(self, diameter: float): + self.diameter = diameter + + @property + def area(self) -> float: + return pi * self.radius * self.radius + + @property + def circumference(self) -> float: + return self.perimeter + + @property + def diameter(self) -> float: + return self.__diameter + + @diameter.setter + def diameter(self, val: float): + if val < 0: + raise ValueError("The diameter must be a number greater than zero") + + self.__diameter = val + + @property + def height(self) -> float: + return self.diameter + + @property + def perimeter(self) -> float: + return pi * self.diameter + + @property + def radius(self) -> float: + return self.diameter / 2 + + @property + def width(self) -> float: + return self.diameter \ No newline at end of file diff --git a/src/colonyscanalyser/image_file.py b/src/colonyscanalyser/image_file.py new file mode 100644 index 0000000..93632f1 --- /dev/null +++ b/src/colonyscanalyser/image_file.py @@ -0,0 +1,189 @@ +from typing import Optional, List +from pathlib import Path +from datetime import datetime +from re import search +from numpy import ndarray +from skimage.io import imread +from skimage.color import rgb2gray +from .base import IdentifiedCollection, Unique, TimeStampElapsed +from .file_access import file_exists + + +class ImageFile(Unique, TimeStampElapsed): + """ + An object to hold information about, and provide access to, a timestamped image file + """ + def __init__( + self, + file_path: Path, + timestamp: datetime = None, + timestamp_initial: datetime = None, + cache_image: bool = False + ): + super(ImageFile, self).__init__() + self.file_path = file_path + + if timestamp is None: + timestamp = self.timestamp_from_string(str(self.file_path.name)) + if timestamp_initial is None: + timestamp_initial = timestamp + + self.timestamp = timestamp + self.timestamp_initial = timestamp_initial + self.cache_image = cache_image + self.__image = None + if self.cache_image: + self.__image = ImageFile.__load_image(self.file_path) + + def __enter__(self): + # Load and cache image ready for use + if self.__image is None: + self.__image = ImageFile.__load_image(self.file_path) + + return self + + def __exit__(self, exception_type, exception_value, traceback): + # Remove cached images, unless required + if not self.cache_image: + self.__image = None + + @property + def cache_image(self) -> bool: + return self.__cache_image + + @cache_image.setter + def cache_image(self, val: bool): + self.__cache_image = val + + @property + def image(self) -> ndarray: + if self.cache_image and self.__image is not None: + return self.__image.copy() + else: + return ImageFile.__load_image(self.file_path) + + @property + def image_gray(self) -> ndarray: + return rgb2gray(self.image) + + @property + def file_path(self) -> Path: + return self.__file_path + + @file_path.setter + def file_path(self, val: Path): + if not isinstance(val, Path): + val = Path(val) + + if not file_exists(val): + raise FileNotFoundError(f"The image file could not be found: {val}") + + self.__file_path = val + + @staticmethod + def timestamp_from_exif(image_file: Path) -> Optional[datetime]: + raise NotImplementedError() + + @staticmethod + def timestamp_from_string( + search_string: str, + pattern: str = + "(?P\\d{4}).?(?P[0-1][0-9]).?(?P[0-3][0-9]).?(?P[0-2][0-9]).?(?P[0-5][0-9])" + ) -> Optional[datetime]: + """ + Attempts to read a datetime value from a string + + Requires a regex pattern with the following named pattern groups: + year, month, day, hour, minute + + :param search_string: a string to check against the regex pattern + :param pattern: a regex pattern used to match the datetime + :returns: a datetime parsed from the string, if successful + """ + if not len(search_string) > 0 or not len(pattern) > 0: + raise ValueError("The search string or pattern must not be empty") + + result = search(pattern, search_string) + if result: + return datetime( + year = int(result.groupdict()["year"]), + month = int(result.groupdict()["month"]), + day = int(result.groupdict()["day"]), + hour = int(result.groupdict()["hour"]), + minute = int(result.groupdict()["minute"]) + ) + else: + return None + + @staticmethod + def __load_image(file_path: Path, as_gray: bool = False, plugin: str = "pil") -> ndarray: + return imread(str(file_path), as_gray = as_gray, plugin = plugin) + + +class ImageFileCollection(IdentifiedCollection): + """ + Holds a collection of ImageFiles + """ + @IdentifiedCollection.items.getter + def items(self) -> List[ImageFile]: + return sorted(self._IdentifiedCollection__items, key = lambda item: item.timestamp) + + @property + def file_paths(self) -> List[datetime]: + return [image_file.file_path for image_file in self.items] + + @property + def timestamps(self) -> List[datetime]: + return [image_file.timestamp for image_file in self.items] + + @property + def timestamps_initial(self) -> List[datetime]: + return [image_file.timestamp_initial for image_file in self.items] + + @timestamps_initial.setter + def timestamps_initial(self, val: datetime): + for image_file in self.items: + image_file.timestamp_initial = val + + @property + def timestamps_elapsed(self) -> List[datetime]: + return [image_file.timestamp_elapsed for image_file in self.items] + + @property + def timestamps_elapsed_hours(self) -> List[float]: + return [image_file.timestamp_elapsed_hours for image_file in self.items] + + @property + def timestamps_elapsed_minutes(self) -> List[int]: + return [image_file.timestamp_elapsed_minutes for image_file in self.items] + + @property + def timestamps_elapsed_seconds(self) -> List[int]: + return [image_file.timestamp_elapsed_seconds for image_file in self.items] + + def add( + self, + file_path: Path, + timestamp: datetime = None, + timestamp_initial: datetime = None, + cache_image: bool = False + ) -> ImageFile: + """ + Create a new ImageFile and append it to the collection + + :param file_path: a Path object representing the image location + :param timestamp: a datetime associated with the image + :param timestamp_initial: a starting datetime used to calculate elapsed timestamps + :param cache_image: load the image dynamically from file, or store in memory + :returns: the new ImageFile instance + """ + image_file = ImageFile( + file_path = file_path, + timestamp = timestamp, + timestamp_initial = timestamp_initial, + cache_image = cache_image + ) + + self.append(image_file) + + return image_file diff --git a/src/colonyscanalyser/imaging.py b/src/colonyscanalyser/imaging.py index 510acfd..191df0f 100644 --- a/src/colonyscanalyser/imaging.py +++ b/src/colonyscanalyser/imaging.py @@ -73,7 +73,7 @@ def crop_image(image, crop_shape, center = None): if any(x < 0 for x in crop_shape) or any(not isinstance(x, int) for x in crop_shape) or len(image.shape) < len(crop_shape): raise ValueError( f"The crop shape ({crop_shape}) must be positive integers and the same dimensions as the image to crop" - ) + ) if crop_shape > img.shape: raise ValueError(f"The crop shape ({crop_shape}) cannot be larger than the image ({image.shape}) to crop") @@ -174,7 +174,7 @@ def get_image_circles(image, circle_radius, circle_count = -1, search_radius = 0 min_xdistance = circle_radius, min_ydistance = circle_radius # total_num_peaks = circle_count - ) + ) # Temporary helper function until hough_circle_peaks respects min distances cx, cy, radii = circles_radius_median(cx, cy, radii, circle_count) diff --git a/src/colonyscanalyser/main.py b/src/colonyscanalyser/main.py index 707a835..d8827da 100644 --- a/src/colonyscanalyser/main.py +++ b/src/colonyscanalyser/main.py @@ -1,6 +1,7 @@ # System modules import sys import argparse +from typing import Union, Dict, List, Tuple from pathlib import Path from datetime import datetime from distutils.util import strtobool @@ -9,7 +10,7 @@ from functools import partial # Third party modules -from skimage.io import imread +from numpy import ndarray # Local modules from colonyscanalyser import ( @@ -18,79 +19,17 @@ imaging, plots ) +from .image_file import ImageFile, ImageFileCollection +from .plate import Plate, PlateCollection from .colony import Colony, timepoints_from_image, colonies_from_timepoints, timepoints_from_image -def get_plate_directory(parent_path, row, col, create_dir = True): - """ - Determine the directory path for a specified plate - - Can create the directory if needed - - :param parent_path: a path object - :param row: a lattice co-ordinate row - :param col: a lattice co-ordinate column - :param create_dir: specify if the directory should be created - :returns: a path object for the specified plate - """ - - child_path = '_'.join(['row', str(row), 'col', str(col)]) - if create_dir: - return file_access.create_subdirectory(parent_path, child_path) - else: - return parent_path.joinpath(child_path) - - -def get_image_timestamps(image_paths, elapsed_minutes = False): - """ - Get timestamps from a list of images - - Assumes images have a file name with as timestamp - Timestamps should be in YYYYMMDD_HHMM format - - :param images: a list of image file path objects - :param elapsed_minutes: return timestamps as elapsed integer minutes - :returns: a list of timestamps - """ - time_points = list() - - # Get date and time information from filenames - dates = [str(image.name[:-8].split("_")[-2]) for image in image_paths] - times = [str(image.name[:-4].split("_")[-1]) for image in image_paths] - - # Convert string timestamps to Python datetime objects - for i, date in enumerate(dates): - time_points.append(datetime.combine(datetime.strptime(date, "%Y%m%d"), datetime.strptime(times[i], "%H%M").time())) - - if elapsed_minutes: - # Store time points as elapsed minutes since start - time_points_elapsed = list() - for time_point in time_points: - time_points_elapsed.append(int((time_point - time_points[0]).total_seconds() / 60)) - time_points = time_points_elapsed - - return time_points - - -def get_plate_images(image, plate_coordinates, edge_cut = 100): - """ - Split image into lattice subimages and delete background - - :param img: a black and white image as a numpy array - :param plate_coordinates: a list of centers and radii - :param edge_cut: a radius, in pixels, to remove from the outer edge of the plate - :returns: a list of plate images - """ - plates = [] - - for coordinate in plate_coordinates: - center, radius = coordinate - plates.append(imaging.cut_image_circle(image, center, radius - edge_cut)) - - return plates - - -def segment_image(plate_image, plate_mask, plate_noise_mask, area_min = 5): +def segment_image( + plate_image: ndarray, + plate_mask: ndarray, + plate_noise_mask: ndarray, + area_min: float = 5 +) -> ndarray: """ Attempts to find and label all colonies on a plate @@ -136,17 +75,20 @@ def segment_image(plate_image, plate_mask, plate_noise_mask, area_min = 5): return colonies -def image_file_to_timepoints(image_path, plate_coordinates, plate_images_mask, time_point, elapsed_minutes, edge_cut, plot_path = None): +def image_file_to_timepoints( + image_file: ndarray, + plates: PlateCollection, + plate_images_mask: List[ndarray], + plot_path: Path = None +) -> Dict[int, List[Colony.Timepoint]]: """ Get Timepoint object data from a plate image Lists the results in a dict with the plate number as the key - :param image_path: a Path object representing an image - :param plate_coordinates: a list of (row, column) tuple plate centres + :param image_file: an ImageFile object + :param plates: a PlateCollection of Plate instances :param plate_images_mask: a list of plate images to use as noise masks - :param time_point: a Datetime object - :param elapsed_minutes: the number of integer minutes since starting :param plot_path: a Path directory to save the segmented image plot :returns: a Dict of lists, each containing Timepoint objects """ @@ -155,21 +97,19 @@ def image_file_to_timepoints(image_path, plate_coordinates, plate_images_mask, t plate_timepoints = defaultdict(list) - # Load image - img = imread(str(image_path), as_gray = False) - # Split image into individual plates - plate_images = get_plate_images(img, plate_coordinates, edge_cut = edge_cut) + plate_images = plates.slice_plate_image(image_file.image) - for j, plate_image in enumerate(plate_images): + for plate_id, plate_image in plate_images.items(): plate_image_gray = rgb2gray(plate_image) # Segment each image - plate_images[j] = segment_image(plate_image_gray, plate_image_gray > 0, plate_images_mask[j], area_min = 8) + plate_images[plate_id] = segment_image(plate_image_gray, plate_image_gray > 0, plate_images_mask[plate_id], area_min = 8) # Create Timepoint objects for each plate - plate_timepoints[j + 1].extend(timepoints_from_image(plate_images[j], time_point, elapsed_minutes, image = plate_image)) + plate_timepoints[plate_id].extend(timepoints_from_image(plate_images[plate_id], image_file.timestamp, image_file.timestamp_elapsed_minutes, image = plate_image)) # Save segmented image plot, if required if plot_path is not None: - plots.plot_plate_segmented(plate_image_gray, plate_images[j], time_point, plot_path) + save_path = file_access.create_subdirectory(plot_path, f"plate{plate_id}") + plots.plot_plate_segmented(plate_image_gray, plate_images[plate_id], image_file.timestamp, save_path) return plate_timepoints @@ -179,35 +119,38 @@ def main(): parser = argparse.ArgumentParser( description = "An image analysis tool for measuring microorganism colony growth", formatter_class = argparse.ArgumentDefaultsHelpFormatter - ) + ) parser.add_argument("path", type = str, help = "Image files location", default = None) - parser.add_argument("-v", "--verbose", type = int, default = 1, - help = "Information output level") parser.add_argument("-dpi", "--dots_per_inch", type = int, default = 2540, help = "The image DPI (dots per inch) setting") - parser.add_argument("--plate_size", type = int, default = 100, - help = "The plate diameter, in millimetres") + parser.add_argument("-mp", "--multiprocessing", type = strtobool, default = True, + help = "Enables use of more CPU cores, faster but more resource intensive") + parser.add_argument("-p", "--plots", type = int, default = 1, + help = "The detail level of plot images to store on disk") parser.add_argument("--plate_edge_cut", type = int, default = 60, help = "The radius from the plate edge to remove, in pixels") + parser.add_argument("--plate_labels", type = str, nargs = "*", default = list(), + help = "A list of labels to identify each plate. Plates are ordered from top left, in rows. Example usage: --plate_labels plate1 plate2") parser.add_argument("--plate_lattice", type = int, nargs = 2, default = (3, 2), metavar = ("ROW", "COL"), help = "The row and column co-ordinate layout of plates. Example usage: --plate_lattice 3 3") - parser.add_argument("--save_plots", type = int, default = 1, - help = "The detail level of plot images to store on disk") + parser.add_argument("--plate_size", type = int, default = 100, + help = "The plate diameter, in millimetres") parser.add_argument("--use_cached_data", type = strtobool, default = False, help = "Allow use of previously calculated data") - parser.add_argument("-mp", "--multiprocessing", type = strtobool, default = True, - help = "Enables use of more CPU cores, faster but more resource intensive") + parser.add_argument("-v", "--verbose", type = int, default = 1, + help = "Information output level") args = parser.parse_args() BASE_PATH = args.path - VERBOSE = args.verbose - PLATE_SIZE = imaging.mm_to_pixels(args.plate_size, dots_per_inch = args.dots_per_inch) - PLATE_LATTICE = tuple(args.plate_lattice) + PLOTS = args.plots PLATE_EDGE_CUT = args.plate_edge_cut - SAVE_PLOTS = args.save_plots + PLATE_LABELS = {plate_id: label for plate_id, label in enumerate(args.plate_labels, start = 1)} + PLATE_LATTICE = tuple(args.plate_lattice) + PLATE_SIZE = imaging.mm_to_pixels(args.plate_size, dots_per_inch = args.dots_per_inch) USE_CACHED = args.use_cached_data + VERBOSE = args.verbose POOL_MAX = 1 if args.multiprocessing: POOL_MAX = cpu_count() @@ -225,78 +168,95 @@ def main(): if VERBOSE >= 1: print(f"Working directory: {BASE_PATH}") - # Find images in working directory - image_formats = ["tif", "tiff", "png"] - image_files = file_access.get_files_by_type(BASE_PATH, image_formats) - - # Check if images have been loaded - if len(image_files) > 0: - if VERBOSE >= 1: - print(f"{len(image_files)} images found") - else: - raise IOError(f"No images could be found in the supplied folder path." - " Images are expected in these formats: {image_formats}") - - # Get date and time information from filenames - time_points = get_image_timestamps(image_files) - time_points_elapsed = get_image_timestamps(image_files, elapsed_minutes = True) - if len(time_points) != len(image_files) or len(time_points) != len(image_files): - raise IOError("Unable to load timestamps from all image filenames." - " Please check that images have a filename with YYYYMMDD_HHMM timestamps") - # Check if processed image data is already stored and can be loaded segmented_image_data_filename = "cached_data" + plates = None if USE_CACHED: if VERBOSE >= 1: print("Attempting to load cached data") - plate_colonies = file_access.load_file( + plates = file_access.load_file( BASE_PATH.joinpath("data", segmented_image_data_filename), file_access.CompressionMethod.LZMA, pickle = True - ) + ) # Check that segmented image data has been loaded for all plates - if (VERBOSE >= 1 and plate_colonies is not None and - len(plate_colonies) == utilities.coordinate_to_index_number(PLATE_LATTICE)): + # Also that data is not from an older format (< v0.4.0) + if ( + VERBOSE >= 1 and plates is not None + and plates.count == PlateCollection.coordinate_to_index(PLATE_LATTICE) + and isinstance(plates.items[0], Plate) + ): print("Successfully loaded cached data") + image_files = None else: print("Unable to load cached data, starting image processing") - plate_colonies = None + plates = None + + if not USE_CACHED or plates is None: + # Find images in working directory + image_formats = ["tif", "tiff", "png"] + image_paths = file_access.get_files_by_type(BASE_PATH, image_formats) + + # Store images as ImageFile objects + # Timestamps are automatically read from filenames + image_files = ImageFileCollection() + for image_path in image_paths: + image_files.add( + file_path = image_path, + timestamp = None, + timestamp_initial = None, + cache_image = False + ) - # Process images to Timepoint data objects - plate_coordinates = None - if not USE_CACHED or plate_colonies is None: - plate_coordinates = None + # Check if images have been loaded and timestamps could be read + if image_files.count > 0: + if VERBOSE >= 1: + print(f"{image_files.count} images found") + else: + raise IOError(f"No images could be found in the supplied folder path." + " Images are expected in these formats: {image_formats}") + if image_files.count != len(image_files.timestamps): + raise IOError("Unable to load timestamps from all image filenames." + " Please check that images have a filename with YYYYMMDD_HHMM timestamps") + + # Set intial timestamp + image_files.timestamps_initial = image_files.timestamps[0] + + # Process images to Timepoint data objects plate_images_mask = None plate_timepoints = defaultdict(list) if VERBOSE >= 1: print("Preprocessing images to locate plates") - # Load the first image to get plate coordinate and mask - with image_files[0] as image_file: - # Load image - img = imread(str(image_file), as_gray = True) - + # Load the first image to get plate coordinates and mask + with image_files.items[0] as image_file: # Only find centers using first image. Assume plates do not move - if plate_coordinates is None: + if plates is None: if VERBOSE >= 2: - print(f"Locating plate centres in image: {image_file}") - plate_coordinates = imaging.get_image_circles( - img, - int(PLATE_SIZE / 2), - circle_count = utilities.coordinate_to_index_number(PLATE_LATTICE), - search_radius = 50 - ) + print(f"Locating plate centres in image: {image_file.file_path}") + + # Create new Plate instances to store the information + plates = PlateCollection.from_image( + shape = PLATE_LATTICE, + image = image_file.image_gray, + diameter = PLATE_SIZE, + search_radius = 50, + edge_cut = PLATE_EDGE_CUT, + labels = PLATE_LABELS + ) + + if not plates.count > 0: + print(f"Unable to locate plates in image: {image_file.file_path}") + print(f"Processing unable to continue") + sys.exit() + if VERBOSE >= 3: - for k, center in enumerate(plate_coordinates, start = 1): - print(f"Plate {k} center: {center[0]}") - - # Split image into individual plates - plate_images = get_plate_images(img, plate_coordinates, edge_cut = PLATE_EDGE_CUT) + for plate in plates.items: + print(f"Plate {plate.id} center: {plate.center}") # Use the first plate image as a noise mask - if plate_images_mask is None: - plate_images_mask = plate_images + plate_images_mask = plates.slice_plate_image(image_file.image_gray) if VERBOSE >= 1: print("Processing colony data from all images") @@ -307,17 +267,17 @@ def progress_update(result, progress): processes = list() with Pool(processes = POOL_MAX) as pool: - for i, image_file in enumerate(image_files): + for i, image_file in enumerate(image_files.items): # Allow args to be passed to callback function - callback_function = partial(progress_update, progress = ((i + 1) / len(image_files)) * 100) + callback_function = partial(progress_update, progress = ((i + 1) / image_files.count) * 100) # Create processes processes.append(pool.apply_async( image_file_to_timepoints, - args = (image_file, plate_coordinates, plate_images_mask, time_points[i], time_points_elapsed[i], PLATE_EDGE_CUT), + args = (image_file, plates, plate_images_mask), kwds = {"plot_path" : None}, callback = callback_function - )) + )) # Consolidate the results to a single dict for process in processes: @@ -335,25 +295,27 @@ def progress_update(result, progress): print("Calculating colony properties") # Group Timepoints by centres and create Colony objects - plate_colonies = dict() - for plate_id, plate in plate_timepoints.items(): - plate_colonies[plate_id] = {colony.id : colony for colony in colonies_from_timepoints(plate)} + for plate_id, plate_timepoints in plate_timepoints.items(): + plate = plates.get_item(plate_id) + plate.items = colonies_from_timepoints(plate_timepoints, distance_tolerance = 8) + if VERBOSE >= 3: + print(f"{plate.count} colonies located on plate {plate.id}, before filtering") # Filter colonies to remove noise, background objects and merged colonies - plate_colonies[plate_id] = dict(filter(lambda item: + plate.items = list(filter(lambda colony: # Remove objects that do not have sufficient data points, usually just noise - len(item[1].timepoints) > len(time_points) * 0.2 and + len(colony.timepoints) > image_files.count * 0.2 and # Remove object that do not show growth, these are not colonies - item[1].growth_rate > 1 and + colony.growth_rate > 1 and # Colonies that appear with a large initial area are most likely merged colonies, not new colonies - item[1].timepoint_first.area < 50, - plate_colonies[plate_id].items() - )) + colony.timepoint_first.area < 50, + plate.items + )) if VERBOSE >= 1: - print(f"Colony data stored for {len(plate_colonies[plate_id])} colonies on plate {plate_id}") + print(f"Colony data stored for {plate.count} colonies on plate {plate.id}") - if not any([len(plate) for plate in plate_colonies.values()]): + if not any([plate.count for plate in plates.items]): if VERBOSE >= 1: print("Unable to locate any colonies in the images provided") print(f"ColonyScanalyser analysis completed for: {BASE_PATH}") @@ -362,7 +324,7 @@ def progress_update(result, progress): # Store pickled data to allow quick re-use save_path = file_access.create_subdirectory(BASE_PATH, "data") save_path = save_path.joinpath(segmented_image_data_filename) - save_status = file_access.save_file(save_path, plate_colonies, file_access.CompressionMethod.LZMA) + save_status = file_access.save_file(save_path, plates, file_access.CompressionMethod.LZMA) if VERBOSE >= 1: if save_status: print(f"Cached data saved to {save_path}") @@ -374,82 +336,38 @@ def progress_update(result, progress): print("Saving data to CSV") save_path = BASE_PATH.joinpath("data") - for plate_id, plate in plate_colonies.items(): - headers = [ - "Colony ID", - "Time of appearance", - "Time of appearance (elapsed minutes)", - "Center point averaged (row, column)", - "Colour averaged name", - "Colour averaged (R,G,B)", - "Growth rate average", - "Growth rate", - "Doubling time average (minutes)", - "Doubling times (minutes)", - "First detection (elapsed minutes)", - "First center point (row, column)", - "First area (pixels)", - "First diameter (pixels)", - "Final detection (elapsed minutes)", - "Final center point (row, column)", - "Final area (pixels)", - "Final diameter (pixels)" - ] - + for plate in plates.items: # Save data for all colonies on one plate - file_access.save_to_csv( - plate.values(), - headers, - save_path.joinpath(f"plate{plate_id}_colonies") - ) + plate.colonies_to_csv(save_path) # Save data for each colony on a plate - headers = [ - "Colony ID", - "Date/Time", - "Elapsed time (minutes)", - "Area (pixels)", - "Center (row, column)", - "Diameter (pixels)", - "Perimeter (pixels)", - "Color average (R,G,B)" - ] - colony_timepoints = list() - for colony_id, colony in plate.items(): - for timepoint in colony.timepoints.values(): - # Unpack timepoint properties to a flat list - colony_timepoints.append([colony_id, *timepoint]) - - file_access.save_to_csv( - colony_timepoints, - headers, - save_path.joinpath(f"plate{plate_id}_colony_timepoints") - ) - - # Plot colony growth curves for all plates - if SAVE_PLOTS >= 1: + plate.colonies_timepoints_to_csv(save_path) + + # Only generate plots when working with original images + # Can't guarantee that the original images and full list of time points + # will be available when using cached data + if image_files is not None: + # Plots for all plates + if PLOTS >= 1: + if VERBOSE >= 1: + print("Saving plots") + save_path = file_access.create_subdirectory(BASE_PATH, "plots") + plots.plot_growth_curve(plates.items, image_files.timestamps_elapsed_minutes, save_path) + plots.plot_appearance_frequency(plates.items, image_files.timestamps_elapsed_minutes, save_path) + plots.plot_appearance_frequency(plates.items, image_files.timestamps_elapsed_minutes, save_path, bar = True) + plots.plot_doubling_map(plates.items, image_files.timestamps_elapsed_minutes, save_path) + plots.plot_colony_map(image_files.items[-1].image, plates.items, save_path) + + # Plot colony growth curves, ID map and time of appearance for each plate + if PLOTS >= 2: + for plate in plates.items: + save_path_plate = file_access.create_subdirectory(save_path, file_access.file_safe_name([f"plate{plate.id}", plate.name])) + plots.plot_growth_curve([plate], image_files.timestamps_elapsed_minutes, save_path_plate) + plots.plot_appearance_frequency([plate], image_files.timestamps_elapsed_minutes, save_path_plate) + plots.plot_appearance_frequency([plate], image_files.timestamps_elapsed_minutes, save_path_plate, bar = True) + else: if VERBOSE >= 1: - print("Saving plots") - save_path = file_access.create_subdirectory(BASE_PATH, "plots") - plots.plot_growth_curve(plate_colonies, time_points_elapsed, save_path) - plots.plot_appearance_frequency(plate_colonies, time_points_elapsed, save_path) - plots.plot_appearance_frequency(plate_colonies, time_points_elapsed, save_path, bar = True) - plots.plot_doubling_map(plate_colonies, time_points_elapsed, save_path) - - # Only generate plate map when working with original images - # Can't guarantee that the original image will be available when using cached data - if plate_coordinates is not None: - plots.plot_colony_map(imread(image_files[-1], as_gray = False), {k: v for k, v in enumerate(plate_coordinates, start = 1)}, plate_colonies, save_path, edge_cut = PLATE_EDGE_CUT ) - - # Plot colony growth curves, ID map and time of appearance for each plate - if SAVE_PLOTS >= 2: - for plate_id, plate in plate_colonies.items(): - row, col = utilities.index_number_to_coordinate(plate_id, PLATE_LATTICE) - save_path_plate = get_plate_directory(save_path, row, col, create_dir = True) - plate_item = {plate_id : plate} - plots.plot_growth_curve(plate_item, time_points_elapsed, save_path_plate) - plots.plot_appearance_frequency(plate_item, time_points_elapsed, save_path_plate) - plots.plot_appearance_frequency(plate_item, time_points_elapsed, save_path_plate, bar = True) + print("Unable to generate plots from cached data. Run analysis on original images to generate plot images") if VERBOSE >= 1: print(f"ColonyScanalyser analysis completed for: {BASE_PATH}") diff --git a/src/colonyscanalyser/plate.py b/src/colonyscanalyser/plate.py new file mode 100644 index 0000000..1ca8779 --- /dev/null +++ b/src/colonyscanalyser/plate.py @@ -0,0 +1,361 @@ +from __future__ import annotations +from typing import Union, Dict, List, Tuple +from collections.abc import Collection +from pathlib import Path, PurePath +from numpy import ndarray +from .base import Identified, IdentifiedCollection, Named +from .geometry import Circle +from .file_access import save_to_csv + + +class Plate(Identified, IdentifiedCollection, Named, Circle): + """ + An object to hold information about an agar plate and a collection of Colony objects + """ + def __init__( + self, + id: int, + diameter: float, + edge_cut: float = 0, + name: str = "", + center: Union[Tuple[float, float], Tuple[float, float, float]] = None, + colonies: list = None + ): + self.id = id + self.diameter = diameter + + # Can't set argument default otherwise it is shared across all class instances + if center is None: + center = tuple() + if colonies is None: + colonies = list() + + # Set property defaults + self.center = center + self.items = colonies + self.edge_cut = edge_cut + self.name = name + + def __iter__(self): + return iter([ + self.id, + self.name, + self.center, + self.diameter, + self.area, + self.edge_cut, + self.count + ]) + + @property + def center(self) -> Union[Tuple[float, float], Tuple[float, float, float]]: + return self.__center + + @center.setter + def center(self, val: Union[Tuple[float, float], Tuple[float, float, float]]): + self.__center = val + + @property + def edge_cut(self) -> float: + return self.__edge_cut + + @edge_cut.setter + def edge_cut(self, val: float): + self.__edge_cut = val + + def colonies_to_csv(self, save_path: Path, headers: List[str] = None) -> Path: + """ + Output the data from the colonies collection to a CSV file + + :param save_path: the location to save the CSV data file + :param headers: a list of strings to use as column headers + :returns: a Path representing the new file, if successful + """ + from .file_access import file_safe_name + + if headers is None: + headers = [ + "Colony ID", + "Time of appearance", + "Time of appearance (elapsed minutes)", + "Center point averaged (row, column)", + "Colour averaged name", + "Colour averaged (R,G,B)", + "Growth rate average", + "Growth rate", + "Doubling time average (minutes)", + "Doubling times (minutes)", + "First detection (elapsed minutes)", + "First center point (row, column)", + "First area (pixels)", + "First diameter (pixels)", + "Final detection (elapsed minutes)", + "Final center point (row, column)", + "Final area (pixels)", + "Final diameter (pixels)" + ] + + return self.__collection_to_csv( + save_path, + # "_".join(filter(None, [f"plate{str(self.id)}", self.name.replace(" ", "_"), "colonies"])), + file_safe_name([f"plate{str(self.id)}", self.name, "colonies"]), + self.items, + headers + ) + + def colonies_timepoints_to_csv(self, save_path: Path, headers: List[str] = None) -> Path: + """ + Output the data from the timepoints in the colonies collection to a CSV file + + :param save_path: the location to save the CSV data file + :param headers: a list of strings to use as column headers + :returns: a Path representing the new file, if successful + """ + from .file_access import file_safe_name + + if headers is None: + headers = [ + "Colony ID", + "Date/Time", + "Elapsed time (minutes)", + "Area (pixels)", + "Center (row, column)", + "Diameter (pixels)", + "Perimeter (pixels)", + "Color average (R,G,B)" + ] + + # Unpack timepoint properties to a flat list + colony_timepoints = list() + for colony in self.items: + for timepoint in colony.timepoints.values(): + colony_timepoints.append([colony.id, *timepoint]) + + return self.__collection_to_csv( + save_path, + # "_".join(filter(None, [f"plate{str(self.id)}", self.name.replace(" ", "_"), "colony", "timepoints"])), + file_safe_name([f"plate{str(self.id)}", self.name, "colony", "timepoints"]), + colony_timepoints, + headers + ) + + def colonies_rename_sequential(self, start: int = 1) -> int: + """ + Update the ID numbers of all colonies in the plate colony collection + + :param start: the new initial ID number + :returns: the final ID number of the renamed sequence + """ + for i, colony in enumerate(self.items, start = start): + colony.id = i + + return i + + @staticmethod + def __collection_to_csv(save_path: Path, file_name: str, data: Collection, headers: List[str] = None) -> Path: + """ + Output the data from the timepoints in the colonies collection to a CSV file + + :param save_path: the location to save the CSV data file + :param file_name: the name of the CSV data file + :param data: a collection of iterables to output as rows to the CSV file + :param headers: a list of strings to use as column headers + :returns: a Path representing the new file, if successful + """ + # Check that a path has been specified and can be found + if not isinstance(save_path, Path): + save_path = Path(save_path) + if not save_path.exists() or str(PurePath(save_path)) == ".": + raise FileNotFoundError(f"The path '{str(save_path)}' could not be found. Please specify a different save path") + + return save_to_csv( + data, + headers, + save_path.joinpath(file_name) + ) + + +class PlateCollection(IdentifiedCollection): + """ + Holds a collection of Plates + """ + def __init__(self, plates: Collection = None, shape: Tuple[int, int] = None): + super(PlateCollection, self).__init__(plates) + if shape is None: + shape = tuple() + + self.shape = shape + + @property + def centers(self) -> Union[List[Tuple[float, float]], List[Tuple[float, float, float]]]: + return [plate.center for plate in self.items] + + @property + def shape(self) -> Tuple[int, int]: + return self.__shape + + @shape.setter + def shape(self, val: Tuple[int, int]): + if not PlateCollection.__is_valid_shape(val): + raise ValueError(f"{val} is not a valid shape. All values must be non-negative integers") + + self.__shape = val + + def add( + self, + id: int, + diameter: float, + edge_cut: float = 0, + name: str = "", + center: Union[Tuple[float, float], Tuple[float, float, float]] = None, + colonies: list = None + ) -> Plate: + """ + Create a new Plate and append it to the collection + + :param id: the integer ID number for the plate + :param diameter: the physical diameter of the plate, in millimeters + :param colonies: a collection of Colony objects contained in the Plate + :returns: the new Plate instance + """ + + plate = Plate( + id = id, + diameter = diameter, + edge_cut = edge_cut, + name = name, + center = center, + colonies = colonies + ) + + self.append(plate) + + return plate + + @classmethod + def from_image(cls, shape: Tuple[int, int], image: ndarray, diameter: float, **kwargs) -> PlateCollection: + """ + Create a new instance of PlateCollection from an image + + :param shape: row and column boundaries + :param image: an image containing a set of plates, as a numpy array + :returns: a new instance of PlateCollection + """ + plates = cls(shape = shape) + plates.plates_from_image(image = image, diameter = diameter, **kwargs) + + return plates + + def plates_from_image( + self, + image: ndarray, + diameter: float, + search_radius: float = 50, + edge_cut: float = 0, + labels: Dict[int, str] = dict() + ) -> List[Plate]: + """ + Create a collection of Plate instances from an image + + :param image: a grayscale image as a numpy array + :param diameter: the expected plate diameter, in pixels + :param search_radius: the distance, in pixels to search around the expected plate diameter + :param edge_cut: the radius to exclude from imaging analysis, in pixels + :param labels: a dict of labels for each plate, with the plate ID as a key + :returns: a list of Plate instances + """ + from .imaging import get_image_circles + + if not self.shape: + raise ValueError("The PlateCollection shape property is required, but has not been set") + + plate_coordinates = get_image_circles( + image, + int(diameter / 2), + circle_count = PlateCollection.coordinate_to_index(self.shape), + search_radius = search_radius + ) + + plates = list() + for plate_id, coord in enumerate(plate_coordinates, start = 1): + center, radius = coord + name = "" + if plate_id in labels: + name = labels[plate_id] + + plates.append(self.add( + id = plate_id, + diameter = radius * 2, + edge_cut = edge_cut, + name = name, + center = center + )) + + return plates + + def slice_plate_image(self, image: ndarray) -> Dict[int, ndarray]: + """ + Split an image into individual plate subimages and delete background + + Slices according to the Plate instances in the current collection + + :param image: an image as a numpy array + :returns: a doct of plate images with the plate ID number as the key + """ + from .imaging import cut_image_circle + images = dict() + + for plate in self.items: + images[plate.id] = cut_image_circle(image, center = plate.center, radius = plate.radius - plate.edge_cut) + + return images + + @staticmethod + def coordinate_to_index(coordinate: Tuple[int, int]) -> int: + """ + Find a positional index for a coordinate + + Starting along columns and then progressing down rows + + :param coordinate: a row, column coordinate tuple + :returns: a positional index number + """ + from numpy import prod + + if not PlateCollection.__is_valid_shape(coordinate): + raise ValueError( + f"The supplied coordinates, {coordinate}, are not valid. All values must be non-negative integers" + ) + + return prod(coordinate) + + @staticmethod + def index_to_coordinate(index: int, shape: Tuple[int, int]) -> Tuple[int, int]: + """ + Calculate row and column numbers for an item index + + Lattice coordinate and item index numbers are 1-based + + :param index: item index number + :param shape: row and column boundaries + :returns: row and column coordinate tuple + """ + if index < 1 or not PlateCollection.__is_valid_shape(shape): + raise ValueError("The supplied index or shape is not valid. All values must be non-negative integers") + + shape_row, shape_col = shape + + row = ((index - 1) // shape_col) + 1 + col = ((index - 1) % shape_col) + 1 + + if row > shape_row or col > shape_col: + raise IndexError("Index number is greater than the supplied shape size") + + return (row, col) + + @staticmethod + def __is_valid_shape(shape: Tuple[int, int]) -> bool: + return ( + all(shape) and + not any([(not isinstance(val, int) or val < 1) for val in shape]) + ) \ No newline at end of file diff --git a/src/colonyscanalyser/plots.py b/src/colonyscanalyser/plots.py index f2f865c..78dc847 100644 --- a/src/colonyscanalyser/plots.py +++ b/src/colonyscanalyser/plots.py @@ -1,18 +1,23 @@ +from typing import List +from datetime import datetime +from pathlib import Path +from numpy import ndarray import matplotlib.pyplot as plt import matplotlib.cm as cm +from matplotlib.axes import Axes from .utilities import average_dicts_values_by_key from .plotting import rc_to_xy, axis_minutes_to_hours +from .plate import Plate -def plot_colony_map(plate_image, plate_coordinates, plate_colonies, save_path, edge_cut = 0): +def plot_colony_map(plate_image: ndarray, plates: List[Plate], save_path: Path) -> Path: """ Saves original plate image with overlaid plate and colony IDs :param plate_image: the final timepoint image of all plates - :param plate_coordinates: a dictionary of centre and radii tuples - :param plate_colonies: a dictionary of Colony objects - :param save_path: a path object + :param plates: a PlateCollection of Plate instances + :param save_path: the directory to save the plot image :returns: a file path object if the plot was saved sucessfully """ from matplotlib import rcParams @@ -25,21 +30,21 @@ def plot_colony_map(plate_image, plate_coordinates, plate_colonies, save_path, e # Create a figure that takes up the full size of the image fig = plt.figure(figsize = figsize) ax = fig.add_axes([0, 0, 1, 1]) - ax.axis('off') + ax.axis("off") ax.imshow(plate_image) - for plate_id, plate in plate_colonies.items(): - (center_y, center_x), plate_radius = plate_coordinates[plate_id] + for plate in plates: + center_y, center_x = plate.center # Colony coordinates are relative to individual plate images # Calculate a correction factor to allow plotting on the original image - offset_y = center_y - plate_radius + edge_cut - offset_x = center_x - plate_radius + edge_cut + offset_y = center_y - plate.radius + plate.edge_cut + offset_x = center_x - plate.radius + plate.edge_cut # Label plates ax.annotate( - f"Plate #{plate_id}".upper(), - (center_x, center_y - plate_radius - (edge_cut * 1.4)), + f"Plate #{plate.id}".upper(), + (center_x, center_y - plate.radius - (plate.edge_cut * 1.4)), xycoords = "data", horizontalalignment = "center", verticalalignment = "center", @@ -47,11 +52,22 @@ def plot_colony_map(plate_image, plate_coordinates, plate_colonies, save_path, e backgroundcolor = "black", color = "white" ) + if len(plate.name) > 0: + ax.annotate( + plate.name, + (center_x, center_y - plate.radius - (plate.edge_cut * 0.6)), + xycoords = "data", + horizontalalignment = "center", + verticalalignment = "center", + fontsize = "32", + backgroundcolor = "black", + color = "white" + ) # Mark the detected boundary of the plate plate_circle = plt.Circle( (center_x, center_y), - radius = plate_radius, + radius = plate.radius, facecolor = "none", edgecolor = "purple", linewidth = "2.5", @@ -63,7 +79,7 @@ def plot_colony_map(plate_image, plate_coordinates, plate_colonies, save_path, e # Mark the measured area of the plate plate_circle_measured = plt.Circle( (center_x, center_y), - radius = plate_radius - edge_cut, + radius = plate.radius - plate.edge_cut, facecolor = "none", edgecolor = "white", linewidth = "1.5", @@ -73,7 +89,7 @@ def plot_colony_map(plate_image, plate_coordinates, plate_colonies, save_path, e ax.add_artist(plate_circle_measured) # Mark colony centres and ID numbers - for colony in plate.values(): + for colony in plate.items: x, y = rc_to_xy(colony.center) x = offset_x + x y = offset_y + y @@ -122,8 +138,7 @@ def plot_colony_map(plate_image, plate_coordinates, plate_colonies, save_path, e fontsize = "18" ) - image_path = "plate_map.png" - save_path = save_path.joinpath(image_path) + save_path = save_path.joinpath("plate_map.png") try: plt.savefig(str(save_path), format = "png") except Exception: @@ -133,14 +148,19 @@ def plot_colony_map(plate_image, plate_coordinates, plate_colonies, save_path, e return save_path -def plot_plate_segmented(plate_image, segmented_image, date_time, save_path): +def plot_plate_segmented( + plate_image: ndarray, + segmented_image: ndarray, + date_time: datetime, + save_path: Path +) -> Path: """ Saves processed plate images and corresponding segmented data plots :param plate_image: a black and white image as a numpy array :param segmented_image: a segmented and labelled image as a numpy array :param date_time: a datetime object - :param save_path: a path object + :param save_path: the directory to save the plot image :returns: a file path object if the plot was saved sucessfully """ from skimage.measure import regionprops @@ -159,7 +179,7 @@ def plot_plate_segmented(plate_image, segmented_image, date_time, save_path): color = "red", horizontalalignment = "center", verticalalignment = "center" - ) + ) plt.suptitle(f"Plate time point {date_time.strftime('%Y/%m/%d %H:%M')}") image_path = f"time_point_{date_time.strftime('%Y%m%d')}_{date_time.strftime('%H%M')}.png" @@ -173,50 +193,69 @@ def plot_plate_segmented(plate_image, segmented_image, date_time, save_path): return save_path -def plot_growth_curve(plates_dict, time_points_elapsed, save_path): +def plot_growth_curve(plates: List[Plate], time_points_elapsed: List[int], save_path: Path) -> Path: """ Growth curves for either a single plate, or all plates on the lattice + + :param plates: a list of Plate instances + :param time_points_elapsed: a list of elapsed time values + :param save_path: the directory to save the plot image + :returns: a file path object if the plot was saved sucessfully """ _, ax = plt.subplots() colormap = cm.get_cmap("plasma") - for plate_item in plates_dict.items(): - if len(plates_dict) > 1: + for plate in plates: + if len(plates) > 1: # Get a color from the colourmap - cm_scatter = colormap(0.2 + (0.65 - 0.2) * (plate_item[0] / len(plates_dict))) + cm_scatter = colormap(0.2 + (0.65 - 0.2) * (plate.id / len(plates))) cm_line = None else: cm_scatter = "Mediumpurple" cm_line = "Purple" # Add the growth curve plot for this plate - growth_curve(ax, plate_item, time_points_elapsed, cm_scatter, cm_line) + growth_curve(ax, plate, time_points_elapsed, cm_scatter, cm_line) lgd = ax.legend(loc = 'center right', fontsize = 8, bbox_to_anchor = (1.25, 0.5)) save_params = { "format": "png", "bbox_extra_artists": (lgd,), "bbox_inches": "tight" - } - - plt.ylim(ymin = 0) - plt.title("Colony growth") - plt.savefig(str(save_path.joinpath("growth_curve.png")), **save_params) + } - plt.close() + save_path = save_path.joinpath("growth_curve.png") + try: + plt.savefig(str(save_path), **save_params) + except Exception: + save_path = None + finally: + plt.close() + return save_path -def growth_curve(ax, plate_item, time_points_elapsed, scatter_color, line_color = None): +def growth_curve( + ax: Axes, + plate: Plate, + time_points_elapsed: List[int], + scatter_color: str, + line_color: str = None +): """ Add a growth curve scatter plot, with mean, to an axis + + :param ax: a Matplotlib Axes object to add a plot to + :param plate: a Plate instance + :param time_points_elapsed: a list of elapsed time values + :param scatter_color: a Colormap color + :param line_color: a Colormap color """ - plate_id, plate = plate_item areas_average = list() if line_color is None: line_color = scatter_color - for colony in plate.values(): + for colony in plate.items: # Map areas to a dictionary of all timepoints time_points_dict = dict.fromkeys(time_points_elapsed) for timepoint in colony.timepoints.values(): @@ -232,16 +271,16 @@ def growth_curve(ax, plate_item, time_points_elapsed, scatter_color, line_color marker = "o", s = 1, alpha = 0.25 - ) + ) # Plot the mean areas_averages = average_dicts_values_by_key(areas_average) ax.plot( *zip(*sorted(areas_averages.items())), color = line_color, - label = f"Plate {plate_id}", + label = f"Plate {plate.id}", linewidth = 2 - ) + ) # Format x-axis labels as integer hours ax.set_xticklabels(axis_minutes_to_hours(ax.get_xticks())) @@ -249,32 +288,38 @@ def growth_curve(ax, plate_item, time_points_elapsed, scatter_color, line_color ax.set_ylabel("Colony area (pixels)") -def plot_appearance_frequency(plates_dict, time_points_elapsed, save_path, bar = False): +def plot_appearance_frequency(plates: List[Plate], time_points_elapsed, save_path, bar = False) -> Path: """ Time of appearance frequency for either a single plate, or all plates on the lattice + + :param plates: a list of Plate instances + :param time_points_elapsed: a list of elapsed time values + :param save_path: the directory to save the plot image + :param bar: if a bar plot should be used instead of the default line plot + :returns: a file path object if the plot was saved sucessfully """ _, ax = plt.subplots() colormap = cm.get_cmap("plasma") - for plate_id, plate_item in plates_dict.items(): - if len(plates_dict) > 1: + for plate in plates: + if len(plates) > 1: # Get a color from the colourmap - cm_plate = colormap(0.2 + (0.65 - 0.2) * (plate_id / len(plates_dict))) - plot_total = len(plates_dict) + cm_plate = colormap(0.2 + (0.65 - 0.2) * (plate.id / len(plates))) + plot_total = len(plates) else: cm_plate = "Purple" plot_total = None - if not len(plate_item) < 1: + if not plate.count < 1: # Plot frequency for each time point - time_of_appearance_frequency(ax, (plate_id, plate_item), time_points_elapsed, cm_plate, plot_total, bar = bar) + time_of_appearance_frequency(ax, plate, time_points_elapsed, cm_plate, plot_total, bar = bar) lgd = ax.legend(loc = 'center right', fontsize = 8, bbox_to_anchor = (1.25, 0.5)) save_params = { "format": "png", "bbox_extra_artists": (lgd,), "bbox_inches": "tight" - } + } plt.ylim(ymin = 0) plt.title("Time of appearance") @@ -282,19 +327,37 @@ def plot_appearance_frequency(plates_dict, time_points_elapsed, save_path, bar = save_name = "time_of_appearance_bar.png" else: save_name = "time_of_appearance.png" - plt.savefig(str(save_path.joinpath(save_name)), **save_params) + save_path = save_path.joinpath(save_name) - plt.close() + try: + plt.savefig(str(save_path), **save_params) + except Exception: + save_path = None + finally: + plt.close() + return save_path -def time_of_appearance_frequency(ax, plate_item, time_points_elapsed, plot_color, plot_total = None, bar = False): +def time_of_appearance_frequency( + ax: Axes, + plate: Plate, + time_points_elapsed: List[int], + plot_color: str, + plot_total: int = None, + bar: bool = False +): """ Add a time of appearance frequency bar or line plot to an axis - """ - plate_id, plate = plate_item + :param ax: a Matplotlib Axes object to add a plot to + :param plate: a Plate instance + :param time_points_elapsed: a list of elapsed time values + :param plot_color: a Colormap color + :param plot_total: the total number of plots on the Axes + :param bar: if a bar plot should be used instead of the default line plot + """ time_points_dict = dict() - for colony in plate.values(): + for colony in plate.items: key = colony.timepoint_first.elapsed_minutes if key not in time_points_dict: time_points_dict[key] = 0 @@ -307,14 +370,14 @@ def time_of_appearance_frequency(ax, plate_item, time_points_elapsed, plot_color ax.plot( *zip(*sorted(time_points_dict.items())), color = plot_color, - label = f"Plate {plate_id}", + label = f"Plate {plate.id}", alpha = 0.9 - ) + ) else: if plot_total is not None: width = plot_total + 1 # Offset x positions so bars aren't obscured - x = [x + ((plate_id - 1) * width) for x in sorted(time_points_dict.keys())] + x = [x + ((plate.id - 1) * width) for x in sorted(time_points_dict.keys())] else: width = 14 x = [x for x in sorted(time_points_dict.keys())] @@ -326,8 +389,8 @@ def time_of_appearance_frequency(ax, plate_item, time_points_elapsed, plot_color y, width = width, color = plot_color, - label = f"Plate {plate_id}" - ) + label = f"Plate {plate.id}" + ) # Format x-axis labels as integer hours ax.set_xticklabels(axis_minutes_to_hours(ax.get_xticks())) @@ -335,9 +398,14 @@ def time_of_appearance_frequency(ax, plate_item, time_points_elapsed, plot_color ax.set_ylabel("Frequency") -def plot_doubling_map(plates_dict, time_points_elapsed, save_path): +def plot_doubling_map(plates: List[Plate], time_points_elapsed: List[int], save_path: Path) -> Path: """ Heatmap of doubling time vs time of appearance + + :param plates: a list of Plate instances + :param time_points_elapsed: a list of elapsed time values + :param save_path: the directory to save the plot image + :returns: a file path object if the plot was saved sucessfully """ from numpy import histogram2d, zeros_like from numpy.ma import masked_where @@ -347,8 +415,8 @@ def plot_doubling_map(plates_dict, time_points_elapsed, save_path): x = [0] y = [0] - for plate in plates_dict.values(): - for colony in plate.values(): + for plate in plates: + for colony in plate.items: x.append(colony.timepoint_first.elapsed_minutes) y.append(colony.get_doubling_time_average(elapsed_minutes = True)) @@ -365,7 +433,7 @@ def plot_doubling_map(plates_dict, time_points_elapsed, save_path): cmap = "RdPu", extent = extent, origin = "lower" - ) + ) plt.xlim(xmin = 0) plt.ylim(ymin = 0) @@ -383,6 +451,12 @@ def plot_doubling_map(plates_dict, time_points_elapsed, save_path): ax.set_ylabel("Average doubling time (hours)") plt.tight_layout() - plt.savefig(str(save_path.joinpath("appearance_doubling_distribution.png")), format = "png") - plt.close() \ No newline at end of file + save_path = save_path.joinpath("appearance_doubling_distribution.png") + try: + plt.savefig(str(save_path), format = "png") + except Exception: + save_path = None + finally: + plt.close() + return save_path \ No newline at end of file diff --git a/src/colonyscanalyser/utilities.py b/src/colonyscanalyser/utilities.py index e363d46..e8fac36 100644 --- a/src/colonyscanalyser/utilities.py +++ b/src/colonyscanalyser/utilities.py @@ -1,42 +1,8 @@ def round_tuple_floats(tuple_item, precision = 2): if not isinstance(tuple_item, tuple): raise ValueError(f"The object must be of type 'tuple', not type '{type(tuple_item)}'") - return tuple(map(lambda x: isinstance(x, float) and round(x, precision) or x, tuple_item)) - - -def index_number_to_coordinate(index, lattice): - """ - Calculate row and column numbers for an item index - - Lattice co-ordinate and item index numbers are 1-based - - :param index: item index integer - :param lattice: row and column tuple boundaries - :returns: row and column co-ordinate tuple - :raises IndexError: if the returned index number would exceed the lattice size - """ - (lattice_row, lattice_col) = lattice - if index < 1 or lattice_row < 1 or lattice_col < 1: - raise ValueError("All supplied parameters must be greater than zero") - - row = ((index - 1) // lattice[1]) + 1 - col = ((index - 1) % lattice[1]) + 1 - if row > lattice_row or col > lattice_col: - raise IndexError("Index number is greater than the supplied lattice size") - - return (row, col) - - -def coordinate_to_index_number(coordinate): - """ - Find a positional index for a coordinate - - Starting along rows and then down columns - """ - import numpy as np - - return np.prod(coordinate) + return tuple(map(lambda x: isinstance(x, float) and round(x, precision) or x, tuple_item)) def progress_bar(bar_progress, bar_length = 30, message = ""): diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py new file mode 100644 index 0000000..1e245c5 --- /dev/null +++ b/tests/unit/test_base.py @@ -0,0 +1,243 @@ +import pytest +from unittest import mock +from datetime import datetime + +from colonyscanalyser.base import ( + Identified, + IdentifiedCollection, + Named, + Unique, + TimeStamped, + TimeStampElapsed +) + + +@pytest.fixture(params = [1, 2, 10000000]) +def id(request): + yield request.param + + +@pytest.fixture(params = [-1, 0, 0.5, "one"]) +def id_invalid(request): + yield request.param + + +class TestIdentified: + def test_init(self, id): + assert Identified(id).id == id + + def test_id(self, id): + obj = Identified(id) + obj.id = id * 3 + + assert obj.id == id * 3 + + def test_id_invalid(self, id_invalid): + with pytest.raises(ValueError): + Identified(id_invalid) + + +class TestIdentifiedCollection: + @staticmethod + def IdentifiedMock(id): + identified = mock.Mock(spec = Identified) + identified.id = id + + return identified + + @pytest.fixture + def identified_items(self): + items = list() + + for i in range(2, 10): + items.append(self.IdentifiedMock(i)) + + return items + + @pytest.fixture(scope = "function") + def item_rand_id(self, request, identified_items): + from random import randint + + yield randint(2, len(identified_items)) + + class TestInitialize: + def test_init(self): + assert IdentifiedCollection().items == list() + + def test_init_list(self, identified_items): + collection = IdentifiedCollection(identified_items) + + assert collection.items == identified_items + + class TestProperties: + def test_count(self, identified_items): + collection = IdentifiedCollection(identified_items) + + assert collection.count == len(identified_items) + + def test_items_none(self): + collection = IdentifiedCollection(None) + collection.items = None + + assert collection.items is not None + assert isinstance(collection.items, list) + + def test_items_sorted(self, identified_items): + from random import sample + + identified_items_shuffled = sample(identified_items, len(identified_items)) + collection = IdentifiedCollection(identified_items_shuffled) + + assert collection.items != identified_items_shuffled + assert collection.items == identified_items + + @pytest.mark.parametrize("items", [list(), dict()]) + def test_items_iterable(self, items): + collection = IdentifiedCollection() + if isinstance(items, dict): + items[1] = TestIdentifiedCollection.IdentifiedMock(1) + else: + items.append(TestIdentifiedCollection.IdentifiedMock(1)) + collection.items = items + + assert isinstance(collection.items, list) + assert len(collection.items) == len(items) + + @pytest.mark.parametrize("items", [1, "1"]) + def test_items_iterable_invalid(self, items): + collection = IdentifiedCollection() + + with pytest.raises((TypeError, ValueError)): + collection.items = items + + class TestMethods: + def test_add_item(self, identified_items): + collection = IdentifiedCollection(identified_items) + item_new = collection.add(id = 1) + + assert collection.count == len(identified_items) + 1 + assert item_new in collection.items + assert collection.items[0] == item_new + + def test_append_item(self): + collection = IdentifiedCollection() + identified_item = TestIdentifiedCollection.IdentifiedMock(1) + collection.append(identified_item) + + assert collection.count == 1 + assert any(identified_item.id == item.id for item in collection.items) + with pytest.raises(ValueError): + collection.append(identified_item) + + def test_exists(self, identified_items): + collection = IdentifiedCollection(identified_items) + + assert collection.exists(identified_items[0]) + + def test_id_exists(self, identified_items): + collection = IdentifiedCollection(identified_items) + + assert collection.id_exists(identified_items[0].id) + + def test_get_colony(self, identified_items, item_rand_id): + collection = IdentifiedCollection(identified_items) + + item = collection.get_item(item_rand_id) + + assert item is not None + assert item.id == item_rand_id + + def test_remove_item(self, identified_items, item_rand_id): + collection = IdentifiedCollection(identified_items) + + item = collection.get_item(item_rand_id) + assert item is not None + + collection.remove(item_rand_id) + item = collection.get_item(item_rand_id) + assert item is None + + def test_remove_item_invalid(self, identified_items): + collection = IdentifiedCollection(identified_items) + + with pytest.raises(KeyError): + collection.remove(-1) + + +class TestNamed: + @pytest.fixture(params = ["name", "name with spaces", 1, 0, -1, 1.1]) + def name(self, request): + yield request.param + + def test_init(self, name): + assert Named(name).name == str(name) + + +class TestUnique: + @pytest.fixture(scope = "class") + def unique(self): + yield Unique() + + def test_init(self, unique): + assert unique.id == 1 + + def test_unique(self, unique): + assert unique.id == 1 + unique = None + assert Unique().id != 1 + + def test_increment(self, unique): + id_count = unique.id_count + items_total = 10 + items = list() + for i in range(items_total): + items.append(Unique()) + + assert len(items) == items_total + for i in range(items_total): + assert items[i].id == i + 1 + id_count + for item in items: + assert (item.id != existing_item.id for existing_item in items) + + def test_id(self, id): + obj = Unique() + original_id = obj.id + obj.id = id + + assert obj.id == original_id + + +class TestTimeStamped: + def test_init(self): + timestamp = datetime(1, 1, 1) + timestamped = TimeStamped(timestamp) + + assert timestamped.timestamp == timestamp + + def test_init_auto(self): + + assert TimeStamped().timestamp is not None + + +class TestTimeStampElapsed: + def test_init(self): + timestamp = datetime(1, 1, 1) + timestampelapsed = TimeStampElapsed(timestamp, timestamp) + + assert timestampelapsed.timestamp == timestamp + assert timestampelapsed.timestamp_initial == timestamp + + def test_init_auto(self): + + assert TimeStampElapsed().timestamp is not None + assert TimeStampElapsed().timestamp_initial is not None + + def test_timestamp_elapsed(self): + timestamp_inital = datetime(1, 1, 1, 0, 0) + timestamp = datetime(1, 1, 1, 1, 1) + timestampelapsed = TimeStampElapsed(timestamp, timestamp_inital) + timestamp_diff = timestamp - timestamp_inital + + assert timestampelapsed.timestamp_elapsed == timestamp_diff + assert timestampelapsed.timestamp_elapsed_hours == timestamp_diff.total_seconds() / 3600 + assert timestampelapsed.timestamp_elapsed_minutes == int(timestamp_diff.total_seconds() / 60) \ No newline at end of file diff --git a/tests/unit/test_colony.py b/tests/unit/test_colony.py index 2d7b1a0..2c949e8 100644 --- a/tests/unit/test_colony.py +++ b/tests/unit/test_colony.py @@ -6,7 +6,7 @@ timepoints_from_image, colonies_from_timepoints, group_timepoints_by_center - ) +) centers = [ @@ -19,7 +19,7 @@ (4, 3.9), (3, 10), (0, 4.4) - ] +] @pytest.fixture(params = [centers]) @@ -133,13 +133,6 @@ def test_empty(self): colony.timepoints class TestProperties(): - @pytest.mark.parametrize("timepoints_iter", [list, None], indirect = True) - @pytest.mark.parametrize("id", [-1, 0, 2, 3.4, 1000000, "1"]) - def test_id(self, timepoints_iter, id): - colony = Colony(id, timepoints_iter) - - assert colony.id == id - def test_iterable(self, colony): assert len([*colony.__iter__()]) == 18 @@ -151,7 +144,7 @@ def test_timepoints(self, timepoints, colony): def test_center(self, timepoints): from statistics import mean - colony = Colony(id, timepoints) + colony = Colony(1, timepoints) for i, coord in enumerate(colony.center): assert round(coord, 4) == round(mean([t.center[i] for t in timepoints]), 4) @@ -164,7 +157,7 @@ def test_growth_rate(self, timepoints, timepoint_empty): assert colony_empty.growth_rate == 0 def test_growth_rate_average(self, timepoints, timepoint_empty): - colony = Colony(id, timepoints) + colony = Colony(1, timepoints) colony_empty = Colony(1, [timepoint_empty]) assert colony.growth_rate_average == ((timepoints[-1].area - timepoints[0].area) ** (1 / len(timepoints))) - 1 @@ -182,7 +175,6 @@ class TestMethods(): def test_get_timepoint(self, timepoints): colony = Colony(1, timepoints) - # Get timepoint assert colony.get_timepoint(timepoints[0].date_time) == timepoints[0] with pytest.raises(ValueError): colony.get_timepoint(None) @@ -210,7 +202,7 @@ def test_remove_timepoint(self, timepoints,): @pytest.mark.parametrize("timepoint_index, expected", [(0, 12.57), (-1, 1.4)]) def test_circularity(self, timepoints, timepoint_index, expected): colony = Colony(1, timepoints) - circularity = colony.circularity_at_timepoint(timepoints[timepoint_index].date_time) + circularity = colony.get_circularity_at_timepoint(timepoints[timepoint_index].date_time) assert round(circularity, 2) == expected diff --git a/tests/unit/test_file_access.py b/tests/unit/test_file_access.py index d08b913..f749eb1 100644 --- a/tests/unit/test_file_access.py +++ b/tests/unit/test_file_access.py @@ -4,6 +4,7 @@ from colonyscanalyser.file_access import ( file_exists, + file_safe_name, get_files_by_type, create_subdirectory, move_to_subdirectory, @@ -12,7 +13,7 @@ load_file, save_file, save_to_csv - ) +) FILE_NON_EXISTANT = Path("") SUB_DIR = Path("test_subdir") @@ -22,7 +23,7 @@ def create_temp_file(dir, extension = None): return Path(tempfile.mkstemp(dir = dir, suffix = extension)[1]) -class TestFileExists(): +class TestFileExists: def test_file_exists(self): assert file_exists(Path(__file__)) is True @@ -33,7 +34,24 @@ def test_file_string(self): assert file_exists("") is False -class TestGetFilesByType(): +class TestFileSafeName: + def test_join(self): + first = "first" + second = "" + third = "third" + safe_name = file_safe_name([first, second, third]) + + assert safe_name == first + "_" + third + + def test_replace(self): + first = "first second" + third = "third" + safe_name = file_safe_name([first, third]) + + assert safe_name == "first_second_third" + + +class TestGetFilesByType: def test_return_list(self): # Check that a populated list is returned result = get_files_by_type(Path(__file__).parent) @@ -62,7 +80,7 @@ def test_file_string(self): assert all(result) is True -class TestCreateSubdirectory(): +class TestCreateSubdirectory: @pytest.fixture def result(self, tmp_path): yield create_subdirectory(tmp_path, SUB_DIR) @@ -94,7 +112,7 @@ def test_readonly_exception(self, tmp_path): test_dir.chmod(test_dir_chmod) -class TestMoveToSubdirectory(): +class TestMoveToSubdirectory: def test_moved(self, tmp_path): temp_file = create_temp_file(tmp_path) result = move_to_subdirectory([temp_file], SUB_DIR) @@ -125,7 +143,7 @@ def test_write_exception(self, tmp_path): test_dir.chmod(test_dir_chmod) -class TestFileCompression(): +class TestFileCompression: @pytest.fixture(params=["r", "wb"]) def file_access_modes(self, request): yield request.param @@ -143,7 +161,7 @@ def test_compression_readable(self, tmp_path, file_access_modes): assert outfile.readable() is True -class TestLoadFile(): +class TestLoadFile: @pytest.fixture def data(self): return [0, 1, 2, 3, 4] @@ -161,7 +179,7 @@ def test_return_file(self, tmp_path, data): assert load_file(temp_file, method) is not None -class TestSaveFile(): +class TestSaveFile: @pytest.fixture def data(self): return [0, 1, 2, 3, 4] @@ -176,7 +194,7 @@ def test_save_compressed(self, tmp_path, data): assert file_exists(temp_file) is True -class TestSaveToCSV(): +class TestSaveToCSV: @pytest.fixture def headers(self): return ["one", "two", "three"] @@ -243,7 +261,7 @@ def __init__(self, prop): def __iter__(self): return iter([ self.prop - ]) + ]) data_iters = list() for row in data_list: diff --git a/tests/unit/test_geometry.py b/tests/unit/test_geometry.py new file mode 100644 index 0000000..874b816 --- /dev/null +++ b/tests/unit/test_geometry.py @@ -0,0 +1,100 @@ +import pytest + +from math import pi +from colonyscanalyser.geometry import ( + Shape, + Circle +) + + +class TestShape: + @pytest.fixture + def shape(self, request): + yield Shape() + + @pytest.fixture(params = [0, 1, 1.1, -1]) + def distance(self, request): + yield request.param + + def test_area(self, shape): + with pytest.raises(NotImplementedError): + shape.area + + @pytest.mark.parametrize("center", [(0, 0), (1, 1), (0, -1), (0, 0, 0), (1, 1, 1)]) + def test_center(self, shape, center): + assert shape.center is None + + shape.center = center + assert shape.center == center + + def test_depth(self, shape, distance): + assert shape.depth == 0 + + shape.depth = distance + assert shape.depth == distance + + def test_height(self, shape, distance): + assert shape.height == 0 + + shape.height = distance + assert shape.height == distance + + def test_perimeter(self, shape): + with pytest.raises(NotImplementedError): + shape.perimeter + + def test_width(self, shape, distance): + assert shape.width == 0 + + shape.width = distance + assert shape.width == distance + + +class TestCircle(): + @pytest.fixture(params = [0, 0.1, 1, 1.5, 1000]) + def diameter(self, request): + yield request.param + + @pytest.fixture(params = [-1, "1", "one"]) + def diameter_invalid(self, request): + yield request.param + + def test_area(self, diameter): + circle = Circle(diameter) + radius = diameter / 2 + + assert circle.area == pi * radius * radius + + def test_circumference(self, diameter): + circle = Circle(diameter) + + assert circle.circumference == pi * diameter + + def test_diameter(self, diameter): + circle = Circle(diameter) + + assert circle.diameter == diameter + + def test_height(self, diameter): + circle = Circle(diameter) + + assert circle.height == diameter + + def test_perimeter(self, diameter): + circle = Circle(diameter) + + assert circle.perimeter == pi * diameter + + def test_radius(self, diameter): + circle = Circle(diameter) + + assert circle.radius == diameter / 2 + + def test_width(self, diameter): + circle = Circle(diameter) + + assert circle.width == diameter + + def test_diameter_invalid(self, diameter_invalid): + with pytest.raises((TypeError, ValueError)): + Circle(diameter_invalid) \ No newline at end of file diff --git a/tests/unit/test_image_file.py b/tests/unit/test_image_file.py new file mode 100644 index 0000000..2edd3a3 --- /dev/null +++ b/tests/unit/test_image_file.py @@ -0,0 +1,225 @@ +import pytest +from unittest import mock +from datetime import datetime, timedelta +from pathlib import Path +from numpy import array + +from colonyscanalyser.image_file import ( + ImageFile, + ImageFileCollection +) + + +@pytest.fixture +def image(request): + # A single pixel png image + image_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89" + image_bytes += b"\x00\x00\x00\x0bIDAT\x08\x99c\xf8\x0f\x04\x00\t\xfb\x03\xfd\xe3U\xf2\x9c\x00\x00\x00\x00IEND\xaeB`\x82" + + return image_bytes + + +class TestImageFile: + @pytest.fixture(params = [ + (datetime(2019, 4, 2, 14, 23), "img_20190402_1423.tif"), + (datetime(1901, 12, 1, 23, 1), "1901-12-01-23:01.png"), + (datetime(2020, 1, 31, 0, 59), "image 2020 01 31 00 59.gif") + ]) + def timestamp_image(self, request): + yield request.param + + @staticmethod + def create_temp_file(tmp_path, file_name, suffix = "", file_data = b"1"): + from tempfile import mkstemp + + _, temp_file = mkstemp(prefix = file_name, suffix = suffix, dir = tmp_path) + + # Write some data to the file so it is not empty + with open(temp_file, "wb") as f: + f.write(file_data) + + return temp_file + + @pytest.fixture(params = [True, False], autouse = True) + def cache_image(self, request): + yield request.param + + @pytest.mark.usefixtures("cache_image") + class TestInitialize: + def test_init(self, tmp_path, timestamp_image, cache_image, image): + image_path = TestImageFile.create_temp_file(tmp_path, timestamp_image[1], suffix = "png", file_data = image) + imagefile = ImageFile(image_path, cache_image = cache_image) + + assert str(imagefile.file_path) == str(image_path) + assert imagefile.timestamp == imagefile.timestamp_initial + assert imagefile.timestamp_elapsed == timedelta() + assert imagefile.cache_image == cache_image + + def test_init_timestamp(self, tmp_path, timestamp_image, cache_image): + timestamp, image_name = timestamp_image + timestamp_diff = timedelta(hours = 1) + image_path = TestImageFile.create_temp_file(tmp_path, image_name) + imagefile = ImageFile( + image_path, + timestamp = timestamp, + timestamp_initial = timestamp - timestamp_diff + ) + + assert imagefile.timestamp == timestamp + assert imagefile.timestamp_initial == timestamp - timestamp_diff + assert imagefile.timestamp_elapsed == timestamp_diff + + def test_enter_exit(self, tmp_path, timestamp_image, cache_image, image): + image_path = TestImageFile.create_temp_file(tmp_path, timestamp_image[1], suffix = "png", file_data = image) + imagefile = ImageFile(image_path, cache_image = cache_image) + + with imagefile as image_file: + assert (image_file._ImageFile__image == array([[[255, 255, 255, 255]]])).all() + if imagefile.cache_image: + assert (imagefile._ImageFile__image == array([[[255, 255, 255, 255]]])).all() + else: + assert imagefile._ImageFile__image is None + + class TestProperties: + @pytest.mark.parametrize("image_path", ["", Path(), "."]) + def test_filepath_missing(self, image_path): + with pytest.raises(FileNotFoundError): + ImageFile(image_path) + + def test_image(self, tmp_path, timestamp_image, image, cache_image): + _, image_name = timestamp_image + image_path = TestImageFile.create_temp_file(tmp_path, image_name, suffix = "png", file_data = image) + imagefile = ImageFile(image_path, cache_image = cache_image) + + assert (imagefile.image == array([[[255, 255, 255, 255]]])).all() + assert (imagefile.image_gray == array([[1.]])).all() + + if cache_image: + assert (imagefile._ImageFile__image == array([[[255, 255, 255, 255]]])).all() + else: + assert imagefile._ImageFile__image is None + + class TestMethods: + def test_timestamp_from_exif(self, tmp_path, timestamp_image): + timestamp, image_name = timestamp_image + image_path = TestImageFile.create_temp_file(tmp_path, image_name) + imagefile = ImageFile(image_path) + + with pytest.raises(NotImplementedError): + imagefile.timestamp_from_exif(image_path) + + def test_timestamp_from_string(self, tmp_path, timestamp_image): + timestamp, image_name = timestamp_image + image_path = TestImageFile.create_temp_file(tmp_path, image_name) + imagefile = ImageFile(image_path) + + assert imagefile.timestamp_from_string(image_path) == timestamp + + def test_timestamp_from_string_invalid(self, tmp_path): + image_path = TestImageFile.create_temp_file(tmp_path, "test_image_123456789") + imagefile = ImageFile(image_path) + + with pytest.raises(ValueError): + imagefile.timestamp_from_string(image_path, pattern = "") + imagefile.timestamp_from_string("") + assert imagefile.timestamp_from_string(imagefile.file_path.name) is None + + +class TestImageFileCollection: + @staticmethod + def ImageFileMock(file_path, timestamp): + image_file = mock.Mock(spec = ImageFile) + image_file.file_path = file_path + image_file.timestamp = timestamp + image_file.timestamp_initial = image_file.timestamp - timedelta(hours = 1) + + return image_file + + @pytest.fixture + def image_files(self): + image_files = list() + timestamp = datetime.now() + + for i in range(10): + image_files.append(self.ImageFileMock(str(i), timestamp + timedelta(hours = i))) + + return image_files + + class TestInitialize: + def test_init(self): + imagefiles = ImageFileCollection() + + assert imagefiles.items == list() + + def test_init_list(self, image_files): + imagefiles = ImageFileCollection(image_files) + + assert imagefiles.items == image_files + + class TestProperties: + def test_image_files_sorted(self, image_files): + from random import sample + + image_files_shuffled = sample(image_files, len(image_files)) + imagefiles = ImageFileCollection(image_files_shuffled) + + assert imagefiles.items != image_files_shuffled + assert imagefiles.items == image_files + + def test_image_file_count(self, image_files): + imagefiles = ImageFileCollection(image_files) + + assert imagefiles.count == len(image_files) + + def test_file_paths(self, image_files): + imagefiles = ImageFileCollection(image_files) + + assert len(imagefiles.file_paths) == len(image_files) + assert imagefiles.file_paths == [image_file.file_path for image_file in image_files] + + def test_timestamps(self, image_files): + imagefiles = ImageFileCollection(image_files) + + assert len(imagefiles.timestamps) == len(image_files) + assert imagefiles.timestamps == [image_file.timestamp for image_file in image_files] + + def test_timestamps_initial(self, image_files): + imagefiles = ImageFileCollection(image_files) + + assert len(imagefiles.timestamps_initial) == len(image_files) + assert imagefiles.timestamps_initial == [image_file.timestamp_initial for image_file in image_files] + + timestamp_initial = datetime(1, 1, 1, 1, 1, 1) + imagefiles.timestamps_initial = timestamp_initial + + assert imagefiles.timestamps_initial == [image_file.timestamp_initial for image_file in image_files] + + def test_timestamps_elapsed(self, image_files): + imagefiles = ImageFileCollection(image_files) + + assert len(imagefiles.timestamps_elapsed) == len(image_files) + assert imagefiles.timestamps_elapsed == [image_file.timestamp_elapsed for image_file in image_files] + assert imagefiles.timestamps_elapsed_hours == [image_file.timestamp_elapsed_hours for image_file in image_files] + assert ( + imagefiles.timestamps_elapsed_minutes == [image_file.timestamp_elapsed_minutes for image_file in image_files] + ) + assert ( + imagefiles.timestamps_elapsed_seconds == [image_file.timestamp_elapsed_seconds for image_file in image_files] + ) + + class TestMethods: + @mock.patch("colonyscanalyser.image_file.file_exists", return_value = True) + def test_add_image_file(self, patch, image_files): + imagefiles = ImageFileCollection(image_files) + image_file_first = imagefiles.items[0] + new_image_file = imagefiles.add( + file_path = "", + timestamp = image_file_first.timestamp - timedelta(hours = 1), + timestamp_initial = image_file_first.timestamp_initial - timedelta(hours = 1), + cache_image = False + ) + + assert imagefiles.count == len(image_files) + 1 + assert new_image_file in imagefiles.items + assert image_file_first != new_image_file + assert imagefiles.items[0] == new_image_file \ No newline at end of file diff --git a/tests/unit/test_imaging.py b/tests/unit/test_imaging.py index 9796547..4a2b5c8 100644 --- a/tests/unit/test_imaging.py +++ b/tests/unit/test_imaging.py @@ -9,7 +9,7 @@ get_image_circles, remove_background_mask, watershed_separation - ) +) image_ref = np.array([ [0, 0, 0, 0, 0, 0, 0, 1, 0], @@ -254,7 +254,7 @@ def test_size_mismatch(self): remove_background_mask( np.ones((5, 5), dtype = np.uint8), np.ones((3, 5), dtype = np.uint8) - ) + ) def test_image_blank(self): image_blank = np.zeros((3, 3)) diff --git a/tests/unit/test_plate.py b/tests/unit/test_plate.py new file mode 100644 index 0000000..015939e --- /dev/null +++ b/tests/unit/test_plate.py @@ -0,0 +1,270 @@ +import pytest + +from colonyscanalyser.plate import ( + Plate, + PlateCollection +) + +invalid = [-1, -1.1, "one", None] + + +@pytest.fixture(params = [1, 2, 123456789]) +def id(request): + yield request.param + + +@pytest.fixture(params = [0, 1.0, 1 * 10**14]) +def diameter(request): + yield request.param + + +@pytest.fixture(scope = "function") +def plate(request, id, diameter): + yield Plate(id, diameter) + + +@pytest.fixture(scope = "function") +def colonies(request): + colonies = list() + for i in range(1, 10): + colonies.append(Colony(i)) + + yield colonies + + +class Colony(): + def __init__(self, id): + self.id = id + self.timepoints = {str(id): str(id)} + + def __iter__(self): + return iter([ + self.id + ]) + + +class TestPlate(): + class TestInitialize(): + def test_init(self, id, diameter): + plate = Plate(id, diameter) + + assert plate.id == id + assert plate.diameter == diameter + assert isinstance(plate.items, list) + assert plate.count == 0 + + @pytest.mark.parametrize("id", invalid) + @pytest.mark.parametrize("diameter", invalid) + def test_init_invalid(self, id, diameter): + with pytest.raises((TypeError, ValueError)): + Plate(id, diameter) + + class TestProperties(): + def test_iterable(self, plate): + assert len([*plate.__iter__()]) == 7 + + @pytest.mark.parametrize( + "colonies, edge_cut, name", + [ + (list(), 0, ""), + (list(), 1.0, "Test name"), + (list(), -1, "1") + ] + ) + def test_properties(self, plate, colonies, edge_cut, name): + + plate.items = colonies + plate.edge_cut = edge_cut + plate.name = name + + assert plate.items == colonies + assert plate.edge_cut == edge_cut + assert plate.name == name + + class TestMethods(): + def test_colonies_rename_sequential(self, plate, colonies): + seq_start = 11 + plate.items = colonies + plate.colonies_rename_sequential(start = seq_start) + + assert plate.count == len(colonies) + for i in range(seq_start, seq_start + len(colonies)): + assert any(colony.id == i for colony in plate.items) + + def test_colonies_to_csv(self, plate, colonies, tmp_path): + import csv + + plate.items = colonies + result = plate.colonies_to_csv(tmp_path) + + # Check all rows were written correctly + with open(result, 'r') as csvfile: + reader = csv.reader(csvfile) + for i, row in enumerate(reader): + # Skip headers row + if i != 0: + assert [str(x) for x in colonies[i - 1]] == row + + def test_colonies_timepoints_to_csv(self, plate, colonies, tmp_path): + import csv + + plate.items = colonies + result = plate.colonies_timepoints_to_csv(tmp_path) + + # Check all rows were written correctly + with open(result, 'r') as csvfile: + reader = csv.reader(csvfile) + for i, row in enumerate(reader): + # Skip headers row + if i != 0: + assert [str(x) for x in colonies[i - 1].timepoints.items()] == [str(tuple(row))] + + def test_collection_to_csv(self, plate, tmp_path, colonies): + file_name = "test" + result = plate._Plate__collection_to_csv(str(tmp_path), file_name, colonies, list()) + + assert result == tmp_path.joinpath(file_name).with_suffix(".csv") + + def test_collection_to_csv_path(self, plate): + with pytest.raises(FileNotFoundError): + plate._Plate__collection_to_csv("", "", list()) + + +class TestPlateCollection: + class TestInitialize: + @pytest.mark.parametrize("shape", [(3, 1), (5, 5), (2, 2)]) + def test_init(self, shape): + collection = PlateCollection(shape = shape) + + assert collection.shape == shape + + @pytest.mark.parametrize("shape", [(0, 1), (-1, 1), (1.1, 1)]) + def test_init_invalid(self, shape): + with pytest.raises(ValueError): + PlateCollection(shape = shape) + + class TestProperties: + def test_centers(self): + centers = list() + collection = PlateCollection() + for i in range(1, 10): + center = (i, i) + centers.append(center) + collection.add(id = i, diameter = 1, center = center) + + assert collection.count == len(centers) + assert collection.centers == centers + + class TestMethods: + @pytest.fixture + def image_circle(self, request): + from numpy import uint8, mgrid + + # Create a 200x200 array with a donut shaped circle around the centre + xx, yy = mgrid[:200, :200] + circle = (xx - 100) ** 2 + (yy - 100) ** 2 + img = ((circle < (6400 + 60)) & (circle > (6400 - 60))).astype(uint8) + img[img == circle] = 255 + + yield img + + def test_add(self, id, diameter): + collection = PlateCollection() + item_new = collection.add(id = id, diameter = diameter) + + assert collection.count == 1 + assert item_new in collection.items + + def test_from_image(self, image_circle): + plates = PlateCollection.from_image( + shape = (1, 1), + image = image_circle, + diameter = 180, + ) + + assert plates is not None + assert isinstance(plates, PlateCollection) + + def test_plates_from_image(self, image_circle): + label = "label" + plates = PlateCollection(shape = (1, 1)) + plates.plates_from_image( + image = image_circle, + diameter = 180, + labels = {1: label} + ) + + assert plates.count == 1 + assert plates.centers == [(102, 102)] + assert plates.items[0].diameter == 160 + assert plates.items[0].name == label + + def test_plates_from_image_invalid(self, image_circle): + plates = PlateCollection() + + print(plates.shape) + print(PlateCollection.coordinate_to_index(plates.shape)) + with pytest.raises(ValueError): + plates.plates_from_image( + image = image_circle, + diameter = 180 + ) + + def test_slice_plate_image(self, image_circle): + plates = PlateCollection(shape = (1, 1)) + plates.add( + id = 1, + diameter = 180, + edge_cut = 20, + center = (102, 102) + ) + + images = plates.slice_plate_image(image_circle) + + assert len(images) == 1 + assert images[1].shape == (141, 141) + + @pytest.mark.parametrize( + "index, shape, expected", + [ + (3, (3, 2), (2, 1)), + (5, (1, 8), (1, 5)), + (10, (5, 5), (2, 5)), + ]) + def test_index_to_coordinate(self, index, shape, expected): + result = PlateCollection.index_to_coordinate(index, shape) + + assert result == expected + + @pytest.mark.parametrize( + "index, shape", + [ + (-1, (1, 1)), + (0, (1, 1)), + (1, (0, 0)), + (1, (0, 1)), + (1, (-1, 1)), + ]) + def test_index_to_coordinate_invalid(self, index, shape): + with pytest.raises(ValueError): + PlateCollection.index_to_coordinate(index, shape) + + with pytest.raises(IndexError): + PlateCollection.index_to_coordinate(100, (1, 1)) + + @pytest.mark.parametrize( + "coordinate, expected", + [ + ((3, 2), 6), + ((1, 8), 8), + ((5, 5), 25), + ]) + def test_coordinate_to_index(self, coordinate, expected): + result = PlateCollection.coordinate_to_index(coordinate) + + assert result == expected + + @pytest.mark.parametrize("coordinate", [(0, 0), (-1, 1)]) + def test_coordinate_to_index_invalid(self, coordinate): + with pytest.raises(ValueError): + PlateCollection.coordinate_to_index(coordinate) diff --git a/tests/unit/test_utilities.py b/tests/unit/test_utilities.py index 2b6362d..fe75d78 100644 --- a/tests/unit/test_utilities.py +++ b/tests/unit/test_utilities.py @@ -2,12 +2,10 @@ from colonyscanalyser.utilities import ( round_tuple_floats, - index_number_to_coordinate, - coordinate_to_index_number, progress_bar, average_dicts_values_by_key, average_median_dicts_values_by_key - ) +) class TestRoundTupleFloats(): @@ -15,7 +13,7 @@ class TestRoundTupleFloats(): (1.3285, 1.00001), (-95840.3567, 0.0), (43.94387678, "string.", 2.567) - ]) + ]) def tuples(self, request): yield request.param @@ -38,60 +36,6 @@ def test_value_error(self): round_tuple_floats(1) -class TestIndexNumberToCoordinate(): - @pytest.fixture(params = [1, 2, 4, 10]) - def index_number(self, request): - yield request.param - - @pytest.fixture(params = [-1, 0]) - def index_number_invalid(self, request): - yield request.param - - @pytest.fixture(params = [(1, 1), (10, 1), (3, 2), (5, 5)]) - def lattice(self, request): - yield request.param - - @pytest.fixture(params=[(0, 0), (-1, 1), (0, 1)]) - def lattice_invalid(self, request): - yield request.param - - @pytest.mark.parametrize( - "index, lattice, expected", - [ - (3, (3, 2), (2, 1)), - (5, (1, 8), (1, 5)), - (10, (5, 5), (2, 5)), - ]) - def test_index_valid(self, index, lattice, expected): - result = index_number_to_coordinate(index, lattice) - assert result == expected - - def test_index_invalid(self, index_number_invalid, lattice): - with pytest.raises(ValueError): - index_number_to_coordinate(index_number_invalid, lattice) - - def test_lattice_invalid(self, index_number, lattice_invalid): - with pytest.raises(ValueError): - index_number_to_coordinate(index_number, lattice_invalid) - - def test_index_error(self, lattice): - with pytest.raises(IndexError): - index_number_to_coordinate(100, lattice) - - -class TestCoordinateToIndexNumber(): - @pytest.mark.parametrize( - "coordinate, expected", - [ - ((3, 2), 6), - ((1, 8), 8), - ((5, 5), 25), - ]) - def test_index_valid(self, coordinate, expected): - result = coordinate_to_index_number(coordinate) - assert result == expected - - class TestProgressBar(): def count_lines(self, text): return len(text.split('\n'))