Skip to content

Commit

Permalink
Fixed tests and linting.
Browse files Browse the repository at this point in the history
  • Loading branch information
canismarko committed Jan 28, 2025
1 parent b769cc6 commit 44c6219
Show file tree
Hide file tree
Showing 17 changed files with 93 additions and 260 deletions.
10 changes: 3 additions & 7 deletions src/firefly/run_browser/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime as dt
import logging
import warnings
from collections import OrderedDict, ChainMap
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Mapping, Sequence

Expand Down Expand Up @@ -162,7 +162,7 @@ async def load_all_runs(self, filters: Mapping = {}):
all_runs.append(run_data)
return all_runs

async def hints(self, stream: str="primary") -> tuple[list, list]:
async def hints(self, stream: str = "primary") -> tuple[list, list]:
"""Get hints for this stream, as two lists.
(*independent_hints*, *dependent_hints*)
Expand All @@ -178,10 +178,7 @@ async def hints(self, stream: str="primary") -> tuple[list, list]:
dhints = [hint for hints in dhints for hint in hints]
return ihints, dhints



async def signal_names(self, stream: str, *, hinted_only: bool = False):

"""Get a list of valid signal names (data columns) for selected runs.
Parameters
Expand Down Expand Up @@ -258,7 +255,7 @@ async def all_signals(self, stream: str, *, hinted_only=False) -> dict:

async def dataset(
self,
dataset_name: str,
dataset_name: str,
*,
stream: str,
uids: Sequence[str] | None = None,
Expand All @@ -284,7 +281,6 @@ async def dataset(
arrays[run.uid] = arr
return arrays


async def signals(
self,
x_signal,
Expand Down
27 changes: 18 additions & 9 deletions src/firefly/run_browser/display.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import asyncio
import datetime as dt
import logging
from collections import Counter, ChainMap
from collections import ChainMap, Counter
from contextlib import contextmanager
from functools import partial, wraps
from typing import Mapping, Optional, Sequence

import qtawesome as qta
import yaml
import numpy as np
import qtawesome as qta
from ophyd import Device as ThreadedDevice
from ophyd_async.core import Device
from pydm import PyDMChannel
Expand Down Expand Up @@ -95,7 +94,9 @@ async def change_catalog(self, catalog_name: str):
)

@asyncSlot(str)
async def retrieve_dataset(self, dataset_name: str, callback, task_name: str) -> np.ndarray:
async def retrieve_dataset(
self, dataset_name: str, callback, task_name: str
) -> np.ndarray:
"""Retrieve a dataset from disk, and provide it to the slot.
Parameters
Expand All @@ -106,9 +107,11 @@ async def retrieve_dataset(self, dataset_name: str, callback, task_name: str) ->
Will be called with the retrieved dataset.
task_name
For handling parallel database tasks.
"""
"""
# Retrieve data from the database
data = await self.db_task(self.db.dataset(dataset_name, stream=self.stream), task_name)
data = await self.db_task(
self.db.dataset(dataset_name, stream=self.stream), task_name
)
# Pass it back to the slot
callback(data)

Expand Down Expand Up @@ -414,7 +417,9 @@ async def update_data_keys(self, *args):
self.db_task(self.db.hints(stream), "update data hints"),
)
independent_hints, dependent_hints = hints
self.data_keys_changed.emit(data_keys, set(independent_hints), set(dependent_hints))
self.data_keys_changed.emit(
data_keys, set(independent_hints), set(dependent_hints)
)

@asyncSlot()
@cancellable
Expand All @@ -425,8 +430,12 @@ async def update_data_frames(self):
assert False
log.info("Not loading data frames for empty stream.")
else:
with self.busy_hints(run_widgets=True, run_table=False, filter_widgets=False):
data_frames = await self.db_task(self.db.data_frames(stream), "update data frames")
with self.busy_hints(
run_widgets=True, run_table=False, filter_widgets=False
):
data_frames = await self.db_task(
self.db.data_frames(stream), "update data frames"
)
self.data_frames_changed.emit(data_frames)

@asyncSlot()
Expand Down
87 changes: 17 additions & 70 deletions src/firefly/run_browser/gridplot_view.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import logging
from itertools import count
from pathlib import Path
from typing import Mapping, Sequence

from scipy.interpolate import griddata
import numpy as np
import pandas as pd
import yaml
from matplotlib.colors import TABLEAU_COLORS
from pandas.api.types import is_numeric_dtype
from pyqtgraph import GraphicsLayoutWidget, ImageView, PlotItem, PlotWidget
import pyqtgraph
import qtawesome as qta
from qtpy import QtCore, QtWidgets, uic
from qtpy.QtCore import Qt, Signal, Slot
from qtpy.QtWidgets import QFileDialog, QWidget
from matplotlib.colors import TABLEAU_COLORS
from pyqtgraph import ImageView, PlotItem
from qtpy import QtWidgets, uic
from qtpy.QtCore import Slot
from scipy.interpolate import griddata

log = logging.getLogger(__name__)
colors = list(TABLEAU_COLORS.values())
Expand All @@ -32,6 +27,7 @@ def __init__(self, *args, view=None, **kwargs):

class GridplotView(QtWidgets.QWidget):
"""Handles the plotting of tabular data that was taken on a grid."""

ui_file = Path(__file__).parent / "gridplot_view.ui"
shape = ()
extent = ()
Expand All @@ -46,7 +42,7 @@ def __init__(self, parent=None):
self.ui = uic.loadUi(self.ui_file, self)
# Prepare plotting style
vbox = self.ui.plot_widget.ui.roiPlot.getPlotItem().getViewBox()
vbox.setBackgroundColor('k')
vbox.setBackgroundColor("k")
# Connect internal signals/slots
self.ui.use_hints_checkbox.stateChanged.connect(self.update_signal_widgets)
self.ui.regrid_checkbox.stateChanged.connect(self.update_signal_widgets)
Expand All @@ -67,8 +63,8 @@ def set_image_dimensions(self, metadata: Sequence):
return
md = list(metadata.values())[0]
try:
self.shape = md['start']['shape']
self.extent = md['start']['extents']
self.shape = md["start"]["shape"]
self.extent = md["start"]["extents"]
except KeyError as exc:
self.shape = ()
self.extent = ()
Expand Down Expand Up @@ -116,7 +112,9 @@ def update_signal_widgets(
self.ui.value_signal_combobox,
self.ui.r_signal_combobox,
]
for combobox, new_cols in zip(comboboxes, [new_xcols, new_xcols, new_ycols, new_ycols]):
for combobox, new_cols in zip(
comboboxes, [new_xcols, new_xcols, new_ycols, new_ycols]
):
old_cols = [combobox.itemText(idx) for idx in range(combobox.count())]
if old_cols != new_cols:
old_value = combobox.currentText()
Expand Down Expand Up @@ -173,13 +171,11 @@ def prepare_plotting_data(self, df: pd.DataFrame) -> tuple[np.ndarray, np.ndarra
return img

def regrid(self, points: np.ndarray, values: np.ndarray):
"""Calculate a new image with a shape based on metadata.
"""
"""Calculate a new image with a shape based on metadata."""
# Prepare new regular grid to interpolate to
(ymin, ymax), (xmin, xmax) = self.extent
ystep, xstep = (npts * 1j for npts in self.shape)
yy, xx = np.mgrid[ymin:ymax:ystep,xmin:xmax:xstep]
yy, xx = np.mgrid[ymin:ymax:ystep, xmin:xmax:xstep]
xi = np.c_[yy.flatten(), xx.flatten()]
# Interpolate
new_values = griddata(points, values, xi, method="cubic")
Expand Down Expand Up @@ -218,7 +214,9 @@ def plot(self, dataframes: Mapping | None = None):
try:
ylabel, xlabel = self.independent_hints
except ValueError:
log.warning(f"Could not determine grid labels from hints: {self.independent_hints}")
log.warning(
f"Could not determine grid labels from hints: {self.independent_hints}"
)
else:
view = self.ui.plot_widget.view
view.setLabels(left=ylabel, bottom=xlabel)
Expand All @@ -233,54 +231,3 @@ def plot(self, dataframes: Mapping | None = None):
def clear_plot(self):
self.ui.plot_widget.getImageItem().clear()
self.data_items = {}


# class Browser2DPlotWidget(ImageView):
# """A plot widget for 2D maps."""

# def __init__(self, *args, view=None, **kwargs):
# if view is None:
# view = PlotItem()
# super().__init__(*args, view=view, **kwargs)

# def plot_runs(
# self, runs: Mapping, xlabel: str = "", ylabel: str = "", extents=None
# ):
# """Take loaded 2D or 3D mapping data and plot it.

# Parameters
# ==========
# runs
# Dictionary with pandas series for each curve. The keys
# should be the curve labels, the series' indexes are the x
# values and the series' values are the y data.
# xlabel
# The label for the horizontal axis.
# ylabel
# The label for the vertical axis.
# extents
# Spatial extents for the map as ((-y, +y), (-x, +x)).

# """
# images = np.asarray(list(runs.values()))
# # Combine the different runs into one image
# # To-do: make this respond to the combobox selection
# image = np.mean(images, axis=0)
# # To-do: Apply transformations

# # # Plot the image
# if 2 <= image.ndim <= 3:
# self.setImage(image.T, autoRange=False)
# else:
# log.info(f"Could not plot image of dataset with shape {image.shape}.")
# return
# # Determine the axes labels
# self.view.setLabel(axis="bottom", text=xlabel)
# self.view.setLabel(axis="left", text=ylabel)
# # Set axes extent
# yextent, xextent = extents
# x = xextent[0]
# y = yextent[0]
# w = xextent[1] - xextent[0]
# h = yextent[1] - yextent[0]
# self.getImageItem().setRect(x, y, w, h)
14 changes: 5 additions & 9 deletions src/firefly/run_browser/lineplot_view.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import logging
from itertools import count
from pathlib import Path
from typing import Mapping, Sequence

import numpy as np
import pandas as pd
import yaml
from matplotlib.colors import TABLEAU_COLORS
from pandas.api.types import is_numeric_dtype
from pyqtgraph import GraphicsLayoutWidget, ImageView, PlotItem, PlotWidget
import qtawesome as qta
from qtpy import QtCore, QtWidgets, uic
from qtpy.QtCore import Qt, Signal, Slot
from qtpy.QtWidgets import QFileDialog, QWidget
from matplotlib.colors import TABLEAU_COLORS
from pyqtgraph import PlotItem, PlotWidget
from qtpy import QtWidgets, uic
from qtpy.QtCore import Signal, Slot

log = logging.getLogger(__name__)
colors = list(TABLEAU_COLORS.values())
Expand Down Expand Up @@ -148,7 +144,7 @@ def axis_labels(self):
use_reference = self.ui.r_signal_checkbox.checkState()
inverted = self.ui.invert_checkbox.checkState()
logarithm = self.ui.logarithm_checkbox.checkState()
gradient = self.ui.gradient_checkbox.checkState()
gradient = self.ui.gradient_checkbox.checkState()
if use_reference and inverted:
ylabel = f"{rlabel}/{ylabel}"
elif use_reference:
Expand Down
3 changes: 1 addition & 2 deletions src/firefly/run_browser/metadata_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from typing import Mapping

import yaml
from qtpy import QtCore, QtWidgets, uic
from qtpy.QtCore import Signal
from qtpy import QtWidgets, uic


class MetadataView(QtWidgets.QWidget):
Expand Down
16 changes: 4 additions & 12 deletions src/firefly/run_browser/multiplot_view.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
import logging
from itertools import count
from typing import Mapping, Sequence
from pathlib import Path
from typing import Mapping, Sequence

import yaml
from qtpy import QtCore, QtWidgets, uic
from qtpy.QtCore import Signal, Slot
import numpy as np
from matplotlib.colors import TABLEAU_COLORS
from pandas.api.types import is_numeric_dtype
from pyqtgraph import GraphicsLayoutWidget, ImageView, PlotItem, PlotWidget
from qtpy.QtCore import Qt, Signal
from qtpy.QtWidgets import QFileDialog, QWidget


from qtpy import QtWidgets, uic
from qtpy.QtCore import Slot

log = logging.getLogger(__name__)


class MultiplotView(QtWidgets.QWidget):
_multiplot_items: Mapping
ui_file = Path(__file__).parent / "multiplot_view.ui"
Expand Down Expand Up @@ -155,4 +148,3 @@ def multiplot_items(self, n_cols: int = 3):
# view.resize(int(width), int(plot_width * row))
view.setFixedHeight(1200)
yield new_item

1 change: 1 addition & 0 deletions src/firefly/run_browser/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def test_data_frames(worker):
# Check the results
assert uids[0] in data_keys.keys()


@pytest.mark.asyncio
async def test_hints(worker):
uids = (await worker.catalog.client).keys()
Expand Down
33 changes: 0 additions & 33 deletions src/firefly/run_browser/tests/test_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from functools import partial
from unittest.mock import AsyncMock, MagicMock

import numpy as np
import pytest
import time_machine
from ophyd.sim import instantiate_fake_device
from pyqtgraph import ImageItem, ImageView, PlotItem, PlotWidget
from qtpy.QtWidgets import QFileDialog

from firefly.run_browser.display import RunBrowserDisplay
Expand Down Expand Up @@ -103,37 +101,6 @@ async def test_metadata(display, qtbot):
await display.update_selected_runs()


@pytest.mark.xfail
async def test_update_2d_plot(catalog, display):
display.plot_2d_item.setRect = MagicMock()
# Load test data
run = await catalog["85573831-f4b4-4f64-b613-a6007bf03a8d"]
display.db.selected_runs = [run]
await display.update_1d_signals()
# Set the controls to describe the data we want to test
val_combobox = display.ui.signal_value_combobox
val_combobox.addItem("It_net_counts")
val_combobox.setCurrentText("It_net_counts")
display.ui.logarithm_checkbox_2d.setChecked(True)
display.ui.invert_checkbox_2d.setChecked(True)
display.ui.gradient_checkbox_2d.setChecked(True)
# Update the plots
await display.update_2d_plot()
# Determine what the image data should look like
expected_data = await run.__getitem__("It_net_counts", stream="primary")
expected_data = expected_data.reshape((5, 21)).T
# Check that the data were added
image = display.plot_2d_item.image
np.testing.assert_almost_equal(image, expected_data)
# Check that the axes were formatted correctly
axes = display.plot_2d_view.view.axes
xaxis = axes["bottom"]["item"]
yaxis = axes["left"]["item"]
assert xaxis.labelText == "aerotech_horiz"
assert yaxis.labelText == "aerotech_vert"
display.plot_2d_item.setRect.assert_called_with(-100, -80, 200, 160)


def test_busy_hints_run_widgets(display):
"""Check that the display widgets get disabled during DB hits."""
with display.busy_hints(run_widgets=True, run_table=False):
Expand Down
Loading

0 comments on commit 44c6219

Please sign in to comment.