From 70477d0981f5126f4455c21a8c1eb42abbbec18e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 19 Jan 2024 13:04:20 -0500 Subject: [PATCH 1/2] Allow the results to be saved to a file as a table --- src/kbmod/configuration.py | 2 + src/kbmod/result_list.py | 234 +++++++++++++++++++++++++++++++++- src/kbmod/run_search.py | 4 +- tests/test_regression_test.py | 46 ++++--- tests/test_result_list.py | 115 +++++++++++++++-- tests/test_work_unit.py | 3 +- 6 files changed, 373 insertions(+), 31 deletions(-) diff --git a/src/kbmod/configuration.py b/src/kbmod/configuration.py index 6ff84eb47..9b1323388 100644 --- a/src/kbmod/configuration.py +++ b/src/kbmod/configuration.py @@ -51,6 +51,7 @@ def __init__(self): "encode_num_bytes": -1, "flag_keys": default_flag_keys, "gpu_filter": False, + "individual_output_files": True, "im_filepath": None, "known_obj_obs": 3, "known_obj_thresh": None, @@ -72,6 +73,7 @@ def __init__(self): "psf_file": None, "repeated_flag_keys": default_repeated_flag_keys, "res_filepath": None, + "result_filename": None, "sigmaG_lims": [25, 75], "stamp_radius": 10, "stamp_type": "sum", diff --git a/src/kbmod/result_list.py b/src/kbmod/result_list.py index b59ac1ef0..cd4024e0b 100644 --- a/src/kbmod/result_list.py +++ b/src/kbmod/result_list.py @@ -7,18 +7,44 @@ import multiprocessing as mp import numpy as np import os.path as ospath +from pathlib import Path from astropy.table import Table from yaml import dump, safe_load from kbmod.file_utils import * from kbmod.trajectory_utils import ( + make_trajectory, trajectory_from_yaml, trajectory_predict_skypos, trajectory_to_yaml, ) +def _check_optional_allclose(arr1, arr2): + """Check whether toward optional numpy arrays have the same information. + + Parameters + ---------- + arr1 : `numpy.ndarray` or `None` + The first array. + arr1 : `numpy.ndarray` or `None` + The second array. + + Returns + ------- + result : `bool` + Indicates whether the arrays are the same. + """ + if arr1 is None and arr2 is None: + return True + if arr1 is not None and arr2 is None: + return False + if arr1 is None and arr2 is not None: + return False + return np.allclose(arr1, arr2) + + class ResultRow: """This class stores a collection of related data from a single kbmod result. In order to maintain a consistent internal state, the class uses private variables @@ -86,6 +112,104 @@ def __init__(self, trj, num_times): self.trajectory = trj self._valid_indices = [i for i in range(num_times)] + @classmethod + def from_table_row(cls, data, num_times=None): + """Create a ResultRow object directly from an AstroPy Table row. + + Parameters + ---------- + data : 'astropy.table.row.Row' + The incoming row. + all_times : `int`, optional + The number of total times in the data. If ``None`` tries + to extract from a "num_times" or "all_times" column. + + Raises + ------ + KeyError if a column is missing. + """ + if num_times is None: + if "num_times" in data.columns: + num_times = data["num_times"] + elif "all_times" in data.columns: + num_times = len(data["all_times"]) + else: + raise KeyError("Number of times is not specified.") + + # Create the Trajectory object from the correct fields. + trj = make_trajectory( + data["trajectory_x"], + data["trajectory_y"], + data["trajectory_vx"], + data["trajectory_vy"], + data["flux"], + data["likelihood"], + data["obs_count"], + ) + + # Manually fill in all the rest of the values. We let the stamp related columns + # be empty to save space. + row = ResultRow(trj, num_times) + row._final_likelihood = data["likelihood"] + row._phi_curve = data["phi_curve"] + row.pred_dec = data["pred_dec"] + row.pred_ra = data["pred_ra"] + row._psi_curve = data["psi_curve"] + row._valid_indices = data["valid_indices"] + + if "all_stamps" in data.columns: + row.all_stamps = data["all_stamps"] + else: + row.all_stamps = None + + if "stamp" in data.columns: + row.stamp = data["stamp"] + else: + row.stamp = None + + return row + + def __eq__(self, other): + """Test if two result rows are equal.""" + if not isinstance(other, ResultRow): + return False + + # Check the attributes of the trajectory first. + if ( + self.trajectory.x != other.trajectory.x + or self.trajectory.y != other.trajectory.y + or self.trajectory.vx != other.trajectory.vx + or self.trajectory.vy != other.trajectory.vy + or self.trajectory.lh != other.trajectory.lh + or self.trajectory.flux != other.trajectory.flux + or self.trajectory.obs_count != other.trajectory.obs_count + ): + return False + + # Check the simple attributes. + if not self._num_times == other._num_times: + return False + if not self._final_likelihood == other._final_likelihood: + return False + + # Check the curves and stamps. + if not _check_optional_allclose(self.all_stamps, other.all_stamps): + return False + if not _check_optional_allclose(self._phi_curve, other._phi_curve): + return False + if not _check_optional_allclose(self._psi_curve, other._psi_curve): + return False + if not _check_optional_allclose(self.stamp, other.stamp): + return False + if not _check_optional_allclose(self._valid_indices, other._valid_indices): + return False + if not _check_optional_allclose(self.pred_dec, other.pred_dec): + return False + if not _check_optional_allclose(self.pred_ra, other.pred_ra): + return False + + return True + @property def final_likelihood(self): return self._final_likelihood @@ -399,6 +523,59 @@ def from_yaml(cls, yaml_str): result_list.filtered[key] = [ResultRow.from_yaml(row) for row in yaml_dict["filtered"][key]] return result_list + @classmethod + def from_table(self, data, all_times=None, track_filtered=False): + """Extract the ResultList from an astropy Table. + + Parameters + ---------- + data : `astropy.table.Table` + The input data. + all_times : `List` or `numpy.ndarray` or None + The list of all time stamps. Must either be set or there + must be an all_times column in the Table. + track_filtered : `bool` + Indicates whether the ResultList should track future filtered points. + + Raises + ------ + KeyError if any columns are missing or if is ``all_times`` is None and there + is no all_times column in the data. + """ + # Check that we have some list of time stamps and place it in all_times. + if all_times is None: + if "all_times" not in data.columns: + raise KeyError(f"No time stamps provided.") + else: + all_times = data["all_times"][0] + num_times = len(all_times) + + result_list = ResultList(all_times, track_filtered) + for i in range(len(data)): + row = ResultRow.from_table_row(data[i], num_times) + result_list.append_result(row) + return result_list + + @classmethod + def read_table(self, filename): + """Read the ResultList from a table file. + + Parameters + ---------- + filename : `str` + The name of the file to load. + + + Raises + ------ + FileNotFoundError if the file is not found. + KeyError if any of the columns are missing. + """ + if not Path(filename).is_file(): + raise FileNotFoundError + data = Table.read(filename) + return ResultList.from_table(data) + def num_results(self): """Return the number of results in the list. @@ -413,6 +590,37 @@ def __len__(self): """Return the number of results in the list.""" return len(self.results) + def __eq__(self, other): + """Test if two ResultLists are equal. Includes both ordering and values.""" + if not isinstance(other, ResultList): + return False + if not np.allclose(self._all_times, other._all_times): + return False + if self.track_filtered != other.track_filtered: + return False + + num_results = len(self.results) + if num_results != len(other.results): + return False + for i in range(num_results): + if self.results[i] != other.results[i]: + return False + + if len(self.filtered) != len(other.filtered): + return False + for key in self.filtered.keys(): + if key not in other.filtered: + return False + + num_filtered = len(self.filtered[key]) + if num_filtered != len(other.filtered): + return False + for i in range(num_filtered): + if self.filtered[key][i] != other.filtered[key][i]: + return False + + return True + def clear(self): """Clear the list of results.""" self.results.clear() @@ -648,7 +856,7 @@ def revert_filter(self, label=None): return self - def to_table(self, filtered_label=None): + def to_table(self, filtered_label=None, append_times=False): """Extract the results into an astropy table. Parameters @@ -656,6 +864,8 @@ def to_table(self, filtered_label=None): filtered_label : `str`, optional The filtering label to extract. If None then extracts the unfiltered rows. (default=None) + append_times : `bool` + Append the list of all times as a column in the data. Returns ------- @@ -690,12 +900,34 @@ def to_table(self, filtered_label=None): "pred_ra": [], "pred_dec": [], } + if append_times: + table_dict["all_times"] = [] # Use a (slow) linear scan to do the transformation. for row in list_ref: row.append_to_dict(table_dict, True) + if append_times: + table_dict["all_times"].append(self._all_times) + return Table(table_dict) + def write_table(self, filename, overwrite=True): + """Write the unfiltered results to a single (ecsv) file. + + Parameter + --------- + filename : `str` + The name of the result file. + overwrite : `bool` + Overwrite the file if it already exists. + """ + table_version = self.to_table(append_times=True) + + # Drop the all stamps column as this is often too large to write in a CSV entry. + table_version.remove_column("all_stamps") + + table_version.write(filename, overwrite=True) + def to_yaml(self, serialize_filtered=False): """Serialize the ResultList as a YAML string. diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index f2d740408..8ec738ed0 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -191,11 +191,13 @@ def run_search(self, config, stack): # Save the results and the configuration information used. print(f"Found {keep.num_results()} potential trajectories.") - if config["res_filepath"] is not None: + if config["res_filepath"] is not None and config["individual_output_files"]: keep.save_to_files(config["res_filepath"], config["output_suffix"]) config_filename = os.path.join(config["res_filepath"], f"config_{config['output_suffix']}.yml") config.to_file(config_filename, overwrite=True) + if config["result_filename"] is not None: + keep.write_table(config["result_filename"]) end = time.time() print("Time taken for patch: ", end - start) diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index 3794b17f8..8708f0253 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -14,6 +14,7 @@ from kbmod.fake_data_creator import add_fake_object from kbmod.file_utils import * +from kbmod.result_list import ResultList from kbmod.run_search import SearchRunner from kbmod.search import * from kbmod.trajectory_utils import make_trajectory @@ -312,18 +313,23 @@ def load_trajectories_from_file(filename): return trjs -def perform_search(im_filepath, time_file, psf_file, res_filepath, results_suffix, default_psf): +def perform_search(im_filepath, time_file, psf_file, res_filename, default_psf): """ Run the core search algorithm. - Arguments: - im_filepath - The file path (directory) for the image files. - time_file - The path and file name of the file of timestamps. - psf_file - The path and file name of the psf values. - res_filepath - The path (directory) for the new result files. - results_suffix - The file suffix to use for the new results. - default_psf - The default PSF value to use when nothing is provided - in the PSF file. + Parameters + ---------- + im_filepath : `str` + The file path (directory) for the image files. + time_file : `str` + The path and file name of the file of timestamps. + psf_file : `str` + The path and file name of the psf values. + res_filename : `str` + The path (directory) for the new result files. + default_psf : `float` + The default PSF value to use when nothing is provided + in the PSF file. """ v_min = 92.0 # Pixels/day v_max = 550.0 @@ -370,11 +376,12 @@ def perform_search(im_filepath, time_file, psf_file, res_filepath, results_suffi input_parameters = { "im_filepath": im_filepath, - "res_filepath": res_filepath, + "res_filepath": None, + "result_filename": res_filename, "time_file": time_file, "psf_file": psf_file, "psf_val": default_psf, - "output_suffix": results_suffix, + "output_suffix": "", "v_arr": v_arr, "ang_arr": ang_arr, "num_obs": num_obs, @@ -466,17 +473,20 @@ def run_full_test(): # Do the search. print("Running search with data in %s/" % dir_name) + result_filename = os.path.join(dir_name, "results.ecsv") perform_search( - dir_name + "/imgs", - dir_name + "/times.dat", - dir_name + "/psf_vals.dat", - dir_name, - "tmp", + os.path.join(dir_name, "imgs"), + os.path.join(dir_name, "times.dat"), + os.path.join(dir_name, "psf_vals.dat"), + result_filename, default_psf, ) - # Load the results from the results file. - found = load_trajectories_from_file(dir_name + "/results_tmp.txt") + # Load the results from the results file and extract a list of trajectories. + found = [] + loaded_data = ResultList.read_table(result_filename) + for row in loaded_data.results: + found.append(row.trajectory) print("Found %i trajectories vs %i used." % (len(found), len(trjs))) # Determine which trajectories we did not recover. diff --git a/tests/test_result_list.py b/tests/test_result_list.py index 44990321f..e903465fe 100644 --- a/tests/test_result_list.py +++ b/tests/test_result_list.py @@ -1,3 +1,4 @@ +import copy import numpy as np import os import numpy as np @@ -34,6 +35,26 @@ def test_get_boolean_valid_indices(self): self.rdr.filter_indices([1, 2]) self.assertEqual(self.rdr.valid_indices_as_booleans(), [False, True, True, False]) + def test_equal(self): + row_copy = copy.deepcopy(self.rdr) + self.assertTrue(self.rdr == row_copy) + + # Change something in the trajectory + row_copy.trajectory.x = 20 + self.assertFalse(self.rdr == row_copy) + row_copy.trajectory.x = self.rdr.trajectory.x + self.assertTrue(self.rdr == row_copy) + + # Change a value in the psi array + row_copy.psi_curve[2] = 1.9 + self.assertFalse(self.rdr == row_copy) + row_copy.psi_curve[2] = self.rdr.psi_curve[2] + self.assertTrue(self.rdr == row_copy) + + # None out all all stamps + row_copy.all_stamps = None + self.assertFalse(self.rdr == row_copy) + def test_filter(self): self.assertEqual(self.rdr.valid_indices, [0, 1, 2, 3]) self.assertTrue(np.allclose(self.rdr.valid_times(self.times), [1.0, 2.0, 3.0, 4.0])) @@ -82,11 +103,54 @@ def test_to_from_yaml(self): self.assertEqual(row2.all_stamps.shape[0], 4) self.assertEqual(row2.all_stamps.shape[1], 5) self.assertEqual(row2.all_stamps.shape[2], 5) + self.assertEqual(self.rdr, row2) self.assertIsNotNone(row2.trajectory) self.assertAlmostEqual(row2.trajectory.flux, 1.15) self.assertAlmostEqual(row2.trajectory.lh, 2.3) + def test_from_table_row(self): + test_dict = { + "trajectory_x": [], + "trajectory_y": [], + "trajectory_vx": [], + "trajectory_vy": [], + "obs_count": [], + "flux": [], + "likelihood": [], + "stamp": [], + "all_stamps": [], + "valid_indices": [], + "psi_curve": [], + "phi_curve": [], + "pred_ra": [], + "pred_dec": [], + } + self.rdr.append_to_dict(test_dict, expand_trajectory=True) + + trjB = make_trajectory(0, 1, 2.0, -3.0, 10.0, 21.0, 3) + rowB = ResultRow(trjB, 4) + rowB.append_to_dict(test_dict, expand_trajectory=True) + + # Test that we can extract them from a row. + data = Table(test_dict) + self.assertEqual(self.rdr, ResultRow.from_table_row(data[0], 4)) + self.assertEqual(rowB, ResultRow.from_table_row(data[1], 4)) + + # We fail if no number of times is given. Unless the table has an + # appropriate column with that information. + with self.assertRaises(KeyError): + _ = ResultRow.from_table_row(data[0]) + + test_dict["all_times"] = [self.times, self.times] + data = Table(test_dict) + self.assertEqual(self.rdr, ResultRow.from_table_row(data[0])) + + # Test that we can still extract the data without the stamp or all_stamps columns + del test_dict["stamp"] + del test_dict["all_stamps"] + self.assertIsNotNone(ResultRow.from_table_row(data[0], 4)) + def test_compute_predicted_skypos(self): self.assertIsNone(self.rdr.pred_ra) self.assertIsNone(self.rdr.pred_dec) @@ -368,7 +432,7 @@ def test_to_from_yaml(self): self.assertEqual(len(rs_b.filtered), 1) self.assertEqual(len(rs_b.filtered["test"]), 10 - len(inds)) - def test_to_table(self): + def test_to_from_table(self): """Check that we correctly dump the data to a astropy Table""" rs = ResultList(self.times, track_filtered=True) for i in range(10): @@ -405,6 +469,14 @@ def test_to_table(self): self.assertEqual(table["psi_curve"][i][j], i) self.assertEqual(table["phi_curve"][i][j], 0.01 * i) + # Check that we can extract from the table + rs2 = ResultList.from_table(table, self.times, track_filtered=True) + self.assertEqual(rs, rs2) + + # We cannot reconstruct without a list of times. + with self.assertRaises(KeyError): + _ = ResultList.from_table(table) + # Filter the result list. inds = [1, 2, 5, 6, 7, 8, 9] rs.filter_results(inds, "test") @@ -424,7 +496,30 @@ def test_to_table(self): with self.assertRaises(KeyError): rs.to_table(filtered_label="test2") + def test_to_from_table_file(self): + rs = ResultList(self.times, track_filtered=False) + for i in range(10): + # Flux and likelihood will be auto calculated during set_psi_phi() + trj = make_trajectory(x=i, y=2 * i, vx=100.0 - i, vy=-i, obs_count=self.num_times - i) + row = ResultRow(trj, self.num_times) + row.set_psi_phi(np.array([i] * self.num_times), np.array([0.01 * i] * self.num_times)) + row.stamp = np.ones((10, 10)) + row.all_stamps = None + rs.append_result(row) + + # Test read/write to file. + with tempfile.TemporaryDirectory() as dir_name: + file_path = os.path.join(dir_name, "results.ecsv") + self.assertFalse(Path(file_path).is_file()) + + rs.write_table(file_path) + self.assertTrue(Path(file_path).is_file()) + + rs2 = ResultList.read_table(file_path) + self.assertEqual(rs, rs2) + def test_save_results(self): + """Test the legacy save into a bunch of individual files.""" times = [0.0, 1.0, 2.0] # Fill the ResultList with 3 fake rows. @@ -440,13 +535,13 @@ def test_save_results(self): rs.save_to_files(dir_name, "tmp") # Check the results_ file. - fname = f"{dir_name}/results_tmp.txt" + fname = os.path.join(dir_name, "results_tmp.txt") self.assertTrue(Path(fname).is_file()) data = FileUtils.load_results_file_as_trajectories(fname) self.assertEqual(len(data), 3) # Check the psi_ file. - fname = f"{dir_name}/psi_tmp.txt" + fname = os.path.join(dir_name, "psi_tmp.txt") self.assertTrue(Path(fname).is_file()) data = FileUtils.load_csv_to_list(fname, use_dtype=float) self.assertEqual(len(data), 3) @@ -454,7 +549,7 @@ def test_save_results(self): self.assertEqual(d.tolist(), [0.1, 0.2, 0.3]) # Check the phi_ file. - fname = f"{dir_name}/phi_tmp.txt" + fname = os.path.join(dir_name, "phi_tmp.txt") self.assertTrue(Path(fname).is_file()) data = FileUtils.load_csv_to_list(fname, use_dtype=float) self.assertEqual(len(data), 3) @@ -462,7 +557,7 @@ def test_save_results(self): self.assertEqual(d.tolist(), [1.0, 1.0, 0.5]) # Check the lc_ file. - fname = f"{dir_name}/lc_tmp.txt" + fname = os.path.join(dir_name, "lc_tmp.txt") self.assertTrue(Path(fname).is_file()) data = FileUtils.load_csv_to_list(fname, use_dtype=float) self.assertEqual(len(data), 3) @@ -470,7 +565,7 @@ def test_save_results(self): self.assertEqual(d.tolist(), [0.1, 0.2, 0.6]) # Check the lc__index_ file. - fname = f"{dir_name}/lc_index_tmp.txt" + fname = os.path.join(dir_name, "lc_index_tmp.txt") self.assertTrue(Path(fname).is_file()) data = FileUtils.load_csv_to_list(fname, use_dtype=int) self.assertEqual(len(data), 3) @@ -479,7 +574,7 @@ def test_save_results(self): self.assertEqual(data[2].tolist(), [0, 1, 2]) # Check the times_ file. - fname = f"{dir_name}/times_tmp.txt" + fname = os.path.join(dir_name, "times_tmp.txt") self.assertTrue(Path(fname).is_file()) data = FileUtils.load_csv_to_list(fname, use_dtype=float) self.assertEqual(len(data), 3) @@ -488,9 +583,9 @@ def test_save_results(self): self.assertEqual(data[2].tolist(), [0.0, 1.0, 2.0]) # Check that the other files exist. - self.assertTrue(Path(f"{dir_name}/filtered_likes_tmp.txt").is_file()) - self.assertTrue(Path(f"{dir_name}/ps_tmp.txt").is_file()) - self.assertTrue(Path(f"{dir_name}/all_ps_tmp.npy").is_file()) + self.assertTrue(Path(os.path.join(dir_name, "filtered_likes_tmp.txt")).is_file()) + self.assertTrue(Path(os.path.join(dir_name, "ps_tmp.txt")).is_file()) + self.assertTrue(Path(os.path.join(dir_name, "all_ps_tmp.npy")).is_file()) def test_save_and_load_results(self): times = [0.0, 10.0, 21.0, 30.5] diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index 0e1c2f324..23af0d806 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -2,6 +2,7 @@ from astropy.table import Table from astropy.wcs import WCS import numpy as np +import os from pathlib import Path import tempfile import unittest @@ -147,7 +148,7 @@ def test_create_from_dict(self): def test_save_and_load_fits(self): with tempfile.TemporaryDirectory() as dir_name: - file_path = f"{dir_name}/test_workunit.fits" + file_path = os.path.join(dir_name, "test_workunit.fits") self.assertFalse(Path(file_path).is_file()) # Unable to load non-existent file. From 43366471c266e2bdbeb656498e8bb378aa7dfa8c Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 19 Jan 2024 13:11:02 -0500 Subject: [PATCH 2/2] Update config documentation --- docs/source/user_manual/search_params.rst | 10 +++++++++- src/kbmod/configuration.py | 2 +- src/kbmod/run_search.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/user_manual/search_params.rst b/docs/source/user_manual/search_params.rst index 549d84fcd..32efcce6c 100644 --- a/docs/source/user_manual/search_params.rst +++ b/docs/source/user_manual/search_params.rst @@ -72,6 +72,9 @@ This document serves to provide a quick overview of the existing parameters and | | | directory with multiple FITS files | | | | (one for each exposure). | +------------------------+-----------------------------+----------------------------------------+ +| ``ind_output_files`` | True | Output results to a series of | +| | | individual files (legacy format) | ++------------------------+-----------------------------+----------------------------------------+ | ``known_obj_obs`` | 3 | The minimum number of observations | | | | needed to count a known object match. | +------------------------+-----------------------------+----------------------------------------+ @@ -145,7 +148,12 @@ This document serves to provide a quick overview of the existing parameters and | | | mask. See :ref:`Masking`. | +------------------------+-----------------------------+----------------------------------------+ | ``res_filepath`` | None | The path of the directory in which to | -| | | store the results files. | +| | | store the individual results files. | ++------------------------+-----------------------------+----------------------------------------+ +| ``result_filename`` | None | Full filename and path for a single | +| | | tabular result saves as ecsv. | +| | | Can be use used in addition to | +| | | outputting individual result files. | +------------------------+-----------------------------+----------------------------------------+ | ``sigmaG_lims`` | [25, 75] | The percentiles to use in sigmaG | | | | filtering, if | diff --git a/src/kbmod/configuration.py b/src/kbmod/configuration.py index 9b1323388..c6fdeeba4 100644 --- a/src/kbmod/configuration.py +++ b/src/kbmod/configuration.py @@ -51,7 +51,7 @@ def __init__(self): "encode_num_bytes": -1, "flag_keys": default_flag_keys, "gpu_filter": False, - "individual_output_files": True, + "ind_output_files": True, "im_filepath": None, "known_obj_obs": 3, "known_obj_thresh": None, diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index 8ec738ed0..0065ccae3 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -191,7 +191,7 @@ def run_search(self, config, stack): # Save the results and the configuration information used. print(f"Found {keep.num_results()} potential trajectories.") - if config["res_filepath"] is not None and config["individual_output_files"]: + if config["res_filepath"] is not None and config["ind_output_files"]: keep.save_to_files(config["res_filepath"], config["output_suffix"]) config_filename = os.path.join(config["res_filepath"], f"config_{config['output_suffix']}.yml")