Skip to content

Commit

Permalink
Initial implemenation of Trace class, switch to using it
Browse files Browse the repository at this point in the history
Adds a Trace class which holds pairs of points for a trace using
a deque, which allows for fast appending at either end.

Also switches trace_lines() to use `Trace`s rather than the lists
it was previously using. A slight change from the previous behavior
is that the collection(s) of points returned from the traces are
arranged entirely in ascending order by y-coordinate, however,
given how they are used this shouldn't actually change anything
(besides making the output slightly easier to read).
  • Loading branch information
DBerke committed Apr 17, 2024
1 parent 485aee4 commit 12d9de1
Showing 1 changed file with 71 additions and 7 deletions.
78 changes: 71 additions & 7 deletions gempy/library/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Classes in this module:
Aperture: defines an aperture on an image
Trace: defines a trace on an image
Functions in this module:
average_along_slit: collapse a 2D spectral image in the wavelength direction
Expand All @@ -23,6 +24,7 @@
trace_lines: trace lines from a set of supplied starting positions
"""
from collections import deque
import warnings

import numpy as np
Expand Down Expand Up @@ -371,6 +373,68 @@ def extract(self, ext, width=None, aper_lower=None, aper_upper=None,
return ndd


class Trace:
"""A class describing a trace along columns. It has the following attributes:
starting_point: len-2 iterable
The starting point of the trace on the array, in (y, x) format.
trace: collections.deque
A deque holding points (len-2 iterables) found for the trace.
top_limit: float
The highest y-value that the trace has reached.
bottom_limit: float
The lowest y-value that the trace have reached.
"""
def __init__(self, starting_point):
self.starting_point = self._verify_point(starting_point)
self.trace = deque([self.starting_point])

# Set the initial limits of the trace.
self.top_limit = self.bottom_limit = self.starting_point[0]

def _as_list(self):
return list(self.trace)

def __iter__(self):
return iter(self.trace)

def __repr__(self):
return self.trace

def _verify_point(self, point):
"""Return a tuple from a len-2 iterable"""
if len(point) != 2:
raise RuntimeError(f"Point {point} should have 2 values, not "
f"len(point)")
if isinstance(point, tuple):
return point
else:
try:
return(tuple(point))
except:
raise RuntimeError(f"Something went wrong with point {point}")

def add_point(self, point):
"""Add a point to the deque, at either end as appropriate"""
point = self._verify_point(point)
y = point[0]

if y > self.top_limit:
self.trace.append(point)
self.top = y
elif y < self.bottom_limit:
self.trace.appendleft(point)
self.bottom = y
else:
# Should only add points at ends of range
raise RuntimeError("Trying to insert point in middle of trace,"
f"{point}, top: {self.top_limit}, "
f"bottom: {self.bottom_limit}")

def predict_location(self):
pass


###############################################################################
# FUNCTIONS RELATED TO PEAK-FINDING
@insert_descriptor_values("dispersion_axis")
Expand Down Expand Up @@ -1526,7 +1590,7 @@ def _slice(center):
if np.bincount((ext_mask[s] & DQ.not_signal).min(axis=1))[0] <= 1:
del step_centers[i]

coord_lists = [[(start, peak)] for peak in initial_peaks]
traces = [Trace((start, peak)) for peak in initial_peaks]
for direction in (1, -1):
ypos = start
last_coords = [[ypos, peak] for peak in initial_peaks]
Expand Down Expand Up @@ -1604,7 +1668,6 @@ def _slice(center):
# If it's gone for good, set the coord to NaN to avoid it
# picking up a different line if there's significant tilt
if steps_missed >= max_missed:
#coord_lists[i].append([ypos, np.nan])
last_coords[i] = [ypos, np.nan]
continue

Expand All @@ -1621,7 +1684,7 @@ def _slice(center):
else reversed(new_coord))))
viewer.line(origin=0, **kwargs)

coord_lists[i].append(new_coord)
traces[i].add_point(new_coord)
last_coords[i] = new_coord.copy()
try:
missing_but_not_lost = direction * min(
Expand All @@ -1637,14 +1700,15 @@ def _slice(center):
def keep_line(line, min_length):
positions = [element[0] for element in line]
return (max(positions) - min(positions)) > min_length * ext_data.shape[0]
final_coord_lists = [line for line in coord_lists if keep_line(line, min_line_length)]
final_peaks = [cl[0][1] for cl in final_coord_lists]
final_coord_lists = [line for line in traces if keep_line(line, min_line_length)]
final_peaks = [trace.starting_point[1] for trace in traces]

# List of traced peak positions
in_coords = np.array([c for coo in final_coord_lists for c in coo]).T
in_coords = np.array([c for coo in traces for c in coo]).T
# List of "reference" positions (i.e., the coordinate perpendicular to
# the line remains constant at its initial value
ref_coords = np.array([(ypos, ref) for coo, ref in zip(final_coord_lists, final_peaks) for (ypos, xpos) in coo]).T
ref_coords = np.array([(ypos, ref) for coo, ref in zip(traces, final_peaks)
for (ypos, xpos) in coo]).T

# Return the coordinate lists, in the form (x-coords, y-coords),
# regardless of the dispersion axis
Expand Down

0 comments on commit 12d9de1

Please sign in to comment.