Skip to content

Commit

Permalink
Fix all mypy issues (#341)
Browse files Browse the repository at this point in the history
* Fix all mypy issues

* Fix kaleido version

0.4.1 contains an issue that breaks the tests: plotly/Kaleido#223

* Fix tests

* Fix version limit
  • Loading branch information
stefsmeets authored Nov 18, 2024
1 parent c1888d8 commit 6f8375e
Show file tree
Hide file tree
Showing 14 changed files with 63 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ changelog = "https://github.com/GEMDAT-repos/GEMDAT/releases"

[project.optional-dependencies]
develop = [
"kaleido",
"kaleido < 0.4", # 0.4: https://github.com/plotly/Kaleido/issues/223
"bump-my-version",
"coverage[toml]",
"mypy",
Expand Down
2 changes: 1 addition & 1 deletion scripts/analyse_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def analyse_md(
"""
trajectory = Trajectory.from_vasprun(vasp_xml)

equilibration_steps = round(equil_time / trajectory.time_step)
equilibration_steps = round(equil_time / trajectory.time_step) # type: ignore

trajectory = trajectory[equilibration_steps:]

Expand Down
2 changes: 1 addition & 1 deletion src/gemdat/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def write_cif(structure: Structure, filename: Path | str):
filename : Path | str
Filename to write to
"""
filename = Path(filename).with_suffix('.cif')
filename = str(Path(filename).with_suffix('.cif'))
structure.to_file(filename)


Expand Down
2 changes: 1 addition & 1 deletion src/gemdat/jumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def site_pairs(self) -> list[tuple[str, str]]:
"""Return list of all unique site pairs."""
labels = self.sites.labels
site_pairs = product(labels, repeat=2)
return [pair for pair in site_pairs]
return [pair for pair in site_pairs] # type: ignore

@property
def jump_names(self) -> list[str]:
Expand Down
3 changes: 3 additions & 0 deletions src/gemdat/orientations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __post_init__(self, in_vectors: np.ndarray | None = None):
@property
def _time_step(self) -> float:
"""Return the time step of the trajectory."""
assert self.trajectory.time_step
return self.trajectory.time_step

@property
Expand All @@ -75,7 +76,9 @@ def _distances(self) -> np.ndarray:
"""Calculate distances between every central atom and all satellite
atoms."""
central_start_coord = self._trajectory_cent.base_positions
assert central_start_coord is not None
satellite_start_coord = self._trajectory_sat.base_positions
assert satellite_start_coord is not None
lattice = self.trajectory.get_lattice()
distance = np.array(
[
Expand Down
3 changes: 2 additions & 1 deletion src/gemdat/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ def total_length(self, lattice: Lattice) -> FloatWithUnit:
length : FloatWithUnit
Total distance in Ångstrom
"""
length = 0
length = 0.0
for a, b in pairwise(self.frac_sites()):
dist, _ = lattice.get_distance_and_image(a, b)
assert dist
length += dist
return FloatWithUnit(length, 'ang')

Expand Down
3 changes: 3 additions & 0 deletions src/gemdat/plots/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
from pymatgen.core import Element, Species
from scipy.optimize import curve_fit
from scipy.stats import skewnorm

Expand All @@ -25,6 +26,8 @@ def _mean_displacements_per_element(

grouped = defaultdict(list)
for sp, distances in zip(species, trajectory.distances_from_base_position()):
assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

grouped[sp.symbol].append(distances)

means = {}
Expand Down
7 changes: 5 additions & 2 deletions src/gemdat/plots/matplotlib/_msd_per_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import matplotlib.pyplot as plt
import numpy as np
from pymatgen.core import Element, Species

if TYPE_CHECKING:
import matplotlib.figure
Expand Down Expand Up @@ -41,6 +42,8 @@ def msd_per_element(
t_values = np.arange(len(trajectory)) * time_ps

for sp in species:
assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

traj = trajectory.filter(sp.symbol)
msd = traj.mean_squared_displacement()

Expand All @@ -52,9 +55,9 @@ def msd_per_element(
last_color = ax.lines[-1].get_color()

if show_traces:
for i, traj in enumerate(msd):
for i, y_values in enumerate(msd):
label = f'{sp.symbol} trajectories' if (i == 0) else None
ax.plot(t_values, traj, lw=0.1, c=last_color, label=label)
ax.plot(t_values, y_values, lw=0.1, c=last_color, label=label)

if show_shaded:
ax.fill_between(
Expand Down
3 changes: 3 additions & 0 deletions src/gemdat/plots/plotly/_msd_per_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import plotly.graph_objects as go
from pymatgen.core import Element, Species

from gemdat.plots._shared import hex2rgba

Expand Down Expand Up @@ -31,6 +32,8 @@ def msd_per_element(*, trajectory: Trajectory) -> go.Figure:
species = list(set(trajectory.species))

for i, sp in enumerate(species):
assert isinstance(sp, (Species, Element)), f'got {type(sp)}'

color_hex = fig.layout['template']['layout']['colorway'][i]
color_rgba = hex2rgba(color_hex, opacity=0.3)

Expand Down
27 changes: 15 additions & 12 deletions src/gemdat/plots/plotly/_plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,35 +246,34 @@ def plot_jumps(jumps: Jumps, *, fig: go.Figure):
fig : plotly.graph_objects.Figure
Plotly figure to add traces too
"""
coords = jumps.sites.frac_coords
site_coords = jumps.sites.frac_coords
lattice = jumps.trajectory.get_lattice()

for i, j in zip(*np.triu_indices(len(coords), k=1)):
for i, j in zip(*np.triu_indices(len(site_coords), k=1)):
count = jumps.matrix()[i, j] + jumps.matrix()[j, i]
if count == 0:
continue

coord_i = tuple(coords[i].tolist())
coord_j = tuple(coords[j].tolist())
site_coord_i = tuple(site_coords[i].tolist())
site_coord_j = tuple(site_coords[j].tolist())

lw = 1 + np.log(count)

length, image = lattice.get_distance_and_image(coord_i, coord_j)
length, image = lattice.get_distance_and_image(site_coord_i, site_coord_j)

if np.any(image != 0):
lines = [(coord_i, coord_j + image), (coord_i - image, coord_j)]
lines = [(site_coord_i, site_coord_j + image), (site_coord_i - image, site_coord_j)]
else:
lines = [(coord_i, coord_j)]
lines = [(site_coord_i, site_coord_j)]

for line in lines:
line = lattice.get_cartesian_coords(line)
line_t = [_ for _ in zip(*line)] # transpose, but pythonic
x, y, z = lattice.get_cartesian_coords(line).T

fig.add_trace(
go.Scatter3d(
x=line_t[0],
y=line_t[1],
z=line_t[2],
x=x,
y=y,
z=z,
mode='lines',
showlegend=False,
line_dash='dashdot' if any(image) != 0 else 'solid',
Expand Down Expand Up @@ -356,6 +355,10 @@ def plot_3d(
lattice = structure.lattice
elif jumps:
lattice = jumps.trajectory.get_lattice()
else:
raise ValueError(
'Lattice cannot be determined form volume, structure, or jumps object.'
)
else:
raise ValueError('Cannot derive lattice from input.')

Expand Down
4 changes: 2 additions & 2 deletions src/gemdat/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def radial_distribution(
coords = trajectory.positions
sp_coords = trajectory.filter(floating_specie).positions

states2str = _get_states(sites.labels)
states_array = _get_states_array(transitions, sites.labels)
states2str = _get_states(sites.labels) # type: ignore
states_array = _get_states_array(transitions, sites.labels) # type: ignore
symbol_indices = _get_symbol_indices(base_structure)

bins = np.arange(0, max_dist + resolution, resolution)
Expand Down
5 changes: 3 additions & 2 deletions src/gemdat/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .utils import warn_lattice_not_close

if TYPE_CHECKING:
from pymatgen.symmetry.analyzer import SpacegroupOperations
from pymatgen.symmetry.groups import SpaceGroup
from pymatgen.symmetry.structure import SymmetrizedStructure

Expand Down Expand Up @@ -95,7 +96,7 @@ def __init__(
*,
sites: Collection[PeriodicSite],
lattice: Lattice,
spacegroup: SpaceGroup,
spacegroup: SpaceGroup | SpacegroupOperations,
):
"""Set up shape analyzer from a collection of unique periodic sites,
the lattice, and spacegroup.
Expand Down Expand Up @@ -400,7 +401,7 @@ def to_structure(self) -> Structure:
sg=self.spacegroup.int_number,
lattice=self.lattice,
species=[site.specie for site in self.sites],
coords=[site.frac_coords for site in self.sites],
coords=[site.frac_coords for site in self.sites], # type: ignore
labels=[site.label for site in self.sites],
)
return structure
29 changes: 22 additions & 7 deletions src/gemdat/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import TYPE_CHECKING, Collection, Optional

import numpy as np
from pymatgen.core import Element, Lattice
from pymatgen.core import Element, Lattice, Species
from pymatgen.core.trajectory import Trajectory as PymatgenTrajectory
from pymatgen.io import vasp

Expand Down Expand Up @@ -133,16 +133,19 @@ def to_volume(self, resolution: float = 0.2) -> Volume:
@property
def time_step_ps(self) -> float:
"""Return time step in picoseconds."""
assert self.time_step
return self.time_step * 1e12

@property
def total_time(self) -> float:
"""Return total time for trajectory."""
assert self.time_step
return len(self) * self.time_step

@property
def sampling_frequency(self) -> float:
"""Return number of time steps per second."""
assert self.time_step
return 1 / self.time_step

@property
Expand Down Expand Up @@ -469,9 +472,9 @@ def get_lattice(self, idx: int | None = None) -> Lattice:
Pymatgen Lattice object
"""
if self.constant_lattice:
return Lattice(self.lattice)
return Lattice(self.lattice) # type: ignore

latt = self.lattices[idx]
latt = self.lattices[idx] # type: ignore
return Lattice(latt)

@property
Expand Down Expand Up @@ -503,7 +506,10 @@ def distances_from_base_position(self) -> np.ndarray:

def center_of_mass(self) -> Trajectory:
"""Return trajectory with center of mass for positions."""
weights = [s.atomic_mass for s in self.species]
weights = []
for s in self.species:
assert isinstance(s, (Species, Element)), f'got {type(s)=}'
weights.append(s.atomic_mass)

positions_no_pbc = self.base_positions + self.cumulative_displacements

Expand Down Expand Up @@ -547,8 +553,13 @@ def drift(
if fixed_species:
displacements = self.filter(species=fixed_species).displacements
elif floating_species:
species = {sp.symbol for sp in self.species if sp.symbol not in floating_species}
displacements = self.filter(species=species).displacements
species = set()
for sp in self.species:
assert isinstance(sp, Species), f'got {type(sp)=}'
if sp.symbol not in floating_species:
species.add(sp)

displacements = self.filter(species=species).displacements # type: ignore
else:
displacements = self.displacements

Expand Down Expand Up @@ -609,7 +620,11 @@ def filter(self, species: str | Collection[str]) -> Trajectory:
if isinstance(species, str):
species = [species]

idx = [sp.symbol in species for sp in self.species]
idx = []
for sp in self.species:
assert isinstance(sp, (Species, Element))
idx.append(sp.symbol in species)

new_coords = self.positions[:, idx]
new_species = list(compress(self.species, idx))

Expand Down
3 changes: 1 addition & 2 deletions src/gemdat/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def from_volumetric_data(cls, volume: VolumetricData):
Input volumetric data
"""
return cls(
data=volume.data,
data=volume.data['total'],
lattice=volume.structure.lattice,
)

Expand Down Expand Up @@ -506,5 +506,4 @@ def trajectory_to_volume(
data=data,
lattice=lattice,
label='trajectory',
units=None,
)

0 comments on commit 6f8375e

Please sign in to comment.