Skip to content

Commit

Permalink
Improve smoothing with Savitzky–Golay filter (#74)
Browse files Browse the repository at this point in the history
* Smooth measurements when calculating growth curves

* Update growth curve tests for SavGol filtering

* Add tests for savgol_filter
  • Loading branch information
Erik-White authored Jan 25, 2021
1 parent 04f7ebd commit da6dea2
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 14 deletions.
5 changes: 4 additions & 1 deletion src/colonyscanalyser/growth_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def fit_curve(self, initial_params: List[float] = None):
"""
from statistics import median
from numpy import errstate, iinfo, intc, isinf, isnan, sqrt, diag, std
from .utilities import savgol_filter

timestamps = [timestamp.total_seconds() for timestamp in sorted(self.data.keys())]
measurements = [val for _, val in sorted(self.data.items())]
Expand All @@ -211,10 +212,12 @@ def fit_curve(self, initial_params: List[float] = None):

# The number of values must be at least the number of parameters
if len(timestamps) >= 4 and len(measurements) >= 4:
# Calculate standard deviation
if all(isinstance(m, Iterable) for m in measurements):
# Calculate standard deviation
measurements_std = [std(m, axis = 0) for m in measurements]
# Use the filtered median
measurements = [median(val) for val in measurements]
measurements = savgol_filter(measurements, window = 15, order = 2)

# Try to estimate initial parameters, if unsuccessful pass None
# None will result in scipy.optimize.curve_fit using its own default parameters
Expand Down
14 changes: 4 additions & 10 deletions src/colonyscanalyser/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def growth_curve(
:param line_color: a Colormap color for the median
"""
from statistics import median
from .utilities import savgol_filter

if line_color is None:
line_color = scatter_color
Expand All @@ -341,20 +342,13 @@ def growth_curve(
alpha = 0.25
)

# Plot the windowed median
from scipy.signal import savgol_filter
# Plot the smoothed median
median = [median(val) for _, val in sorted(plate.growth_curve.data.items())]
window = 15 if len(median) > 15 else len(median)
# Window length must be odd and greater than polyorder for Savitzky-Golay filter
if window % 2 == 0:
window -= 1
median_filtered = savgol_filter(median, window, 2) if window >= 3 else median

ax.plot(
[td.total_seconds() / 3600 for td in sorted(plate.growth_curve.data.keys())],
median_filtered,
savgol_filter(median, 15, 2),
color = line_color,
label = "Windowed median" if growth_params else f"Plate {plate.id}",
label = "Smoothed median" if growth_params else f"Plate {plate.id}",
linewidth = 2
)

Expand Down
26 changes: 25 additions & 1 deletion src/colonyscanalyser/utilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Collection, Tuple, List, Dict
from typing import Collection, Tuple, List, Dict, Any


def round_tuple_floats(tuple_item: Tuple[float], precision: int = 2) -> Tuple[float]:
Expand All @@ -15,6 +15,30 @@ def round_tuple_floats(tuple_item: Tuple[float], precision: int = 2) -> Tuple[fl
return tuple(map(lambda x: isinstance(x, float) and round(x, precision) or x, tuple_item))


def savgol_filter(measurements: Collection[Any], window: int = 15, order: int = 2) -> Collection[Any]:
"""
Smooth a one dimensional set of data with a Savitzky-Golay filter
If no filtering can be performed the original data is returned unaltered
:param measurements: a collection of values to filter
:param window: the window length used in the filter
:param order: the polynomial order used to fit the values
:returns: filtered values, if possible
"""
from scipy.signal import savgol_filter

# Window length must be odd and greater than polyorder for Savitzky-Golay filter
if window > len(measurements):
window = len(measurements)
if window % 2 == 0:
window -= 1
if window >= 1 and window > order:
measurements = savgol_filter(measurements, window, order)

return measurements


def progress_bar(bar_progress: float, bar_length: float = 30, message: str = ""):
"""
Output a simple progress bar to the console
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_growth_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ def test_fit_growth_curve(self, host):
assert host.growth_curve._carrying_capacity_std >= 0

def test_fit_growth_curve_iter(self, host):
host.signals = [[0, signal] for signal in host.signals]
host.growth_curve.fit_curve()
host.signals = [[signal] for signal in host.signals]

host.growth_curve.fit_curve(initial_params = [0.2, 0.4, 1, 1])

assert host.growth_curve._lag_time
assert host.growth_curve._lag_time_std.total_seconds() >= 0
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from colonyscanalyser.utilities import (
round_tuple_floats,
progress_bar,
savgol_filter,
dicts_merge,
dicts_mean,
dicts_median
Expand Down Expand Up @@ -62,6 +63,28 @@ def test_message(self, capsys):
assert captured.out[slice(-len(message) - 1, -1, 1)] == message


class TestSavGolFilter():
@pytest.mark.parametrize("window_length", [3, 4, 5, 7, 10])
def test_window(self, window_length):
measurements = [0] * window_length

results = savgol_filter(measurements, window_length, 1)

assert (results == measurements).all()

@pytest.mark.parametrize("order", [1, 2, 3, 4, 6])
def test_order(self, order):
measurements = [1] * 10
window_length = 3

results = savgol_filter(measurements, window_length, order)

if order >= window_length:
assert list(results) == measurements
else:
assert (results != measurements).any()


class TestDictsMerge():
@pytest.mark.parametrize(
"dicts, expected",
Expand Down

0 comments on commit da6dea2

Please sign in to comment.