Skip to content

Commit

Permalink
Merge pull request #492 from dirac-institute/analyze_attributes
Browse files Browse the repository at this point in the history
Create helper functions for analyzing results
  • Loading branch information
jeremykubica authored Feb 26, 2024
2 parents cb36d38 + a81e090 commit 1e8f893
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 7 deletions.
69 changes: 65 additions & 4 deletions notebooks/kbmod_analysis_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"outputs": [],
"source": [
"import os\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from kbmod.analysis.plot_results import *\n",
Expand Down Expand Up @@ -148,6 +149,66 @@
"ResultsVisualizer.plot_result_row(row0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Histograms of Results\n",
"\n",
"We can plot the histograms of individual attributes of the results. First we need to load a result set with more than one entry. Then we plot the histogram of number of observations in the trajectory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results2 = load_result_list_from_files(\"../data/fake_results_noisy/\", \"DEMO\")\n",
"print(f\"Loaded {results2.num_results()} results.\")\n",
"\n",
"counts = results2.get_result_values(\"obs_count\")\n",
"_ = plt.hist(counts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ResultsVisualizer` class provides a helper function to plotting a 2-d histogram of the trajectory's starting positions on the image. This requires knowing the image size. The demo data consists of 256 by 256 images."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ResultsVisualizer.plot_starting_pixel_histogram(results2, 256, 256)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or we could manually create a histogram of the pixel velocities."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"xv_vals = results2.get_result_values(\"trajectory.vx\")\n",
"yv_vals = results2.get_result_values(\"trajectory.vy\")\n",
"_, _, _, img = ax.hist2d(xv_vals, yv_vals, bins=10)\n",
"fig.colorbar(img, ax=ax)\n",
"_ = ax.set_xlabel(\"vx (pixels per day)\")\n",
"_ = ax.set_ylabel(\"vy (pixels per day)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -159,9 +220,9 @@
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Jeremy's KBMOD",
"language": "python",
"name": "python3"
"name": "kbmod_jk"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -173,9 +234,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
23 changes: 21 additions & 2 deletions notebooks/kbmod_results_and_filtering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@
"results.sort(key=\"final_likelihood\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Extracting Individual Attributes\n",
"\n",
"The `ResultList`class provides a helper function `get_result_values()` that allows a user to extract all of the values for a given attribute of the results. For example we could extract all of the flux values and create a histogram."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"flux_values = results.get_result_values(\"trajectory.flux\")\n",
"plt.hist(flux_values)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -300,9 +319,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
22 changes: 21 additions & 1 deletion src/kbmod/analysis/plot_results.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -156,3 +155,24 @@ def plot_result_row(row, times=None, title=None, fig=None):
else:
ax = fig_bot.add_axes([0, 0, 1, 1])
ax.text(0.5, 0.5, "No Individual Stamps")

@staticmethod
def plot_starting_pixel_histogram(results, height, width):
"""Plot a histogram of the starting pixels of each found trajectory.
Parameters
----------
results : `ResultList`
The results to analyze.
height : `int`
The image height in pixels
width : `int`
The image width in pixels
"""
fig, ax = plt.subplots()

x_vals = results.get_result_values("trajectory.x")
y_vals = results.get_result_values("trajectory.y")
_, _, _, img = ax.hist2d(x_vals, y_vals, bins=[height, width])
fig.colorbar(img, ax=ax)
ax.set_title("Histogram of Starting Pixel")
34 changes: 34 additions & 0 deletions src/kbmod/result_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,40 @@ def sort(self, key="final_likelihood", reverse=True):
self.results.sort(key=lambda x: getattr(x, key), reverse=reverse)
return self

def get_result_values(self, attribute):
"""Return the values of the ResultRows' attribute as a list.
Subattributes can be extracted as "attribute.subattribute",
such as "trajectory.x"
Examples:
trj_list = res.get_trajectory_values("trajectory")
x_values = res.get_trajectory_values("trajectory.x")
stamps = res.get_trajectory_values("stamp")
Parameter
---------
attribute : `str`
The name of the attribute to extract.
Returns
-------
values : `list`
A list of the results' values.
Raises
------
Raises an ``AttributeError`` if the attribute does not exist.
"""
att_list = attribute.split(".")

values = []
for row in self.results:
obj = row
for att in att_list:
obj = getattr(obj, att)
values.append(obj)
return values

def compute_predicted_skypos(self, wcs):
"""Compute the predict sky position for each result's trajectory
at each time step.
Expand Down
36 changes: 36 additions & 0 deletions tests/test_result_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,42 @@ def test_sort(self):
for i, val in enumerate(expected_order):
self.assertEqual(rs.results[i].trajectory.x, val)

def test_get_result_values(self):
rs = ResultList(self.times)
rs.append_result(ResultRow(make_trajectory(x=0, lh=1.0, obs_count=1), self.num_times))
rs.append_result(ResultRow(make_trajectory(x=1, lh=-1.0, obs_count=2), self.num_times))
rs.append_result(ResultRow(make_trajectory(x=2, lh=5.0, obs_count=3), self.num_times))
rs.append_result(ResultRow(make_trajectory(x=3, lh=4.0, obs_count=5), self.num_times))
rs.append_result(ResultRow(make_trajectory(x=4, lh=6.0, obs_count=4), self.num_times))

# Test getting a list of trajectories.
trjs = rs.get_result_values("trajectory")
self.assertEqual(len(trjs), 5)
for i in range(5):
self.assertTrue(type(trjs[i]) is Trajectory)

# Stamps should all be None
stamps = rs.get_result_values("stamp")
self.assertEqual(len(stamps), 5)
for i in range(5):
self.assertTrue(stamps[i] is None)

# We can extract sub-attributes
x_vals = rs.get_result_values("trajectory.x")
self.assertEqual(len(x_vals), 5)
for i in range(5):
self.assertEqual(x_vals[i], i)

vx_vals = rs.get_result_values("trajectory.vx")
self.assertEqual(len(vx_vals), 5)
for i in range(5):
self.assertEqual(vx_vals[i], 0.0)

# We get an error if we try to extract an attribute that doesn't exist.
self.assertRaises(AttributeError, rs.get_result_values, "")
self.assertRaises(AttributeError, rs.get_result_values, "Not There")
self.assertRaises(AttributeError, rs.get_result_values, "trajectory.z")

def test_filter(self):
rs = ResultList(self.times)
for i in range(10):
Expand Down

0 comments on commit 1e8f893

Please sign in to comment.