Skip to content

Commit

Permalink
Replace for loops with vtk helper functions in MolecularViewer
Browse files Browse the repository at this point in the history
  • Loading branch information
MBartkowiakSTFC committed Mar 8, 2024
1 parent f29beef commit 8e92cd8
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 61 deletions.
120 changes: 72 additions & 48 deletions MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/AtomProperties.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
#
# **************************************************************************

from typing import List

import numpy as np
from vtk.util.numpy_support import numpy_to_vtk
import vtk
from qtpy.QtGui import QStandardItemModel, QStandardItem, QColor
from qtpy.QtCore import Signal, Slot
from qtpy.QtCore import Signal, Slot, QObject

RGB_COLOURS = []
RGB_COLOURS.append((1.00, 0.20, 1.00)) # selection
RGB_COLOURS.append((1.00, 0.90, 0.90)) # background


def ndarray_to_vtkarray(colors, scales, n_atoms):
def ndarray_to_vtkarray(colors, scales, indices):
"""Convert the colors and scales NumPy arrays to vtk arrays.
Args:
Expand All @@ -31,30 +34,15 @@ def ndarray_to_vtkarray(colors, scales, n_atoms):
n_atoms (int): the number of atoms
"""
# define the colours
color_scalars = vtk.vtkFloatArray()
color_scalars.SetNumberOfValues(len(colors))
# print("colors")
for i, c in enumerate(colors):
# print(i,c)
color_scalars.SetValue(i, c)
color_scalars = numpy_to_vtk(colors)
color_scalars.SetName("colors")

# some scales
scales_scalars = vtk.vtkFloatArray()
scales_scalars.SetNumberOfValues(len(scales))
# print("scales")
for i, r in enumerate(scales):
scales_scalars.SetValue(i, r)
# print(i,r)
scales_scalars = numpy_to_vtk(scales)
scales_scalars.SetName("scales")

# the original index
index_scalars = vtk.vtkIntArray()
index_scalars.SetNumberOfValues(n_atoms)
# print("index")
for i in range(n_atoms):
index_scalars.SetValue(i, i)
# print(i,i)
index_scalars = numpy_to_vtk(indices)
index_scalars.SetName("index")

scalars = vtk.vtkFloatArray()
Expand All @@ -67,6 +55,33 @@ def ndarray_to_vtkarray(colors, scales, n_atoms):
return scalars


class AtomEntry(QObject):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._index = None
self._indices = []
self._items = []

def set_values(
self, name: str, indices: List[int], colour: QColor, size: float
) -> List[QStandardItem]:
self._indices = indices
name_item = QStandardItem(str(name))
colour_item = QStandardItem(str(colour))
size_item = QStandardItem(str(size))
self._items = [name_item, colour_item, size_item]
return self._items

def colour(self) -> QColor:
return QColor(self._items[1].text())

def size(self) -> float:
return float(self._items[2].text())

def indices(self) -> List[int]:
return self._indices


class AtomProperties(QStandardItemModel):
new_atom_properties = Signal(object)

Expand All @@ -79,8 +94,10 @@ def __init__(self, *args, init_colours: list = None, **kwargs):
else:
self._colour_list = init_colours
self.rebuild_colours()
self.setHorizontalHeaderLabels(["Index", "Element", "Radius", "Colour"])
self.setHorizontalHeaderLabels(["Element", "Colour", "Radius"])
self.itemChanged.connect(self.onNewValues)
self._groups = []
self._total_length = 0

def clear_table(self):
"""This was meant to be used for cleaning up,
Expand Down Expand Up @@ -138,37 +155,44 @@ def reinitialise_from_database(
list[int] -- a list of indices of colours, with one numbed per atom.
"""
self.removeRows(0, self.rowCount())

index_list = []
for nat, atom in enumerate(atoms):
row = []
self._groups = []
self._total_length = 0

all_atoms = np.array(atoms)
unique_atoms = np.unique(all_atoms)
indices = np.arange(len(all_atoms))
groups = {}
for unique in unique_atoms:
groups[unique] = indices[np.where(all_atoms == unique)]

colour_index_list = []
for atom in unique_atoms:
atom_entry = AtomEntry()
rgb = [int(x) for x in element_database[atom]["color"].split(";")]
index_list.append(self.add_colour(rgb))
row.append(QStandardItem(str(nat + 1))) # atom number
row.append(QStandardItem(atom)) # chemical element name
row.append(
QStandardItem(str(round(element_database[atom]["atomic_radius"], 2)))
colour_index_list.append(self.add_colour(rgb))
item_row = atom_entry.set_values(
atom,
groups[atom],
QColor(rgb[0], rgb[1], rgb[2]).name(QColor.NameFormat.HexRgb),
round(element_database[atom]["vdw_radius"], 2),
)
row.append(
QStandardItem(
QColor(rgb[0], rgb[1], rgb[2]).name(QColor.NameFormat.HexRgb)
)
)
self.appendRow(row)
return index_list
self.appendRow(item_row)
self._groups.append(atom_entry)
self._total_length = len(all_atoms)
return colour_index_list

@Slot()
def onNewValues(self):
colours = []
radii = []
numbers = []
for row in range(self.rowCount()):
colour = QColor(self.item(row, 3).text())
colours = np.empty(self._total_length, dtype=int)
radii = np.empty(self._total_length, dtype=float)
numbers = np.arange(self._total_length)
for entry in self._groups:
colour = entry.colour()
red, green, blue = colour.red(), colour.green(), colour.blue()
colours.append(self.add_colour((red, green, blue)))
radius = float(self.item(row, 2).text())
radii.append(radius)
numbers.append(int(self.item(row, 0).text()))

scalars = ndarray_to_vtkarray(colours, radii, len(numbers))
vtk_colour = self.add_colour((red, green, blue))
radius = entry.size()
indices = entry.indices()
radii[indices] = radius
colours[indices] = vtk_colour
scalars = ndarray_to_vtkarray(colours, radii, numbers)
self.new_atom_properties.emit(scalars)
4 changes: 2 additions & 2 deletions MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/Controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
QPushButton,
QStyle,
QSizePolicy,
QTableView,
QTreeView,
QDoubleSpinBox,
QColorDialog,
QGroupBox,
Expand Down Expand Up @@ -200,7 +200,7 @@ def createSidePanel(self):
# the table of chemical elements
wrapper1 = QGroupBox("Atom properties", base)
layout1 = QHBoxLayout(wrapper1)
self._atom_details = QTableView(base)
self._atom_details = QTreeView(base)
self._atom_details.setSizePolicy(
QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Maximum
)
Expand Down
13 changes: 2 additions & 11 deletions MDANSE_GUI/Src/MDANSE_GUI/MolecularViewer/MolecularViewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,18 +636,9 @@ def set_reader(self, reader, frame=0):
)
# this returs a list of indices, mapping colours to atoms

self._atom_scales = np.array(
[CHEMICAL_ELEMENTS[at]["vdw_radius"] for at in self._atoms]
).astype(np.float32)
self._current_frame = frame

scalars = ndarray_to_vtkarray(
self._atom_colours, self._atom_scales, self._n_atoms
)

self._polydata = vtk.vtkPolyData()
self._polydata.GetPointData().SetScalars(scalars)

self.set_coordinates(frame)
self._colour_manager.onNewValues()

@Slot(object)
def take_atom_properties(self, scalars):
Expand Down

0 comments on commit 8e92cd8

Please sign in to comment.