Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add electronCount for single frame #320

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion python/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ py::array_t<T> vectorToPyArray(std::vector<T>&& v)
auto deleter = [](void* v) { delete reinterpret_cast<std::vector<T>*>(v); };
auto* ptr = new std::vector<T>(std::move(v));
auto capsule = py::capsule(ptr, deleter);
return py::array(ptr->size(), ptr->data(), capsule);
py::array_t<T> arr({ ptr->size() }, { sizeof(T) }, ptr->data(), capsule);
return arr;
}

struct ElectronCountedDataPyArray
Expand Down Expand Up @@ -214,6 +215,44 @@ ElectronCountedDataPyArray electronCount(Reader* reader,
return electronCount(reader, options.toCpp());
}

// Function to process individual frames
template <typename FrameType>
py::array_t<uint32_t> electronCount(
py::array_t<FrameType>& frame, Dimensions2D frameDimensions,
const ElectronCountOptionsClassicPy& options)
{
py::buffer_info frameBufferInfo = frame.request();

if (frameBufferInfo.ndim != 2 ||
frameBufferInfo.format != py::format_descriptor<FrameType>::format()) {
throw std::runtime_error(
"Input frame must be a 2D array of the correct type.");
}

const ElectronCountOptionsClassic cppOptions = options.toCpp();

// Convert the buffer to a std::vector
// TODO: is there a way to avoid this copy? For e.g span impl:
// https://github.com/pybind/pybind11/issues/1042#issuecomment-663154709
std::vector<FrameType> frameVec(static_cast<FrameType*>(frameBufferInfo.ptr),
static_cast<FrameType*>(frameBufferInfo.ptr) +
frameBufferInfo.size);

// Call the electronCount function with the std::vector
if (!cppOptions.darkReference) {
return vectorToPyArray(
electronCount<FrameType, false>(frameVec, frameDimensions, cppOptions));
} else {
return vectorToPyArray(
electronCount<FrameType, true>(frameVec, frameDimensions, cppOptions));
}

std::vector<uint32_t> result =
electronCount<FrameType, true>(frameVec, frameDimensions, cppOptions);

return vectorToPyArray(std::move(result));
}

// Explicitly instantiate version for py::array_t
template std::vector<STEMImage> createSTEMImages(
const std::vector<std::vector<py::array_t<uint32_t>>>& sparseData,
Expand Down Expand Up @@ -490,6 +529,12 @@ PYBIND11_MODULE(_image, m)
electronCount,
py::call_guard<py::gil_scoped_release>());

// Count individual frame
m.def("electron_count_frame",
(py::array_t<uint32_t>(*)(py::array_t<uint16_t>&, Dimensions2D,
const ElectronCountOptionsClassicPy&)) &
electronCount);

// Calculate thresholds, with gain
m.def(
"calculate_thresholds",
Expand Down
68 changes: 68 additions & 0 deletions python/stempy/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,74 @@ def electron_count(reader, darkreference=None, number_of_samples=40,

return array


def electron_count_frame(
frame: np.ndarray,
options=None,
darkreference=None,
background_threshold=4.0,
xray_threshold=2000.0,
gain=None,
):
"""Generate a list of coordinates of electron hits for a single 2D numpy array.

:param frame: the frame.
:type frame: numpy.ndarray
:param options: the options to use for electron counting. If set, all other
parameters are ignored.
:type options: stempy.image.ElectronCountOptionsClassic
:param darkreference: the dark reference to subtract, potentially generated
via stempy.image.calculate_average().
:type darkreference: stempy.image.ImageArray or numpy.ndarray
:param background_threshold: the threshold for background
:type background_threshold: float
:param xray_threshold: the threshold for x-rays
:type xray_threshold: float
:param gain: the gain mask to apply. Must match the frame dimensions.
:type gain: numpy.ndarray (2D)

:return: the coordinates of the electron hits for the frame.
:rtype: numpy.ndarray
"""

if gain is not None:
# Invert, as we will multiply in C++
# It also must be a float32
gain = np.power(gain, -1)
gain = _safe_cast(gain, np.float32, "gain")

if options is None:
if isinstance(darkreference, np.ndarray):
# Must be float32 for correct conversions
darkreference = _safe_cast(darkreference, np.float32, "dark reference")

options = _image.ElectronCountOptionsClassic()

options.dark_reference = darkreference
options.background_threshold = background_threshold
options.x_ray_threshold = xray_threshold
options.gain = gain
options.apply_row_dark_subtraction = False
options.optimized_mean = 0.0
options.apply_row_dark_use_mean = False
else:
if options.apply_row_dark_subtraction:
print("Warning: apply_row_dark_subtraction is not supported "
"for single frame electron counting. Ignoring this option.")
options.apply_row_dark_subtraction = False
options.gain = gain

electron_counts = _image.electron_count_frame(frame, frame.shape, options)
np_data = np.empty((1, 1), dtype=object)
np_data[0, 0] = np.array(electron_counts, copy=False)
kwargs = {
"data": np_data,
"scan_shape": (1, 1),
"frame_shape": frame.shape,
}
return SparseArray(**kwargs)


def radial_sum(reader, center=(-1, -1), scan_dimensions=(0, 0)):
"""Generate a radial sum from which STEM images can be generated.

Expand Down
26 changes: 26 additions & 0 deletions stempy/electron.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,24 @@ std::vector<uint32_t> electronCount(
return maximalPoints<FrameType>(frame, frameDimensions);
}

template <typename FrameType, bool dark>
std::vector<uint32_t> electronCount(std::vector<FrameType>& frame,
Dimensions2D frameDimensions,
const ElectronCountOptionsClassic& options)
{
auto* darkReference = options.darkReference;
auto backgroundThreshold = options.backgroundThreshold;
auto xRayThreshold = options.xRayThreshold;
auto* gain = options.gain;
auto applyRowDarkSubtraction = options.applyRowDarkSubtraction;
auto applyRowDarkUseMean = options.applyRowDarkUseMean;
auto optimizedMean = options.optimizedMean;

return electronCount<FrameType, dark>(
frame, frameDimensions, darkReference, backgroundThreshold, xRayThreshold,
gain, applyRowDarkSubtraction, optimizedMean, applyRowDarkUseMean);
}

template <typename Reader, typename FrameType, bool dark>
ElectronCountedData electronCount(Reader* reader,
const ElectronCountOptions& options)
Expand Down Expand Up @@ -874,4 +892,12 @@ template ElectronCountedData electronCount(
SectorStreamMultiPassThreadedReader* reader,
const ElectronCountOptions& options);

template std::vector<uint32_t> electronCount<uint16_t, true>(
std::vector<uint16_t>& frame, Dimensions2D frameDimensions,
const ElectronCountOptionsClassic& options);

template std::vector<uint32_t> electronCount<uint16_t, false>(
std::vector<uint16_t>& frame, Dimensions2D frameDimensions,
const ElectronCountOptionsClassic& options);

} // end namespace stempy
5 changes: 5 additions & 0 deletions stempy/electron.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ template <typename InputIt>
ElectronCountedData electronCount(InputIt first, InputIt last,
const ElectronCountOptionsClassic& options);

template <typename FrameType, bool dark>
std::vector<uint32_t> electronCount(std::vector<FrameType>& frame,
Dimensions2D frameDimensions,
const ElectronCountOptionsClassic& options);

template <typename Reader>
ElectronCountedData electronCount(Reader* reader,
const ElectronCountOptions& options);
Expand Down
69 changes: 68 additions & 1 deletion tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from stempy.image import com_dense, com_sparse, radial_sum_sparse
from stempy.image import com_dense, com_sparse, electron_count_frame, radial_sum_sparse
from stempy.io.sparse_array import SparseArray


Expand Down Expand Up @@ -61,3 +61,70 @@ def test_com_sparse_parameters(simulate_sparse_array):
# No counts will be in the center so all positions will be np.nan
com2 = com_sparse(sp, crop_to=(10,10), init_center=(1,1))
assert np.isnan(com2[0,0,0])


def test_electron_count_frame():
# Create a synthetic 2D numpy array (frame)
frame = np.array(
[
[2000, 0, 1000, 0, 0],
[0, 0, 0, 200, 0],
[0, 0, 1000, 0, 0],
[0, 200, 0, 200, 0],
[0, 0, 1000, 0, 0],
],
dtype=np.uint16,
)

dark = np.ones_like(frame) * 100

# Define expected electron hits (coordinates)
expected_hits = np.array(
[
[
[
[1, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
]
]
]
)

# Electron
electron_hits = electron_count_frame(
frame, xray_threshold=10000, background_threshold=1, darkreference=dark
)
assert np.array_equal(
electron_hits.to_dense(), expected_hits
), f"Expected {expected_hits}, but got {electron_hits}"

# Test with no dark reference
electron_hits = electron_count_frame(
frame, xray_threshold=10000, background_threshold=1
)

# Test where dark reference removes some points
dark = np.ones_like(frame) * 1000
electron_hits = electron_count_frame(
frame, xray_threshold=10000, background_threshold=1, darkreference=dark
)
expected_hits = np.array(
[
[
[
[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]
]
]
)

assert np.array_equal(
electron_hits.to_dense(), expected_hits
), f"Expected {expected_hits}, but got {electron_hits}"
Loading