Skip to content

Commit

Permalink
Refactor rdf between species
Browse files Browse the repository at this point in the history
  • Loading branch information
stefsmeets committed Nov 11, 2024
1 parent a852705 commit 966e4a7
Showing 1 changed file with 57 additions and 22 deletions.
79 changes: 57 additions & 22 deletions src/gemdat/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
from rich.progress import track

from ._plot_backend import plot_backend

if TYPE_CHECKING:
from typing import Collection

import matplotlib.figure
from pymatgen.core import Structure

from gemdat import Trajectory
from gemdat.transitions import Transitions


Expand Down Expand Up @@ -187,39 +192,69 @@ def radial_distribution(
return ret


def plot_rdf_between_species(
trajectory: Trajectory,
specie_1: str | Collection[str],
specie_2: str | Collection[str],
max_dist: float = 5.0,
resolution: float = 0.1,
) -> matplotlib.figure.Figure:
"""Calculate RDFs from specie_1 to specie_2.
import matplotlib.pyplot as plt
import numpy as np
from joblib import Parallel, delayed
Parameters
----------
trajectory: Trajectory
Input trajectory.
specie_1: str | list[str]
Name of specie or list of species
specie_2: str | list[str]
Name of specie or list of species
max_dist: float, optional
Max distance for rdf calculation
resolution: float, optional
Width of the bins
def calculate_rdf_parallelized(trajectory, specie_1, specie_2, resolution=1, max_distance=10.0):
'''Calculate RDFs from specie_1 to specie_2'''
Returns
-------
fig : matplotlib.figure.Figure
Output figure
"""
coords_1 = trajectory.filter(specie_1).coords
coords_2 = trajectory.filter(specie_2).coords
lattice = trajectory.get_lattice()

try:
num_time_steps, num_atoms, num_dimensions = coords_2.shape
except ValueError:
if coords_2.ndim == 2:
num_time_steps = 1
num_atoms, num_dimensions = coords_2.shape
else:
num_time_steps, num_atoms, num_dimensions = coords_2.shape

particle_vol = num_atoms / lattice.volume

def calculate_distances(t):
return lattice.get_all_distances(coords_1[t, :, :], coords_2[t, :, :])
all_dists = np.concatenate(
[
lattice.get_all_distances(coords_1[t, :, :], coords_2[t, :, :])
for t in range(num_time_steps)
]
)
distances = all_dists.flatten()

all_dists = Parallel(n_jobs=-1)(delayed(calculate_distances)(t) for t in range(num_time_steps))
distances = np.concatenate([dists[dists != 0].flatten() for dists in all_dists])

bins = np.arange(0, max_distance + resolution, resolution)
bins = np.arange(0, max_dist + resolution, resolution)
rdf, _ = np.histogram(distances, bins=bins, density=False)

norm = np.array([(4 / 3) * np.pi * ((r + resolution) ** 3 - r ** 3) * particle_vol for r in bins[:-1]])
rdf = np.array([rdf[i] / norm[i] for i in range(len(rdf))])

return bins, rdf


bins, rdf = calculate_rdf_parallelized(trajectory=trajectory, specie_1='Li', specie_2=['S','Cl'], resolution=0.1, max_distance=10)
plt.plot(bins[:-1], rdf)
def normalize(radius: np.ndarray) -> np.ndarray:
"""Normalize bin to volume."""
shell = (radius + resolution) ** 3 - radius**3
return particle_vol * (4 / 3) * np.pi * shell

norm = normalize(bins)[:-1]
rdf = rdf / norm

fig, ax = plt.subplots()
ax.plot(bins[:-1], rdf)
ax.set(
title=f'RDF between {specie_1} and {specie_2} per element',
xlabel='Displacement (Å)',
ylabel='Nr. of atoms',
)
return fig

0 comments on commit 966e4a7

Please sign in to comment.