Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move from mongodb to tiled-backed catalog #339

Merged
merged 16 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ dependencies:
- bluesky-queueserver-api
- bluesky-widgets
- bluesky-adaptive
- bluesky >=1.8.1
# - bluesky >=1.13.1
- git+https://github.com/bluesky/[email protected] # Replace with pypi version once released
- ophyd >=1.6.3
- ophyd-async >=0.9.0a1
- apstools == 1.6.20 # Leave at 1.6.20 until this is fixed: https://github.com/BCDA-APS/apstools/issues/1022
Expand Down
58 changes: 34 additions & 24 deletions src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ophyd import Kind
from ophyd.sim import instantiate_fake_device, make_fake_device
from tiled.adapters.mapping import MapAdapter
from tiled.adapters.xarray import DatasetAdapter
from tiled.adapters.table import TableAdapter
from tiled.client import Context, from_context
from tiled.server.app import build_app

Expand Down Expand Up @@ -260,10 +260,11 @@ def filters(sim_registry):
run1 = pd.DataFrame(
{
"energy_energy": np.linspace(8300, 8400, num=100),
"energy_id_energy_readback": np.linspace(8.3, 8.4, num=100),
"It_net_counts": np.abs(np.sin(np.linspace(0, 4 * np.pi, num=100))),
"I0_net_counts": np.linspace(1, 2, num=100),
}
).to_xarray()
)

grid_scan = pd.DataFrame(
{
Expand All @@ -272,7 +273,7 @@ def filters(sim_registry):
"aerotech_horiz": np.linspace(0, 104, num=105),
"aerotech_vert": np.linspace(0, 104, num=105),
}
).to_xarray()
)

hints = {
"energy": {"fields": ["energy_energy", "energy_id_energy_readback"]},
Expand All @@ -283,9 +284,13 @@ def filters(sim_registry):
{
"primary": MapAdapter(
{
"data": DatasetAdapter.from_dataset(run1),
"internal": MapAdapter(
{
"events": TableAdapter.from_pandas(run1),
}
),
},
metadata={"descriptors": [{"hints": hints}]},
metadata={"hints": hints},
),
},
metadata={
Expand All @@ -301,42 +306,47 @@ def filters(sim_registry):
{
"primary": MapAdapter(
{
"data": DatasetAdapter.from_dataset(run1),
"internal": MapAdapter(
{
"events": TableAdapter.from_pandas(run1),
}
),
},
metadata={"descriptors": [{"hints": hints}]},
metadata={"hints": hints},
),
},
metadata={
"plan_name": "rel_scan",
"start": {
"plan_name": "rel_scan",
"uid": "9d33bf66-9701-4ee3-90f4-3be730bc226c",
"hints": {"dimensions": [[["pitch2"], "primary"]]},
}
},
},
),
# 2D grid scan map data
"85573831-f4b4-4f64-b613-a6007bf03a8d": MapAdapter(
{
"primary": MapAdapter(
{
"data": DatasetAdapter.from_dataset(grid_scan),
"internal": MapAdapter(
{
"events": TableAdapter.from_pandas(grid_scan),
},
),
},
metadata={
"descriptors": [
{
"hints": {
"Ipreslit": {"fields": ["Ipreslit_net_counts"]},
"CdnIPreKb": {"fields": ["CdnIPreKb_net_counts"]},
"I0": {"fields": ["I0_net_counts"]},
"CdnIt": {"fields": ["CdnIt_net_counts"]},
"aerotech_vert": {"fields": ["aerotech_vert"]},
"aerotech_horiz": {"fields": ["aerotech_horiz"]},
"Ipre_KB": {"fields": ["Ipre_KB_net_counts"]},
"CdnI0": {"fields": ["CdnI0_net_counts"]},
"It": {"fields": ["It_net_counts"]},
}
}
]
"hints": {
"Ipreslit": {"fields": ["Ipreslit_net_counts"]},
"CdnIPreKb": {"fields": ["CdnIPreKb_net_counts"]},
"I0": {"fields": ["I0_net_counts"]},
"CdnIt": {"fields": ["CdnIt_net_counts"]},
"aerotech_vert": {"fields": ["aerotech_vert"]},
"aerotech_horiz": {"fields": ["aerotech_horiz"]},
"Ipre_KB": {"fields": ["Ipre_KB_net_counts"]},
"CdnI0": {"fields": ["CdnI0_net_counts"]},
"It": {"fields": ["It_net_counts"]},
},
},
),
},
Expand Down
33 changes: 15 additions & 18 deletions src/firefly/run_browser/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,20 @@ async def filtered_nodes(self, filters: Mapping):
log.debug(f"Filtering nodes: {filters}")
filter_params = [
# (filter_name, query type, metadata key)
("user", queries.Regex, "proposal_users"),
("proposal", queries.Regex, "proposal_id"),
("esaf", queries.Regex, "esaf_id"),
("sample", queries.Regex, "sample_name"),
# ('exit_status', queries.Regex, "exit_status"),
("plan", queries.Regex, "plan_name"),
("edge", queries.Regex, "edge"),
("user", queries.Contains, "start.proposal_users"),
("proposal", queries.Eq, "start.proposal_id"),
("esaf", queries.Eq, "start.esaf_id"),
("sample", queries.Contains, "start.sample_name"),
("exit_status", queries.Eq, "stop.exit_status"),
("plan", queries.Eq, "start.plan_name"),
("edge", queries.Contains, "start.edge"),
]
# Apply filters
runs = self.catalog
for filter_name, Query, md_name in filter_params:
val = filters.get(filter_name, "")
if val != "":
runs = await runs.search(
Query(md_name, val, case_sensitive=case_sensitive)
)
runs = await runs.search(Query(md_name, val))
full_text = filters.get("full_text", "")
if full_text != "":
runs = await runs.search(
Expand Down Expand Up @@ -135,7 +133,7 @@ async def signal_names(self, hinted_only: bool = False):
if hinted_only:
xsig, ysig = await run.hints()
else:
df = await run.to_dataframe()
df = await run.data()
xsig = ysig = df.columns
xsignals.extend(xsig)
ysignals.extend(ysig)
Expand All @@ -156,7 +154,6 @@ async def metadata(self):
async def load_selected_runs(self, uids):
# Prepare the query for finding the runs
uids = list(dict.fromkeys(uids))
print(f"Loading runs: {uids}")
# Retrieve runs from the database
runs = [await self.catalog[uid] for uid in uids]
# runs = await asyncio.gather(*run_coros)
Expand All @@ -170,14 +167,14 @@ async def images(self, signal):
# Load datasets from the database
try:
image = await run[signal]
except KeyError:
log.warning(f"Signal {signal} not found in run {run}.")
except KeyError as exc:
log.exception(exc)
else:
images[run.uid] = image
return images

async def all_signals(self, hinted_only=False):
"""Produce dataframe with all signals for each run.
async def all_signals(self, hinted_only=False) -> dict:
"""Produce dataframes with all signals for each run.

The keys of the dictionary are the labels for each curve, and
the corresponding value is a pandas dataframe with the scan data.
Expand All @@ -188,7 +185,7 @@ async def all_signals(self, hinted_only=False):
dfs = OrderedDict()
for run in self.selected_runs:
# Get data from the database
df = await run.to_dataframe(signals=xsignals + ysignals)
df = await run.data(signals=xsignals + ysignals)
dfs[run.uid] = df
return dfs

Expand Down Expand Up @@ -236,7 +233,7 @@ async def signals(
if uids is not None and run.uid not in uids:
break
# Get data from the database
df = await run.to_dataframe(signals=signals)
df = await run.data(signals=signals)
# Check for missing signals
missing_x = x_signal not in df.columns and df.index.name != x_signal
missing_y = y_signal not in df.columns
Expand Down
25 changes: 14 additions & 11 deletions src/firefly/run_browser/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ def customize_ui(self):
self.ui.plot_1d_hints_checkbox.stateChanged.connect(self.update_1d_signals)
self.ui.autorange_1d_button.clicked.connect(self.auto_range)
# Respond to changes in displaying the 2d plot
self.ui.plot_multi_hints_checkbox.stateChanged.connect(
self.update_multi_signals
)
self.ui.plot_multi_hints_checkbox.stateChanged.connect(self.update_multi_plot)
for signal in [
self.ui.plot_multi_hints_checkbox.stateChanged,
self.ui.multi_signal_x_combobox.currentTextChanged,
]:
signal.connect(self.update_multi_signals)
signal.connect(self.update_multi_plot)
# Respond to changes in displaying the 2d plot
self.ui.signal_value_combobox.currentTextChanged.connect(self.update_2d_plot)
self.ui.logarithm_checkbox_2d.stateChanged.connect(self.update_2d_plot)
Expand Down Expand Up @@ -288,10 +290,12 @@ async def update_multi_signals(self, *args):
)
xcols, ycols = await signals_task
# Update the comboboxes with new signals
combobox.clear()
combobox.addItems(xcols)
# Restore previous value
combobox.setCurrentText(old_value)
old_cols = [combobox.itemText(idx) for idx in range(combobox.count())]
if xcols != old_cols:
combobox.clear()
combobox.addItems(xcols)
# Restore previous value
combobox.setCurrentText(old_value)

@asyncSlot()
@cancellable
Expand Down Expand Up @@ -376,7 +380,7 @@ async def export_runs(self):
@asyncSlot(str)
@cancellable
async def update_running_scan(self, uid: str):
print(f"Updating running scan: {uid=}")
log.debug(f"Updating running scan: {uid=}")
await self.update_1d_plot(uids=[uid])

@asyncSlot()
Expand Down Expand Up @@ -432,7 +436,6 @@ async def update_1d_plot(self, *args, uids: Sequence[str] = None):
if use_grad:
ylabel = f"∇ {ylabel}"
# Do the plotting
print("RUNNING", self.ui.plot_1d_view.plot_runs)
self.ui.plot_1d_view.plot_runs(runs, xlabel=xlabel, ylabel=ylabel)
if self.ui.autorange_1d_checkbox.isChecked():
self.ui.plot_1d_view.autoRange()
Expand All @@ -448,7 +451,7 @@ async def update_2d_plot(self):
use_grad = self.ui.gradient_checkbox_2d.isChecked()
images = await self.db_task(self.db.images(value_signal), "2D plot")
# Get axis labels
# Eventually this will be replaced with robus choices for plotting multiple images
# Eventually this will be replaced with robust choices for plotting multiple images
metadata = await self.db_task(self.db.metadata(), "2D plot")
metadata = list(metadata.values())[0]
dimensions = metadata["start"]["hints"]["dimensions"]
Expand Down
3 changes: 2 additions & 1 deletion src/firefly/run_browser/tests/test_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ async def display(qtbot, catalog, mocker):
run = [run async for run in catalog.values()][0]
display.db.selected_runs = [run]
await display.update_1d_signals()
run_data = await run.to_dataframe()
print(run.uid)
run_data = await run.data()
expected_xdata = run_data.energy_energy
expected_ydata = np.log(run_data.I0_net_counts / run_data.It_net_counts)
expected_ydata = np.gradient(expected_ydata, expected_xdata)
Expand Down
1 change: 0 additions & 1 deletion src/firefly/run_browser/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def plot_runs(self, runs: Mapping, ylabel="", xlabel=""):
)
# Cursor to drag around on the data
if self.cursor_line is None:
print("CURSOR LINE: ", np.median(series.index), series.index)
self.cursor_line = plot_item.addLine(
x=np.median(series.index), movable=True, label="{value:.3f}"
)
Expand Down
53 changes: 25 additions & 28 deletions src/haven/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import warnings
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Sequence

import databroker
import numpy as np
import pandas as pd
from tiled.client import from_uri
from tiled.client.cache import Cache

Expand Down Expand Up @@ -172,24 +172,25 @@ def write_safe(self):


def tiled_client(
entry_node=None, uri=None, cache_filepath=None, structure_clients="dask"
entry_node=None, uri=None, cache_filepath=None, structure_clients="numpy"
):
config = load_config()
tiled_config = config["database"].get("tiled", {})
# Create a cache for saving local copies
if cache_filepath is None:
cache_filepath = config["database"].get("tiled", {}).get("cache_filepath", "")
cache_filepath = cache_filepath or None
cache_filepath = tiled_config.get("cache_filepath", "")
if os.access(cache_filepath, os.W_OK):
cache = ThreadSafeCache(filepath=cache_filepath)
else:
warnings.warn(f"Cache file is not writable: {cache_filepath}")
cache = None
# Create the client
if uri is None:
uri = config["database"]["tiled"]["uri"]
client_ = from_uri(uri, structure_clients)
uri = tiled_config["uri"]
api_key = tiled_config.get("api_key")
client_ = from_uri(uri, structure_clients, api_key=api_key)
if entry_node is None:
entry_node = config["database"]["tiled"]["entry_node"]
entry_node = tiled_config["entry_node"]
client_ = client_[entry_node]
return client_

Expand All @@ -206,10 +207,17 @@ def __init__(self, container, executor=None):
self.container = container
self.executor = executor

def _read_data(self, signals, dataset="primary/data"):
# Fetch data if needed
def _read_data(
self, signals: Sequence | None, dataset: str = "primary/internal/events"
):
data = self.container[dataset]
return data.read(signals)
if signals is None:
return data.read()
# Remove duplicates and missing signals
signals = set(signals)
available_signals = set(data.columns)
signals = signals & available_signals
return data.read()

def _read_metadata(self, keys=None):
container = self.container
Expand All @@ -232,31 +240,20 @@ async def export(self, filename: str, format: str):
def formats(self):
return self.container.formats

async def data(self, stream="primary"):
async def data(self, signals=None, stream="primary"):
return await self.loop.run_in_executor(
None, self._read_data, None, f"{stream}/data"
None, self._read_data, signals, f"{stream}/internal/events/"
)

async def to_dataframe(self, signals=None):
"""Convert the dataset into a pandas dataframe."""
xarray = await self.run(self._read_data, signals)
if len(xarray) > 0:
df = xarray.to_dataframe()
# Add a copy of the index to the dataframe itself
if df.index.name is not None:
df[df.index.name] = df.index
else:
df = pd.DataFrame()
return df

@property
def loop(self):
return asyncio.get_running_loop()

def _data_keys(self, stream):
return self.container[stream]["internal/events"].columns

async def data_keys(self, stream="primary"):
stream_md = await self.loop.run_in_executor(None, self._read_metadata, stream)
# Assumes the 0-th descriptor is for the primary stream
return stream_md["descriptors"][0]["data_keys"]
return await self.run(self._data_keys, ("primary",))

async def hints(self):
"""Retrieve the data hints for this scan.
Expand All @@ -277,7 +274,7 @@ async def hints(self):
# Get hints for the dependent (X)
dependent = []
primary_metadata = await self.run(self._read_metadata, "primary")
hints = primary_metadata["descriptors"][0]["hints"]
hints = primary_metadata["hints"]
for device, dev_hints in hints.items():
dependent.extend(dev_hints["fields"])
return independent, dependent
Expand Down
Loading