-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Re-factored the run browser into its own sub-package within Firefly.
- Loading branch information
1 parent
6af6fe8
commit 7e53ad2
Showing
11 changed files
with
384 additions
and
983 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ asyncio_mode = auto | |
testpaths = | ||
src/haven/tests | ||
src/firefly/tests | ||
src/firefly/run_browser/tests |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,27 +3,21 @@ | |
from collections import Counter | ||
from contextlib import contextmanager | ||
from functools import wraps | ||
from itertools import count | ||
from typing import Mapping, Optional, Sequence | ||
|
||
import numpy as np | ||
import qtawesome as qta | ||
import yaml | ||
from matplotlib.colors import TABLEAU_COLORS | ||
from pandas.api.types import is_numeric_dtype | ||
from pyqtgraph import GraphicsLayoutWidget, ImageView, PlotItem, PlotWidget | ||
from qasync import asyncSlot | ||
from qtpy.QtCore import Qt, Signal | ||
from qtpy.QtGui import QStandardItem, QStandardItemModel | ||
from qtpy.QtWidgets import QFileDialog, QWidget | ||
from qtpy.QtWidgets import QWidget | ||
|
||
from firefly import display | ||
from firefly.run_client import DatabaseWorker | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
from .client import DatabaseWorker | ||
from .widgets import ExportDialog | ||
|
||
colors = list(TABLEAU_COLORS.values()) | ||
log = logging.getLogger(__name__) | ||
|
||
|
||
def cancellable(fn): | ||
|
@@ -37,220 +31,6 @@ async def inner(*args, **kwargs): | |
return inner | ||
|
||
|
||
class ExportDialog(QFileDialog): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.setFileMode(QFileDialog.FileMode.AnyFile) | ||
self.setAcceptMode(QFileDialog.AcceptSave) | ||
|
||
def ask(self, mimetypes: Optional[Sequence[str]] = None): | ||
"""Get the name of the file to save for exporting.""" | ||
self.setMimeTypeFilters(mimetypes) | ||
# Show the file dialog | ||
if self.exec_() == QFileDialog.Accepted: | ||
return self.selectedFiles() | ||
else: | ||
return None | ||
|
||
|
||
class FiltersWidget(QWidget): | ||
returnPressed = Signal() | ||
|
||
def keyPressEvent(self, event): | ||
super().keyPressEvent(event) | ||
# Check for return keys pressed | ||
if event.key() in [Qt.Key_Enter, Qt.Key_Return]: | ||
self.returnPressed.emit() | ||
|
||
|
||
class Browser1DPlotItem(PlotItem): | ||
hover_coords_changed = Signal(str) | ||
|
||
def hoverEvent(self, event): | ||
super().hoverEvent(event) | ||
if event.isExit(): | ||
self.hover_coords_changed.emit("NaN") | ||
return | ||
# Get data coordinates from event | ||
pos = event.scenePos() | ||
data_pos = self.vb.mapSceneToView(pos) | ||
pos_str = f"({data_pos.x():.3f}, {data_pos.y():.3f})" | ||
self.hover_coords_changed.emit(pos_str) | ||
|
||
|
||
class BrowserMultiPlotWidget(GraphicsLayoutWidget): | ||
_multiplot_items: Mapping | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self._multiplot_items = {} | ||
|
||
def multiplot_items(self, n_cols: int = 3): | ||
view = self | ||
item0 = None | ||
for idx in count(): | ||
row = int(idx / n_cols) | ||
col = idx % n_cols | ||
# Make a new plot item if one doesn't exist | ||
if (row, col) not in self._multiplot_items: | ||
self._multiplot_items[(row, col)] = view.addPlot(row=row, col=col) | ||
new_item = self._multiplot_items[(row, col)] | ||
# Link the X-axes together | ||
if item0 is None: | ||
item0 = new_item | ||
else: | ||
new_item.setXLink(item0) | ||
# Resize the viewing area to fit the contents | ||
width = view.width() | ||
plot_width = width / n_cols | ||
# view.resize(int(width), int(plot_width * row)) | ||
view.setFixedHeight(1200) | ||
yield new_item | ||
|
||
def plot_runs(self, runs: Mapping, xsignal: str): | ||
"""Take loaded run data and plot small multiples. | ||
Parameters | ||
========== | ||
runs | ||
Dictionary with pandas series for each curve. The keys | ||
should be the curve labels, the series' indexes are the x | ||
values and the series' values are the y data. | ||
xsignal | ||
The name of the signal to use for the common horizontal | ||
axis. | ||
""" | ||
# Use all the data columns as y signals | ||
ysignals = [] | ||
for run in runs.values(): | ||
ysignals.extend(run.columns) | ||
# Remove the x-signal from the list of y signals | ||
ysignals = sorted(list(dict.fromkeys(ysignals))) | ||
# Plot the runs | ||
self.clear() | ||
self._multiplot_items = {} | ||
for label, data in runs.items(): | ||
# Figure out which signals to plot | ||
try: | ||
xdata = data[xsignal] | ||
except KeyError: | ||
log.warning(f"Cannot plot x='{xsignal}' for {list(data.keys())}") | ||
continue | ||
# Plot each y signal on a separate plot | ||
for ysignal, plot_item in zip(ysignals, self.multiplot_items()): | ||
try: | ||
if is_numeric_dtype(data[ysignal]): | ||
plot_item.plot(xdata, data[ysignal]) | ||
except KeyError: | ||
log.warning(f"No signal {ysignal} in data.") | ||
else: | ||
log.debug(f"Plotted {ysignal} vs. {xsignal} for {data}") | ||
plot_item.setTitle(ysignal) | ||
|
||
|
||
class Browser1DPlotWidget(PlotWidget): | ||
cursor_needed: bool | ||
data_items: dict | ||
|
||
def __init__(self, parent=None, background="default", plotItem=None, **kargs): | ||
plot_item = Browser1DPlotItem(**kargs) | ||
super().__init__(parent=parent, background=background, plotItem=plot_item) | ||
self.clear_runs() | ||
|
||
def clear_runs(self): | ||
self.getPlotItem().clear() | ||
self.cursor_needed = True | ||
self.data_items = {} | ||
|
||
def plot_runs(self, runs: Mapping, ylabel="", xlabel=""): | ||
"""Take loaded run data and plot it. | ||
Parameters | ||
========== | ||
runs | ||
Dictionary with pandas series for each curve. The keys | ||
should be the curve labels, the series' indexes are the x | ||
values and the series' values are the y data. | ||
""" | ||
plot_item = self.getPlotItem() | ||
# Plot this run's data | ||
for idx, (label, series) in enumerate(runs.items()): | ||
color = colors[idx % len(colors)] | ||
if label in self.data_items.keys(): | ||
# We've plotted this item before, so reuse it | ||
data_item = self.data_items[label] | ||
data_item.setData(series.index, series.values) | ||
else: | ||
self.data_items[label] = plot_item.plot( | ||
x=series.index, | ||
y=series.values, | ||
pen=color, | ||
name=label, | ||
clear=False, | ||
) | ||
# Cursor to drag around on the data | ||
if self.cursor_needed: | ||
plot_item.addLine( | ||
x=np.median(series.index), movable=True, label="{value:.3f}" | ||
) | ||
self.cursor_needed = False | ||
# Axis formatting | ||
plot_item.setLabels(left=ylabel, bottom=xlabel) | ||
|
||
|
||
class Browser2DPlotWidget(ImageView): | ||
"""A plot widget for 2D maps.""" | ||
|
||
def __init__(self, *args, view=None, **kwargs): | ||
if view is None: | ||
view = PlotItem() | ||
super().__init__(*args, view=view, **kwargs) | ||
|
||
def plot_runs( | ||
self, runs: Mapping, xlabel: str = "", ylabel: str = "", extents=None | ||
): | ||
"""Take loaded 2D or 3D mapping data and plot it. | ||
Parameters | ||
========== | ||
runs | ||
Dictionary with pandas series for each curve. The keys | ||
should be the curve labels, the series' indexes are the x | ||
values and the series' values are the y data. | ||
xlabel | ||
The label for the horizontal axis. | ||
ylabel | ||
The label for the vertical axis. | ||
extents | ||
Spatial extents for the map as ((-y, +y), (-x, +x)). | ||
""" | ||
images = np.asarray(list(runs.values())) | ||
# Combine the different runs into one image | ||
# To-do: make this respond to the combobox selection | ||
image = np.mean(images, axis=0) | ||
# To-do: Apply transformations | ||
|
||
# # Plot the image | ||
if 2 <= image.ndim <= 3: | ||
self.setImage(image.T, autoRange=False) | ||
else: | ||
log.info(f"Could not plot image of dataset with shape {image.shape}.") | ||
return | ||
# Determine the axes labels | ||
self.view.setLabel(axis="bottom", text=xlabel) | ||
self.view.setLabel(axis="left", text=ylabel) | ||
# Set axes extent | ||
yextent, xextent = extents | ||
x = xextent[0] | ||
y = yextent[0] | ||
w = xextent[1] - xextent[0] | ||
h = yextent[1] - yextent[0] | ||
self.getImageItem().setRect(x, y, w, h) | ||
|
||
|
||
class RunBrowserDisplay(display.FireflyDisplay): | ||
runs_model: QStandardItemModel | ||
_run_col_names: Sequence = [ | ||
|
@@ -651,6 +431,7 @@ async def update_metadata(self, *args): | |
text += yaml.dump(md) | ||
text += f"\n\n{'=' * 20}\n\n" | ||
# Update the widget with the rendered metadata | ||
print(text) | ||
self.ui.metadata_textedit.document().setPlainText(text) | ||
|
||
def clear_plots(self): | ||
|
@@ -659,39 +440,36 @@ def clear_plots(self): | |
If a *uid* is provided, only the plots matching the scan with | ||
*uid* will be updated. | ||
""" | ||
self.plot_1d_view.clear_plots() | ||
self.plot_1d_view.clear_runs() | ||
|
||
@asyncSlot() | ||
@cancellable | ||
async def update_plots(self, uid: str = ""): | ||
async def update_plots(self): | ||
"""Get new data, and update all the plots. | ||
If a *uid* is provided, only the plots matching the scan with | ||
*uid* will be updated. | ||
""" | ||
|
||
asyncio.gather( | ||
await asyncio.gather( | ||
self.update_metadata(), | ||
self.update_1d_plot(uid=uid), | ||
self.update_1d_plot(), | ||
self.update_2d_plot(), | ||
self.update_multi_plot(), | ||
) | ||
|
||
@asyncSlot() | ||
@cancellable | ||
async def update_selected_runs(self, uid=None, *args): | ||
|
||
"""Get the current runs from the database and stash them. | ||
""" | ||
async def update_selected_runs(self, *args): | ||
"""Get the current runs from the database and stash them.""" | ||
# Get UID's from the selection | ||
col_idx = self._run_col_names.index("UID") | ||
indexes = self.ui.run_tableview.selectedIndexes() | ||
uids = [i.siblingAtColumn(col_idx).data() for i in indexes] | ||
# Get selected runs from the database | ||
with self.busy_hints(run_widgets=True, run_table=False, filter_widgets=False): | ||
task = self.db_task( | ||
self.db.load_selected_runs(uids), "update selected runs" | ||
self.db.load_selected_runs(uids=uids), "update selected runs" | ||
) | ||
self.selected_runs = await task | ||
# Update the necessary UI elements | ||
|
@@ -731,30 +509,4 @@ def load_models(self): | |
self.ui.run_tableview.setModel(self.runs_model) | ||
|
||
def ui_filename(self): | ||
return "run_browser.ui" | ||
|
||
|
||
# ----------------------------------------------------------------------------- | ||
# :author: Mark Wolfman | ||
# :email: [email protected] | ||
# :copyright: Copyright © 2023, UChicago Argonne, LLC | ||
# | ||
# Distributed under the terms of the 3-Clause BSD License | ||
# | ||
# The full license is in the file LICENSE, distributed with this software. | ||
# | ||
# DISCLAIMER | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | ||
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | ||
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | ||
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | ||
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | ||
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | ||
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | ||
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | ||
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
# | ||
# ----------------------------------------------------------------------------- | ||
return "run_browser/run_browser.ui" |
Oops, something went wrong.