From 9197bf1806bfa81931f93c636cf8c319c63c6d79 Mon Sep 17 00:00:00 2001 From: David Waterman Date: Mon, 16 Dec 2024 15:10:12 +0000 Subject: [PATCH] Fast feedback indexer (#2717) Add a Fast Feedback Indexer implementation of the TORO algorithm for snapshot serial indexing on a GPU with Cuda. --------- Co-authored-by: Hans-Christian Stadler --- .conda-envs/linux.txt | 1 + AUTHORS | 1 + newsfragments/2717.feature | 1 + setup.py | 1 + .../indexing/lattice_search/__init__.py | 3 +- .../indexing/lattice_search/ffb_indexer.py | 215 ++++++++++++++++++ .../indexing/lattice_search/pinkindexer.py | 2 +- .../algorithms/indexing/ssx/processing.py | 1 + src/dials/command_line/ssx_index.py | 2 +- tests/algorithms/indexing/test_index.py | 53 ++++- 10 files changed, 276 insertions(+), 4 deletions(-) create mode 100644 newsfragments/2717.feature create mode 100644 src/dials/algorithms/indexing/lattice_search/ffb_indexer.py diff --git a/.conda-envs/linux.txt b/.conda-envs/linux.txt index cc61e8f0e2..733a53e51b 100644 --- a/.conda-envs/linux.txt +++ b/.conda-envs/linux.txt @@ -8,6 +8,7 @@ cxx-compiler dials-data>=2.4.72 docutils eigen +ffbidx flexparser<0.4 future gemmi>=0.6.5,<0.7 diff --git a/AUTHORS b/AUTHORS index ded15654bb..9d94b294b3 100644 --- a/AUTHORS +++ b/AUTHORS @@ -12,6 +12,7 @@ David Waterman Derek Mendez Graeme Winter Huw Jenkins +Hans-Christian Stadler Ian Rees Iris Young James Beilsten-Edmands diff --git a/newsfragments/2717.feature b/newsfragments/2717.feature new file mode 100644 index 0000000000..b54ba8b4e9 --- /dev/null +++ b/newsfragments/2717.feature @@ -0,0 +1 @@ +``dials.index`` and ``dials.ssx_index``: Add the CUDA-accelerated fast-feedback-indexer to DIALS as a lattice search algorithm. See https://github.com/paulscherrerinstitute/fast-feedback-indexer for more details. diff --git a/setup.py b/setup.py index 5ec84a6f02..6b2615a63b 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ "dials.index.lattice_search": [ "low_res_spot_match = dials.algorithms.indexing.lattice_search:LowResSpotMatch", "pink_indexer = dials.algorithms.indexing.lattice_search:PinkIndexer", + "ffbidx = dials.algorithms.indexing.lattice_search:FfbIndexer", ], "dials.integration.background": [ "Auto = dials.extensions.auto_background_ext:AutoBackgroundExt", diff --git a/src/dials/algorithms/indexing/lattice_search/__init__.py b/src/dials/algorithms/indexing/lattice_search/__init__.py index e192763f90..3ee8b855bf 100644 --- a/src/dials/algorithms/indexing/lattice_search/__init__.py +++ b/src/dials/algorithms/indexing/lattice_search/__init__.py @@ -15,11 +15,12 @@ from dials.algorithms.indexing import indexer from dials.algorithms.indexing.basis_vector_search import combinations, optimise +from .ffb_indexer import FfbIndexer from .low_res_spot_match import LowResSpotMatch from .pinkindexer import PinkIndexer from .strategy import Strategy -__all__ = ["Strategy", "LowResSpotMatch", "PinkIndexer"] +__all__ = ["Strategy", "LowResSpotMatch", "PinkIndexer", "FfbIndexer"] logger = logging.getLogger(__name__) diff --git a/src/dials/algorithms/indexing/lattice_search/ffb_indexer.py b/src/dials/algorithms/indexing/lattice_search/ffb_indexer.py new file mode 100644 index 0000000000..09f15ee9ce --- /dev/null +++ b/src/dials/algorithms/indexing/lattice_search/ffb_indexer.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import logging + +import numpy + +import iotbx.phil +from cctbx.sgtbx import space_group +from dxtbx import flumpy +from dxtbx.model import Crystal + +from dials.algorithms.indexing import DialsIndexError + +from .strategy import Strategy + +# Import fast feedback indexer package (CUDA implementation of the TORO algorithm) +# https://github.com/paulscherrerinstitute/fast-feedback-indexer/tree/main/python +try: + import ffbidx +except ModuleNotFoundError: + ffbidx = None + + +logger = logging.getLogger(__name__) + +ffbidx_phil_str = """ +ffbidx + .expert_level = 1 +{ + max_output_cells = 32 + .type = int(value_min=1) + .help = "Maximum number of output cells" + max_spots = 300 + .type = int(value_min=8) + .help = "Maximum number of reciprocal spots taken into account" + num_candidate_vectors = 32 + .type = int(value_min=1) + .help = "Number of candidate cell vectors" + redundant_computations = True + .type = bool + .help = "Calculate candidates for all three cell vectors" + dist1 = 0.3 + .type = float(value_min=0.001, value_max=0.5) + .help = "Reciprocal spots within this threshold contribute to the score for vector sampling" + dist3 = 0.15 + .type = float(value_min=0.001, value_max=0.8) + .help = "Reciprocal spots within this threshold contribute to the score for cell sampling" + num_halfsphere_points = 32768 + .type = int(value_min=8000) + .help = "Number of sampling points on the half sphere" + max_dist = 0.00075 + .type = float(value_min=0.0) + .help = "Maximum final distance between measured and calculated reciprocal spot" + min_spots = 8 + .type = int(value_min=6) + .help = "Minimum number of reciprocal spots within distance max_dist" + method = *ifssr ifss ifse raw + .type = choice + .help = "Refinement method (consult algorithm description)" + triml = 0.001 + .type = float(value_min=0, value_max=0.5) + .help = "lower trimming value for intermediate score calculations" + trimh = 0.3 + .type = float(value_min=0, value_max=0.5) + .help = "higher trimming value for intermediate score calculations" + delta = 0.1 + .type = float(value_min=0.000001) + .help = "log2 curve position for intermediate score calculations, lower values will me more selective in choosing close spots" + simple_data_filename = None + .type = path + .help = "Optional filename for the output of a simple data file for debugging" + .expert_level = 2 +} +""" + + +def write_simple_data_file(filename, rlp, cell): + """Write a simple data file for debugging.""" + with open(filename, "w") as f: + f.write(" ".join(map(str, cell.ravel())) + "\n") + for r in rlp: + f.write(" ".join(map(str, r.ravel())) + "\n") + + +class FfbIndexer(Strategy): + """ + A lattice search strategy using a Cuda-accelerated implementation of the TORO algorithm. + For more info, see: + [Gasparotto P, et al. TORO Indexer: a PyTorch-based indexing algorithm for kilohertz serial crystallography. J. Appl. Cryst. 2024 57(4)](https://doi.org/10.1107/S1600576724003182) + """ + + phil_help = ( + "A lattice search strategy for very fast indexing using Cuda acceleration" + ) + + phil_scope = iotbx.phil.parse(ffbidx_phil_str) + + def __init__( + self, target_symmetry_primitive, max_lattices, params=None, *args, **kwargs + ): + """Construct FfbIndexer object. + + Args: + target_symmetry_primitive (cctbx.crystal.symmetry): The target + crystal symmetry and unit cell + max_lattices (int): The maximum number of lattice models to find + params (phil,optional): Phil params + + Returns: + None + """ + super().__init__(params=None, *args, **kwargs) + + if ffbidx is None: + raise DialsIndexError( + "ffbidx requires the fast feedback indexer package. See (https://github.com/paulscherrerinstitute/fast-feedback-indexer)" + ) + + self._target_symmetry_primitive = target_symmetry_primitive + self._max_lattices = max_lattices + + if target_symmetry_primitive is None: + raise DialsIndexError("Target unit cell must be provided for ffbidx") + + target_cell = target_symmetry_primitive.unit_cell() + if target_cell is None: + raise ValueError("Please specify known_symmetry.unit_cell") + + self.params = params + + # Need the real space cell as numpy float32 array with all x vector coordinates, followed by y and z coordinates consecutively in memory + self.input_cell = numpy.reshape( + numpy.array(target_cell.orthogonalization_matrix(), dtype="float32"), (3, 3) + ) + + # Create fast feedback indexer object (on default CUDA device) + try: + self.indexer = ffbidx.Indexer( + max_output_cells=params.max_output_cells, + max_spots=params.max_spots, + num_candidate_vectors=params.num_candidate_vectors, + redundant_computations=params.redundant_computations, + ) + except RuntimeError as e: + raise DialsIndexError( + "The ffbidx package is not correctly configured for this system. See (https://github.com/paulscherrerinstitute/fast-feedback-indexer). Error: " + + str(e) + ) + + def find_crystal_models(self, reflections, experiments): + """Find a list of candidate crystal models. + + Args: + reflections (dials.array_family.flex.reflection_table): + The found spots centroids and associated data + + experiments (dxtbx.model.experiment_list.ExperimentList): + The experimental geometry models + + Returns: + A list of candidate crystal models. + """ + + rlp = numpy.array(flumpy.to_numpy(reflections["rlp"]), dtype="float32") + + if self.params.simple_data_filename is not None: + write_simple_data_file( + self.params.simple_data_filename, rlp, self.input_cell + ) + + # Need the reciprocal lattice points as numpy float32 array with all x coordinates, followed by y and z coordinates consecutively in memory + rlp = rlp.transpose().copy() + + output_cells, scores = self.indexer.run( + rlp, + self.input_cell, + dist1=self.params.dist1, + dist3=self.params.dist3, + num_halfsphere_points=self.params.num_halfsphere_points, + max_dist=self.params.max_dist, + min_spots=self.params.min_spots, + n_output_cells=self.params.max_output_cells, + method=self.params.method, + triml=self.params.triml, + trimh=self.params.trimh, + delta=self.params.delta, + ) + + cell_indices = self.indexer.crystals( + output_cells, + rlp, + scores, + threshold=self.params.max_dist, + min_spots=self.params.min_spots, + method=self.params.method, + ) + + candidate_crystal_models = [] + if cell_indices is None: + return candidate_crystal_models + + for index in cell_indices: + j = 3 * index + real_a = output_cells[:, j] + real_b = output_cells[:, j + 1] + real_c = output_cells[:, j + 2] + crystal = Crystal( + real_a.tolist(), + real_b.tolist(), + real_c.tolist(), + space_group=space_group("P1"), + ) + candidate_crystal_models.append(crystal) + + return candidate_crystal_models diff --git a/src/dials/algorithms/indexing/lattice_search/pinkindexer.py b/src/dials/algorithms/indexing/lattice_search/pinkindexer.py index b653e23767..5b66022686 100644 --- a/src/dials/algorithms/indexing/lattice_search/pinkindexer.py +++ b/src/dials/algorithms/indexing/lattice_search/pinkindexer.py @@ -365,7 +365,7 @@ def __init__( if target_symmetry_primitive is None: raise DialsIndexError( - "Target unit cell and space group must be provided for small_cell" + "Target unit cell and space group must be provided for pink_indexer" ) target_cell = target_symmetry_primitive.unit_cell() diff --git a/src/dials/algorithms/indexing/ssx/processing.py b/src/dials/algorithms/indexing/ssx/processing.py index 86537f392b..177826bb70 100644 --- a/src/dials/algorithms/indexing/ssx/processing.py +++ b/src/dials/algorithms/indexing/ssx/processing.py @@ -66,6 +66,7 @@ class IndexingResult: "dials.algorithms.indexing.indexer", "dials.algorithms.indexing.lattice_search", "dials.algorithms.indexing.lattice_search.low_res_spot_match", + "dials.algorithms.indexing.lattice_search.ffb_indexer", ] debug_loggers_to_disable = [ "dials.algorithms.indexing.symmetry", diff --git a/src/dials/command_line/ssx_index.py b/src/dials/command_line/ssx_index.py index 1945da1fe9..ee322c251b 100644 --- a/src/dials/command_line/ssx_index.py +++ b/src/dials/command_line/ssx_index.py @@ -79,7 +79,7 @@ phil_scope = phil.parse( """ -method = *fft1d *real_space_grid_search pink_indexer low_res_spot_match +method = *fft1d *real_space_grid_search pink_indexer low_res_spot_match ffbidx .type = choice(multi=True) nproc = Auto .type = int diff --git a/tests/algorithms/indexing/test_index.py b/tests/algorithms/indexing/test_index.py index 30b3064564..3bc9e7594d 100644 --- a/tests/algorithms/indexing/test_index.py +++ b/tests/algorithms/indexing/test_index.py @@ -819,7 +819,58 @@ def test_pink_indexer( "min_lattices=5", "percent_bandwidth=2", 'known_symmetry.space_group="P 21 3"', - "known_symmetry.unit_cell=96.410, 96.410,96.410,90.0,90.0,90.0", + "known_symmetry.unit_cell=96.410,96.410,96.410,90.0,90.0,90.0", + ] + + expected_unit_cell = uctbx.unit_cell((96.41, 96.41, 96.41, 90, 90, 90)) + expected_rmsds = (0.200, 0.200, 0.000) + expected_hall_symbol = " P 2ac 2ab 3" + + run_indexing( + "combined.expt", + "combined.refl", + tmp_path, + extra_args, + expected_unit_cell, + expected_rmsds, + expected_hall_symbol, + n_expected_lattices=5, + ) + + +def test_ffbidx( + dials_data, + tmp_path, +): + try: + import ffbidx # noqa: F401 + except ModuleNotFoundError: + pytest.skip("ffbidx not installed") + try: + ffbidx.Indexer() + except RuntimeError: + pytest.skip("ffbidx installed but not functional on this system") + + data_dir = dials_data("cunir_serial_processed", pathlib=True) + expt_file = data_dir / "imported_with_ref_5.expt" + refl_file = data_dir / "strong_5.refl" + + command = [shutil.which("dials.split_experiments"), expt_file, refl_file] + result = subprocess.run(command, cwd=tmp_path) + assert not result.returncode and not result.stderr + + command = [shutil.which("dials.combine_experiments")] + for i in range(5): + command.append(f"split_{i}.expt") + command.append(f"split_{i}.refl") + result = subprocess.run(command, cwd=tmp_path) + assert not result.returncode and not result.stderr + + extra_args = [ + "joint_indexing=False", + "indexing.method=ffbidx", + 'known_symmetry.space_group="P 21 3"', + "known_symmetry.unit_cell=96.410,96.410,96.410,90.0,90.0,90.0", ] expected_unit_cell = uctbx.unit_cell((96.41, 96.41, 96.41, 90, 90, 90))