diff --git a/MDANSE/Src/MDANSE/Trajectory/H5MDTrajectory.py b/MDANSE/Src/MDANSE/Trajectory/H5MDTrajectory.py index 2684e54bf..351de1208 100644 --- a/MDANSE/Src/MDANSE/Trajectory/H5MDTrajectory.py +++ b/MDANSE/Src/MDANSE/Trajectory/H5MDTrajectory.py @@ -52,19 +52,36 @@ def __init__(self, h5_filename): self._h5_filename = h5_filename self._h5_file = h5py.File(self._h5_filename, "r") - - # Load the chemical system - try: + particle_types = self._h5_file["/particles/all/species"] + particle_lookup = h5py.check_enum_dtype( + self._h5_file["/particles/all/species"].dtype + ) + if particle_lookup is None: + # Load the chemical system + try: + symbols = self._h5_file["/parameters/atom_symbols"] + except KeyError: + LOG.error( + f"No information about chemical elements in {self._h5_filename}" + ) + return + else: + chemical_elements = [byte.decode() for byte in symbols] + else: + reverse_lookup = {item: key for key, item in particle_lookup.items()} chemical_elements = [ - byte.decode() for byte in self._h5_file["/parameters/atom_symbols"] + reverse_lookup[type_number] for type_number in particle_types ] - except KeyError: - chemical_elements = self._h5_file["/particles/all/species"] self._chemical_system = ChemicalSystem( os.path.splitext(os.path.basename(self._h5_filename))[0] ) - self._chemical_system.initialise_atoms(chemical_elements) - + try: + self._chemical_system.initialise_atoms(chemical_elements) + except Exception: + LOG.error( + "It was not possible to read chemical element information from an H5MD file." + ) + return # Load all the unit cells self._load_unit_cells() @@ -72,10 +89,10 @@ def __init__(self, h5_filename): coords = self._h5_file["/particles/all/position/value"][0, :, :] try: pos_unit = self._h5_file["/particles/all/position/value"].attrs["unit"] - except: + except Exception: conv_factor = 1.0 else: - if pos_unit == "Ang": + if pos_unit == "Ang" or pos_unit == "Angstrom": pos_unit = "ang" conv_factor = measure(1.0, pos_unit).toval("nm") coords *= conv_factor @@ -94,6 +111,7 @@ def file_is_right(self, filename): temp["h5md"] except KeyError: result = False + temp.close() return result def close(self): @@ -117,7 +135,7 @@ def __getitem__(self, frame): except: conv_factor = 1.0 else: - if pos_unit == "Ang": + if pos_unit == "Ang" or pos_unit == "Angstrom": pos_unit = "ang" conv_factor = measure(1.0, pos_unit).toval("nm") configuration = {} @@ -195,7 +213,7 @@ def coordinates(self, frame): except: conv_factor = 1.0 else: - if pos_unit == "Ang": + if pos_unit == "Ang" or pos_unit == "Angstrom": pos_unit = "ang" conv_factor = measure(1.0, pos_unit).toval("nm") @@ -245,10 +263,10 @@ def _load_unit_cells(self): self._unit_cells = [] try: box_unit = self._h5_file["/particles/all/box/edges/value"].attrs["unit"] - except: - conv_factor = 1.0 + except (AttributeError, KeyError): + conv_factor = 0.1 else: - if box_unit == "Ang": + if box_unit == "Ang" or box_unit == "Angstrom": box_unit = "ang" conv_factor = measure(1.0, box_unit).toval("nm") try: @@ -258,9 +276,16 @@ def _load_unit_cells(self): else: if len(cells.shape) > 1: for cell in cells: - temp_array = np.array( - [[cell[0], 0.0, 0.0], [0.0, cell[1], 0.0], [0.0, 0.0, cell[2]]] - ) + if cell.shape == (3, 3): + temp_array = np.array(cell) + else: + temp_array = np.array( + [ + [cell[0], 0.0, 0.0], + [0.0, cell[1], 0.0], + [0.0, 0.0, cell[2]], + ] + ) uc = UnitCell(temp_array) self._unit_cells.append(uc) else: @@ -272,14 +297,17 @@ def _load_unit_cells(self): def time(self): try: time_unit = self._h5_file["/particles/all/position/time"].attrs["unit"] - except: + except KeyError: conv_factor = 1.0 else: conv_factor = measure(1.0, time_unit).toval("ps") try: time = self._h5_file["/particles/all/position/time"] * conv_factor - except: - time = [] + except TypeError: + try: + time = self._h5_file["/particles/all/position/time"][:] * conv_factor + except Exception: + time = [] return time def unit_cell(self, frame): @@ -371,7 +399,7 @@ def read_com_trajectory( except: conv_factor = 1.0 else: - if pos_unit == "Ang": + if pos_unit == "Ang" or pos_unit == "Angstrom": pos_unit = "ang" conv_factor = measure(1.0, pos_unit).toval("nm") @@ -469,7 +497,7 @@ def read_atomic_trajectory( except: conv_factor = 1.0 else: - if pos_unit == "Ang": + if pos_unit == "Ang" or pos_unit == "Angstrom": pos_unit = "ang" conv_factor = measure(1.0, pos_unit).toval("nm") coords = grp[first:last:step, index, :].astype(np.float64) * conv_factor diff --git a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/MoleculePreviewWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/MoleculePreviewWidget.py new file mode 100644 index 000000000..ef8ddbbfd --- /dev/null +++ b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/MoleculePreviewWidget.py @@ -0,0 +1,147 @@ +import numpy as np + +from qtpy.QtWidgets import QWidget, QSizePolicy, QVBoxLayout, QLabel, QDialog +from qtpy.Qt3DRender import QDirectionalLight +from qtpy.QtGui import QColor, QVector3D, QQuaternion, QFont +from qtpy.Qt3DExtras import ( + QPhongMaterial, + QCylinderMesh, + QSphereMesh, + Qt3DWindow, + QOrbitCameraController, +) +from qtpy.QtCore import Qt as _Qt +from qtpy.Qt3DCore import QEntity, QTransform + + +class MoleculePreviewWidget(QDialog): + def __init__(self, parent, molecule_information, molecule_name, atom_database): + super().__init__(parent) + self.setWindowTitle("Molecule Preview") + self.resize(800, 600) + self.view = Qt3DWindow() + self.view.defaultFrameGraph().setClearColor( + QColor(0x4D4D4F) + ) # molecular viewer mdansechemistry atoms.json + container = QWidget.createWindowContainer( + self.view + ) # from mdanse chemistry atoms database + screenSize = self.view.screen().size() + container.setMinimumSize(200, 100) + container.setMaximumSize(screenSize) + container.setSizePolicy( + QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding + ) + layout = QVBoxLayout() + layout.addWidget(container) + self.rootEntity = QEntity() + self.cuboidTransform = QTransform() + self.axes = [] + mass = [] + coords = [] + info_text = f"Molecule name: {molecule_name}\n" + for key, value in molecule_information["atom_number"].items(): + info_text += f"Number of {key} atoms: {value}\n" + + info_text += f"Number of such molecules in trajectory: {molecule_information['no_of_molecules']}\n" + + coordinates = molecule_information["atom_coordinates"] + indices = molecule_information["atom_indices"] + atom_symbols = molecule_information["atom_symbols"] + bonds = molecule_information["bond_list"] + for at_number in range(len(coordinates)): + x, y, z = coordinates[at_number] + x, y, z = (20 * x - 10, 20 * y - 10, 20 * z - 10) + symbol = atom_symbols[at_number] + colour = atom_database.get_atom_property(symbol, "color") + radius = atom_database.get_atom_property(symbol, "covalent_radius") + mass.append(atom_database.get_atom_property(symbol, "atomic_weight")) + coords.append(coordinates[at_number]) + r, g, b = [int(x) for x in colour.split(";")] + colour = QColor(r, g, b) + m_sphereEntity = QEntity(self.rootEntity) + sphereMesh = QSphereMesh() + sphereMesh.setRings(20) + sphereMesh.setSlices(20) + sphereMesh.setRadius(radius * 10) + sphereTransform = QTransform() + sphereTransform.setScale(1.0) + sphereTransform.setTranslation(QVector3D(x, y, z)) + sphereMaterial = QPhongMaterial() + sphereMaterial.setDiffuse(colour) + sphereMaterial.setAmbient(colour) # Set ambient to the same as diffuse. + # sphereMaterial.setSpecular(QColor(0, 0, 0)) # Eliminate specular reflection. + # sphereMaterial.setShininess(0.0) # Eliminate shininess. + m_sphereEntity.addComponent(sphereMesh) + m_sphereEntity.addComponent(sphereMaterial) + m_sphereEntity.addComponent(sphereTransform) + + atom_information = molecule_information["atom_information"] + for bond in bonds: + coord1, coord2 = bond[0], bond[1] + coord1 = (20 * coord1[0] - 10, 20 * coord1[1] - 10, 20 * coord1[2] - 10) + coord2 = (20 * coord2[0] - 10, 20 * coord2[1] - 10, 20 * coord2[2] - 10) + direction = QVector3D( + coord2[0] - coord1[0], coord2[1] - coord1[1], coord2[2] - coord1[2] + ) + length = direction.length() + direction.normalize() + # Compute rotation + up_vector = QVector3D(0, 1, 0) + axis = QVector3D.crossProduct(up_vector, direction) + angle = float( + np.degrees(np.arccos(QVector3D.dotProduct(up_vector, direction))) + ) + + # Create cylinder mesh + cylinder_mesh = QCylinderMesh() + cylinder_mesh.setRadius(radius) + cylinder_mesh.setLength(length) + + # Create material + material = QPhongMaterial() + # material.setDiffuse(QColor(color)) + + # Set transformation + transform = QTransform() + transform.setTranslation(QVector3D(*coord1) + direction * length / 2) + transform.setRotation(QQuaternion.fromAxisAndAngle(axis, angle)) + + # Create entity + entity = QEntity(self.rootEntity) + entity.addComponent(cylinder_mesh) + entity.addComponent(material) + entity.addComponent(transform) + + info_label = QLabel(info_text) + font = QFont("Arial", 12) + info_label.setFont(font) + info_label.setWordWrap(True) + layout.addWidget(info_label) + self.setLayout(layout) + mass = np.array(mass) + coords = np.array(coords) + com = np.einsum("i,ik->k", mass, coords) / np.sum(mass) + x_com, y_com, z_com = 20 * com - 10 + # Camera + self.camera = self.view.camera() + self.camera.lens().setPerspectiveProjection(45.0, 16.0 / 9.0, 0.1, 1000.0) + self.camera.setPosition(QVector3D(x_com + 5, y_com + 5, z_com + 10)) + self.camera.setViewCenter(QVector3D(x_com, y_com, z_com)) + self.camera.setUpVector(QVector3D(0.0, 0.0, 1.0)) + # add light + lightEntity = QEntity(self.rootEntity) + light = QDirectionalLight(lightEntity) + light.setColor(_Qt.white) + light.setIntensity(1) + lightEntity.addComponent(light) + lightTransform = QTransform(lightEntity) + lightTransform.setTranslation(self.camera.position()) + lightEntity.addComponent(lightTransform) + # For camera controls + camController = QOrbitCameraController(self.rootEntity) + camController.setLinearSpeed(-20) + camController.setLookSpeed(-90) + camController.setCamera(self.camera) + self.view.setRootEntity(self.rootEntity) + # # info_box.setStandardButtons(QMessageBox.Close) diff --git a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/MoleculeWidget.py b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/MoleculeWidget.py index 8d5a691eb..910ac1758 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/MoleculeWidget.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/InputWidgets/MoleculeWidget.py @@ -13,9 +13,13 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -from qtpy.QtWidgets import QComboBox +import numpy as np + +from qtpy.QtCore import Slot +from qtpy.QtWidgets import QComboBox, QPushButton from MDANSE_GUI.InputWidgets.WidgetBase import WidgetBase +from MDANSE_GUI.InputWidgets.MoleculePreviewWidget import MoleculePreviewWidget class MoleculeWidget(WidgetBase): @@ -36,24 +40,98 @@ def __init__(self, *args, **kwargs): else: option_list = configurator.choices default_option = configurator.default - field = QComboBox(self._base) - field.addItems(option_list) - field.setCurrentText(default_option) - field.currentTextChanged.connect(self.updateValue) + traj_config = self._configurator._configurable[ + self._configurator._dependencies["trajectory"] + ] + hdf_traj = traj_config["hdf_trajectory"] + unique_molecules = hdf_traj.chemical_system.unique_molecules() + traj_bond_list = hdf_traj.chemical_system._bonds + self.atom_database = hdf_traj + self.mol_dict = {} + coords_0 = hdf_traj.trajectory.coordinates(0) + for mol_name in unique_molecules: + no_of_molecules = len(hdf_traj.chemical_system._clusters[mol_name]) + atom_indices = hdf_traj.chemical_system._clusters[mol_name][0] + atom_symbols = [ + hdf_traj.chemical_system.atom_list[index] for index in atom_indices + ] + coordinates = coords_0[atom_indices] + unique_atoms, atom_counts = np.unique(atom_symbols, return_counts=True) + atom_counts = { + unique_atoms[n]: atom_counts[n] for n in range(len(unique_atoms)) + } + bonds = [ + (coords_0[bond[0]], coords_0[bond[1]]) + for bond in traj_bond_list + if bond[0] in atom_indices or bond[1] in atom_indices + ] + self.mol_dict[mol_name] = { + "no_of_molecules": no_of_molecules, + "atom_coordinates": coordinates, + "atom_number": atom_counts, + "atom_indices": atom_indices, + "atom_symbols": atom_symbols, + "bond_list": bonds, + } + + self.field = QComboBox(self._base) + self.field.addItems(option_list) + self.field.setCurrentText(default_option) + self.selected_name = self.field.currentText() + if self.selected_name in self.mol_dict.keys(): + self.selected_mol = self.mol_dict[self.selected_name] + else: + self.selected_mol = None + self.field.currentTextChanged.connect(self.updateValue) + self.field.currentTextChanged.connect(self.molecule_changed) + button = QPushButton(self._base) + button.setText("Molecule Preview") + button.clicked.connect(self.button_clicked) if self._tooltip: tooltip_text = self._tooltip else: tooltip_text = ( "A single option can be picked out of all the options listed." ) - field.setToolTip(tooltip_text) - self._field = field - self._layout.addWidget(field) + self.field.setToolTip(tooltip_text) + self._field = self.field + self._layout.addWidget(self.field) + self._layout.addWidget(button) self._configurator = configurator self.default_labels() self.update_labels() self.updateValue() + @Slot() + def molecule_changed(self): + """ + Change molecule preview and molecule information + """ + self.selected_name = self.field.currentText() + try: + self.selected_mol = self.mol_dict[self.selected_name] + except KeyError: + self.selected_mol = None + else: + self.window = MoleculePreviewWidget( + self._base, self.selected_mol, self.selected_name, self.atom_database + ) + + @Slot() + def button_clicked(self): + """ + Opens a window that shows a preview of selected molecule + """ + if self.selected_mol is None: + return + self.window = MoleculePreviewWidget( + self._base, self.selected_mol, self.selected_name, self.atom_database + ) + if self.window.isVisible(): + self.window.close() + else: + self.window.show() + def configure_using_default(self): """This is too complex to have a default value""" @@ -68,4 +146,6 @@ def default_labels(self): self._tooltip = "You only have one option. Choose wisely." def get_widget_value(self): - return self._field.currentText() + mol_key = self._field.currentText() + if mol_key in self.mol_dict.keys(): + return mol_key diff --git a/MDANSE_GUI/Src/MDANSE_GUI/TabbedWindow.py b/MDANSE_GUI/Src/MDANSE_GUI/TabbedWindow.py index af0e29d32..0a1232eec 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/TabbedWindow.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/TabbedWindow.py @@ -47,6 +47,7 @@ from MDANSE_GUI.Tabs.PlotSelectionTab import PlotSelectionTab from MDANSE_GUI.Tabs.PlotTab import PlotTab from MDANSE_GUI.Tabs.InstrumentTab import InstrumentTab +from MDANSE_GUI.Tabs.Views.PlotDataView import PlotDataView from MDANSE_GUI.Widgets.StyleDialog import StyleDialog, StyleDatabase from MDANSE_GUI.Widgets.NotificationTabWidget import NotificationTabWidget @@ -109,10 +110,10 @@ def __init__( self._tabs["Plot Creator"]._visualiser.create_new_text.connect( self._tabs["Plot Holder"]._visualiser.new_text ) - self._tabs["Instruments"]._visualiser.instrument_details_changed.connect( self._tabs["Actions"].update_action_after_instrument_change ) + self.tabs.currentChanged.connect(self.tabs.reset_current_color) def createCommonModels(self): @@ -326,6 +327,11 @@ def createPlotSelection(self): self._tabs[name] = plot_tab self._job_holder.results_for_loading.connect(plot_tab.load_results) self._job_holder.results_for_loading.connect(plot_tab.tab_notification) + plot_tab._view.fast_plotting_data.connect(self.accept_external_data) + + def accept_external_data(self, model): + self._tabs["Plot Creator"]._visualiser.new_plot() + self._tabs["Plot Holder"].accept_external_data(model) def createPlotHolder(self): name = "Plot Holder" diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/JobTab.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/JobTab.py index 937165f34..4ccea7aee 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/JobTab.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/JobTab.py @@ -28,6 +28,7 @@ from MDANSE_GUI.Tabs.Visualisers.TextInfo import TextInfo from MDANSE_GUI.Tabs.Models.JobTree import JobTree from MDANSE_GUI.Tabs.Views.ActionsTree import ActionsTree +from MDANSE_GUI.InputWidgets.MoleculeWidget import MoleculeWidget job_tab_label = """This is the list of jobs diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py index 5ea30bc5e..a05991230 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Models/PlottingContext.py @@ -57,11 +57,12 @@ def get_mpl_colours(): class SingleDataset: - def __init__(self, name: str, source: "h5py.File"): + def __init__(self, name: str, source: "h5py.File", linestyle: str = "-"): self._name = name self._filename = source.filename self._curves = {} self._curve_labels = {} + self._linestyle = linestyle self._planes = {} self._plane_labels = {} self._data_limits = None @@ -425,7 +426,7 @@ def add_dataset(self, new_dataset: SingleDataset): new_dataset.longest_axis()[-1], "", self.next_colour(), - "-", + new_dataset._linestyle, "", ] ] diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/PlotSelectionTab.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/PlotSelectionTab.py index 3d798ac44..b3b72c169 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/PlotSelectionTab.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/PlotSelectionTab.py @@ -33,6 +33,7 @@ data sets. Load the files and assign the data sets to a plot. The plotting interface will appear in a new tab of the interface. +DOUBLE CLICK FILE FOR FAST PLOTTING """ diff --git a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py index 4aff58457..3c58b69fd 100644 --- a/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py +++ b/MDANSE_GUI/Src/MDANSE_GUI/Tabs/Views/PlotDataView.py @@ -14,12 +14,13 @@ # along with this program. If not, see . # from typing import Union - +import time from qtpy.QtWidgets import QTreeView, QAbstractItemView, QApplication, QMenu from qtpy.QtCore import Signal, Slot, QModelIndex, Qt, QMimeData from qtpy.QtGui import QMouseEvent, QDrag, QContextMenuEvent, QStandardItem from MDANSE_GUI.Tabs.Visualisers.DataPlotter import DataPlotter +from MDANSE_GUI.Tabs.Models.PlottingContext import PlottingContext, SingleDataset from MDANSE_GUI.Tabs.Visualisers.PlotDataInfo import PlotDataInfo from MDANSE_GUI.Widgets.DataDialog import DataDialog @@ -29,6 +30,7 @@ class PlotDataView(QTreeView): execute_action = Signal(object) item_details = Signal(object) error = Signal(str) + fast_plotting_data = Signal(object) free_name = Signal(str) def __init__(self, *args, **kwargs): @@ -40,6 +42,30 @@ def __init__(self, *args, **kwargs): # self.data_dialog = DataDialog(self) self._data_packet = None + def mouseDoubleClickEvent(self, e: QMouseEvent) -> None: + self.click_position = e.position() + if self.model() is None: + return None + index = self.indexAt(e.pos()) + model = self.model() + mda_data_structure = model.inner_object(index) + model = PlottingContext() + for key in mda_data_structure._file.keys(): + try: + if "main" in mda_data_structure._file[key].attrs["tags"]: + if "partial" in mda_data_structure._file[key].attrs["tags"]: + dataset = SingleDataset( + key, mda_data_structure._file, linestyle="--" + ) + else: + dataset = SingleDataset(key, mda_data_structure._file) + + model.add_dataset(dataset) + except KeyError: + print(f"No attribute called Tag found in {key}, skipping") + self.fast_plotting_data.emit(model) + return super().mouseDoubleClickEvent(e) + def mousePressEvent(self, e: QMouseEvent) -> None: self.click_position = e.position() if self.model() is None: