Skip to content

Commit

Permalink
Merge pull request #347 from ISISNeutronMuon/gui-atom-tree
Browse files Browse the repository at this point in the history
Atom Properties speed-up in 3D view
  • Loading branch information
ChiCheng45 authored Mar 11, 2024
2 parents 905aed1 + 341212b commit cbd8d7d
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 136 deletions.
121 changes: 73 additions & 48 deletions MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/AtomProperties.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
#
# **************************************************************************

from typing import List

import numpy as np
from vtk.util.numpy_support import numpy_to_vtk
import vtk
from qtpy.QtGui import QStandardItemModel, QStandardItem, QColor
from qtpy.QtCore import Signal, Slot
from qtpy.QtCore import Signal, Slot, QObject

RGB_COLOURS = []
RGB_COLOURS.append((1.00, 0.20, 1.00)) # selection
RGB_COLOURS.append((1.00, 0.90, 0.90)) # background


def ndarray_to_vtkarray(colors, scales, n_atoms):
def ndarray_to_vtkarray(colors, scales, indices):
"""Convert the colors and scales NumPy arrays to vtk arrays.
Args:
Expand All @@ -31,30 +34,15 @@ def ndarray_to_vtkarray(colors, scales, n_atoms):
n_atoms (int): the number of atoms
"""
# define the colours
color_scalars = vtk.vtkFloatArray()
color_scalars.SetNumberOfValues(len(colors))
# print("colors")
for i, c in enumerate(colors):
# print(i,c)
color_scalars.SetValue(i, c)
color_scalars = numpy_to_vtk(colors)
color_scalars.SetName("colors")

# some scales
scales_scalars = vtk.vtkFloatArray()
scales_scalars.SetNumberOfValues(len(scales))
# print("scales")
for i, r in enumerate(scales):
scales_scalars.SetValue(i, r)
# print(i,r)
scales_scalars = numpy_to_vtk(scales)
scales_scalars.SetName("scales")

# the original index
index_scalars = vtk.vtkIntArray()
index_scalars.SetNumberOfValues(n_atoms)
# print("index")
for i in range(n_atoms):
index_scalars.SetValue(i, i)
# print(i,i)
index_scalars = numpy_to_vtk(indices)
index_scalars.SetName("index")

scalars = vtk.vtkFloatArray()
Expand All @@ -67,6 +55,34 @@ def ndarray_to_vtkarray(colors, scales, n_atoms):
return scalars


class AtomEntry(QObject):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._index = None
self._indices = []
self._items = []

def set_values(
self, name: str, indices: List[int], colour: QColor, size: float
) -> List[QStandardItem]:
self._indices = indices
name_item = QStandardItem(str(name))
name_item.setEditable(False)
colour_item = QStandardItem(str(colour))
size_item = QStandardItem(str(size))
self._items = [name_item, colour_item, size_item]
return self._items

def colour(self) -> QColor:
return QColor(self._items[1].text())

def size(self) -> float:
return float(self._items[2].text())

def indices(self) -> List[int]:
return self._indices


class AtomProperties(QStandardItemModel):
new_atom_properties = Signal(object)

Expand All @@ -79,8 +95,10 @@ def __init__(self, *args, init_colours: list = None, **kwargs):
else:
self._colour_list = init_colours
self.rebuild_colours()
self.setHorizontalHeaderLabels(["Index", "Element", "Radius", "Colour"])
self.setHorizontalHeaderLabels(["Element", "Colour", "Radius"])
self.itemChanged.connect(self.onNewValues)
self._groups = []
self._total_length = 0

def clear_table(self):
"""This was meant to be used for cleaning up,
Expand Down Expand Up @@ -138,37 +156,44 @@ def reinitialise_from_database(
list[int] -- a list of indices of colours, with one numbed per atom.
"""
self.removeRows(0, self.rowCount())

index_list = []
for nat, atom in enumerate(atoms):
row = []
self._groups = []
self._total_length = 0

all_atoms = np.array(atoms)
unique_atoms = np.unique(all_atoms)
indices = np.arange(len(all_atoms))
groups = {}
for unique in unique_atoms:
groups[unique] = indices[np.where(all_atoms == unique)]

colour_index_list = []
for atom in unique_atoms:
atom_entry = AtomEntry()
rgb = [int(x) for x in element_database[atom]["color"].split(";")]
index_list.append(self.add_colour(rgb))
row.append(QStandardItem(str(nat + 1))) # atom number
row.append(QStandardItem(atom)) # chemical element name
row.append(
QStandardItem(str(round(element_database[atom]["atomic_radius"], 2)))
colour_index_list.append(self.add_colour(rgb))
item_row = atom_entry.set_values(
atom,
groups[atom],
QColor(rgb[0], rgb[1], rgb[2]).name(QColor.NameFormat.HexRgb),
round(element_database[atom]["vdw_radius"], 2),
)
row.append(
QStandardItem(
QColor(rgb[0], rgb[1], rgb[2]).name(QColor.NameFormat.HexRgb)
)
)
self.appendRow(row)
return index_list
self.appendRow(item_row)
self._groups.append(atom_entry)
self._total_length = len(all_atoms)
return colour_index_list

@Slot()
def onNewValues(self):
colours = []
radii = []
numbers = []
for row in range(self.rowCount()):
colour = QColor(self.item(row, 3).text())
colours = np.empty(self._total_length, dtype=int)
radii = np.empty(self._total_length, dtype=float)
numbers = np.arange(self._total_length)
for entry in self._groups:
colour = entry.colour()
red, green, blue = colour.red(), colour.green(), colour.blue()
colours.append(self.add_colour((red, green, blue)))
radius = float(self.item(row, 2).text())
radii.append(radius)
numbers.append(int(self.item(row, 0).text()))

scalars = ndarray_to_vtkarray(colours, radii, len(numbers))
vtk_colour = self.add_colour((red, green, blue))
radius = entry.size()
indices = entry.indices()
radii[indices] = radius
colours[indices] = vtk_colour
scalars = ndarray_to_vtkarray(colours, radii, numbers)
self.new_atom_properties.emit(scalars)
31 changes: 28 additions & 3 deletions MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/Controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
QPushButton,
QStyle,
QSizePolicy,
QTableView,
QTreeView,
QDoubleSpinBox,
QColorDialog,
QGroupBox,
QCheckBox,
)
from qtpy.QtGui import (
QDoubleValidator,
Expand All @@ -39,7 +40,6 @@
)

from MDANSE_GUI.MolecularViewer.MolecularViewer import MolecularViewer
from MDANSE_GUI.MolecularViewer.Contents import TrajectoryAtomData

button_lookup = {
"start": QStyle.StandardPixmap.SP_MediaSkipBackward,
Expand Down Expand Up @@ -125,6 +125,7 @@ def __init__(self, *args, **kwargs):
self._frame_step = 1
self._time_per_frame = 80 # in ms
self._frame_factor = 1 # just a scalar multiplication factor
self._visibility = [True, True, True, True]
self.createSlider()
self.createButtons(Qt.Orientation.Horizontal)
self.createSidePanel()
Expand Down Expand Up @@ -200,7 +201,7 @@ def createSidePanel(self):
# the table of chemical elements
wrapper1 = QGroupBox("Atom properties", base)
layout1 = QHBoxLayout(wrapper1)
self._atom_details = QTableView(base)
self._atom_details = QTreeView(base)
self._atom_details.setSizePolicy(
QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Maximum
)
Expand Down Expand Up @@ -244,10 +245,34 @@ def createSidePanel(self):
layout4.addWidget(size_factor)
wrapper4.setLayout(layout4)
layout.addWidget(wrapper4)
wrapper5 = QGroupBox("Visible Objects", base)
layout5 = QHBoxLayout(wrapper5)
atoms_visible = QCheckBox("atoms", wrapper5)
bonds_visible = QCheckBox("bonds", wrapper5)
axes_visible = QCheckBox("axes", wrapper5)
cell_visible = QCheckBox("cell", wrapper5)
self._visibility_checkboxes = [
atoms_visible,
bonds_visible,
axes_visible,
cell_visible,
]
for nw, box in enumerate(self._visibility_checkboxes):
box.setTristate(False)
box.setChecked(self._visibility[nw])
box.stateChanged.connect(self.setVisibility)
layout5.addWidget(box)
layout.addWidget(wrapper5)
# the database of atom types
# self._database = TrajectoryAtomData()
self.layout().addWidget(base, 0, 2, 2, 1) # row, column, rowSpan, columnSpan

@Slot()
def setVisibility(self):
for nw, box in enumerate(self._visibility_checkboxes):
self._visibility[nw] = box.isChecked()
self._viewer._new_visibility(self._visibility)

@Slot(int)
def setTimeStep(self, new_value: int):
self._time_per_frame = new_value
Expand Down
Loading

0 comments on commit cbd8d7d

Please sign in to comment.