Skip to content

Commit

Permalink
Small bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Nov 30, 2023
1 parent 561980a commit 1704dd1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions src/kbmod/result_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,13 @@ def from_yaml(cls, yaml_str):
The serialized string.
"""
yaml_dict = safe_load(yaml_str)
result = ResultList(yaml_dict["all_times"], yaml_dict["track_filtered"])
result.results = [ResultRow.from_yaml(row) for row in yaml_dict["results"]]
result_list = ResultList(yaml_dict["all_times"], yaml_dict["track_filtered"])
result_list.results = [ResultRow.from_yaml(row) for row in yaml_dict["results"]]

if result.track_filtered:
if result_list.track_filtered:
for key in yaml_dict["filtered"]:
results.filtered[key] = [ResultRow.from_yaml(row) for row in yaml_dict["filtered"][key]]
return result
result_list.filtered[key] = [ResultRow.from_yaml(row) for row in yaml_dict["filtered"][key]]
return result_list

def num_results(self):
"""Return the number of results in the list.
Expand Down Expand Up @@ -467,7 +467,7 @@ def to_yaml(self, serialize_filtered=False):
if serialize_filtered and self.track_filtered:
yaml_dict["track_filtered"] = True
for key in self.filtered:
yaml_dict["filtered"][key] = [row.to_yaml() for row in self.results]
yaml_dict["filtered"][key] = [row.to_yaml() for row in self.filtered[key]]

return dump(yaml_dict)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_result_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ def test_to_from_yaml(self):
rs_a = ResultList.from_yaml(yaml_str_a)
self.assertEqual(len(rs_a.results), len(inds))
for i in range(len(inds)):
self.assertAlmostEqual(rs_a.results[i].psi_curve[0], i)
self.assertAlmostEqual(rs_a.results[i].phi_curve[0], 0.01 * i)
self.assertAlmostEqual(rs_a.results[i].psi_curve[0], inds[i])
self.assertAlmostEqual(rs_a.results[i].phi_curve[0], 0.01 * inds[i])
self.assertFalse(rs_a.track_filtered)
self.assertEqual(len(rs_a.filtered), 0)

Expand All @@ -248,8 +248,8 @@ def test_to_from_yaml(self):
rs_b = ResultList.from_yaml(yaml_str_b)
self.assertEqual(len(rs_b.results), len(inds))
for i in range(len(inds)):
self.assertAlmostEqual(rs_b.results[i].psi_curve[0], i)
self.assertAlmostEqual(rs_b.results[i].phi_curve[0], 0.01 * i)
self.assertAlmostEqual(rs_b.results[i].psi_curve[0], inds[i])
self.assertAlmostEqual(rs_b.results[i].phi_curve[0], 0.01 * inds[i])
self.assertTrue(rs_b.track_filtered)
self.assertEqual(len(rs_b.filtered), 1)
self.assertEqual(len(rs_b.filtered["test"]), 10 - len(inds))
Expand Down

0 comments on commit 1704dd1

Please sign in to comment.