diff --git a/src/colonyscanalyser/growth_curve.py b/src/colonyscanalyser/growth_curve.py index 0e4767f..0a55365 100644 --- a/src/colonyscanalyser/growth_curve.py +++ b/src/colonyscanalyser/growth_curve.py @@ -197,6 +197,7 @@ def fit_curve(self, initial_params: List[float] = None): """ from statistics import median from numpy import errstate, iinfo, intc, isinf, isnan, sqrt, diag, std + from .utilities import savgol_filter timestamps = [timestamp.total_seconds() for timestamp in sorted(self.data.keys())] measurements = [val for _, val in sorted(self.data.items())] @@ -211,10 +212,12 @@ def fit_curve(self, initial_params: List[float] = None): # The number of values must be at least the number of parameters if len(timestamps) >= 4 and len(measurements) >= 4: - # Calculate standard deviation if all(isinstance(m, Iterable) for m in measurements): + # Calculate standard deviation measurements_std = [std(m, axis = 0) for m in measurements] + # Use the filtered median measurements = [median(val) for val in measurements] + measurements = savgol_filter(measurements, window = 15, order = 2) # Try to estimate initial parameters, if unsuccessful pass None # None will result in scipy.optimize.curve_fit using its own default parameters diff --git a/src/colonyscanalyser/plots.py b/src/colonyscanalyser/plots.py index 3377d26..32dd761 100644 --- a/src/colonyscanalyser/plots.py +++ b/src/colonyscanalyser/plots.py @@ -326,6 +326,7 @@ def growth_curve( :param line_color: a Colormap color for the median """ from statistics import median + from .utilities import savgol_filter if line_color is None: line_color = scatter_color @@ -341,20 +342,13 @@ def growth_curve( alpha = 0.25 ) - # Plot the windowed median - from scipy.signal import savgol_filter + # Plot the smoothed median median = [median(val) for _, val in sorted(plate.growth_curve.data.items())] - window = 15 if len(median) > 15 else len(median) - # Window length must be odd and greater than polyorder for Savitzky-Golay filter - if window % 2 == 0: - window -= 1 - median_filtered = savgol_filter(median, window, 2) if window >= 3 else median - ax.plot( [td.total_seconds() / 3600 for td in sorted(plate.growth_curve.data.keys())], - median_filtered, + savgol_filter(median, 15, 2), color = line_color, - label = "Windowed median" if growth_params else f"Plate {plate.id}", + label = "Smoothed median" if growth_params else f"Plate {plate.id}", linewidth = 2 ) diff --git a/src/colonyscanalyser/utilities.py b/src/colonyscanalyser/utilities.py index 1053550..2d56668 100644 --- a/src/colonyscanalyser/utilities.py +++ b/src/colonyscanalyser/utilities.py @@ -1,4 +1,4 @@ -from typing import Collection, Tuple, List, Dict +from typing import Collection, Tuple, List, Dict, Any def round_tuple_floats(tuple_item: Tuple[float], precision: int = 2) -> Tuple[float]: @@ -15,6 +15,30 @@ def round_tuple_floats(tuple_item: Tuple[float], precision: int = 2) -> Tuple[fl return tuple(map(lambda x: isinstance(x, float) and round(x, precision) or x, tuple_item)) +def savgol_filter(measurements: Collection[Any], window: int = 15, order: int = 2) -> Collection[Any]: + """ + Smooth a one dimensional set of data with a Savitzky-Golay filter + + If no filtering can be performed the original data is returned unaltered + + :param measurements: a collection of values to filter + :param window: the window length used in the filter + :param order: the polynomial order used to fit the values + :returns: filtered values, if possible + """ + from scipy.signal import savgol_filter + + # Window length must be odd and greater than polyorder for Savitzky-Golay filter + if window > len(measurements): + window = len(measurements) + if window % 2 == 0: + window -= 1 + if window >= 1 and window > order: + measurements = savgol_filter(measurements, window, order) + + return measurements + + def progress_bar(bar_progress: float, bar_length: float = 30, message: str = ""): """ Output a simple progress bar to the console diff --git a/tests/unit/test_growth_curve.py b/tests/unit/test_growth_curve.py index 4fc3729..0a588cc 100644 --- a/tests/unit/test_growth_curve.py +++ b/tests/unit/test_growth_curve.py @@ -141,8 +141,9 @@ def test_fit_growth_curve(self, host): assert host.growth_curve._carrying_capacity_std >= 0 def test_fit_growth_curve_iter(self, host): - host.signals = [[0, signal] for signal in host.signals] - host.growth_curve.fit_curve() + host.signals = [[signal] for signal in host.signals] + + host.growth_curve.fit_curve(initial_params = [0.2, 0.4, 1, 1]) assert host.growth_curve._lag_time assert host.growth_curve._lag_time_std.total_seconds() >= 0 diff --git a/tests/unit/test_utilities.py b/tests/unit/test_utilities.py index 5f68370..01b9cbc 100644 --- a/tests/unit/test_utilities.py +++ b/tests/unit/test_utilities.py @@ -3,6 +3,7 @@ from colonyscanalyser.utilities import ( round_tuple_floats, progress_bar, + savgol_filter, dicts_merge, dicts_mean, dicts_median @@ -62,6 +63,28 @@ def test_message(self, capsys): assert captured.out[slice(-len(message) - 1, -1, 1)] == message +class TestSavGolFilter(): + @pytest.mark.parametrize("window_length", [3, 4, 5, 7, 10]) + def test_window(self, window_length): + measurements = [0] * window_length + + results = savgol_filter(measurements, window_length, 1) + + assert (results == measurements).all() + + @pytest.mark.parametrize("order", [1, 2, 3, 4, 6]) + def test_order(self, order): + measurements = [1] * 10 + window_length = 3 + + results = savgol_filter(measurements, window_length, order) + + if order >= window_length: + assert list(results) == measurements + else: + assert (results != measurements).any() + + class TestDictsMerge(): @pytest.mark.parametrize( "dicts, expected",