diff --git a/MDANSE/Src/MDANSE/IO/MinimalPDBReader.py b/MDANSE/Src/MDANSE/IO/MinimalPDBReader.py index a2cb8e3ed..1d56f844d 100644 --- a/MDANSE/Src/MDANSE/IO/MinimalPDBReader.py +++ b/MDANSE/Src/MDANSE/IO/MinimalPDBReader.py @@ -14,56 +14,50 @@ # along with this program. If not, see . # +from collections.abc import Generator, Iterable from typing import List import numpy as np from ase.io import read as ase_read - -from MDANSE.MLogging import LOG -from MDANSE.Framework.Units import measure from MDANSE.Chemistry import ATOMS_DATABASE from MDANSE.Chemistry.ChemicalSystem import ChemicalSystem -from MDANSE.MolecularDynamics.Configuration import ( - RealConfiguration, - PeriodicRealConfiguration, -) -from MDANSE.MolecularDynamics.UnitCell import UnitCell - -atom_line_records = { - "record_name": (0, 6), - "atom_number": (6, 11), - "atom_name": (12, 16), - "location": (16, 17), - "residue_name": (17, 20), - "chain_id": (21, 22), - "residue_number": (22, 26), - "insertion_code": (26, 27), - "pos_x": (30, 38), - "pos_y": (38, 46), - "pos_z": (46, 54), - "occupancy": (54, 60), - "temperature_factor": (60, 66), - "element_symbol": (76, 78), - "charge": (78, 80), +from MDANSE.MLogging import LOG + +ATOM_LINE_RECORDS = { + "record_name": slice(0, 6), + "atom_number": slice(6, 11), + "atom_name": slice(12, 16), + "location": slice(16, 17), + "residue_name": slice(17, 20), + "chain_id": slice(21, 22), + "residue_number": slice(22, 26), + "insertion_code": slice(26, 27), + "pos_x": slice(30, 38), + "pos_y": slice(38, 46), + "pos_z": slice(46, 54), + "occupancy": slice(54, 60), + "temperature_factor": slice(60, 66), + "element_symbol": slice(76, 78), + "charge": slice(78, 80), } -def atom_line_slice(keyword: str) -> slice: - try: - limits = atom_line_records[keyword] - except KeyError: - return slice() - else: - return slice(limits[0], limits[1]) +class MinimalPDBReader: + """ + Basic parser for PDB format files. + Parameters + ---------- + filename : str + Input file to read. + """ -class MinimalPDBReader: def __init__(self, filename: str): self._unit_cell = None cell_params = self.find_unit_cell(filename) - if len(cell_params) == 0: - self.periodic = False - else: + self.periodic = False + + if cell_params: try: ase_atoms = ase_read(filename, format="pdb", index=0) cell = ase_atoms.get_cell() @@ -72,113 +66,121 @@ def __init__(self, filename: str): else: self.periodic = True self._unit_cell = np.vstack(cell) + self._chemical_system = ChemicalSystem(filename) atom_lines = self.find_atoms(filename) self.build_chemical_system(atom_lines) - def find_unit_cell(self, filename: str, frame_number: int = 0): - cell_found = False - cell_params = [] - cell_line = "" + def find_unit_cell(self, filename: str, frame_number: int = 0) -> List[float]: + """ + Find unit cell in PDB file. + + Parameters + ---------- + filename : str + File to parse. + frame_number : int + Unused, for interface compatibility. + + Returns + ------- + list[float] + Unit cell as a,b,c,α,β,γ. + """ fail_count = 0 + with open(filename, "r") as source: for line in source: - if "CRYST" in line[0:5]: - cell_found = True - cell_line += line + if line.startswith("CRYST"): + cell_line = line break - if "ENDMDL" in line[0:6]: + + if line.startswith("ENDMDL"): fail_count += 1 - if fail_count > 2: - cell_line = "" - cell_found = False - break - if not cell_found: - return cell_params - cell_params = [float(x) for x in cell_line.split()[1:7]] - return cell_params - def find_atoms(self, filename: str, frame_number: int = 0): - fail_count = 0 - result = [] + if fail_count > 2: + return [] + else: + return [] + + return [float(x) for x in cell_line.split()[1:7]] + + def find_atoms( + self, filename: str, _frame_number: int = 0 + ) -> Generator[str, None, None]: + """ + Get all items from PDB file. + + Parameters + ---------- + filename : str + File to read. + + Yields + ------ + str + Each atom line in PDB file. + """ with open(filename, "r") as source: for line in source: - if "ATOM" in line[0:5]: - result.append(line) - elif "HETATM" in line[0:6]: - result.append(line) - if "END" in line[0:3]: - fail_count += 1 - if fail_count > 0: + if line.startswith(("ATOM", "HETATM")): + yield line + + if line.startswith("END"): break - return result - def build_chemical_system(self, atom_lines: List[str]): - """Build the chemical system. + def build_chemical_system(self, atom_lines: Iterable[str]) -> None: + """ + Build a :class:`~MDANSE.Chemistry.ChemicalSystem.ChemicalSystem`. - Returns: - MDANSE.Chemistry.ChemicalSystem.ChemicalSystem: the chemical system + Parameters + ---------- + atom_lines : Iterable[str] + Lines to read. """ - coordinates = [] - element_slice = atom_line_slice("element_symbol") - name_slice = atom_line_slice("atom_name") - posx_slice = atom_line_slice("pos_x") - posy_slice = atom_line_slice("pos_y") - posz_slice = atom_line_slice("pos_z") - residue_slice = atom_line_slice("residue_name") - residue_number_slice = atom_line_slice("residue_number") + element_slice = ATOM_LINE_RECORDS["element_symbol"] + name_slice = ATOM_LINE_RECORDS["atom_name"] + pos_slice = (ATOM_LINE_RECORDS[x] for x in ("pos_x", "pos_y", "pos_z")) + residue_slice = ATOM_LINE_RECORDS["residue_name"] + residue_number_slice = ATOM_LINE_RECORDS["residue_number"] + coordinates = [] element_list = [] name_list = [] label_dict = {} clusters = {} for atom_number, atom_line in enumerate(atom_lines): - chemical_element = atom_line[element_slice].strip() + chemical_element = atom_line[element_slice].strip().capitalize() atom_name = atom_line[name_slice] processed_atom_name = atom_name[:2].strip() - if len(processed_atom_name) == 2: - if processed_atom_name[0].isnumeric(): - processed_atom_name = processed_atom_name[1].upper() - else: - processed_atom_name = ( - processed_atom_name[0].upper() + processed_atom_name[1].lower() - ) - if len(chemical_element) == 2: - chemical_element = ( - chemical_element[0].upper() + chemical_element[1].lower() - ) - backup_element = atom_line.rstrip().split()[-1] - backup_element2 = atom_line.split()[-2].strip() + + if len(processed_atom_name) == 2 and processed_atom_name[0].isnumeric(): + processed_atom_name = processed_atom_name[1] + processed_atom_name = processed_atom_name.capitalize() + + backup_elements = atom_line.split()[-2:] + if atom_name[-2:].isnumeric(): - backup_element3 = atom_name[0] - else: - backup_element3 = "fail" - if backup_element in ATOMS_DATABASE.atoms: - element_list.append(backup_element) - elif backup_element2 in ATOMS_DATABASE.atoms: - element_list.append(backup_element2) - elif backup_element3 in ATOMS_DATABASE.atoms: - element_list.append(backup_element3) - elif chemical_element in ATOMS_DATABASE.atoms: - element_list.append(chemical_element) - elif processed_atom_name in ATOMS_DATABASE.atoms: - element_list.append(processed_atom_name) + backup_elements.append(atom_name[0]) + + for trial in (*backup_elements, chemical_element, processed_atom_name): + if trial in ATOMS_DATABASE.atoms: + element_list.append(trial) + break else: LOG.warning(f"Dummy atom introduced from line {atom_line}") element_list.append("Du") - x, y, z = ( - atom_line[posx_slice], - atom_line[posy_slice], - atom_line[posz_slice], - ) - coordinates.append([float(aaa) for aaa in [x, y, z]]) + + coordinates.append([float(atom_line[pos]) for pos in pos_slice]) + residue_name = atom_line[residue_slice] - if residue_name not in label_dict.keys(): - label_dict[residue_name] = [] + label_dict.setdefault(residue_name, []) + label_dict[residue_name].append(atom_number) name_list.append(atom_name.strip()) + residue_number_string = atom_line[residue_number_slice] try: residue_number = int(residue_number_string) @@ -187,10 +189,11 @@ def build_chemical_system(self, atom_lines: List[str]): residue_number = int(residue_number_string, base=16) except ValueError: continue - if (residue_name, residue_number) in clusters.keys(): - clusters[(residue_name, residue_number)].append(atom_number) - else: - clusters[(residue_name, residue_number)] = [atom_number] + + idx = (residue_name, residue_number) + clusters.setdefault(idx, []) + clusters[idx].append(atom_number) + self._chemical_system.initialise_atoms(element_list, name_list) self._chemical_system.add_labels(label_dict) self._chemical_system.add_clusters(clusters.values())