Skip to content

Commit

Permalink
Modifications to the plot_ank_with_edos method (#309)
Browse files Browse the repository at this point in the history
* Added some handles to control the ank and bqnu plotting

* Added some handles to control linewidth in Ank and Bqnu plots

* Added a bit more control to the Ank plots: DOS coloring, automatic xlims for DOS
  • Loading branch information
ezhique authored Feb 4, 2025
1 parent 54d7e18 commit 9cd89c5
Showing 1 changed file with 137 additions and 23 deletions.
160 changes: 137 additions & 23 deletions abipy/eph/varpeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from abipy.abio.robots import Robot
from abipy.eph.common import BaseEphReader

from scipy.interpolate import interp1d


#TODO Finalize the implementation. Look at Pedro's implementation for GFR
#from abipy.electrons.effmass_analyzer import EffMassAnalyzer
Expand Down Expand Up @@ -445,7 +447,7 @@ def get_a2_interpolator_state(self, interp_method) -> BzRegularGridInterpolator:
interp_method: The method of interpolation. Supported are “linear”, “nearest”,
“slinear”, “cubic”, “quintic” and “pchip”.
"""
a_data, ngkpt, shifts = self.insert_a_inbox()
a_data, ngkpt, shifts = self.insert_a_inbox(fill_value=0)

return [BzRegularGridInterpolator(self.structure, shifts, np.abs(a_data[pstate])**2, method=interp_method)
for pstate in range(self.nstates)]
Expand All @@ -458,7 +460,7 @@ def get_b2_interpolator_state(self, interp_method) -> BzRegularGridInterpolator:
interp_method: The method of interpolation. Supported are “linear”, “nearest”,
“slinear”, “cubic”, “quintic” and “pchip”.
"""
b_data, ngqpt, shifts = self.insert_b_inbox()
b_data, ngqpt, shifts = self.insert_b_inbox(fill_value=0)

return [BzRegularGridInterpolator(self.structure, shifts, np.abs(b_data[pstate])**2, method=interp_method)
for pstate in range(self.nstates)]
Expand Down Expand Up @@ -555,10 +557,12 @@ def plot_scf_cycle(self, ax_mat=None, fontsize=8, **kwargs) -> Figure:

@add_fig_kwargs
def plot_ank_with_ebands(self, ebands_kpath,
ebands_kmesh=None, lpratio: int = 5,
ebands_kmesh=None, lpratio: int = 5, with_info = True, with_legend=True,
with_ibz_a2dos=True, method="gaussian", step: float = 0.05, width: float = 0.1,
nksmall: int = 20, normalize: bool = False, with_title=True, interp_method="linear",
ax_mat=None, ylims=None, scale=10, marker_color="gold", fontsize=12, **kwargs) -> Figure:
ax_mat=None, ylims=None, scale=50, marker_color="gold", marker_edgecolor="gray",
marker_alpha=0.5, fontsize=12, lw_bands=1.0, lw_dos=1.0,
filter_value=None, fill_dos=True, **kwargs) -> Figure:
"""
Plot electron bands with markers whose size is proportional to |A_nk|^2.
Expand All @@ -579,7 +583,13 @@ def plot_ank_with_ebands(self, ebands_kpath,
ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
scale: Scaling factor for |A_nk|^2.
marker_color: Color for markers.
marker_color: Color for markers and ADOS.
marker_edgecolor: Color for marker edges.
marker_edgecolor: Marker transparency.
lw_bands: Electronic bands linewidth.
lw_dos: DOS linewidth.
filter_value: Energy filter value (eV) wrt the band edge.
fill_dos: True if the regions defined by ADOS (BZ and IBZ) should be colored.
fontsize: fontsize for legends and titles
"""
nrows, ncols = self.nstates, 2
Expand All @@ -601,36 +611,64 @@ def plot_ank_with_ebands(self, ebands_kpath,
# Plot electron bands with markers.
ebands_kpath = ElectronBands.as_ebands(ebands_kpath)
ymin, ymax = +np.inf, -np.inf

a_data, *_ = self.insert_a_inbox(fill_value=0)

pkind = self.varpeq.r.varpeq_pkind
vbm_or_cbm = "vbm" if pkind == "hole" else "cbm"
bm = self.ebands.get_edge_state(vbm_or_cbm, self.spin).eig
e0 = self.ebands.fermie

for pstate in range(self.nstates):
x, y, s = [], [], []

a2_max = np.max(np.abs(a_data[pstate]))**2
scale *= 1./a2_max

for ik, kpoint in enumerate(ebands_kpath.kpoints):
enes_n = ebands_kpath.eigens[self.spin, ik, self.bstart:self.bstop]
for e, a2 in zip(enes_n, a2_interp_state[pstate].eval_kpoint(kpoint), strict=True):
x.append(ik); y.append(e); s.append(scale * a2)
ymin, ymax = min(ymin, e), max(ymax, e)
# Handle filtering
allowed = True
if filter_value:
energy_window = filter_value*1.1
if pkind == "hole" and bm - e > energy_window:
allowed = False
elif pkind == "electron" and e - bm > energy_window:
allowed = False

if allowed:
x.append(ik); y.append(e); s.append(scale * a2)
ymin, ymax = min(ymin, e), max(ymax, e)

points = Marker(x, y, s, color=marker_color, edgecolors='gray', alpha=0.8, label=r'$|A_{n\mathbf{k}}|^2$')
ax = ax_mat[pstate, 0]
ebands_kpath.plot(ax=ax, points=points, show=False)

points = Marker(x, y, s, color=marker_color, edgecolors=marker_edgecolor,
alpha=marker_alpha, label=r'$|A_{n\mathbf{k}}|^2$')

ebands_kpath.plot(ax=ax, points=points, show=False, linewidth=lw_bands)

ax.legend(loc="best", shadow=True, fontsize=fontsize)

# Add energy and convergence status
with_info = True
if with_info:
data = (df[df["pstate"] == pstate]).to_dict(orient="list")
e_pol_ev, converged = float(data["E_pol"][0]), bool(data["converged"][0])
ax.set_title(f"Formation energy: {e_pol_ev:.3f} eV, {converged=}" , fontsize=8)

if pstate != self.nstates - 1:
if pstate != self.nstates - 1 or not with_legend:
set_visible(ax, False, *["legend", "xlabel"])

vertices_names = [(k.frac_coords, k.name) for k in ebands_kpath.kpoints]


if ebands_kmesh is None:
edos_ngkpt = self.structure.calc_ngkpt(nksmall)
print(f"Computing ebands_kmesh with star-function interpolation and {nksmall=} --> {edos_ngkpt=} ...")
r = self.ebands.interpolate(lpratio=lpratio, vertices_names=vertices_names, kmesh=edos_ngkpt)
ebands_kmesh = r.ebands_kmesh
else:
ebands_kmesh = ElectronBands.as_ebands(ebands_kmesh)

# Get electronic DOS from ebands_kmesh.
edos_kws = dict(method=method, step=step, width=width)
Expand All @@ -645,6 +683,7 @@ def plot_ank_with_ebands(self, ebands_kpath,
# Here we get the mapping BZ --> IBZ needed to obtain the KS eigenvalues e_nk from the IBZ for the DOS.
kdata = ebands_kmesh.get_bz2ibz_bz_points()

xmax = -np.inf
for pstate in range(self.nstates):

# Compute A^2(E) DOS with A_nk in the full BZ.
Expand All @@ -661,11 +700,17 @@ def plot_ank_with_ebands(self, ebands_kpath,

ax = ax_mat[pstate, 1]
edos_opts = {"color": "black",} if self.spin == 0 else {"color": "red"}
edos.plot_ax(ax, e0, spin=self.spin, normalize=normalize, exchange_xy=True, label="eDOS(E)", **edos_opts)
ank_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$A^2$(E)", color=marker_color)
lines_edos = edos.plot_ax(ax, e0, spin=self.spin, normalize=normalize, exchange_xy=True, label="eDOS(E)", **edos_opts,
linewidth=lw_dos, zorder=3)


lines_ados = ank_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$A^2$(E)", color=marker_color,
linewidth=lw_dos, zorder=2)


# Computes A2(E) using only k-points in the IBZ. This is just for testing.
# A2_IBZ(E) should be equal to A2(E) only if A_nk fullfills the lattice symmetries. See notes above.

if with_ibz_a2dos:
ank_dos = np.zeros(len(edos_mesh))
for ik_ibz, ibz_kpoint in enumerate(ebands_kmesh.kpoints):
Expand All @@ -676,22 +721,82 @@ def plot_ank_with_ebands(self, ebands_kpath,
ank_dos += weight * a2 * gaussian(edos_mesh, width, center=e-e0)

ank_dos = Function1D(edos_mesh, ank_dos)
ibz_dos_opts = {"color": "darkred",}
print(f"For {pstate=}, A2_IBZ(E) integrates to:", ank_dos.integral_value, " Ideally, it should be 1.")
ank_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$A^2_{IBZ}$(E)", color=marker_color, ls="--")
lines_ados_ibz = ank_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$A^2_{IBZ}$(E)", ls="--",
linewidth=lw_dos, **ibz_dos_opts, zorder=1)

set_grid_legend(ax, fontsize, xlabel="Arb. unit")
if pstate != self.nstates - 1:
if pstate != self.nstates - 1 or not with_legend:
set_visible(ax, False, *["legend", "xlabel"])


dos_lines = [lines_edos, lines_ados]
colors = [edos_opts["color"], marker_color]
span = ymax - ymin

if with_ibz_a2dos:
dos_lines.append(lines_ados_ibz)
colors.append(ibz_dos_opts["color"])
# determine max x value for auto xlims
for dos, c in zip(dos_lines, colors):
for line in dos:
x_data, y_data = line.get_xdata(), line.get_ydata()
mask = (y_data > ymin-e0-span*0.1) & (y_data < ymax-e0+span*0.1)
xmax = max(np.max(x_data[mask]), xmax)

# fill Ank dos in order
if fill_dos:
y_common = np.linspace(ymin-e0-span*0.1, ymax-e0+span*0.1, 100)
xleft = np.zeros_like(y_common)
# skip eDOS, fill only ADOS
for dos, c in zip(dos_lines[1:], colors[1:]):
for line in dos:
x_data, y_data = line.get_xdata(), line.get_ydata()
interp_x = interp1d(y_data, x_data, kind='linear', fill_value='extrapolate')
xright = interp_x(y_common)

mask = (xright - xleft) > 0
y, x0, x1 = y_common[mask], xleft[mask], xright[mask]

ax.fill_betweenx(y, x0, x1,
alpha=marker_alpha, color=c, linewidth=0)
xleft = xright

# Auto xlims for DOS
span = xmax
xmax += 0.1 * span
for ax in ax_mat[:,1]:
ax.set_xlim(0, xmax)

if ylims is None:
# Automatic ylims.
ymin -= 0.1 * abs(ymin)
ymax += 0.1 * abs(ymax)
span = ymax - ymin
ymin -= 0.1 * span
ymax += 0.1 * span
ylims = [ymin - e0, ymax - e0]


for ax in ax_mat.ravel():
set_axlims(ax, ylims, "y")


# if filtering is used, show the filtering region
for ax in ax_mat.ravel():
xmin, xmax = ax.get_xlim()
xrange = np.linspace(xmin,xmax,100)
shifted_bm = bm - e0
if filter_value:
if pkind == "hole":
fill_from, fill_to = shifted_bm - filter_value, shifted_bm
elif pkind == "electron":
fill_from, fill_to = shifted_bm, shifted_bm + filter_value

ax.fill_between(xrange, ylims[0], fill_from,
color='gray', linewidth=lw_dos, alpha=0.5, zorder=0)
ax.fill_between(xrange, fill_to, ylims[1],
color='gray', linewidth=lw_dos, alpha=0.5, zorder=0)

if with_title:
fig.suptitle(self.get_title(with_gaps=True))

Expand Down Expand Up @@ -721,10 +826,11 @@ def plot_bqnu_with_ddb(self, ddb, with_phdos=True, anaddb_kwargs=None, **kwargs)
ddb=ddb, **kwargs)

@add_fig_kwargs
def plot_bqnu_with_phbands(self, phbands_qpath,
def plot_bqnu_with_phbands(self, phbands_qpath, anaddb_file=None,
phdos_file=None, ddb=None, width=0.001, normalize: bool = True,
verbose=0, anaddb_kwargs=None, with_title=True, interp_method="linear",
ax_mat=None, scale=10, marker_color="gold", fontsize=12, **kwargs) -> Figure:
ax_mat=None, scale=10, marker_color="gold", marker_edgecolor='gray',
marker_alpha=0.8, fontsize=12, lw_bands=1.0, lw_dos=1.0, **kwargs) -> Figure:
"""
Plot phonon energies with markers whose size is proportional to |B_qnu|^2.
Expand Down Expand Up @@ -753,6 +859,10 @@ def plot_bqnu_with_phbands(self, phbands_qpath,

phbands_qpath = PhononBands.as_phbands(phbands_qpath)

# maybe this is unnecessary
if anaddb_file:
phbands_qpath.read_non_anal_from_file(anaddb_file)

# Get interpolators for B_qnu
b2_interp_state = self.get_b2_interpolator_state(interp_method)

Expand All @@ -765,8 +875,9 @@ def plot_bqnu_with_phbands(self, phbands_qpath,
x.append(iq); y.append(w); s.append(scale * b2)

ax = ax_mat[pstate, 0]
points = Marker(x, y, s, color=marker_color, edgecolors='gray', alpha=0.8, label=r'$|B_{\nu\mathbf{q}}|^2$')
phbands_qpath.plot(ax=ax, points=points, show=False)
points = Marker(x, y, s, color=marker_color, edgecolors=marker_edgecolor,
alpha=marker_alpha, label=r'$|B_{\nu\mathbf{q}}|^2$')
phbands_qpath.plot(ax=ax, points=points, show=False, linewidth=lw_bands)
ax.legend(loc="best", shadow=True, fontsize=fontsize)

if pstate != self.nstates - 1:
Expand Down Expand Up @@ -807,6 +918,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath,
# Compute B2(E) by looping over the full BZ.
bqnu_dos = np.zeros(len(phdos_mesh))
for iq_bz, qpoint in enumerate(phbands_qmesh.qpoints):
q_weight = 1./phdos_nqbz
freqs_nu = phbands_qmesh.phfreqs[iq_bz]
for w, b2 in zip(freqs_nu, b2_interp_state[pstate].eval_kpoint(qpoint), strict=True):
bqnu_dos += b2 * gaussian(phdos_mesh, width, center=w)
Expand All @@ -816,8 +928,10 @@ def plot_bqnu_with_phbands(self, phbands_qpath,
bqnu_dos = Function1D(phdos_mesh, bqnu_dos)

ax = ax_mat[pstate, 1]
phdos.plot_ax(ax, exchange_xy=True, normalize=normalize, label="phDOS(E)", color="black")
bqnu_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$B^2$(E)", color=marker_color)
phdos.plot_ax(ax, exchange_xy=True, normalize=normalize, label="phDOS(E)", color="black",
linewidth=lw_dos)
bqnu_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$B^2$(E)", color=marker_color,
linewidth=lw_dos)
set_grid_legend(ax, fontsize, xlabel="Arb. unit")

# Get mapping BZ --> IBZ needed to obtain the KS eigenvalues e_nk from the IBZ for the DOS
Expand Down

0 comments on commit 9cd89c5

Please sign in to comment.