Skip to content

Commit

Permalink
Fast feedback indexer (dials#2717)
Browse files Browse the repository at this point in the history
Add a Fast Feedback Indexer implementation of the TORO algorithm for snapshot serial indexing on a GPU with Cuda.

---------

Co-authored-by: Hans-Christian Stadler <[email protected]>
  • Loading branch information
dagewa and hcstadler authored Dec 16, 2024
1 parent d14566f commit 9197bf1
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 4 deletions.
1 change: 1 addition & 0 deletions .conda-envs/linux.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ cxx-compiler
dials-data>=2.4.72
docutils
eigen
ffbidx
flexparser<0.4
future
gemmi>=0.6.5,<0.7
Expand Down
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ David Waterman
Derek Mendez
Graeme Winter
Huw Jenkins
Hans-Christian Stadler
Ian Rees
Iris Young
James Beilsten-Edmands
Expand Down
1 change: 1 addition & 0 deletions newsfragments/2717.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``dials.index`` and ``dials.ssx_index``: Add the CUDA-accelerated fast-feedback-indexer to DIALS as a lattice search algorithm. See https://github.com/paulscherrerinstitute/fast-feedback-indexer for more details.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"dials.index.lattice_search": [
"low_res_spot_match = dials.algorithms.indexing.lattice_search:LowResSpotMatch",
"pink_indexer = dials.algorithms.indexing.lattice_search:PinkIndexer",
"ffbidx = dials.algorithms.indexing.lattice_search:FfbIndexer",
],
"dials.integration.background": [
"Auto = dials.extensions.auto_background_ext:AutoBackgroundExt",
Expand Down
3 changes: 2 additions & 1 deletion src/dials/algorithms/indexing/lattice_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from dials.algorithms.indexing import indexer
from dials.algorithms.indexing.basis_vector_search import combinations, optimise

from .ffb_indexer import FfbIndexer
from .low_res_spot_match import LowResSpotMatch
from .pinkindexer import PinkIndexer
from .strategy import Strategy

__all__ = ["Strategy", "LowResSpotMatch", "PinkIndexer"]
__all__ = ["Strategy", "LowResSpotMatch", "PinkIndexer", "FfbIndexer"]


logger = logging.getLogger(__name__)
Expand Down
215 changes: 215 additions & 0 deletions src/dials/algorithms/indexing/lattice_search/ffb_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from __future__ import annotations

import logging

import numpy

import iotbx.phil
from cctbx.sgtbx import space_group
from dxtbx import flumpy
from dxtbx.model import Crystal

from dials.algorithms.indexing import DialsIndexError

from .strategy import Strategy

# Import fast feedback indexer package (CUDA implementation of the TORO algorithm)
# https://github.com/paulscherrerinstitute/fast-feedback-indexer/tree/main/python
try:
import ffbidx
except ModuleNotFoundError:
ffbidx = None


logger = logging.getLogger(__name__)

ffbidx_phil_str = """
ffbidx
.expert_level = 1
{
max_output_cells = 32
.type = int(value_min=1)
.help = "Maximum number of output cells"
max_spots = 300
.type = int(value_min=8)
.help = "Maximum number of reciprocal spots taken into account"
num_candidate_vectors = 32
.type = int(value_min=1)
.help = "Number of candidate cell vectors"
redundant_computations = True
.type = bool
.help = "Calculate candidates for all three cell vectors"
dist1 = 0.3
.type = float(value_min=0.001, value_max=0.5)
.help = "Reciprocal spots within this threshold contribute to the score for vector sampling"
dist3 = 0.15
.type = float(value_min=0.001, value_max=0.8)
.help = "Reciprocal spots within this threshold contribute to the score for cell sampling"
num_halfsphere_points = 32768
.type = int(value_min=8000)
.help = "Number of sampling points on the half sphere"
max_dist = 0.00075
.type = float(value_min=0.0)
.help = "Maximum final distance between measured and calculated reciprocal spot"
min_spots = 8
.type = int(value_min=6)
.help = "Minimum number of reciprocal spots within distance max_dist"
method = *ifssr ifss ifse raw
.type = choice
.help = "Refinement method (consult algorithm description)"
triml = 0.001
.type = float(value_min=0, value_max=0.5)
.help = "lower trimming value for intermediate score calculations"
trimh = 0.3
.type = float(value_min=0, value_max=0.5)
.help = "higher trimming value for intermediate score calculations"
delta = 0.1
.type = float(value_min=0.000001)
.help = "log2 curve position for intermediate score calculations, lower values will me more selective in choosing close spots"
simple_data_filename = None
.type = path
.help = "Optional filename for the output of a simple data file for debugging"
.expert_level = 2
}
"""


def write_simple_data_file(filename, rlp, cell):
"""Write a simple data file for debugging."""
with open(filename, "w") as f:
f.write(" ".join(map(str, cell.ravel())) + "\n")
for r in rlp:
f.write(" ".join(map(str, r.ravel())) + "\n")


class FfbIndexer(Strategy):
"""
A lattice search strategy using a Cuda-accelerated implementation of the TORO algorithm.
For more info, see:
[Gasparotto P, et al. TORO Indexer: a PyTorch-based indexing algorithm for kilohertz serial crystallography. J. Appl. Cryst. 2024 57(4)](https://doi.org/10.1107/S1600576724003182)
"""

phil_help = (
"A lattice search strategy for very fast indexing using Cuda acceleration"
)

phil_scope = iotbx.phil.parse(ffbidx_phil_str)

def __init__(
self, target_symmetry_primitive, max_lattices, params=None, *args, **kwargs
):
"""Construct FfbIndexer object.
Args:
target_symmetry_primitive (cctbx.crystal.symmetry): The target
crystal symmetry and unit cell
max_lattices (int): The maximum number of lattice models to find
params (phil,optional): Phil params
Returns:
None
"""
super().__init__(params=None, *args, **kwargs)

if ffbidx is None:
raise DialsIndexError(
"ffbidx requires the fast feedback indexer package. See (https://github.com/paulscherrerinstitute/fast-feedback-indexer)"
)

self._target_symmetry_primitive = target_symmetry_primitive
self._max_lattices = max_lattices

if target_symmetry_primitive is None:
raise DialsIndexError("Target unit cell must be provided for ffbidx")

target_cell = target_symmetry_primitive.unit_cell()
if target_cell is None:
raise ValueError("Please specify known_symmetry.unit_cell")

self.params = params

# Need the real space cell as numpy float32 array with all x vector coordinates, followed by y and z coordinates consecutively in memory
self.input_cell = numpy.reshape(
numpy.array(target_cell.orthogonalization_matrix(), dtype="float32"), (3, 3)
)

# Create fast feedback indexer object (on default CUDA device)
try:
self.indexer = ffbidx.Indexer(
max_output_cells=params.max_output_cells,
max_spots=params.max_spots,
num_candidate_vectors=params.num_candidate_vectors,
redundant_computations=params.redundant_computations,
)
except RuntimeError as e:
raise DialsIndexError(
"The ffbidx package is not correctly configured for this system. See (https://github.com/paulscherrerinstitute/fast-feedback-indexer). Error: "
+ str(e)
)

def find_crystal_models(self, reflections, experiments):
"""Find a list of candidate crystal models.
Args:
reflections (dials.array_family.flex.reflection_table):
The found spots centroids and associated data
experiments (dxtbx.model.experiment_list.ExperimentList):
The experimental geometry models
Returns:
A list of candidate crystal models.
"""

rlp = numpy.array(flumpy.to_numpy(reflections["rlp"]), dtype="float32")

if self.params.simple_data_filename is not None:
write_simple_data_file(
self.params.simple_data_filename, rlp, self.input_cell
)

# Need the reciprocal lattice points as numpy float32 array with all x coordinates, followed by y and z coordinates consecutively in memory
rlp = rlp.transpose().copy()

output_cells, scores = self.indexer.run(
rlp,
self.input_cell,
dist1=self.params.dist1,
dist3=self.params.dist3,
num_halfsphere_points=self.params.num_halfsphere_points,
max_dist=self.params.max_dist,
min_spots=self.params.min_spots,
n_output_cells=self.params.max_output_cells,
method=self.params.method,
triml=self.params.triml,
trimh=self.params.trimh,
delta=self.params.delta,
)

cell_indices = self.indexer.crystals(
output_cells,
rlp,
scores,
threshold=self.params.max_dist,
min_spots=self.params.min_spots,
method=self.params.method,
)

candidate_crystal_models = []
if cell_indices is None:
return candidate_crystal_models

for index in cell_indices:
j = 3 * index
real_a = output_cells[:, j]
real_b = output_cells[:, j + 1]
real_c = output_cells[:, j + 2]
crystal = Crystal(
real_a.tolist(),
real_b.tolist(),
real_c.tolist(),
space_group=space_group("P1"),
)
candidate_crystal_models.append(crystal)

return candidate_crystal_models
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def __init__(

if target_symmetry_primitive is None:
raise DialsIndexError(
"Target unit cell and space group must be provided for small_cell"
"Target unit cell and space group must be provided for pink_indexer"
)

target_cell = target_symmetry_primitive.unit_cell()
Expand Down
1 change: 1 addition & 0 deletions src/dials/algorithms/indexing/ssx/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class IndexingResult:
"dials.algorithms.indexing.indexer",
"dials.algorithms.indexing.lattice_search",
"dials.algorithms.indexing.lattice_search.low_res_spot_match",
"dials.algorithms.indexing.lattice_search.ffb_indexer",
]
debug_loggers_to_disable = [
"dials.algorithms.indexing.symmetry",
Expand Down
2 changes: 1 addition & 1 deletion src/dials/command_line/ssx_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@

phil_scope = phil.parse(
"""
method = *fft1d *real_space_grid_search pink_indexer low_res_spot_match
method = *fft1d *real_space_grid_search pink_indexer low_res_spot_match ffbidx
.type = choice(multi=True)
nproc = Auto
.type = int
Expand Down
53 changes: 52 additions & 1 deletion tests/algorithms/indexing/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,58 @@ def test_pink_indexer(
"min_lattices=5",
"percent_bandwidth=2",
'known_symmetry.space_group="P 21 3"',
"known_symmetry.unit_cell=96.410, 96.410,96.410,90.0,90.0,90.0",
"known_symmetry.unit_cell=96.410,96.410,96.410,90.0,90.0,90.0",
]

expected_unit_cell = uctbx.unit_cell((96.41, 96.41, 96.41, 90, 90, 90))
expected_rmsds = (0.200, 0.200, 0.000)
expected_hall_symbol = " P 2ac 2ab 3"

run_indexing(
"combined.expt",
"combined.refl",
tmp_path,
extra_args,
expected_unit_cell,
expected_rmsds,
expected_hall_symbol,
n_expected_lattices=5,
)


def test_ffbidx(
dials_data,
tmp_path,
):
try:
import ffbidx # noqa: F401
except ModuleNotFoundError:
pytest.skip("ffbidx not installed")
try:
ffbidx.Indexer()
except RuntimeError:
pytest.skip("ffbidx installed but not functional on this system")

data_dir = dials_data("cunir_serial_processed", pathlib=True)
expt_file = data_dir / "imported_with_ref_5.expt"
refl_file = data_dir / "strong_5.refl"

command = [shutil.which("dials.split_experiments"), expt_file, refl_file]
result = subprocess.run(command, cwd=tmp_path)
assert not result.returncode and not result.stderr

command = [shutil.which("dials.combine_experiments")]
for i in range(5):
command.append(f"split_{i}.expt")
command.append(f"split_{i}.refl")
result = subprocess.run(command, cwd=tmp_path)
assert not result.returncode and not result.stderr

extra_args = [
"joint_indexing=False",
"indexing.method=ffbidx",
'known_symmetry.space_group="P 21 3"',
"known_symmetry.unit_cell=96.410,96.410,96.410,90.0,90.0,90.0",
]

expected_unit_cell = uctbx.unit_cell((96.41, 96.41, 96.41, 90, 90, 90))
Expand Down

0 comments on commit 9197bf1

Please sign in to comment.