diff --git a/.gitignore b/.gitignore index 188226b..a1d792a 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,5 @@ KPOINTS PROCAR *.json *.DS_Store + +local_tests/ diff --git a/README.md b/README.md index 072fd91..9b8c698 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,11 @@ package. For the methodology of supercell band unfolding, see [here](https://link.aps.org/doi/10.1103/PhysRevB.85.085201). +### Example Outputs +Cs₂(Sn/Ti)Br₆ Vacancy-Ordered Perovskite Alloys | Symmetry-broken Si Supercell +:-------------------------:|:------------------------------------: + | + ## Usage To generate an unfolded band structure, one typically needs to perform the following steps: diff --git a/docs/examples/example_mgo.md b/docs/examples/example_mgo.md index 6a9ba65..10fac88 100644 --- a/docs/examples/example_mgo.md +++ b/docs/examples/example_mgo.md @@ -38,9 +38,15 @@ easyunfold unfold plot-projections --procar MgO_super/PROCAR --atoms="Mg,O" --co Note that the path of the `PROCAR` is passed along with the desired atom projections (`Mg` and `O` here). +:::{tip} +If the _k_-points have been split into multiple calculations (e.g. hybrid DFT band structures), the `--procar` option +should be passed multiple times to specify the path to each split `PROCAR` file (i.e. +`--procar calc1/PROCAR --procar cal2/PROCAR ...`). +::: + :::{note} -The atomic projections are not stored in the `easyunfold.json` data file, so the `PROCAR` file should be -kept for replotting in the future. +The atomic projections are not stored in the `easyunfold.json` data file, so the `PROCAR` file(s) should be kept for +replotting in the future. ::: The `--combined` option creates a combined plot with different colour maps for each atomic grouping. @@ -97,3 +103,8 @@ easyunfold unfold plot-projections --procar MgO_super/PROCAR --atoms="Mg,O" --em Unfolded MgO band structure with atomic projections plotted separately. ``` + +:::{tip} +There are _many_ customisation options available for the plotting functions in `easyunfold`. See `easyunfold plot -h` or +`easyunfold unfold plot-projections -h` for more details! +::: diff --git a/docs/examples/example_nabis2.md b/docs/examples/example_nabis2.md index ac4c551..2993345 100644 --- a/docs/examples/example_nabis2.md +++ b/docs/examples/example_nabis2.md @@ -88,7 +88,7 @@ When plotting the unfolded band, the `plot-projections` subcommand is used with `--atoms` options: ```bash -easyunfold unfold plot-projections --atoms="Na,Bi,S" --intensity 2 --combined +easyunfold unfold plot-projections --atoms="Na,Bi,S" --intensity 3 --combined ``` ```{figure} ../../examples/NaBiS2/NaBiS2_unfold-plot_proj.png @@ -98,6 +98,12 @@ easyunfold unfold plot-projections --atoms="Na,Bi,S" --intensity 2 --combined Unfolded band structure of NaBiS2 with atomic contributions. ``` +:::{tip} +If the _k_-points have been split into multiple calculations (e.g. hybrid DFT band structures), the `--procar` option +should be passed multiple times to specify the path to each split `PROCAR` file (i.e. +`--procar calc1/PROCAR --procar cal2/PROCAR ...`). +::: + From this plot, we can see that sulfur anions dominate the valence band, while bismuth cations dominate the conduction band, with minimal contributions from the sodium cations as expected. @@ -126,7 +132,7 @@ Unfolded band structure of NaBiS2 with atomic contributions plotted s An alternative option here is also to just plot only the contributions of `Na` and `Bi` cations, with no S projections: ```bash -easyunfold unfold plot-projections --atoms="Na,Bi" --intensity 2 --combined +easyunfold unfold plot-projections --atoms="Na,Bi" --intensity 3 --combined ``` ```{figure} ../../examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png @@ -143,7 +149,7 @@ While this plot isn't the most aesthetic, it clearly shows that Bi (green) contr ### Atom-projected Unfolded Band Structure with DOS We can also combine the atom projections with the DOS plotting, using the `--dos` option as before: ```bash -easyunfold unfold plot-projections --atoms "Na,Bi,S" --intensity 2 --combined --dos vasprun.xml.gz --zero-line \ +easyunfold unfold plot-projections --atoms "Na,Bi,S" --intensity 3 --combined --dos vasprun.xml.gz --zero-line \ --dos-label "DOS" --gaussian 0.1 --no-total --scale 2 ``` @@ -193,7 +199,7 @@ For example, if we want to see the contributions of the Bi $s$, $p$ and S $s$ or we can use the following command: ```bash -easyunfold unfold plot-projections --atoms "Bi,Bi,S" --orbitals="s|p|s" --intensity 2 --combined \ +easyunfold unfold plot-projections --atoms "Bi,Bi,S" --orbitals="s|p|s" --intensity 3 --combined \ --dos vasprun.xml.gz --zero-line --dos-label "DOS" --gaussian 0.1 --no-total --scale 5 ``` @@ -216,7 +222,7 @@ see the contributions of the Bi and S $p_x$, $p_y$ and $p_z$ orbitals to the unf following command: ```bash -easyunfold unfold plot-projections --atoms "Na,Bi,S" --orbitals="all|px,py,pz|px,py,pz" --intensity 2 --combined \ +easyunfold unfold plot-projections --atoms "Na,Bi,S" --orbitals="all|px,py,pz|px,py,pz" --intensity 3 --combined \ --dos vasprun.xml.gz --zero-line --dos-label "DOS" --gaussian 0.1 --no-total --scale 6 ``` @@ -232,6 +238,10 @@ structure, due to the cubic symmetry of the NaBiS2 crystal structure. for the $d$ orbitals of transition metals in octahedral/tetrahedral environments, we would expect to see significant differences in the contributions of different $lm$-decomposed orbitals to the electronic structure. +:::{tip} +There are _many_ customisation options available for the plotting functions in `easyunfold`. See `easyunfold plot -h` or +`easyunfold unfold plot-projections -h` for more details! +::: [^1]: [Huang, YT., Kavanagh, S.R., Righetto, M. et al. Strong absorption and ultrafast localisation in NaBiS2 nanocrystals with slow charge-carrier recombination. Nat Commun 13, 4960 (2022)](https://www.nature.com/articles/s41467-022-32669-3) [^2]: [Wang, Y., Kavanagh, S.R., Burgués-Ceballos, I. et al. Cation disorder engineering yields AgBiS2 nanocrystals with enhanced optical absorption for efficient ultrathin solar cells. Nat. Photon. 16, 235–241 (2022).](https://www.nature.com/articles/s41566-021-00950-4) \ No newline at end of file diff --git a/docs/examples/example_si211_castep.md b/docs/examples/example_si211_castep.md index bbcef90..bfaaf64 100644 --- a/docs/examples/example_si211_castep.md +++ b/docs/examples/example_si211_castep.md @@ -143,4 +143,9 @@ Supercell band structure of Si in a 2x1x1 supercell which, as expected, is the same as the primitive cell if it is folded back along the midpoint between $\Gamma$ and $X$, corresponding to the 2x expansion of the primitive cell along the $x$ direction in -generating the 2x1x1 supercell. \ No newline at end of file +generating the 2x1x1 supercell. + +:::{tip} +There are _many_ customisation options available for the plotting functions in `easyunfold`. See `easyunfold plot -h` or +`easyunfold unfold plot-projections -h` for more details! +::: diff --git a/docs/examples/example_si222.md b/docs/examples/example_si222.md index a109f78..8a15a87 100644 --- a/docs/examples/example_si222.md +++ b/docs/examples/example_si222.md @@ -131,7 +131,6 @@ customising and prettifying the unfolded band structure plot. Here we have also option to increase the spectral function intensity. ::: - Note the appearance of extra branches compared to the band structure of the primitive cell (below), due to symmetry breaking from the displaced atom. diff --git a/docs/img/CSTB_easyunfold.gif b/docs/img/CSTB_easyunfold.gif new file mode 100644 index 0000000..4e8dffd Binary files /dev/null and b/docs/img/CSTB_easyunfold.gif differ diff --git a/docs/img/Si_222_unfold_tall.png b/docs/img/Si_222_unfold_tall.png new file mode 120000 index 0000000..9fd31a0 --- /dev/null +++ b/docs/img/Si_222_unfold_tall.png @@ -0,0 +1 @@ +../../examples/Si222/unfold_tall.png \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index c87d232..db688fb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,6 +20,11 @@ package. For the methodology of supercell band unfolding, see [here](https://link.aps.org/doi/10.1103/PhysRevB.85.085201). +### Example Outputs +Cs₂(Sn/Ti)Br₆ Vacancy-Ordered Perovskite Alloys | Symmetry-broken Si Supercell +:-------------------------:|:------------------------------------: + | + ## Usage To generate an unfolded band structure, one typically needs to perform the following steps: diff --git a/easyunfold/__init__.py b/easyunfold/__init__.py index 87d18a3..caf92c2 100644 --- a/easyunfold/__init__.py +++ b/easyunfold/__init__.py @@ -2,4 +2,4 @@ Collection of code for band unfolding """ -__version__ = '0.2.0' +__version__ = '0.3.0' diff --git a/easyunfold/cli.py b/easyunfold/cli.py index e7114ff..b248c5c 100644 --- a/easyunfold/cli.py +++ b/easyunfold/cli.py @@ -1,6 +1,8 @@ """ Commandline interface """ + +import contextlib import warnings import functools from pathlib import Path @@ -11,7 +13,7 @@ from easyunfold.unfold import parse_atoms, parse_atoms_idx -# pylint:disable=import-outside-toplevel, too-many-locals, too-many-arguments, too-many-nested-blocks, too-many-branches +# pylint:disable=import-outside-toplevel, too-many-locals, too-many-arguments, too-many-nested-blocks, too-many-branches, fixme SUPPORTED_DFT_CODES = ('vasp', 'castep') @@ -105,11 +107,8 @@ def generate(pc_file, code, sc_file, matrix, kpoints, time_reversal, out_file, n dft_code=code, symprec=symprec) unfoldset.kpoint_labels = labels - try: + with contextlib.suppress(KeyError): print_symmetry_data(unfoldset) - except KeyError: - pass - out_file = Path(out_file) if code == 'vasp': out_kpt_name = f'KPOINTS_{out_file.stem}' @@ -142,7 +141,7 @@ def generate(pc_file, code, sc_file, matrix, kpoints, time_reversal, out_file, n Path(out_file).write_text(unfoldset.to_json(), encoding='utf-8') - click.echo('Unfolding settings written to ' + str(out_file)) + click.echo(f'Unfolding settings written to {str(out_file)}') @easyunfold.group('unfold') @@ -211,7 +210,7 @@ def unfold_calculate(ctx, wavefunc, save_as, gamma, ncl): out_path = save_as if save_as else ctx.obj['fname'] Path(out_path).write_text(unfoldset.to_json(), encoding='utf-8') - click.echo('Unfolding data written to ' + out_path) + click.echo(f'Unfolding data written to {out_path}') def add_mpl_style_option(func): @@ -331,30 +330,32 @@ def print_data(entries, tag='me'): ext = Path(out_file).suffix for carrier in ['electrons', 'holes']: for idx, _ in enumerate(output[carrier]): - plotter.plot_effective_mass_fit(efm=efm, - npoints=npoints, - carrier=carrier, - idx=int(idx), - save=fname + f'_fit_{carrier}_{idx}' + ext) + plotter.plot_effective_mass_fit( + efm=efm, + npoints=npoints, + carrier=carrier, + idx=int(idx), + save=f'{fname}_fit_{carrier}_{idx}{ext}', + ) def add_plot_options(func): """ Decorator that adds common plotting options to a function """ - click.option('--gamma', is_flag=True, help='Is the calculation a gamma only one?', show_default=True)(func) - click.option('--ncl', is_flag=True, help='Is the calculation with non-colinear spin?', show_default=True)(func) click.option('--npoints', type=int, default=2000, help='Number of bins for the energy.', show_default=True)(func) click.option('--sigma', type=float, default=0.02, help='Smearing width for the energy in eV.', show_default=True)(func) click.option('--eref', type=float, help='Reference energy in eV.')(func) click.option('--emin', type=float, default=-5., help='Minimum energy in eV relative to the reference.', show_default=True)(func) click.option('--emax', type=float, default=5., help='Maximum energy in eV relative to the reference.', show_default=True)(func) click.option('--intensity', default=1.0, help='Scaling factor for the colour intensity.', type=float, show_default=True)(func) - click.option('--vscale', - type=float, - help='A normalisation/scaling factor for the colour mapping. Equivalent to (1/intensity).', - default=1.0, - show_default=True)(func) + click.option( + '--vscale', + type=float, + help='A normalisation/scaling factor for the colour mapping. Equivalent to (1/intensity). ' + 'Will be deprecated in future versions.', # TODO: deprecate + default=1.0, + show_default=True)(func) click.option('--out-file', '-o', default='unfold.png', help='Name of the output file.', show_default=True)(func) click.option('--cmap', default='PuRd', help='Name of the colour map to use.', show_default=True)(func) click.option('--show', is_flag=True, default=False, help='Show the plot interactively.')(func) @@ -421,7 +422,7 @@ def add_plot_options(func): @click.pass_context @add_plot_options @add_mpl_style_option -def unfold_plot(ctx, gamma, npoints, sigma, eref, out_file, show, emin, emax, cmap, ncl, no_symm_average, vscale, dos, dos_label, zero_line, +def unfold_plot(ctx, npoints, sigma, eref, out_file, show, emin, emax, cmap, no_symm_average, vscale, dos, dos_label, zero_line, dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only, scale, procar, atoms, poscar, atoms_idx, orbitals, title, width, height, dpi, intensity): """ @@ -429,9 +430,9 @@ def unfold_plot(ctx, gamma, npoints, sigma, eref, out_file, show, emin, emax, cm This command uses the stored unfolding data to plot the effective bands structure (EBS) using the spectral function. """ - _unfold_plot(ctx, gamma, npoints, sigma, eref, out_file, show, emin, emax, cmap, ncl, no_symm_average, vscale, dos, dos_label, - zero_line, dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only, scale, procar, atoms, - poscar, atoms_idx, orbitals, title, width, height, dpi, intensity) + _unfold_plot(ctx, npoints, sigma, eref, out_file, show, emin, emax, cmap, no_symm_average, vscale, dos, dos_label, zero_line, + dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only, scale, procar, atoms, poscar, + atoms_idx, orbitals, title, width, height, dpi, intensity) def process_dos(dos, dos_elements, dos_orbitals, dos_atoms, gaussian, total_only, atoms, orbitals, poscar, no_total, legend_cutoff, scale): @@ -481,10 +482,8 @@ def process_dos(dos, dos_elements, dos_orbitals, dos_atoms, gaussian, total_only dos_elements[atom] = () for orbital in orbital_tuple: if orbital != 'all' and orbital[:1] not in dos_elements[atom]: - if orbital[:1] == 'x': # special case in VASP PROCAR labelling, set to 'd' - dos_elements[atom] += ('d',) - else: - dos_elements[atom] += (orbital[:1],) + # special case in VASP PROCAR labelling, set to 'd' if x, else just first letter + dos_elements[atom] += ('d',) if orbital[:1] == 'x' else (orbital[:1],) dos, pdos = load_dos( dos, @@ -511,11 +510,19 @@ def process_dos(dos, dos_elements, dos_orbitals, dos_atoms, gaussian, total_only @click.pass_context @add_plot_options @add_mpl_style_option -@click.option('--combined/--no-combined', is_flag=True, default=False, help='Plot all projections in a combined graph.') -@click.option('--colours', help='Colours to be used for combined plot, comma separated.', default='r,g,b,purple', show_default=True) -def unfold_plot_projections(ctx, gamma, npoints, sigma, eref, out_file, show, emin, emax, cmap, ncl, no_symm_average, vscale, dos, - dos_label, zero_line, dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only, - scale, procar, atoms, poscar, atoms_idx, orbitals, title, combined, colours, width, height, dpi, intensity): +@click.option('--combined/--no-combined', is_flag=True, default=False, help='Plot all projections in a combined graph.', show_default=True) +@click.option('--colours', + help='Colours to be used for combined plot, comma separated (e.g. "r,b,y"). ' + 'Default is pastel red, green, blue if <=3 projections, else red, green, blue, purple, orange, yellow.', + default=None) +@click.option('--colourspace', + help='Colourspace in which to perform interpolation for combined plot.', + default='lab', + show_default=True, + type=click.Choice(['rgb', 'hsv', 'lab', 'luvlch', 'lablch', 'xyz'])) +def unfold_plot_projections(ctx, npoints, sigma, eref, out_file, show, emin, emax, cmap, no_symm_average, vscale, dos, dos_label, zero_line, + dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only, scale, procar, atoms, + poscar, atoms_idx, orbitals, title, combined, colours, colourspace, width, height, dpi, intensity): """ Plot the effective band structure with atomic projections. """ @@ -534,13 +541,11 @@ def unfold_plot_projections(ctx, gamma, npoints, sigma, eref, out_file, show, em dos_label=dos_label, dos_options=dos_options, zero_line=zero_line, - gamma=gamma, npoints=npoints, sigma=sigma, eref=eref, ylim=(emin, emax), cmap=cmap, - ncl=ncl, symm_average=not no_symm_average, atoms=atoms, atoms_idx=atoms_idx, @@ -552,7 +557,8 @@ def unfold_plot_projections(ctx, gamma, npoints, sigma, eref, out_file, show, em intensity=intensity, figsize=(width, height), dpi=dpi, - colours=colours.split(',') if colours is not None else None) + colours=colours.split(',') if colours is not None else None, + colorspace=colourspace) if out_file: fig.savefig(out_file, dpi=dpi, bbox_inches='tight') @@ -563,7 +569,6 @@ def unfold_plot_projections(ctx, gamma, npoints, sigma, eref, out_file, show, em def _unfold_plot(ctx, - gamma, npoints, sigma, eref, @@ -572,7 +577,6 @@ def _unfold_plot(ctx, emin, emax, cmap, - ncl, no_symm_average, vscale, dos, @@ -612,12 +616,12 @@ def _unfold_plot(ctx, eref = unfoldset.calculated_quantities.get('vbm', 0.0) click.echo(f'Using a reference energy of {eref:.3f} eV') - # Setup the atoms_idx and orbitals + # Set up the atoms_idx and orbitals if atoms or atoms_idx: - # Process the PROCAR + # Process the PROCARs click.echo(f'Loading projections from: {procar}') try: - unfoldset.load_procar(procar) + unfoldset.load_procars(procar) except FileNotFoundError as exc: click.echo(f'Could not find and parse the --procar file: {procar} – needed for atomic projections!') raise click.Abort() from exc @@ -655,10 +659,8 @@ def _unfold_plot(ctx, # Collect spectral functions and scale all_sf = [] for this_idx, this_orbitals in zip(atoms_idx_subplots, orbitals_subplots): - eng, spectral_function = unfoldset.get_spectral_function(gamma=gamma, - npoints=npoints, + eng, spectral_function = unfoldset.get_spectral_function(npoints=npoints, sigma=sigma, - ncl=ncl, atoms_idx=this_idx, orbitals=this_orbitals, symm_average=not no_symm_average) @@ -702,26 +704,20 @@ def _unfold_plot(ctx, def print_symmetry_data(kset): """Print the symmetry information""" - # Print space group information - sc_spg = kset.metadata['symmetry_dataset_sc'] - click.echo('Supercell cell information:') - click.echo(' ' * 8 + f'Space group number: {sc_spg["number"]}') - click.echo(' ' * 8 + f'International symbol: {sc_spg["international"]}') - click.echo(' ' * 8 + f'Point group: {sc_spg["pointgroup"]}') + def _print_symmetry_data_from_kset(kset, dataset_key, dataset_title): + # Print space group information + sc_spg = kset.metadata[dataset_key] + click.echo(dataset_title) + click.echo(' ' * 8 + f'Space group number: {sc_spg["number"]}') + click.echo(' ' * 8 + f'International symbol: {sc_spg["international"]}') + click.echo(' ' * 8 + f'Point group: {sc_spg["pointgroup"]}') - pc_spg = kset.metadata['symmetry_dataset_pc'] - click.echo('\nPrimitive cell information:') - click.echo(' ' * 8 + f'Space group number: {pc_spg["number"]}') - click.echo(' ' * 8 + f'International symbol: {pc_spg["international"]}') - click.echo(' ' * 8 + f'Point group: {pc_spg["pointgroup"]}') + _print_symmetry_data_from_kset(kset, 'symmetry_dataset_sc', 'Supercell cell information:') + _print_symmetry_data_from_kset(kset, 'symmetry_dataset_pc', '\nPrimitive cell information:') def matrix_from_string(string): - """Parse transform matrix from a string""" + """Parse transformation matrix from a string""" elems = [float(x) for x in string.split()] - # Try gussing the transform matrix - if len(elems) == 3: - transform_matrix = np.diag(elems) - else: - transform_matrix = np.array(elems).reshape((3, 3)) - return transform_matrix + # try guessing the transormation matrix + return np.diag(elems) if len(elems) == 3 else np.array(elems).reshape((3, 3)) diff --git a/easyunfold/plotting.py b/easyunfold/plotting.py index 0ff4cb2..352aad3 100644 --- a/easyunfold/plotting.py +++ b/easyunfold/plotting.py @@ -1,6 +1,7 @@ """ Plotting utilities """ +import os from typing import Union, Sequence import warnings import itertools @@ -85,7 +86,15 @@ def _get_orbital_colour_dict(index, colour_list): if orbital in key: sumo_colours[atom][key] = orbital_colour_dict[key] else: - sumo_colours = None + from pkg_resources import Requirement, resource_filename + try: + import configparser + except ImportError: + import ConfigParser as configparser + + config_path = resource_filename(Requirement.parse('sumo'), 'sumo/plotting/orbital_colours.conf') + sumo_colours = configparser.ConfigParser() + sumo_colours.read(os.path.abspath(config_path)) # don't use first 4 colours; these are the band structure line colours: cycle = cycler('color', rcParams['axes.prop_cycle'].by_key()['color'][4:]) @@ -117,9 +126,6 @@ def _get_orbital_colour_dict(index, colour_list): lines = plot_data['lines'] spins = [Spin.up] if len(lines[0][0]['dens']) == 1 else [Spin.up, Spin.down] - # disable y ticks for DOS panel - ax.tick_params(axis='y', which='both', right=False) - for line_set in plot_data['lines']: for line, spin in itertools.product(line_set, spins): if spin == Spin.up or len(spins) == 1: @@ -153,8 +159,8 @@ def _get_orbital_colour_dict(index, colour_list): if dos_label is not None: ax.set_xlabel(dos_label) - ax.set_xticklabels([]) - ax.set_yticklabels([]) + ax.set_yticks([]) # no y ticks + ax.set_xticks([]) # no x ticks ax.legend(loc=2, frameon=False, ncol=1, bbox_to_anchor=(1.0, 1.0), fontsize=9) return ax @@ -217,19 +223,15 @@ def plot_spectral_function( if nspin > 1: warnings.warn('DOS plotter is not supported for spin-separated plots. Reverting to non spin-polarised plotting.') nspin = 1 - else: - if ax is None: - if nspin == 1: - fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) - axes = [ax] - else: - fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi) + elif ax is None: + if nspin == 1: + fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) + axes = [ax] else: - if not isinstance(ax, list): - axes = [ax] - else: - axes = ax - fig = axes[0].figure + fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi) + else: + axes = [ax] if not isinstance(ax, list) else ax + fig = axes[0].figure # Shift the kdist so the pcolormesh draw the pixel centred on the original point X, Y = np.meshgrid(kdist, engs - eref) @@ -330,10 +332,7 @@ def _plot_spectral_function_rgba( else: fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi) else: - if not isinstance(ax, list): - axes = [ax] - else: - axes = ax + axes = [ax] if not isinstance(ax, list) else ax fig = axes[0].figure mask = (engs < (ylim[1] + eref)) & (engs > (ylim[0] + eref)) @@ -379,10 +378,7 @@ def _add_kpoint_labels(self, ax: plt.Axes, x_is_kidx=False): tick_locs = [] tick_labels = [] for index, label in labels: - if x_is_kidx: - xloc = index - else: - xloc = kdist[index] + xloc = index if x_is_kidx else kdist[index] ax.axvline(x=xloc, lw=0.5, color='k', ls=':', alpha=0.8) tick_locs.append(xloc) tick_labels.append(clean_latex_string(label)) @@ -511,10 +507,7 @@ def plot_spectral_weights( else: fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi) else: - if not isinstance(ax, list): - axes = [ax] - else: - axes = ax + axes = [ax] if not isinstance(ax, list) else ax fig = axes[0].figure kweights = unfold.expansion_results['weights'] @@ -558,37 +551,37 @@ def plot_spectral_weights( return fig def plot_projected( - self, - procar: Union[str, list] = 'PROCAR', - dos_plotter=None, - dos_label=None, - dos_options=None, - zero_line=False, - eref=None, - gamma=False, - npoints=2000, - sigma=0.2, - ncl=False, - symm_average=True, - figsize=(4, 3), - ylim=(-5, 5), - dpi=300, - vscale=1.0, - contour_plot=False, - alpha=1.0, - save=False, - ax=None, - cmap='PuRd', - show=False, - title=None, - atoms=None, - poscar='POSCAR', - atoms_idx=None, - orbitals=None, - use_subplot=False, - colours=('r', 'g', 'b', 'purple'), - colorspace='lab', - intensity=1.0, + self, + procar: Union[str, list] = 'PROCAR', + dos_plotter=None, + dos_label=None, + dos_options=None, + zero_line=False, + eref=None, + gamma=False, + npoints=2000, + sigma=0.2, + ncl=False, + symm_average=True, + figsize=(4, 3), + ylim=(-5, 5), + dpi=300, + vscale=1.0, + contour_plot=False, + alpha=1.0, + save=False, + ax=None, + cmap='PuRd', + show=False, + title=None, + atoms=None, + poscar='POSCAR', + atoms_idx=None, + orbitals=None, + use_subplot=False, + colours=None, + colorspace='lab', + intensity=1.0, ): """ Plot projected spectral function onto multiple subplots or a single plot with color mapping. @@ -598,10 +591,13 @@ def plot_projected( :param procar: Name of names of the `PROCAR` files. + :param colours: Default is pastel red, green, blue if <=3 projections, else red, green, + blue, purple, orange, yellow. + :returns: Generated plot. """ unfoldset = self.unfold - unfoldset.load_procar(procar) + unfoldset.load_procars(procar) nspin = unfoldset.calculated_quantities['spectral_weights_per_set'][0].shape[0] if atoms_idx is not None: @@ -636,12 +632,11 @@ def plot_projected( # Collect spectral functions and scale for this_idx, this_orbitals in zip(atoms_idx_subplots, orbitals_subplots): - # Setup the atoms_idx and orbitals + # Set up the atoms_idx and orbitals if isinstance(this_idx, str): this_idx, this_orbitals = process_projection_options(this_idx, this_orbitals) - else: # list of integers; pre-processed by specifying atoms - if this_orbitals != 'all': - this_orbitals = [token.strip() for token in this_orbitals.split(',')] + elif this_orbitals != 'all': + this_orbitals = [token.strip() for token in this_orbitals.split(',')] eng, spectral_function = unfoldset.get_spectral_function(gamma=gamma, npoints=npoints, @@ -701,6 +696,18 @@ def plot_projected( stacked_sf = np.stack(all_sf, axis=-1).reshape(np.prod(sf_size), len(all_sf)) # Construct the colour basis + if colours is None: + if len(all_sf) <= 3: + colours = ['#CC33A7', '#A7CC33', '#33A7CC'] + else: + colours = [ + (1, 0, 0), # red + (0, 1, 0), # green + (0, 0, 1), # blue + (152 / 255, 78 / 255, 163 / 255), # purple + (1, 127 / 255, 0), # orange + (1, 1, 51 / 255), # yellow + ] colours = colours[:len(all_sf)] # Compute spectral weight data with RGB reshape it back into the shape (nengs, nk, 3) sf_rgb = interpolate_colors(colours, stacked_sf, colorspace, normalize=True).reshape(sf_size + (3,)) @@ -761,9 +768,7 @@ def plot_projected( warnings.warn('zero_line option requires sumo to be installed!') if atoms is not None: # add figure legend with atoms and colors - legend_elements = [] - for i, atom in enumerate(atoms): - legend_elements.append(Patch(facecolor=colours[i], label=atom, alpha=0.7)) + legend_elements = [Patch(facecolor=colours[i], label=atom, alpha=0.7) for i, atom in enumerate(atoms)] fig.axes[0].legend(handles=legend_elements, bbox_to_anchor=(1.025, 1), fontsize=9) fig.subplots_adjust(right=0.78) # ensure legend is not cut off @@ -814,7 +819,7 @@ def interpolate_colors(colours: Sequence, weights: list, colorspace='lab', norma :param weights: A list of weights with the shape (n, N). Where the N values of the last axis give the amount of N colours supplied in `colours`. :param colorspace: The colorspace in which to perform the interpolation. The - allowed values are rgb, hsv, lab, luvlc, lablch, and xyz. + allowed values are rgb, hsv, lab, luvlch, lablch, and xyz. :returns: A list of colours, specified in the rgb format as a (n, 3) array. """ @@ -844,7 +849,7 @@ def interpolate_colors(colours: Sequence, weights: list, colorspace='lab', norma # Normalise the weights if needed if normalize: - weights = weights / np.linalg.norm(weights, axis=1)[:, None] + weights = weights / np.sum(weights, axis=1)[:, None] # each row sums to 1 # perform the interpolation in the colorspace basis interpolated_colors = colors_basis[0] * weights[:, 0][:, None] @@ -853,10 +858,24 @@ def interpolate_colors(colours: Sequence, weights: list, colorspace='lab', norma # convert the interpolated colors back to RGB rgb_colors = [convert_color(colorspace(*c), sRGBColor).get_value_tuple() for c in interpolated_colors] - rgb_colors = np.stack(rgb_colors, axis=0) - # ensure all rgb values are less than 1 (sometimes issues in interpolation gives - np.clip(rgb_colors, 0, 1, rgb_colors) + # ensure all rgb values are less than 1 (sometimes issues in interpolation) + normalised_rgb_colors = [] + for rgb_color_tuple in rgb_colors: + if np.max(rgb_color_tuple) > 1: + normalised_rgb_color = np.array(rgb_color_tuple) / np.max(rgb_color_tuple) + else: + normalised_rgb_color = np.array(rgb_color_tuple) + + normalised_rgb_color = np.clip(normalised_rgb_color, 0, 1) # ensure all rgb values are between 0 and 1 + # if too white, darken: + if np.linalg.norm(normalised_rgb_color) > 1: # white af + normalised_rgb_color *= (1 / np.linalg.norm(normalised_rgb_color)**(1 / 2)) + + normalised_rgb_colors.append(normalised_rgb_color) + + rgb_colors = np.stack(normalised_rgb_colors, axis=0) + return rgb_colors diff --git a/easyunfold/procar.py b/easyunfold/procar.py index 53218cc..192302b 100644 --- a/easyunfold/procar.py +++ b/easyunfold/procar.py @@ -6,15 +6,21 @@ import re import numpy as np +from monty.json import MSONable, MontyDecoder -class Procar: +from easyunfold import __version__ + +# pylint:disable=too-many-locals, + + +class Procar(MSONable): """Reader for PROCAR file""" - def __init__(self, fobj_or_path=None, is_soc=False): + def __init__(self, fobjs_or_paths=None, is_soc=False): """ Read the PROCAR file from a handle or path - :param fobj_or_path: A file-like obj or a path + :param fobjs_or_paths: Either a string or list of file-like objs or paths :param is_soc: Whether the PROCAR is from a calculation with spin-orbit coupling """ self._is_soc = is_soc @@ -32,85 +38,197 @@ def __init__(self, fobj_or_path=None, is_soc=False): self.proj_xyz = None # Read the PROCAR - if isinstance(fobj_or_path, (str, Path)): - with open(fobj_or_path, encoding='utf-8') as fhandle: - self._read(fhandle) - else: - self._read(fobj_or_path) + if isinstance(fobjs_or_paths, (str, Path)): + fobjs_or_paths = [fobjs_or_paths] + self.read(fobjs_or_paths) - def _read(self, fobj): + def _read(self, fobj, parsed_kpoints=None): """Main function for reading in the data""" + if parsed_kpoints is None: + parsed_kpoints = set() - # First sweep - found the number of kpoints and the number of bands + # First sweep - find the number of kpoints and the number of bands fobj.seek(0) - self.header = fobj.readline() + _header = fobj.readline() # Read the NK, NB and NIONS that are integers - self.nkpts, self.nbands, self.nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()] - # Number of projects and their names - nproj = None - self.proj_names = None - - for line in fobj: - if re.match(r'^ion.*tot', line): # only the first "ion" line, in case of LORBIT >= 12 - nproj = len(line.strip().split()) - 2 - self.proj_names = line.strip().split()[1:-1] - break + _total_nkpts, nbands, nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()] + if nion != self.nion: + raise ValueError(f'Mismatch in number of ions in PROCARs supplied: ({nion} vs {self.nion})!') # Count the number of data lines, these lines do not have any alphabets - proj_data = [] - energies = [] - occs = [] - kvecs = [] - kweights = [] + proj_data, energies, kvecs, kweights, occs = [], [], [], [], [] + tot_count = 0 # count the instances of lines starting with "tot" -> (4 + 1) * nbands * nkpts for SOC calcs fobj.seek(0) - for line in fobj: - if not re.search(r'[a-zA-Z]', line) and line.strip() and len(line.strip().split()) - 2 == nproj: - # only parse data if nproj is expected length, in case of LORBIT >= 12 + + line = fobj.readline() + while line: + if line.startswith(' k-point'): + line = re.sub(r'(\d)-', r'\1 -', line) + tokens = line.strip().split() + kvec = tuple(round(float(val), 5) for val in # tuple to make it hashable + tokens[-6:-3]) # round to 5 decimal places to ensure proper kpoint matching + if kvec not in parsed_kpoints: + parsed_kpoints.add(kvec) + kvecs.append(list(kvec)) + kweights.append(float(tokens[-1])) + else: + # skip ahead to the next instance of two blank lines in a row + while line.strip() or fobj.readline().strip(): + line = fobj.readline() + continue + + elif not re.search(r'[a-zA-Z]', line) and line.strip() and len(line.strip().split()) - 2 == len(self.proj_names): + # only parse data if line is expected length, in case of LORBIT >= 12 proj_data.append([float(token) for token in line.strip().split()[1:-1]]) + elif line.startswith('band'): tokens = line.strip().split() energies.append(float(tokens[4])) occs.append(float(tokens[-1])) - elif line.startswith(' k-point'): - line = re.sub(r'(\d)-', r'\1 -', line) - tokens = line.strip().split() - kvecs.append([float(val) for val in tokens[-6:-3]]) - kweights.append(float(tokens[-1])) - self.occs = np.array(occs) - self.kvecs = np.array(kvecs) - self.kweights = np.array(kweights) - self.eigenvalues = np.array(energies) + + elif line.startswith('tot'): + tot_count += 1 + + line = fobj.readline() + + # dynamically determine whether PROCARs are SOC or not + if tot_count == 4 * len(occs): + self._is_soc = True + elif tot_count == len(occs): + self._is_soc = False + else: + raise ValueError(f"Number of lines starting with 'tot' ({tot_count}) in PROCAR does not match expected " + f'values ({4*len(occs)} or {len(occs)})!') + + occs = np.array(occs) + kvecs = np.array(kvecs) + kweights = np.array(kweights) + eigenvalues = np.array(energies) proj_data = np.array(proj_data, dtype=float) - self.nspins = proj_data.shape[0] // (self.nion * self.nbands * self.nkpts) + # redetermine nkpts in case some were skipped due to already being parsed + nkpts = len(kvecs) + + self.nspins = proj_data.shape[0] // (self.nion * nbands * nkpts) self.nspins //= 4 if self._is_soc else 1 # Reshape - self.occs.resize((self.nspins, self.nkpts, self.nbands)) - self.kvecs.resize((self.nspins, self.nkpts, 3)) - self.kweights.resize((self.nspins, self.nkpts)) - self.eigenvalues.resize((self.nspins, self.nkpts, self.nbands)) + occs.resize((self.nspins, nkpts, nbands)) + kvecs.resize((self.nspins, nkpts, 3)) + kweights.resize((self.nspins, nkpts)) + eigenvalues.resize((self.nspins, nkpts, nbands)) # Reshape the array if self._is_soc is False: - self.proj_data = proj_data.reshape((self.nspins, self.nkpts, self.nbands, self.nion, nproj)) + proj_data = proj_data.reshape((self.nspins, nkpts, nbands, self.nion, len(self.proj_names))) + proj_xyz = None else: - self.proj_data = proj_data.reshape((self.nspins, self.nkpts, self.nbands, 4, self.nion, nproj)) + proj_data = proj_data.reshape((self.nspins, nkpts, nbands, 4, self.nion, len(self.proj_names))) # Split the data into xyz projection and total - self.proj_xyz = self.proj_data[:, :, :, 1:, :, :] - self.proj_data = self.proj_data[:, :, :, 0, :, :] + proj_xyz = proj_data[:, :, :, 1:, :, :] + proj_data = proj_data[:, :, :, 0, :, :] + + # normalise: (for each nspin, nkpt, nband, the sum of the projections over nion and proj_names should be 1) + proj_sum = np.sum(proj_data, axis=(-2, -1), keepdims=True) + proj_sum[proj_sum == 0] = 1 # just in case, avoid division by zero + proj_data /= proj_sum + + if proj_xyz is not None: + proj_sum = np.sum(proj_xyz, axis=(-3, -2, -1), keepdims=True) + proj_sum[proj_sum == 0] = 1 + proj_xyz /= proj_sum + + return self.nspins, occs, kvecs, kweights, eigenvalues, proj_data, proj_xyz, parsed_kpoints + + def _read_header_nion_proj_names(self, fobj): + """Read the header, nion and proj_names from the PROCAR""" + fobj.seek(0) + self.header = fobj.readline() + # Read the NK, NB and NIONS that are integers + _nkpts, _nbands, self.nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()] + self.proj_names = None # projection names + + for line in fobj: + if re.match(r'^ion.*tot', line): # only the first "ion" line, in case of LORBIT >= 12 + self.proj_names = line.strip().split()[1:-1] + break + + def read(self, fobjs_or_paths): + """Read and amalgamate the data from a list of PROCARs""" + + def open_file(fobj_or_path): + if isinstance(fobj_or_path, (str, Path)): + return open(fobj_or_path, encoding='utf-8') # closed later + return fobj_or_path # already a file-like object, just return it + + parsed_kpoints = None + occs_list, kvecs_list, kweights_list = [], [], [] + eigenvalues_list, proj_data_list, proj_xyz_list = [], [], [] + for i, fobj_or_path in enumerate(fobjs_or_paths): + # Note: If PROCAR parsing becomes a significant bottleneck for people (e.g. with several HSE06+SOC PROCARs), + # this could be parallelized (somewhat) with multiprocessing. The actual file parsing in _read() is currently + # serial so no easy wins there, but could at least parallelise over the list of PROCARs + fobj = open_file(fobj_or_path) + if self.header is None: # first file; read header, nion, proj_names + self._read_header_nion_proj_names(fobj) + + current_nspins = self.nspins # check spin consistency between PROCARs + nspins, occs, kvecs, kweights, eigenvalues, proj_data, proj_xyz, parsed_kpoints = self._read(fobj, + parsed_kpoints=parsed_kpoints) + if current_nspins is not None and current_nspins != nspins: + raise ValueError(f'Mismatch in number of spins in PROCARs supplied: ({nspins} vs {current_nspins})!') + + if isinstance(fobj_or_path, (str, Path)): + fobj.close() # if file was opened in this loop, close it + + # Append to respective lists + occs_list.append(occs) + kvecs_list.append(kvecs) + kweights_list.append(kweights) + eigenvalues_list.append(eigenvalues) + proj_data_list.append(proj_data) + proj_xyz_list.append(proj_xyz) + if len(fobjs_or_paths) > 1: # print progress if reading multiple files + print(f'Finished parsing PROCAR {i + 1}/{len(fobjs_or_paths)}') + + # Combine along the nkpts axis: + # for occs, eigenvalues, proj_data and proj_xyz, nbands (axis = 2) could differ, so set missing values to zero: + max_nbands = max(arr.shape[2] for arr in eigenvalues_list) + for array_list in [occs_list, eigenvalues_list, proj_data_list, proj_xyz_list]: + for i, arr in enumerate(array_list): + if arr is not None and arr.shape[2] < max_nbands: + if len(arr.shape) == 3: # occs_list, eigenvalues_list + array_list[i] = np.pad(arr, ((0, 0), (0, 0), (0, max_nbands - arr.shape[2])), mode='constant') + elif len(arr.shape) == 5: # proj_xyz_list + array_list[i] = np.pad(arr, ((0, 0), (0, 0), (0, max_nbands - arr.shape[2]), (0, 0), (0, 0)), mode='constant') + elif len(arr.shape) == 6: # proj_xyz_list + array_list[i] = np.pad(arr, ((0, 0), (0, 0), (0, max_nbands - arr.shape[2]), (0, 0), (0, 0), (0, 0)), + mode='constant') + else: + raise ValueError('Unexpected array shape encountered!') + + self.nbands = max_nbands + self.occs = np.concatenate(occs_list, axis=1) + self.eigenvalues = np.concatenate(eigenvalues_list, axis=1) + self.kvecs = np.concatenate(kvecs_list, axis=1) + self.kweights = np.concatenate(kweights_list, axis=1) + self.proj_data = np.concatenate(proj_data_list, axis=1) + if all(arr is not None for arr in proj_xyz_list): + self.proj_xyz = np.concatenate(proj_xyz_list, axis=1) + + self.nkpts = self.kvecs.shape[1] def get_projection(self, atom_idx: List[int], proj: Union[List[str], str], weight_by_k=False): """ - Get project for specific atoms and specific projectors + Get projection for specific atoms and specific projectors :param atom_idx: A list of index of the atoms to be selected :param proj: A list of the projector names to be selected :param weight_by_k: Apply k weighting or not. - :returns: The project summed over the selected atoms and the projectors + :returns: The projection summed over the selected atoms and the projectors """ atom_mask = [iatom in atom_idx for iatom in range(self.nion)] assert any(atom_mask) @@ -143,3 +261,41 @@ def _replace_p_d(single_proj): for kidx in range(self.nkpts): out[:, kidx, :] *= self.kweights[kidx] return out + + def as_dict(self) -> dict: + """Convert the object into a dictionary representation (so it can be saved to json)""" + output = {'@module': self.__class__.__module__, '@class': self.__class__.__name__, '@version': __version__} + for key in [ + '_is_soc', 'eigenvalues', 'kvecs', 'kweights', 'nbands', 'nkpts', 'nspins', 'nion', 'occs', 'proj_names', 'proj_data', + 'header', 'proj_xyz' + ]: + output[key] = getattr(self, key) + return output + + @classmethod + def from_dict(cls, d): + """ + Reconstructs Procar object from a dict representation, without calling __init__(). + + Args: + d (dict): dict representation of Procar + + Returns: + Procar object + """ + + def decode_dict(subdict): + if isinstance(subdict, dict) and '@module' in subdict: + return MontyDecoder().process_decoded(subdict) + return subdict + + instance = cls.__new__(cls) # create a new instance without calling __init__() + d_decoded = {k: decode_dict(v) for k, v in d.items()} + + # set the instance variables directly from the dictionary + for key, value in d_decoded.items(): + if key in ['@module', '@class', '@version']: + continue + setattr(instance, key, value) + + return instance diff --git a/easyunfold/unfold.py b/easyunfold/unfold.py index 37adcef..fe08d1b 100644 --- a/easyunfold/unfold.py +++ b/easyunfold/unfold.py @@ -2,6 +2,9 @@ """ The main module for unfolding workflow and algorithm """ + +import contextlib +import itertools # pylint: disable=invalid-name,protected-access,too-many-locals ############################################################ @@ -317,9 +320,7 @@ def generate_sc_kpoints(self) -> None: # Collect from form nested list containing the mappings to the reduce sc kpoints reduced_sc_map = [] for sc_set in expended_sc: - map_indx = [] - for _ in sc_set: - map_indx.append(sc_kpts_map.pop(0)) + map_indx = [sc_kpts_map.pop(0) for _ in sc_set] reduced_sc_map.append(map_indx) self.expansion_results['reduced_sckpts'] = reduced_sckpts @@ -411,17 +412,14 @@ def _read_weights(self, wavefunction: Union[str, List[str]], gamma: bool, ncl: b return averaged_weights, weights_per_set - def load_procar(self, procar: Union[str, List[str]], force: bool = False): + def load_procars(self, procars: Union[str, List[str]]): """Read in PROCAR for band-based projection""" - if self.procars and not force: - pass - - if not isinstance(procar, (tuple, list)): - procar = [procar] + if not isinstance(procars, (tuple, list)): + procars = [procars] # list of PROCAR files # Load the procars # Note that this method should be generalised for non-VASP as well. - self.transient_quantities['procars'] = [Procar(path) for path in procar] + self.transient_quantities['procars'] = Procar(procars) # Construct mapping from the primitive cell kpoints to those in the PROCAR self.transient_quantities['procars_kmap'] = self._construct_procar_kmap() @@ -436,20 +434,17 @@ def _construct_procar_kmap(self) -> list: K_super, _ = find_K_from_k(kpoint, self.M) # Search for kpoints in the procar found = False - for iprocar, procar in enumerate(self.procars): - for ikpt, kprocar in enumerate(procar.kvecs[0]): - if kpoints_equal(K_super, kprocar, time_reversal=self.time_reversal): - kidx_procar_sets[-1].append([iprocar, ikpt]) - found = True - break - if found: + for ikpt, kprocar in enumerate(self.procar.kvecs[0]): + if kpoints_equal(K_super, kprocar, time_reversal=self.time_reversal): + kidx_procar_sets[-1].append(ikpt) + found = True break if found is False: raise ValueError(f'Cannot found kpoint {K_super} in PROCAR files') return kidx_procar_sets @property - def procars(self) -> Union[None, Procar]: + def procar(self) -> Union[None, Procar]: """Loaded PROCARS""" return self.transient_quantities.get('procars') @@ -481,9 +476,8 @@ def _get_spectral_weights(self, # If wave function is given - reload from the data if wavefunction: self._read_weights(wavefunction, gamma=gamma, ncl=ncl, gamma_half=gamma_half) - else: - if not self.is_calculated: - raise RuntimeWarning('The spectral weights need to be calculated first - please pass the wave function file(s).') + elif not self.is_calculated: + raise RuntimeWarning('The spectral weights need to be calculated first - please pass the wave function file(s).') # Use existing results if symm_average: @@ -493,11 +487,11 @@ def _get_spectral_weights(self, # No averaging - we just return the first item of each set, which is the weight of the original set sws = [item[:, :1, :, :] for item in self.calculated_quantities['spectral_weights_per_set']] # Recreate the full weights array - kweight_sets = [[1.0] for i in range(len(sws))] + kweight_sets = [[1.0] for _ in range(len(sws))] if also_spectral_function: if atoms_idx is not None: - # Read in the project weights + # Read in the projected weights band_weight_sets = self.get_band_weight_sets(atoms_idx, orbitals) else: band_weight_sets = None @@ -531,18 +525,18 @@ def get_band_weight_sets(self, :returns: A list of weights for each band at each expanded kpoint """ if procars: - self.load_procar(procars) - if self.procars is None: + self.load_procars(procars) + if self.procar is None: raise RuntimeError('PROCAR files needs to be loaded') - projs = [procar.get_projection(atoms_idx, orbitals) for procar in self.procars] + proj = self.procar.get_projection(atoms_idx, orbitals) # Construct band weighting, same structure as o band_weight_sets = [] for kset in self.transient_quantities['procars_kmap']: band_weight_sets.append([]) # Search - for iprocar, kidx in kset: - band_weight = projs[iprocar][:, kidx] + for kidx in kset: + band_weight = proj[:, kidx] band_weight_sets[-1].append(band_weight) return band_weight_sets @@ -673,12 +667,11 @@ def clean_latex_string(label: str): :returns: Cleaned tag string """ if label == 'G': - label = r'$\mathrm{\mathsf{\Gamma}}$' - elif label.startswith('\\'): ## This is a latex formatted label already - label = f'$\\mathrm{{\\mathsf{{{label}}}}}$' - else: - label = r'$\mathrm{\mathsf{' + label + r'}}$' - return label + return r'$\mathrm{\mathsf{\Gamma}}$' + if label.startswith('\\'): ## This is a latex formatted label already + return f'$\\mathrm{{\\mathsf{{{label}}}}}$' + + return r'$\mathrm{\mathsf{' + label + r'}}$' def spectral_function_from_weight_sets(spectral_weight_sets: np.ndarray, @@ -710,18 +703,17 @@ def spectral_function_from_weight_sets(spectral_weight_sets: np.ndarray, emax = spectral_weight_sets[0][:, :, :, 0].max() if emax is None else emax e0 = np.linspace(emin - 5 * sigma, emax + 5 * sigma, nedos) - for ispin in range(ns): - for ii in range(nk): # Iterate through kpoint sets (of primitive cell kpoints) - for jj in range(spectral_weight_sets[ii].shape[1]): - kweight = kweight_sets[ii][jj] - E_Km = spectral_weight_sets[ii][ispin, jj, :, 0] - P_Km = spectral_weight_sets[ii][ispin, jj, :, 1] - if band_weight_sets is not None: - P_Km = P_Km * band_weight_sets[ii][jj][ispin, :P_Km.shape[0]] - # Take weighted average spectral functions - spectral_function[ispin, :, - ii] += np.sum(LorentzSmearing(e0[:, np.newaxis], E_Km[np.newaxis, :], sigma=sigma) * P_Km[np.newaxis, :], - axis=1) * kweight + for ispin, ii in itertools.product(range(ns), range(nk)): # Iterate through kpoint sets (of primitive cell kpoints) + for jj in range(spectral_weight_sets[ii].shape[1]): + kweight = kweight_sets[ii][jj] + E_Km = spectral_weight_sets[ii][ispin, jj, :, 0] + P_Km = spectral_weight_sets[ii][ispin, jj, :, 1] + if band_weight_sets is not None: + P_Km = P_Km * band_weight_sets[ii][jj][ispin, :P_Km.shape[0]] + # Take weighted average spectral functions + spectral_function[ispin, :, + ii] += np.sum(LorentzSmearing(e0[:, np.newaxis], E_Km[np.newaxis, :], sigma=sigma) * P_Km[np.newaxis, :], + axis=1) * kweight return e0, spectral_function @@ -1009,11 +1001,9 @@ def spectral_weight_multiple_source(kpoints: list, unfold_objs: List[Unfold], tr assert ns == obj.wfc.nspins # When reading from multiple wave function files (e.g. WAVECAR for VASP), - # it is possible that each of them may have differnt number of bands. + # it is possible that each of them may have different number of bands. # Ff so, we take only the first N bands, where N is the minimum values of bands - nb = [] - for source in unfold_objs: - nb.append(source.bands.shape[2]) + nb = [source.bands.shape[2] for source in unfold_objs] nbands = min(nb) spectral_weights = [] @@ -1101,9 +1091,8 @@ def parse_atoms_idx(atoms_idx: str) -> List[int]: out.extend(range(int(match.group(1)), int(match.group(2)) + 1)) else: out.append(int(item)) - # Expect passing 1-based indexing - out = [x - 1 for x in out] - return out + + return [x - 1 for x in out] # Expect passing 1-based indexing def process_projection_options(atoms_idx: str, orbitals: str) -> Tuple[list, list]: @@ -1153,7 +1142,7 @@ def parse_atoms(atoms_to_project: str, orbitals: str, poscar: str): """ atoms_to_project = re.split(', *', atoms_to_project) ase_atoms = read_poscar_contcar_if_present(poscar) - try: # check POTCAR if possible, to check the POSCAR-POTCARs match + with contextlib.suppress(FileNotFoundError): atom_types = get_atomtypes('POTCAR') def _check_order(smaller_list, larger_list): @@ -1162,8 +1151,6 @@ def _check_order(smaller_list, larger_list): if not _check_order(atom_types, ase_atoms.get_chemical_symbols()): warnings.warn('The order of atoms in the POSCAR/CONTCAR and POTCAR do not match!') - except FileNotFoundError: - pass atoms_idx = [ [i for i, atom in enumerate(ase_atoms) if projected_atom_symbol in atom.symbol] for projected_atom_symbol in atoms_to_project ] @@ -1175,7 +1162,7 @@ def _check_order(smaller_list, larger_list): # Special case: if only one set is passed, apply it to all atomic specifications if len(orbitals_subplots) == 1: - orbitals_subplots = orbitals_subplots * len(atoms_idx) + orbitals_subplots *= len(atoms_idx) orbitals_list = [] for orbital_sublist in orbitals_subplots: diff --git a/examples/MgO/unfold_project.png b/examples/MgO/unfold_project.png index 46ed044..39073de 100644 Binary files a/examples/MgO/unfold_project.png and b/examples/MgO/unfold_project.png differ diff --git a/examples/MgO/unfold_project_rb.png b/examples/MgO/unfold_project_rb.png index 88287b0..64eb689 100644 Binary files a/examples/MgO/unfold_project_rb.png and b/examples/MgO/unfold_project_rb.png differ diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj.png index 5139b80..e4440fa 100644 Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj.png differ diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj_dos.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj_dos.png index 6664fd2..90550cd 100644 Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj_dos.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj_dos.png differ diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png index eb3a13f..f5bc4e8 100644 Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png differ diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj_orbital_lm_dos.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj_orbital_lm_dos.png index a802af4..50b4d50 100644 Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj_orbital_lm_dos.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj_orbital_lm_dos.png differ diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj_sps_dos.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj_sps_dos.png index 7043ca4..5ee426e 100644 Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj_sps_dos.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj_sps_dos.png differ diff --git a/examples/Si222/unfold_tall.png b/examples/Si222/unfold_tall.png new file mode 100644 index 0000000..e73153c Binary files /dev/null and b/examples/Si222/unfold_tall.png differ diff --git a/tests/test_cli.py b/tests/test_cli.py index b76f020..16f3038 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -116,36 +116,43 @@ def test_plot_projection(mgo_project_dir): runner = CliRunner() output = runner.invoke(easyunfold, ['unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms-idx', '1,2|3-4', '--procar', 'PROCAR']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _plot_projection_check(output) output = runner.invoke( easyunfold, ['unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms-idx', '1,2|3-4', '--procar', 'PROCAR', '--combined']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _plot_projection_check(output) output = runner.invoke(easyunfold, [ 'unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms-idx', '1,2|3-4', '--procar', 'PROCAR', '--combined', '--orbitals', 'px,py|pz' ]) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _plot_projection_check(output) # test --atoms option with --poscar specification output = runner.invoke(easyunfold, ['unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms', 'Mg,O', '--poscar', 'POSCAR.mgo']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _plot_projection_check(output) # test parsing PROCAR from LORBIT = 14 calculation output = runner.invoke(easyunfold, [ 'unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms', 'Mg,O', '--poscar', 'POSCAR.mgo', '--procar', 'PROCAR_LORBIT_14.mgo' ]) + _plot_projection_check(output) + + # test options + output = runner.invoke(easyunfold, [ + 'unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms', 'Mg,O', '--poscar', 'POSCAR.mgo', '--eref', '2', + '--no-symm-average', '--cmap', 'PuBu', '--orbitals', 's,p', '--colours', 'r,g', '--orbitals', 's,p', '--colours', 'r,g', + '--colourspace', 'luvlch', '--intensity', '0.5', '--emin', '-2', '--emax', '5', '--dpi', '500', '--vscale', '2.0', '--cmap', 'PuBu', + '--npoints', '200', '--sigma', '0.15', '--title', 'Test', '--no-combined', '--height', '2', '--width', '3', '-o', 'test.png' + ]) + assert output.exit_code == 0 + assert Path('test.png').is_file() # different file name this time + Path('test.png').unlink() + + +def _plot_projection_check(output): assert output.exit_code == 0 assert Path('unfold.png').is_file() Path('unfold.png').unlink() @@ -199,17 +206,13 @@ def test_dos_atom_orbital_plots(nabis2_project_dir): '--dos-elements', 'Bi.s.p', ]) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, [ 'unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--orbitals', 's|px,py,pz|p', '--vscale', '0.5', '--combined', '--dos', 'vasprun.xml.gz', '--zero-line', '--dos-label', 'DOS', '--gaussian', '0.1', '--no-total', '--scale', '2' ]) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke( easyunfold, @@ -217,84 +220,64 @@ def test_dos_atom_orbital_plots(nabis2_project_dir): 'unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--orbitals', 's|px,py,pz|p', '--intensity', '2', '--combined', '--dos', 'vasprun.xml.gz', '--zero-line', '--dos-label', 'DOS', '--gaussian', '0.1', '--no-total', '--scale', '2' ]) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, [ 'unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--vscale', '0.5', '--combined', '--dos', 'vasprun.xml.gz', '--zero-line', '--dos-label', 'DOS', '--gaussian', '0.1', '--no-total', '--scale', '2' ]) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, [ 'unfold', 'plot-projections', '--atoms', 'Na,Bi', '--vscale', '0.5', '--combined', '--dos', 'vasprun.xml.gz', '--zero-line', '--dos-label', 'DOS', '--gaussian', '0.1', '--no-total', '--scale', '2' ]) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, ['unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--dos', 'vasprun.xml.gz']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, ['unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--combined', '--dos', 'vasprun.xml.gz']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, ['unfold', 'plot', '--atoms-idx', '1-20|21-40', '--orbitals', 's|p', '--dos', 'vasprun.xml.gz']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, ['unfold', 'plot', '--atoms', 'Na,Bi', '--orbitals', 's|p', '--dos', 'vasprun.xml.gz']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, ['unfold', 'plot-projections', '--atoms', 'Na,Bi', '--combined', '--orbitals', 's', '--dos', 'vasprun.xml.gz']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, [ 'unfold', 'plot-projections', '--atoms', 'Na,Bi', '--combined', '--orbitals', 's', '--dos', 'vasprun.xml.gz', '--dos-elements', 'Bi.s' ]) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, [ 'unfold', 'plot-projections', '--atoms-idx', '1-20,21,22,33', '--combined', '--orbitals', 's', '--dos', 'vasprun.xml.gz', '--dos-elements', 'Bi.s' ]) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke( easyunfold, ['unfold', 'plot', '--atoms-idx', '1-20,21,22,33', '--orbitals', 's', '--dos', 'vasprun.xml.gz', '--dos-elements', 'Bi.s']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, ['unfold', 'plot', '--dos', 'vasprun.xml.gz']) - assert output.exit_code == 0 - assert Path('unfold.png').is_file() - Path('unfold.png').unlink() + _check_dos_atom_orbital_plots(output) output = runner.invoke(easyunfold, [ 'unfold', 'plot-projections', '--atoms', 'Na,Bi', '--orbitals', 's', '--combined', '--dos', 'vasprun.xml.gz', '--dos-elements', 'Bi.s' ]) + _check_dos_atom_orbital_plots(output) + + +def _check_dos_atom_orbital_plots(output): assert output.exit_code == 0 assert Path('unfold.png').is_file() Path('unfold.png').unlink() diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 7b3ef4e..bcb910c 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -41,8 +41,7 @@ def silicon_unfolded(si_project_dir) -> UnfoldKSet: @pytest.fixture(scope='module') def unfold_obj() -> UnfoldKSet: """Return an unfolding object""" - obj = loadfn(Path(__file__).parent / 'test_data/mgo.json') - return obj + return loadfn(Path(__file__).parent / 'test_data/mgo.json') def test_plotting(unfold_obj: UnfoldKSet): @@ -77,6 +76,35 @@ def test_plotting_projection(unfold_obj: UnfoldKSet): fig = plotter.plot_projected(procar_path, atoms_idx='0,1|2,3', npoints=200, use_subplot=True) assert isinstance(fig, Figure) + # test options + poscar_path = Path(__file__).parent / 'test_data/POSCAR.mgo' + fig = plotter.plot_projected(procar_path, + eref=2, + gamma=False, + ncl=False, + npoints=200, + sigma=0.15, + symm_average=False, + figsize=(2, 3), + ylim=(-2, 5), + dpi=500, + vscale=2.0, + contour_plot=True, + alpha=0.4, + save='test.png', + ax=None, + cmap='PuBu', + show=True, + title='Test', + atoms='Mg,O', + poscar=poscar_path, + orbitals='s,p', + use_subplot=True, + colours='r,g', + colorspace='luvlch', + intensity=0.5) + assert isinstance(fig, Figure) + def test_color_interpolation(): """Test interpolating colours""" diff --git a/tests/test_procar.py b/tests/test_procar.py index ed9a094..94fe6a1 100644 --- a/tests/test_procar.py +++ b/tests/test_procar.py @@ -4,7 +4,6 @@ from pathlib import Path import numpy as np -import pytest from easyunfold.procar import Procar @@ -17,12 +16,12 @@ def test_procar(): procar = Procar(datapath / 'PROCAR') assert procar.nion == 2 - assert procar.eigenvalues.shape == (1, 48, 20) - assert procar.kvecs.shape == (1, 48, 3) - assert procar.kweights.shape == (1, 48) + assert procar.eigenvalues.shape == (1, 47, 20) + assert procar.kvecs.shape == (1, 47, 3) + assert procar.kweights.shape == (1, 47) assert np.all(procar.kvecs[0][0] == 0.) - assert procar.occs.shape == (1, 48, 20) - assert procar.get_projection([0], 'all').sum() == 310.418 - assert procar.get_projection([0], ['s', 'px']).sum() == 59.32300000000001 - assert procar.get_projection([0], ['s']).sum() + procar.get_projection([0], 'px').sum() == 59.32300000000001 + assert procar.occs.shape == (1, 47, 20) + assert procar.get_projection([0], 'all').sum() == 618.2850603903657 + assert procar.get_projection([0], ['s', 'px']).sum() == 124.10684519359549 + assert procar.get_projection([0], ['s']).sum() + procar.get_projection([0], 'px').sum() == 124.10684519359546 assert procar.proj_names == ['s', 'py', 'pz', 'px', 'dxy', 'dyz', 'dz2', 'dxz', 'x2-y2'] diff --git a/tests/test_unfold.py b/tests/test_unfold.py index 37dda49..eab9af2 100644 --- a/tests/test_unfold.py +++ b/tests/test_unfold.py @@ -232,7 +232,7 @@ def test_unfold_projection(si_project_dir, tag, nspin, ncl, nbands_expected): assert unfolder.is_calculated - unfolder.load_procar(si_project_dir / f'{folder_name}/PROCAR') + unfolder.load_procars(si_project_dir / f'{folder_name}/PROCAR') assert 'procars' in unfolder.transient_quantities assert 'procars_kmap' in unfolder.transient_quantities