diff --git a/MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/MolecularViewer.py b/MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/MolecularViewer.py index 469a1d48e..7ff10d9ab 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/MolecularViewer.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/MolecularViewer.py @@ -14,6 +14,7 @@ # along with this program. If not, see . # from typing import List, Tuple, Dict, Any +import copy import numpy as np from scipy.spatial import cKDTree as KDTree @@ -35,6 +36,7 @@ from MDANSE_GUI.MolecularViewer.readers import hdf5wrapper from MDANSE_GUI.MolecularViewer.Dummy import PyConnectivity from MDANSE_GUI.MolecularViewer.Contents import TrajectoryAtomData +from MDANSE_GUI.MolecularViewer.TraceWidget import TRACE_PARAMETERS from MDANSE_GUI.MolecularViewer.AtomProperties import ( AtomProperties, ndarray_to_vtkarray, @@ -245,16 +247,12 @@ def _draw_isosurface(self, index: int, params=None): return LOG.info("Computing isosurface ...") - if params is not None: - fine_sampling = params.get("fine_sampling", 5) - r, g, b = params.get("surface_colour", (0, 0.5, 0.75)) - opacity = params.get("surface_opacity", 0.5) - trace_cutoff = params.get("trace_cutoff", 90) - else: - fine_sampling = 5 - r, g, b = 0, 0.5, 0.75 - opacity = 0.5 - trace_cutoff = 90 + if params is None: + params = copy.copy(TRACE_PARAMETERS) + fine_sampling = params.get("fine_sampling", 5) + rgb = params.get("surface_colour", (0, 0.5, 0.75)) + opacity = params.get("surface_opacity", 0.5) + trace_cutoff = params.get("trace_cutoff", 90) coords = self._reader.read_atom_trajectory(index) element = self._reader._atom_types[index] @@ -304,7 +302,7 @@ def _draw_isosurface(self, index: int, params=None): new_surface = vtk.vtkActor() new_surface.SetMapper(mapper) - new_surface.GetProperty().SetColor((r, g, b)) + new_surface.GetProperty().SetColor(rgb) new_surface.GetProperty().SetOpacity(opacity) new_surface.PickableOff() diff --git a/MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/TraceWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/TraceWidget.py index 36ca81569..588820541 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/TraceWidget.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/TraceWidget.py @@ -16,8 +16,10 @@ from typing import TYPE_CHECKING from contextlib import suppress +import copy from qtpy.QtCore import Signal, Slot +from qtpy.QtGui import QValidator from qtpy.QtWidgets import ( QWidget, QVBoxLayout, @@ -32,6 +34,45 @@ from MDANSE_GUI.MolecularViewer.MolecularViewer import MolecularViewer +TRACE_PARAMETERS = { + "atom_number": 0, + "fine_sampling": 3, + "surface_colour": (0, 0.5, 0.75), + "surface_opacity": 0.5, + "trace_cutoff": 5, + "surface_number": -1, +} + + +class RGBValidator(QValidator): + + def __init__(self, parent=None): + super().__init__(parent) + + def validate(self, input_string: str, position: int): + state = QValidator.State.Intermediate + comma_count = input_string.count(",") + if len(input_string) > 0: + try: + rgb = [int(x) for x in input_string.split(",")] + except (TypeError, ValueError): + if input_string[-1] == "," and comma_count < 3: + state = QValidator.State.Intermediate + else: + state = QValidator.State.Invalid + else: + if len(rgb) > 3: + state = QValidator.State.Invalid + elif len(rgb) == 3: + if all([(x >= 0) and (x < 256) for x in rgb]): + state = QValidator.State.Acceptable + elif any([x > 255 for x in rgb]): + state = QValidator.State.Invalid + else: + state = QValidator.State.Intermediate + return state, input_string, position + + class TraceWidget(QWidget): new_atom_trace = Signal(dict) @@ -70,6 +111,10 @@ def populate_layout(self): self._grid_spinbox = QSpinBox(self) self._opacity_spinbox = QDoubleSpinBox(self) self._colour_lineedit = QLineEdit("0,128,192", self) + self._colour_lineedit.setPlaceholderText("0,128,192 (red,green,blue)") + self._colour_validator = RGBValidator(self._colour_lineedit) + self._colour_lineedit.setValidator(self._colour_validator) + self._colour_lineedit.textChanged.connect(self.check_rgb) for sbox in [ self._atom_spinbox, self._surface_spinbox, @@ -116,6 +161,13 @@ def update_limits(self): self._surface_spinbox.setMaximum(max(len(self._molviewer._surfaces) - 1, 0)) self.enable_buttons() + @Slot(str) + def check_rgb(self, colour_string: str): + tokens = colour_string.split(",") + non_empty = all([len(token) > 0 for token in tokens]) + colour_count = len(tokens) + self.add_trace_button.setEnabled(non_empty and colour_count == 3) + def enable_buttons(self): if self._molviewer is None: return @@ -123,14 +175,7 @@ def enable_buttons(self): self.add_trace_button.setEnabled(self._molviewer._n_atoms > 0) def get_values(self): - params = { - "atom_number": 0, - "fine_sampling": 3, - "surface_colour": (0, 0.5, 0.75), - "surface_opacity": 0.5, - "trace_cutoff": 5, - "surface_number": -1, - } + params = copy.copy(TRACE_PARAMETERS) params["atom_number"] = self._atom_spinbox.value() params["surface_number"] = self._surface_spinbox.value() with suppress(ValueError, TypeError):