Skip to content

Commit

Permalink
Add input validation for isosurface colour
Browse files Browse the repository at this point in the history
  • Loading branch information
MBartkowiakSTFC committed Feb 17, 2025
1 parent 4244dda commit 570c1a1
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 19 deletions.
20 changes: 9 additions & 11 deletions MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/MolecularViewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from typing import List, Tuple, Dict, Any
import copy

import numpy as np
from scipy.spatial import cKDTree as KDTree
Expand All @@ -35,6 +36,7 @@
from MDANSE_GUI.MolecularViewer.readers import hdf5wrapper
from MDANSE_GUI.MolecularViewer.Dummy import PyConnectivity
from MDANSE_GUI.MolecularViewer.Contents import TrajectoryAtomData
from MDANSE_GUI.MolecularViewer.TraceWidget import TRACE_PARAMETERS
from MDANSE_GUI.MolecularViewer.AtomProperties import (
AtomProperties,
ndarray_to_vtkarray,
Expand Down Expand Up @@ -245,16 +247,12 @@ def _draw_isosurface(self, index: int, params=None):
return

LOG.info("Computing isosurface ...")
if params is not None:
fine_sampling = params.get("fine_sampling", 5)
r, g, b = params.get("surface_colour", (0, 0.5, 0.75))
opacity = params.get("surface_opacity", 0.5)
trace_cutoff = params.get("trace_cutoff", 90)
else:
fine_sampling = 5
r, g, b = 0, 0.5, 0.75
opacity = 0.5
trace_cutoff = 90
if params is None:
params = copy.copy(TRACE_PARAMETERS)
fine_sampling = params.get("fine_sampling", 5)
rgb = params.get("surface_colour", (0, 0.5, 0.75))
opacity = params.get("surface_opacity", 0.5)
trace_cutoff = params.get("trace_cutoff", 90)

coords = self._reader.read_atom_trajectory(index)
element = self._reader._atom_types[index]
Expand Down Expand Up @@ -304,7 +302,7 @@ def _draw_isosurface(self, index: int, params=None):

new_surface = vtk.vtkActor()
new_surface.SetMapper(mapper)
new_surface.GetProperty().SetColor((r, g, b))
new_surface.GetProperty().SetColor(rgb)
new_surface.GetProperty().SetOpacity(opacity)
new_surface.PickableOff()

Expand Down
61 changes: 53 additions & 8 deletions MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/TraceWidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from typing import TYPE_CHECKING
from contextlib import suppress
import copy

from qtpy.QtCore import Signal, Slot
from qtpy.QtGui import QValidator
from qtpy.QtWidgets import (
QWidget,
QVBoxLayout,
Expand All @@ -32,6 +34,45 @@
from MDANSE_GUI.MolecularViewer.MolecularViewer import MolecularViewer


TRACE_PARAMETERS = {
"atom_number": 0,
"fine_sampling": 3,
"surface_colour": (0, 0.5, 0.75),
"surface_opacity": 0.5,
"trace_cutoff": 5,
"surface_number": -1,
}


class RGBValidator(QValidator):

def __init__(self, parent=None):
super().__init__(parent)

def validate(self, input_string: str, position: int):
state = QValidator.State.Intermediate
comma_count = input_string.count(",")
if len(input_string) > 0:
try:
rgb = [int(x) for x in input_string.split(",")]
except (TypeError, ValueError):
if input_string[-1] == "," and comma_count < 3:
state = QValidator.State.Intermediate
else:
state = QValidator.State.Invalid
else:
if len(rgb) > 3:
state = QValidator.State.Invalid
elif len(rgb) == 3:
if all([(x >= 0) and (x < 256) for x in rgb]):
state = QValidator.State.Acceptable
elif any([x > 255 for x in rgb]):
state = QValidator.State.Invalid
else:
state = QValidator.State.Intermediate
return state, input_string, position


class TraceWidget(QWidget):

new_atom_trace = Signal(dict)
Expand Down Expand Up @@ -70,6 +111,10 @@ def populate_layout(self):
self._grid_spinbox = QSpinBox(self)
self._opacity_spinbox = QDoubleSpinBox(self)
self._colour_lineedit = QLineEdit("0,128,192", self)
self._colour_lineedit.setPlaceholderText("0,128,192 (red,green,blue)")
self._colour_validator = RGBValidator(self._colour_lineedit)
self._colour_lineedit.setValidator(self._colour_validator)
self._colour_lineedit.textChanged.connect(self.check_rgb)
for sbox in [
self._atom_spinbox,
self._surface_spinbox,
Expand Down Expand Up @@ -116,21 +161,21 @@ def update_limits(self):
self._surface_spinbox.setMaximum(max(len(self._molviewer._surfaces) - 1, 0))
self.enable_buttons()

@Slot(str)
def check_rgb(self, colour_string: str):
tokens = colour_string.split(",")
non_empty = all([len(token) > 0 for token in tokens])
colour_count = len(tokens)
self.add_trace_button.setEnabled(non_empty and colour_count == 3)

def enable_buttons(self):
if self._molviewer is None:
return
self.remove_trace_button.setEnabled(len(self._molviewer._surfaces) != 0)
self.add_trace_button.setEnabled(self._molviewer._n_atoms > 0)

def get_values(self):
params = {
"atom_number": 0,
"fine_sampling": 3,
"surface_colour": (0, 0.5, 0.75),
"surface_opacity": 0.5,
"trace_cutoff": 5,
"surface_number": -1,
}
params = copy.copy(TRACE_PARAMETERS)
params["atom_number"] = self._atom_spinbox.value()
params["surface_number"] = self._surface_spinbox.value()
with suppress(ValueError, TypeError):
Expand Down

0 comments on commit 570c1a1

Please sign in to comment.