diff --git a/sumo/cli/bandplot.py b/sumo/cli/bandplot.py index dfd12fc2..325ccbd7 100644 --- a/sumo/cli/bandplot.py +++ b/sumo/cli/bandplot.py @@ -295,7 +295,7 @@ def bandplot( if code == "vasp": for vr_file in filenames: vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected) - bs = vr.get_band_structure(line_mode=True) + bs = vr.get_band_structure(line_mode=True, efermi="smart") bandstructures.append(bs) bs = get_reconstructed_band_structure(bandstructures) elif code == "castep": diff --git a/sumo/cli/bandstats.py b/sumo/cli/bandstats.py index d5aa9371..1d91e6a7 100644 --- a/sumo/cli/bandstats.py +++ b/sumo/cli/bandstats.py @@ -94,24 +94,26 @@ def bandstats( bandstructures = [] for vr_file in filenames: vr = BSVasprun(vr_file, parse_projected_eigen=False) - bs = vr.get_band_structure(line_mode=True) + bs = vr.get_band_structure(line_mode=True, efermi="smart") bandstructures.append(bs) - bs = get_reconstructed_band_structure(bandstructures, force_kpath_branches=False) + bs, kpt_mapping = get_reconstructed_band_structure( + bandstructures, force_kpath_branches=True, return_forced_branch_kpt_map=True + ) if bs.is_metal(): logging.error("ERROR: System is metallic!") sys.exit() - _log_band_gap_information(bs) + _log_band_gap_information(bs, kpt_mapping=kpt_mapping) vbm_data = bs.get_vbm() cbm_data = bs.get_cbm() logging.info("\nValence band maximum:") - _log_band_edge_information(bs, vbm_data) + _log_band_edge_information(bs, vbm_data, kpt_mapping=kpt_mapping) logging.info("\nConduction band minimum:") - _log_band_edge_information(bs, cbm_data) + _log_band_edge_information(bs, cbm_data, kpt_mapping=kpt_mapping) if parabolic: logging.info("\nUsing parabolic fitting of the band edges") @@ -179,11 +181,14 @@ def bandstats( return {"hole_data": hole_data, "electron_data": elec_data} -def _log_band_gap_information(bs): +def _log_band_gap_information(bs, kpt_mapping=None): """Log data about the direct and indirect band gaps. Args: bs (:obj:`~pymatgen.electronic_structure.bandstructure.BandStructureSymmLine`): + kpt_mapping (:obj:`dict`, optional): A mapping of k-point indicies from the + band structure with forced branches to the original band structure. + """ bg_data = bs.get_band_gap() if not bg_data["direct"]: @@ -199,6 +204,7 @@ def _log_band_gap_information(bs): direct_kpoint = bs.kpoints[direct_kindex].frac_coords direct_kpoint = kpt_str.format(k=direct_kpoint) eq_kpoints = bs.get_equivalent_kpoints(direct_kindex) + eq_kpoints = _map_kpoints(eq_kpoints, kpt_mapping) k_indices = ", ".join(map(str, eq_kpoints)) # add 1 to band indices to be consistent with VASP band numbers. @@ -215,7 +221,9 @@ def _log_band_gap_information(bs): direct_kindex = direct_data[Spin.up]["kpoint_index"] direct_kpoint = kpt_str.format(k=bs.kpoints[direct_kindex].frac_coords) - k_indices = ", ".join(map(str, bs.get_equivalent_kpoints(direct_kindex))) + eq_kpoints = bs.get_equivalent_kpoints(direct_kindex) + eq_kpoints = _map_kpoints(eq_kpoints, kpt_mapping) + k_indices = ", ".join(map(str, eq_kpoints)) b_indices = ", ".join( [str(i + 1) for i in direct_data[Spin.up]["band_indices"]] ) @@ -225,7 +233,7 @@ def _log_band_gap_information(bs): logging.info(f" Band indices: {b_indices}") -def _log_band_edge_information(bs, edge_data): +def _log_band_edge_information(bs, edge_data, kpt_mapping=None): """Log data about the valence band maximum or conduction band minimum. Args: @@ -233,6 +241,8 @@ def _log_band_edge_information(bs, edge_data): The band structure. edge_data (dict): The :obj:`dict` from ``bs.get_vbm()`` or ``bs.get_cbm()`` + kpt_mapping (:obj:`dict`, optional): A mapping of k-point indicies from the + band structure with forced branches to the original band structure. """ if bs.is_spin_polarized: spins = edge_data["band_index"].keys() @@ -247,7 +257,9 @@ def _log_band_edge_information(bs, edge_data): kpoint = edge_data["kpoint"] kpoint_str = kpt_str.format(k=kpoint.frac_coords) - k_indices = ", ".join(map(str, edge_data["kpoint_index"])) + k_indices = ", ".join( + map(str, _map_kpoints(edge_data["kpoint_index"], kpt_mapping)) + ) k_degen = bs.get_kpoint_degeneracy(kpoint=kpoint.frac_coords) if kpoint.label: @@ -311,6 +323,13 @@ def _log_effective_mass_data(data, is_spin_polarized, mass_type="m_e"): logging.info(f" {mass_type}: {eff_mass:.3f} | {band_str} | {kpoint_str}") +def _map_kpoints(kpt_idxs, kpt_mapping): + """Map k-point indices to the original band structure.""" + if not kpt_mapping: + return kpt_idxs + return sorted(set([kpt_mapping.get(k, k) for k in kpt_idxs])) + + def _get_parser(): parser = argparse.ArgumentParser( description=""" diff --git a/sumo/electronic_structure/bandstructure.py b/sumo/electronic_structure/bandstructure.py index 988e930d..a0488029 100644 --- a/sumo/electronic_structure/bandstructure.py +++ b/sumo/electronic_structure/bandstructure.py @@ -190,7 +190,9 @@ def get_projections(bs, selection, normalise=None): return spec_proj -def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=True): +def get_reconstructed_band_structure( + list_bs, efermi=None, force_kpath_branches=True, return_forced_branch_kpt_map=False +): """Combine a list of band structures into a single band structure. This is typically very useful when you split non self consistent @@ -210,12 +212,17 @@ def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches= across all band structures is used. force_kpath_branches (bool): Force a linemode band structure to contain branches by adding repeated high-symmetry k-points in the path. + return_forced_branch_kpt_map (bool): If True, return a mapping of the + the new k-points to the original k-points. Returns: :obj:`pymatgen.electronic_structure.bandstructure.BandStructure` or \ :obj:`pymatgen.electronic_structure.bandstructureBandStructureSymmLine`: A band structure object. The type depends on the type of the band structures in ``list_bs``. + If return_forced_branch_kpt_map is True, then a tuple is returned + containing the band structure and the mapping from the new k-points + to the original k-points. """ if efermi is None: efermi = sum(b.efermi for b in list_bs) / len(list_bs) @@ -244,13 +251,17 @@ def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches= structure=list_bs[0].structure, projections=projections, ) - if force_kpath_branches: - return force_branches(bs) - else: - return bs + branch_bs, mapping = force_branches(bs, return_mapping=True) + if force_kpath_branches and return_forced_branch_kpt_map: + return branch_bs, mapping + elif force_kpath_branches: + return branch_bs + elif return_forced_branch_kpt_map: + return bs, mapping + return bs -def force_branches(bandstructure): +def force_branches(bandstructure, return_mapping=False): """Force a linemode band structure to contain branches. Branches give a specific portion of the path from one high-symmetry point @@ -262,9 +273,14 @@ def force_branches(bandstructure): Args: bandstructure: A band structure object. + return_mapping: If True, return a mapping of the new k-points (with branches) + to the original k-points. Returns: - A band structure with brnaches. + A band structure with branches. + If return_forced_branch_kpt_map is True, then a tuple is returned + containing the band structure and the mapping from the new k-points + to the original k-points. """ kpoints = np.array([k.frac_coords for k in bandstructure.kpoints]) labels_dict = {k: v.frac_coords for k, v in bandstructure.labels_dict.items()} @@ -275,6 +291,7 @@ def force_branches(bandstructure): # already. dup_ids = [] high_sym_kpoints = tuple(map(tuple, labels_dict.values())) + mapping = {} for i, k in enumerate(kpoints): dup_ids.append(i) if ( @@ -287,6 +304,7 @@ def force_branches(bandstructure): ) ): dup_ids.append(i) + mapping[len(dup_ids) - 1] = i kpoints = kpoints[dup_ids] @@ -297,7 +315,7 @@ def force_branches(bandstructure): if len(bandstructure.projections) != 0: projections[spin] = bandstructure.projections[spin][:, dup_ids] - return type(bandstructure)( + bs = type(bandstructure)( kpoints, eigenvals, bandstructure.lattice_rec, @@ -306,6 +324,9 @@ def force_branches(bandstructure): structure=bandstructure.structure, projections=projections, ) + if return_mapping: + return bs, mapping + return bs def string_to_spin(spin_string): diff --git a/sumo/electronic_structure/dos.py b/sumo/electronic_structure/dos.py index c10a8c34..b5acb59d 100644 --- a/sumo/electronic_structure/dos.py +++ b/sumo/electronic_structure/dos.py @@ -99,7 +99,7 @@ def load_dos( else: vr = vasprun - band = vr.get_band_structure() + band = vr.get_band_structure(efermi="smart") dos = vr.complete_dos dos, band = _scissor_dos(dos, band, scissor)