Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated run browser for new relational database #342

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,20 @@ async def filters(sim_registry):
),
},
metadata={
"plan_name": "xafs_scan",
"start": {
"plan_name": "xafs_scan",
"esaf_id": "1337",
"proposal_id": "158839",
"beamline_id": "255-ID-Z",
"sample_name": "NMC-532",
"sample_formula": "LiNi0.5Mn0.3Co0.2O2",
"edge": "Ni-K",
"uid": "7d1daf1d-60c7-4aa7-a668-d1cd97e5335f",
"hints": {"dimensions": [[["energy_energy"], "primary"]]},
},
"stop": {
"exit_status": "success",
},
},
),
"9d33bf66-9701-4ee3-90f4-3be730bc226c": MapAdapter(
Expand Down Expand Up @@ -357,6 +365,7 @@ async def filters(sim_registry):

mapping = {
"255id_testing": MapAdapter(bluesky_mapping),
"255bm_testing": MapAdapter(bluesky_mapping),
}

tree = MapAdapter(mapping)
Expand All @@ -367,13 +376,12 @@ def tiled_client():
app = build_app(tree)
with Context.from_app(app) as context:
client = from_context(context)
yield client["255id_testing"]
yield client


@pytest.fixture()
def catalog(tiled_client):
cat = Catalog(client=tiled_client)
# cat = mock.AsyncMock()
cat = Catalog(client=tiled_client["255id_testing"])
return cat


Expand Down
19 changes: 10 additions & 9 deletions src/firefly/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from qtpy.QtGui import QIcon, QKeySequence
from qtpy.QtWidgets import QAction, QErrorMessage

from haven import beamline, load_config
from haven import beamline, load_config, tiled_client
from haven.exceptions import ComponentNotFound, InvalidConfiguration
from haven.utils import titleize

Expand Down Expand Up @@ -333,11 +333,18 @@ async def finalize_new_window(self, action):
# Send the current devices to the window
await action.window.update_devices(self.registry)

def finalize_run_browser_window(self, action):
"""Connect up signals that are specific to the run browser window."""
@asyncSlot(QAction)
async def finalize_run_browser_window(self, action):
"""Connect up run browser signals and load initial data."""
display = action.display
self.run_updated.connect(display.update_running_scan)
self.run_stopped.connect(display.update_running_scan)
# Set initial state for the run_browser
client = tiled_client(catalog=None)
config = load_config()["database"]["tiled"]
await display.setup_database(
tiled_client=client, catalog_name=config["default_catalog"]
)

def finalize_status_window(self, action):
"""Connect up signals that are specific to the voltmeters window."""
Expand Down Expand Up @@ -652,12 +659,6 @@ async def add_queue_item(self, item):
if getattr(self, "_queue_client", None) is not None:
await self._queue_client.add_queue_item(item)

@QtCore.Slot()
def show_sample_viewer_window(self):
return self.show_window(
FireflyMainWindow, ui_dir / "sample_viewer.ui", name="sample_viewer"
)

@QtCore.Slot(bool)
def set_open_environment_action_state(self, is_open: bool):
"""Update the readback value for opening the queueserver environment."""
Expand Down
3 changes: 2 additions & 1 deletion src/firefly/kafka_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import warnings
from uuid import uuid4

import msgpack
from aiokafka import AIOKafkaConsumer
Expand Down Expand Up @@ -40,7 +41,7 @@ async def consumer_loop(self):
self.kafka_consumer = AIOKafkaConsumer(
config["queueserver"]["kafka_topic"],
bootstrap_servers="fedorov.xray.aps.anl.gov:9092",
group_id="my-group",
group_id=str(uuid4()),
value_deserializer=msgpack.loads,
)
consumer = self.kafka_consumer
Expand Down
118 changes: 78 additions & 40 deletions src/firefly/run_browser/client.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,97 @@
import asyncio
import datetime as dt
import logging
import warnings
from collections import OrderedDict
from functools import partial
from typing import Mapping, Sequence

import numpy as np
import pandas as pd
from qasync import asyncSlot
from tiled import queries

from haven import exceptions
from haven.catalog import Catalog
from haven.catalog import Catalog, run_in_executor

log = logging.getLogger(__name__)


class DatabaseWorker:
selected_runs: Sequence = []
catalog: Catalog = None

def __init__(self, catalog=None, *args, **kwargs):
if catalog is None:
catalog = Catalog()
self.catalog = catalog
def __init__(self, tiled_client, *args, **kwargs):
self.client = tiled_client
super().__init__(*args, **kwargs)

@asyncSlot(str)
async def change_catalog(self, catalog_name: str):
"""Change the catalog being used for pulling data.

*catalog_name* should be an entry in *worker.tiled_client()*.
"""

def get_catalog(name):
return Catalog(self.client[catalog_name])

loop = asyncio.get_running_loop()
self.catalog = await loop.run_in_executor(None, get_catalog, catalog_name)

@run_in_executor
def catalog_names(self):
return list(self.client.keys())

async def stream_names(self):
awaitables = [scan.stream_names() for scan in self.selected_runs]
all_streams = await asyncio.gather(*awaitables)
# Flatten the lists
streams = [stream for streams in all_streams for stream in streams]
return list(set(streams))

async def filtered_nodes(self, filters: Mapping):
case_sensitive = False
log.debug(f"Filtering nodes: {filters}")
filter_params = [
# (filter_name, query type, metadata key)
("user", queries.Contains, "start.proposal_users"),
("proposal", queries.Eq, "start.proposal_id"),
("esaf", queries.Eq, "start.esaf_id"),
("sample", queries.Contains, "start.sample_name"),
("exit_status", queries.Eq, "stop.exit_status"),
("plan", queries.Eq, "start.plan_name"),
("edge", queries.Contains, "start.edge"),
]
filter_params = {
# filter_name: (query type, metadata key)
"plan": (queries.Eq, "start.plan_name"),
"sample": (queries.Contains, "start.sample_name"),
"formula": (queries.Contains, "start.sample_formula"),
"edge": (queries.Contains, "start.edge"),
"exit_status": (queries.Eq, "stop.exit_status"),
"user": (queries.Contains, "start.proposal_users"),
"proposal": (queries.Eq, "start.proposal_id"),
"esaf": (queries.Eq, "start.esaf_id"),
"beamline": (queries.Eq, "start.beamline_id"),
"before": (partial(queries.Comparison, "le"), "end.time"),
"after": (partial(queries.Comparison, "ge"), "start.time"),
"full_text": (queries.FullText, ""),
"standards_only": (queries.Eq, "start.is_standard"),
}
# Apply filters
runs = self.catalog
for filter_name, Query, md_name in filter_params:
val = filters.get(filter_name, "")
if val != "":
runs = await runs.search(Query(md_name, val))
full_text = filters.get("full_text", "")
if full_text != "":
runs = await runs.search(
queries.FullText(full_text, case_sensitive=case_sensitive)
)
for filter_name, filter_value in filters.items():
if filter_name not in filter_params:
continue
Query, md_name = filter_params[filter_name]
if Query is queries.FullText:
runs = await runs.search(Query(filter_value), case_sensitive=False)
else:
runs = await runs.search(Query(md_name, filter_value))
return runs

async def load_distinct_fields(self):
"""Get distinct metadata fields for filterable metadata."""
new_fields = {}
target_fields = [
"sample_name",
"proposal_users",
"proposal_id",
"esaf_id",
"sample_name",
"plan_name",
"edge",
"start.plan_name",
"start.sample_name",
"start.sample_formula",
"start.edge",
"stop.exit_status",
"start.proposal_id",
"start.esaf_id",
"start.beamline_id",
]
# Get fields from the database
response = await self.catalog.distinct(*target_fields)
Expand Down Expand Up @@ -118,11 +150,13 @@ async def load_all_runs(self, filters: Mapping = {}):
all_runs.append(run_data)
return all_runs

async def signal_names(self, hinted_only: bool = False):
async def signal_names(self, stream: str, *, hinted_only: bool = False):
"""Get a list of valid signal names (data columns) for selected runs.

Parameters
==========
stream
The Tiled stream name to fetch.
hinted_only
If true, only signals with the kind="hinted" parameter get
picked.
Expand All @@ -131,9 +165,9 @@ async def signal_names(self, hinted_only: bool = False):
xsignals, ysignals = [], []
for run in self.selected_runs:
if hinted_only:
xsig, ysig = await run.hints()
xsig, ysig = await run.hints(stream=stream)
else:
df = await run.data()
df = await run.data(stream=stream)
xsig = ysig = df.columns
xsignals.extend(xsig)
ysignals.extend(ysig)
Expand All @@ -160,32 +194,34 @@ async def load_selected_runs(self, uids):
self.selected_runs = runs
return runs

async def images(self, signal):
async def images(self, signal: str, stream: str):
"""Load the selected runs as 2D or 3D images suitable for plotting."""
images = OrderedDict()
for idx, run in enumerate(self.selected_runs):
# Load datasets from the database
try:
image = await run[signal]
image = await run.__getitem__(signal, stream=stream)
except KeyError as exc:
log.exception(exc)
else:
images[run.uid] = image
return images

async def all_signals(self, hinted_only=False) -> dict:
async def all_signals(self, stream: str, *, hinted_only=False) -> dict:
"""Produce dataframes with all signals for each run.

The keys of the dictionary are the labels for each curve, and
the corresponding value is a pandas dataframe with the scan data.

"""
xsignals, ysignals = await self.signal_names(hinted_only=hinted_only)
xsignals, ysignals = await self.signal_names(
hinted_only=hinted_only, stream=stream
)
# Build the dataframes
dfs = OrderedDict()
for run in self.selected_runs:
# Get data from the database
df = await run.data(signals=xsignals + ysignals)
df = await run.data(signals=xsignals + ysignals, stream=stream)
dfs[run.uid] = df
return dfs

Expand All @@ -194,6 +230,8 @@ async def signals(
x_signal,
y_signal,
r_signal=None,
*,
stream: str,
use_log=False,
use_invert=False,
use_grad=False,
Expand Down Expand Up @@ -233,7 +271,7 @@ async def signals(
if uids is not None and run.uid not in uids:
break
# Get data from the database
df = await run.data(signals=signals)
df = await run.data(signals=signals, stream=stream)
# Check for missing signals
missing_x = x_signal not in df.columns and df.index.name != x_signal
missing_y = y_signal not in df.columns
Expand Down
Loading