Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create helper functions for analyzing results #492

Merged
merged 4 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions notebooks/kbmod_analysis_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"outputs": [],
"source": [
"import os\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from kbmod.analysis.plot_results import *\n",
Expand Down Expand Up @@ -148,6 +149,66 @@
"ResultsVisualizer.plot_result_row(row0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Histograms of Results\n",
"\n",
"We can plot the histograms of individual attributes of the results. First we need to load a result set with more than one entry. Then we plot the histogram of number of observations in the trajectory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results2 = load_result_list_from_files(\"../data/fake_results_noisy/\", \"DEMO\")\n",
"print(f\"Loaded {results2.num_results()} results.\")\n",
"\n",
"counts = results2.get_result_values(\"obs_count\")\n",
"_ = plt.hist(counts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ResultsVisualizer` class provides a helper function to plotting a 2-d histogram of the trajectory's starting positions on the image. This requires knowing the image size. The demo data consists of 256 by 256 images."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ResultsVisualizer.plot_starting_pixel_histogram(results2, 256, 256)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or we could manually create a histogram of the pixel velocities."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"xv_vals = results2.get_result_values(\"trajectory.vx\")\n",
"yv_vals = results2.get_result_values(\"trajectory.vy\")\n",
"_, _, _, img = ax.hist2d(xv_vals, yv_vals, bins=10)\n",
"fig.colorbar(img, ax=ax)\n",
"_ = ax.set_xlabel(\"vx (pixels per day)\")\n",
"_ = ax.set_ylabel(\"vy (pixels per day)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -159,9 +220,9 @@
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Jeremy's KBMOD",
"language": "python",
"name": "python3"
"name": "kbmod_jk"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -173,9 +234,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
23 changes: 21 additions & 2 deletions notebooks/kbmod_results_and_filtering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@
"results.sort(key=\"final_likelihood\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Extracting Individual Attributes\n",
"\n",
"The `ResultList`class provides a helper function `get_result_values()` that allows a user to extract all of the values for a given attribute of the results. For example we could extract all of the flux values and create a histogram."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"flux_values = results.get_result_values(\"trajectory.flux\")\n",
"plt.hist(flux_values)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -300,9 +319,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
22 changes: 21 additions & 1 deletion src/kbmod/analysis/plot_results.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -156,3 +155,24 @@ def plot_result_row(row, times=None, title=None, fig=None):
else:
ax = fig_bot.add_axes([0, 0, 1, 1])
ax.text(0.5, 0.5, "No Individual Stamps")

@staticmethod
def plot_starting_pixel_histogram(results, height, width):
"""Plot a histogram of the starting pixels of each found trajectory.

Parameters
----------
results : `ResultList`
The results to analyze.
height : `int`
The image height in pixels
width : `int`
The image width in pixels
"""
fig, ax = plt.subplots()

x_vals = results.get_result_values("trajectory.x")
y_vals = results.get_result_values("trajectory.y")
_, _, _, img = ax.hist2d(x_vals, y_vals, bins=[height, width])
fig.colorbar(img, ax=ax)
ax.set_title("Histogram of Starting Pixel")
34 changes: 34 additions & 0 deletions src/kbmod/result_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,40 @@ def sort(self, key="final_likelihood", reverse=True):
self.results.sort(key=lambda x: getattr(x, key), reverse=reverse)
return self

def get_result_values(self, attribute):
"""Return the values of the ResultRows' attribute as a list.
Subattributes can be extracted as "attribute.subattribute",
such as "trajectory.x"

Examples:
trj_list = res.get_trajectory_values("trajectory")
x_values = res.get_trajectory_values("trajectory.x")
stamps = res.get_trajectory_values("stamp")

Parameter
---------
attribute : `str`
The name of the attribute to extract.

Returns
-------
values : `list`
A list of the results' values.

Raises
------
Raises an ``AttributeError`` if the attribute does not exist.
"""
att_list = attribute.split(".")

values = []
for row in self.results:
obj = row
for att in att_list:
obj = getattr(obj, att)
values.append(obj)
return values

def compute_predicted_skypos(self, wcs):
"""Compute the predict sky position for each result's trajectory
at each time step.
Expand Down
36 changes: 36 additions & 0 deletions tests/test_result_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,42 @@ def test_sort(self):
for i, val in enumerate(expected_order):
self.assertEqual(rs.results[i].trajectory.x, val)

def test_get_result_values(self):
rs = ResultList(self.times)
rs.append_result(ResultRow(make_trajectory(x=0, lh=1.0, obs_count=1), self.num_times))
rs.append_result(ResultRow(make_trajectory(x=1, lh=-1.0, obs_count=2), self.num_times))
rs.append_result(ResultRow(make_trajectory(x=2, lh=5.0, obs_count=3), self.num_times))
rs.append_result(ResultRow(make_trajectory(x=3, lh=4.0, obs_count=5), self.num_times))
rs.append_result(ResultRow(make_trajectory(x=4, lh=6.0, obs_count=4), self.num_times))

# Test getting a list of trajectories.
trjs = rs.get_result_values("trajectory")
self.assertEqual(len(trjs), 5)
for i in range(5):
self.assertTrue(type(trjs[i]) is Trajectory)

# Stamps should all be None
stamps = rs.get_result_values("stamp")
self.assertEqual(len(stamps), 5)
for i in range(5):
self.assertTrue(stamps[i] is None)

# We can extract sub-attributes
x_vals = rs.get_result_values("trajectory.x")
self.assertEqual(len(x_vals), 5)
for i in range(5):
self.assertEqual(x_vals[i], i)

vx_vals = rs.get_result_values("trajectory.vx")
self.assertEqual(len(vx_vals), 5)
for i in range(5):
self.assertEqual(vx_vals[i], 0.0)

# We get an error if we try to extract an attribute that doesn't exist.
self.assertRaises(AttributeError, rs.get_result_values, "")
self.assertRaises(AttributeError, rs.get_result_values, "Not There")
self.assertRaises(AttributeError, rs.get_result_values, "trajectory.z")

def test_filter(self):
rs = ResultList(self.times)
for i in range(10):
Expand Down
Loading