diff --git a/.gitignore b/.gitignore
index 188226b..a1d792a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -47,3 +47,5 @@ KPOINTS
PROCAR
*.json
*.DS_Store
+
+local_tests/
diff --git a/README.md b/README.md
index 072fd91..9b8c698 100644
--- a/README.md
+++ b/README.md
@@ -21,6 +21,11 @@ package.
For the methodology of supercell band unfolding, see
[here](https://link.aps.org/doi/10.1103/PhysRevB.85.085201).
+### Example Outputs
+Cs₂(Sn/Ti)Br₆ Vacancy-Ordered Perovskite Alloys | Symmetry-broken Si Supercell
+:-------------------------:|:------------------------------------:
+ |
+
## Usage
To generate an unfolded band structure, one typically needs to perform the following steps:
diff --git a/docs/examples/example_mgo.md b/docs/examples/example_mgo.md
index 6a9ba65..10fac88 100644
--- a/docs/examples/example_mgo.md
+++ b/docs/examples/example_mgo.md
@@ -38,9 +38,15 @@ easyunfold unfold plot-projections --procar MgO_super/PROCAR --atoms="Mg,O" --co
Note that the path of the `PROCAR` is passed along with the desired atom projections (`Mg` and `O` here).
+:::{tip}
+If the _k_-points have been split into multiple calculations (e.g. hybrid DFT band structures), the `--procar` option
+should be passed multiple times to specify the path to each split `PROCAR` file (i.e.
+`--procar calc1/PROCAR --procar cal2/PROCAR ...`).
+:::
+
:::{note}
-The atomic projections are not stored in the `easyunfold.json` data file, so the `PROCAR` file should be
-kept for replotting in the future.
+The atomic projections are not stored in the `easyunfold.json` data file, so the `PROCAR` file(s) should be kept for
+replotting in the future.
:::
The `--combined` option creates a combined plot with different colour maps for each atomic grouping.
@@ -97,3 +103,8 @@ easyunfold unfold plot-projections --procar MgO_super/PROCAR --atoms="Mg,O" --em
Unfolded MgO band structure with atomic projections plotted separately.
```
+
+:::{tip}
+There are _many_ customisation options available for the plotting functions in `easyunfold`. See `easyunfold plot -h` or
+`easyunfold unfold plot-projections -h` for more details!
+:::
diff --git a/docs/examples/example_nabis2.md b/docs/examples/example_nabis2.md
index ac4c551..2993345 100644
--- a/docs/examples/example_nabis2.md
+++ b/docs/examples/example_nabis2.md
@@ -88,7 +88,7 @@ When plotting the unfolded band, the `plot-projections` subcommand is used with
`--atoms` options:
```bash
-easyunfold unfold plot-projections --atoms="Na,Bi,S" --intensity 2 --combined
+easyunfold unfold plot-projections --atoms="Na,Bi,S" --intensity 3 --combined
```
```{figure} ../../examples/NaBiS2/NaBiS2_unfold-plot_proj.png
@@ -98,6 +98,12 @@ easyunfold unfold plot-projections --atoms="Na,Bi,S" --intensity 2 --combined
Unfolded band structure of NaBiS2 with atomic contributions.
```
+:::{tip}
+If the _k_-points have been split into multiple calculations (e.g. hybrid DFT band structures), the `--procar` option
+should be passed multiple times to specify the path to each split `PROCAR` file (i.e.
+`--procar calc1/PROCAR --procar cal2/PROCAR ...`).
+:::
+
From this plot, we can see that sulfur anions dominate the valence band, while bismuth cations dominate the conduction
band, with minimal contributions from the sodium cations as expected.
@@ -126,7 +132,7 @@ Unfolded band structure of NaBiS2 with atomic contributions plotted s
An alternative option here is also to just plot only the contributions of `Na` and `Bi` cations, with no S projections:
```bash
-easyunfold unfold plot-projections --atoms="Na,Bi" --intensity 2 --combined
+easyunfold unfold plot-projections --atoms="Na,Bi" --intensity 3 --combined
```
```{figure} ../../examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png
@@ -143,7 +149,7 @@ While this plot isn't the most aesthetic, it clearly shows that Bi (green) contr
### Atom-projected Unfolded Band Structure with DOS
We can also combine the atom projections with the DOS plotting, using the `--dos` option as before:
```bash
-easyunfold unfold plot-projections --atoms "Na,Bi,S" --intensity 2 --combined --dos vasprun.xml.gz --zero-line \
+easyunfold unfold plot-projections --atoms "Na,Bi,S" --intensity 3 --combined --dos vasprun.xml.gz --zero-line \
--dos-label "DOS" --gaussian 0.1 --no-total --scale 2
```
@@ -193,7 +199,7 @@ For example, if we want to see the contributions of the Bi $s$, $p$ and S $s$ or
we can use the following command:
```bash
-easyunfold unfold plot-projections --atoms "Bi,Bi,S" --orbitals="s|p|s" --intensity 2 --combined \
+easyunfold unfold plot-projections --atoms "Bi,Bi,S" --orbitals="s|p|s" --intensity 3 --combined \
--dos vasprun.xml.gz --zero-line --dos-label "DOS" --gaussian 0.1 --no-total --scale 5
```
@@ -216,7 +222,7 @@ see the contributions of the Bi and S $p_x$, $p_y$ and $p_z$ orbitals to the unf
following command:
```bash
-easyunfold unfold plot-projections --atoms "Na,Bi,S" --orbitals="all|px,py,pz|px,py,pz" --intensity 2 --combined \
+easyunfold unfold plot-projections --atoms "Na,Bi,S" --orbitals="all|px,py,pz|px,py,pz" --intensity 3 --combined \
--dos vasprun.xml.gz --zero-line --dos-label "DOS" --gaussian 0.1 --no-total --scale 6
```
@@ -232,6 +238,10 @@ structure, due to the cubic symmetry of the NaBiS2 crystal structure.
for the $d$ orbitals of transition metals in octahedral/tetrahedral environments, we would expect to see significant
differences in the contributions of different $lm$-decomposed orbitals to the electronic structure.
+:::{tip}
+There are _many_ customisation options available for the plotting functions in `easyunfold`. See `easyunfold plot -h` or
+`easyunfold unfold plot-projections -h` for more details!
+:::
[^1]: [Huang, YT., Kavanagh, S.R., Righetto, M. et al. Strong absorption and ultrafast localisation in NaBiS2 nanocrystals with slow charge-carrier recombination. Nat Commun 13, 4960 (2022)](https://www.nature.com/articles/s41467-022-32669-3)
[^2]: [Wang, Y., Kavanagh, S.R., Burgués-Ceballos, I. et al. Cation disorder engineering yields AgBiS2 nanocrystals with enhanced optical absorption for efficient ultrathin solar cells. Nat. Photon. 16, 235–241 (2022).](https://www.nature.com/articles/s41566-021-00950-4)
\ No newline at end of file
diff --git a/docs/examples/example_si211_castep.md b/docs/examples/example_si211_castep.md
index bbcef90..bfaaf64 100644
--- a/docs/examples/example_si211_castep.md
+++ b/docs/examples/example_si211_castep.md
@@ -143,4 +143,9 @@ Supercell band structure of Si in a 2x1x1 supercell
which, as expected, is the same as the primitive cell if it is folded back along the midpoint between
$\Gamma$ and $X$, corresponding to the 2x expansion of the primitive cell along the $x$ direction in
-generating the 2x1x1 supercell.
\ No newline at end of file
+generating the 2x1x1 supercell.
+
+:::{tip}
+There are _many_ customisation options available for the plotting functions in `easyunfold`. See `easyunfold plot -h` or
+`easyunfold unfold plot-projections -h` for more details!
+:::
diff --git a/docs/examples/example_si222.md b/docs/examples/example_si222.md
index a109f78..8a15a87 100644
--- a/docs/examples/example_si222.md
+++ b/docs/examples/example_si222.md
@@ -131,7 +131,6 @@ customising and prettifying the unfolded band structure plot. Here we have also
option to increase the spectral function intensity.
:::
-
Note the appearance of extra branches compared to the band structure of the primitive cell (below), due
to symmetry breaking from the displaced atom.
diff --git a/docs/img/CSTB_easyunfold.gif b/docs/img/CSTB_easyunfold.gif
new file mode 100644
index 0000000..4e8dffd
Binary files /dev/null and b/docs/img/CSTB_easyunfold.gif differ
diff --git a/docs/img/Si_222_unfold_tall.png b/docs/img/Si_222_unfold_tall.png
new file mode 120000
index 0000000..9fd31a0
--- /dev/null
+++ b/docs/img/Si_222_unfold_tall.png
@@ -0,0 +1 @@
+../../examples/Si222/unfold_tall.png
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
index c87d232..db688fb 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -20,6 +20,11 @@ package.
For the methodology of supercell band unfolding, see
[here](https://link.aps.org/doi/10.1103/PhysRevB.85.085201).
+### Example Outputs
+Cs₂(Sn/Ti)Br₆ Vacancy-Ordered Perovskite Alloys | Symmetry-broken Si Supercell
+:-------------------------:|:------------------------------------:
+ |
+
## Usage
To generate an unfolded band structure, one typically needs to perform the following steps:
diff --git a/easyunfold/__init__.py b/easyunfold/__init__.py
index 87d18a3..caf92c2 100644
--- a/easyunfold/__init__.py
+++ b/easyunfold/__init__.py
@@ -2,4 +2,4 @@
Collection of code for band unfolding
"""
-__version__ = '0.2.0'
+__version__ = '0.3.0'
diff --git a/easyunfold/cli.py b/easyunfold/cli.py
index e7114ff..b248c5c 100644
--- a/easyunfold/cli.py
+++ b/easyunfold/cli.py
@@ -1,6 +1,8 @@
"""
Commandline interface
"""
+
+import contextlib
import warnings
import functools
from pathlib import Path
@@ -11,7 +13,7 @@
from easyunfold.unfold import parse_atoms, parse_atoms_idx
-# pylint:disable=import-outside-toplevel, too-many-locals, too-many-arguments, too-many-nested-blocks, too-many-branches
+# pylint:disable=import-outside-toplevel, too-many-locals, too-many-arguments, too-many-nested-blocks, too-many-branches, fixme
SUPPORTED_DFT_CODES = ('vasp', 'castep')
@@ -105,11 +107,8 @@ def generate(pc_file, code, sc_file, matrix, kpoints, time_reversal, out_file, n
dft_code=code,
symprec=symprec)
unfoldset.kpoint_labels = labels
- try:
+ with contextlib.suppress(KeyError):
print_symmetry_data(unfoldset)
- except KeyError:
- pass
-
out_file = Path(out_file)
if code == 'vasp':
out_kpt_name = f'KPOINTS_{out_file.stem}'
@@ -142,7 +141,7 @@ def generate(pc_file, code, sc_file, matrix, kpoints, time_reversal, out_file, n
Path(out_file).write_text(unfoldset.to_json(), encoding='utf-8')
- click.echo('Unfolding settings written to ' + str(out_file))
+ click.echo(f'Unfolding settings written to {str(out_file)}')
@easyunfold.group('unfold')
@@ -211,7 +210,7 @@ def unfold_calculate(ctx, wavefunc, save_as, gamma, ncl):
out_path = save_as if save_as else ctx.obj['fname']
Path(out_path).write_text(unfoldset.to_json(), encoding='utf-8')
- click.echo('Unfolding data written to ' + out_path)
+ click.echo(f'Unfolding data written to {out_path}')
def add_mpl_style_option(func):
@@ -331,30 +330,32 @@ def print_data(entries, tag='me'):
ext = Path(out_file).suffix
for carrier in ['electrons', 'holes']:
for idx, _ in enumerate(output[carrier]):
- plotter.plot_effective_mass_fit(efm=efm,
- npoints=npoints,
- carrier=carrier,
- idx=int(idx),
- save=fname + f'_fit_{carrier}_{idx}' + ext)
+ plotter.plot_effective_mass_fit(
+ efm=efm,
+ npoints=npoints,
+ carrier=carrier,
+ idx=int(idx),
+ save=f'{fname}_fit_{carrier}_{idx}{ext}',
+ )
def add_plot_options(func):
"""
Decorator that adds common plotting options to a function
"""
- click.option('--gamma', is_flag=True, help='Is the calculation a gamma only one?', show_default=True)(func)
- click.option('--ncl', is_flag=True, help='Is the calculation with non-colinear spin?', show_default=True)(func)
click.option('--npoints', type=int, default=2000, help='Number of bins for the energy.', show_default=True)(func)
click.option('--sigma', type=float, default=0.02, help='Smearing width for the energy in eV.', show_default=True)(func)
click.option('--eref', type=float, help='Reference energy in eV.')(func)
click.option('--emin', type=float, default=-5., help='Minimum energy in eV relative to the reference.', show_default=True)(func)
click.option('--emax', type=float, default=5., help='Maximum energy in eV relative to the reference.', show_default=True)(func)
click.option('--intensity', default=1.0, help='Scaling factor for the colour intensity.', type=float, show_default=True)(func)
- click.option('--vscale',
- type=float,
- help='A normalisation/scaling factor for the colour mapping. Equivalent to (1/intensity).',
- default=1.0,
- show_default=True)(func)
+ click.option(
+ '--vscale',
+ type=float,
+ help='A normalisation/scaling factor for the colour mapping. Equivalent to (1/intensity). '
+ 'Will be deprecated in future versions.', # TODO: deprecate
+ default=1.0,
+ show_default=True)(func)
click.option('--out-file', '-o', default='unfold.png', help='Name of the output file.', show_default=True)(func)
click.option('--cmap', default='PuRd', help='Name of the colour map to use.', show_default=True)(func)
click.option('--show', is_flag=True, default=False, help='Show the plot interactively.')(func)
@@ -421,7 +422,7 @@ def add_plot_options(func):
@click.pass_context
@add_plot_options
@add_mpl_style_option
-def unfold_plot(ctx, gamma, npoints, sigma, eref, out_file, show, emin, emax, cmap, ncl, no_symm_average, vscale, dos, dos_label, zero_line,
+def unfold_plot(ctx, npoints, sigma, eref, out_file, show, emin, emax, cmap, no_symm_average, vscale, dos, dos_label, zero_line,
dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only, scale, procar, atoms, poscar,
atoms_idx, orbitals, title, width, height, dpi, intensity):
"""
@@ -429,9 +430,9 @@ def unfold_plot(ctx, gamma, npoints, sigma, eref, out_file, show, emin, emax, cm
This command uses the stored unfolding data to plot the effective bands structure (EBS) using the spectral function.
"""
- _unfold_plot(ctx, gamma, npoints, sigma, eref, out_file, show, emin, emax, cmap, ncl, no_symm_average, vscale, dos, dos_label,
- zero_line, dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only, scale, procar, atoms,
- poscar, atoms_idx, orbitals, title, width, height, dpi, intensity)
+ _unfold_plot(ctx, npoints, sigma, eref, out_file, show, emin, emax, cmap, no_symm_average, vscale, dos, dos_label, zero_line,
+ dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only, scale, procar, atoms, poscar,
+ atoms_idx, orbitals, title, width, height, dpi, intensity)
def process_dos(dos, dos_elements, dos_orbitals, dos_atoms, gaussian, total_only, atoms, orbitals, poscar, no_total, legend_cutoff, scale):
@@ -481,10 +482,8 @@ def process_dos(dos, dos_elements, dos_orbitals, dos_atoms, gaussian, total_only
dos_elements[atom] = ()
for orbital in orbital_tuple:
if orbital != 'all' and orbital[:1] not in dos_elements[atom]:
- if orbital[:1] == 'x': # special case in VASP PROCAR labelling, set to 'd'
- dos_elements[atom] += ('d',)
- else:
- dos_elements[atom] += (orbital[:1],)
+ # special case in VASP PROCAR labelling, set to 'd' if x, else just first letter
+ dos_elements[atom] += ('d',) if orbital[:1] == 'x' else (orbital[:1],)
dos, pdos = load_dos(
dos,
@@ -511,11 +510,19 @@ def process_dos(dos, dos_elements, dos_orbitals, dos_atoms, gaussian, total_only
@click.pass_context
@add_plot_options
@add_mpl_style_option
-@click.option('--combined/--no-combined', is_flag=True, default=False, help='Plot all projections in a combined graph.')
-@click.option('--colours', help='Colours to be used for combined plot, comma separated.', default='r,g,b,purple', show_default=True)
-def unfold_plot_projections(ctx, gamma, npoints, sigma, eref, out_file, show, emin, emax, cmap, ncl, no_symm_average, vscale, dos,
- dos_label, zero_line, dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only,
- scale, procar, atoms, poscar, atoms_idx, orbitals, title, combined, colours, width, height, dpi, intensity):
+@click.option('--combined/--no-combined', is_flag=True, default=False, help='Plot all projections in a combined graph.', show_default=True)
+@click.option('--colours',
+ help='Colours to be used for combined plot, comma separated (e.g. "r,b,y"). '
+ 'Default is pastel red, green, blue if <=3 projections, else red, green, blue, purple, orange, yellow.',
+ default=None)
+@click.option('--colourspace',
+ help='Colourspace in which to perform interpolation for combined plot.',
+ default='lab',
+ show_default=True,
+ type=click.Choice(['rgb', 'hsv', 'lab', 'luvlch', 'lablch', 'xyz']))
+def unfold_plot_projections(ctx, npoints, sigma, eref, out_file, show, emin, emax, cmap, no_symm_average, vscale, dos, dos_label, zero_line,
+ dos_elements, dos_orbitals, dos_atoms, legend_cutoff, gaussian, no_total, total_only, scale, procar, atoms,
+ poscar, atoms_idx, orbitals, title, combined, colours, colourspace, width, height, dpi, intensity):
"""
Plot the effective band structure with atomic projections.
"""
@@ -534,13 +541,11 @@ def unfold_plot_projections(ctx, gamma, npoints, sigma, eref, out_file, show, em
dos_label=dos_label,
dos_options=dos_options,
zero_line=zero_line,
- gamma=gamma,
npoints=npoints,
sigma=sigma,
eref=eref,
ylim=(emin, emax),
cmap=cmap,
- ncl=ncl,
symm_average=not no_symm_average,
atoms=atoms,
atoms_idx=atoms_idx,
@@ -552,7 +557,8 @@ def unfold_plot_projections(ctx, gamma, npoints, sigma, eref, out_file, show, em
intensity=intensity,
figsize=(width, height),
dpi=dpi,
- colours=colours.split(',') if colours is not None else None)
+ colours=colours.split(',') if colours is not None else None,
+ colorspace=colourspace)
if out_file:
fig.savefig(out_file, dpi=dpi, bbox_inches='tight')
@@ -563,7 +569,6 @@ def unfold_plot_projections(ctx, gamma, npoints, sigma, eref, out_file, show, em
def _unfold_plot(ctx,
- gamma,
npoints,
sigma,
eref,
@@ -572,7 +577,6 @@ def _unfold_plot(ctx,
emin,
emax,
cmap,
- ncl,
no_symm_average,
vscale,
dos,
@@ -612,12 +616,12 @@ def _unfold_plot(ctx,
eref = unfoldset.calculated_quantities.get('vbm', 0.0)
click.echo(f'Using a reference energy of {eref:.3f} eV')
- # Setup the atoms_idx and orbitals
+ # Set up the atoms_idx and orbitals
if atoms or atoms_idx:
- # Process the PROCAR
+ # Process the PROCARs
click.echo(f'Loading projections from: {procar}')
try:
- unfoldset.load_procar(procar)
+ unfoldset.load_procars(procar)
except FileNotFoundError as exc:
click.echo(f'Could not find and parse the --procar file: {procar} – needed for atomic projections!')
raise click.Abort() from exc
@@ -655,10 +659,8 @@ def _unfold_plot(ctx,
# Collect spectral functions and scale
all_sf = []
for this_idx, this_orbitals in zip(atoms_idx_subplots, orbitals_subplots):
- eng, spectral_function = unfoldset.get_spectral_function(gamma=gamma,
- npoints=npoints,
+ eng, spectral_function = unfoldset.get_spectral_function(npoints=npoints,
sigma=sigma,
- ncl=ncl,
atoms_idx=this_idx,
orbitals=this_orbitals,
symm_average=not no_symm_average)
@@ -702,26 +704,20 @@ def _unfold_plot(ctx,
def print_symmetry_data(kset):
"""Print the symmetry information"""
- # Print space group information
- sc_spg = kset.metadata['symmetry_dataset_sc']
- click.echo('Supercell cell information:')
- click.echo(' ' * 8 + f'Space group number: {sc_spg["number"]}')
- click.echo(' ' * 8 + f'International symbol: {sc_spg["international"]}')
- click.echo(' ' * 8 + f'Point group: {sc_spg["pointgroup"]}')
+ def _print_symmetry_data_from_kset(kset, dataset_key, dataset_title):
+ # Print space group information
+ sc_spg = kset.metadata[dataset_key]
+ click.echo(dataset_title)
+ click.echo(' ' * 8 + f'Space group number: {sc_spg["number"]}')
+ click.echo(' ' * 8 + f'International symbol: {sc_spg["international"]}')
+ click.echo(' ' * 8 + f'Point group: {sc_spg["pointgroup"]}')
- pc_spg = kset.metadata['symmetry_dataset_pc']
- click.echo('\nPrimitive cell information:')
- click.echo(' ' * 8 + f'Space group number: {pc_spg["number"]}')
- click.echo(' ' * 8 + f'International symbol: {pc_spg["international"]}')
- click.echo(' ' * 8 + f'Point group: {pc_spg["pointgroup"]}')
+ _print_symmetry_data_from_kset(kset, 'symmetry_dataset_sc', 'Supercell cell information:')
+ _print_symmetry_data_from_kset(kset, 'symmetry_dataset_pc', '\nPrimitive cell information:')
def matrix_from_string(string):
- """Parse transform matrix from a string"""
+ """Parse transformation matrix from a string"""
elems = [float(x) for x in string.split()]
- # Try gussing the transform matrix
- if len(elems) == 3:
- transform_matrix = np.diag(elems)
- else:
- transform_matrix = np.array(elems).reshape((3, 3))
- return transform_matrix
+ # try guessing the transormation matrix
+ return np.diag(elems) if len(elems) == 3 else np.array(elems).reshape((3, 3))
diff --git a/easyunfold/plotting.py b/easyunfold/plotting.py
index 0ff4cb2..352aad3 100644
--- a/easyunfold/plotting.py
+++ b/easyunfold/plotting.py
@@ -1,6 +1,7 @@
"""
Plotting utilities
"""
+import os
from typing import Union, Sequence
import warnings
import itertools
@@ -85,7 +86,15 @@ def _get_orbital_colour_dict(index, colour_list):
if orbital in key:
sumo_colours[atom][key] = orbital_colour_dict[key]
else:
- sumo_colours = None
+ from pkg_resources import Requirement, resource_filename
+ try:
+ import configparser
+ except ImportError:
+ import ConfigParser as configparser
+
+ config_path = resource_filename(Requirement.parse('sumo'), 'sumo/plotting/orbital_colours.conf')
+ sumo_colours = configparser.ConfigParser()
+ sumo_colours.read(os.path.abspath(config_path))
# don't use first 4 colours; these are the band structure line colours:
cycle = cycler('color', rcParams['axes.prop_cycle'].by_key()['color'][4:])
@@ -117,9 +126,6 @@ def _get_orbital_colour_dict(index, colour_list):
lines = plot_data['lines']
spins = [Spin.up] if len(lines[0][0]['dens']) == 1 else [Spin.up, Spin.down]
- # disable y ticks for DOS panel
- ax.tick_params(axis='y', which='both', right=False)
-
for line_set in plot_data['lines']:
for line, spin in itertools.product(line_set, spins):
if spin == Spin.up or len(spins) == 1:
@@ -153,8 +159,8 @@ def _get_orbital_colour_dict(index, colour_list):
if dos_label is not None:
ax.set_xlabel(dos_label)
- ax.set_xticklabels([])
- ax.set_yticklabels([])
+ ax.set_yticks([]) # no y ticks
+ ax.set_xticks([]) # no x ticks
ax.legend(loc=2, frameon=False, ncol=1, bbox_to_anchor=(1.0, 1.0), fontsize=9)
return ax
@@ -217,19 +223,15 @@ def plot_spectral_function(
if nspin > 1:
warnings.warn('DOS plotter is not supported for spin-separated plots. Reverting to non spin-polarised plotting.')
nspin = 1
- else:
- if ax is None:
- if nspin == 1:
- fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
- axes = [ax]
- else:
- fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
+ elif ax is None:
+ if nspin == 1:
+ fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
+ axes = [ax]
else:
- if not isinstance(ax, list):
- axes = [ax]
- else:
- axes = ax
- fig = axes[0].figure
+ fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
+ else:
+ axes = [ax] if not isinstance(ax, list) else ax
+ fig = axes[0].figure
# Shift the kdist so the pcolormesh draw the pixel centred on the original point
X, Y = np.meshgrid(kdist, engs - eref)
@@ -330,10 +332,7 @@ def _plot_spectral_function_rgba(
else:
fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
else:
- if not isinstance(ax, list):
- axes = [ax]
- else:
- axes = ax
+ axes = [ax] if not isinstance(ax, list) else ax
fig = axes[0].figure
mask = (engs < (ylim[1] + eref)) & (engs > (ylim[0] + eref))
@@ -379,10 +378,7 @@ def _add_kpoint_labels(self, ax: plt.Axes, x_is_kidx=False):
tick_locs = []
tick_labels = []
for index, label in labels:
- if x_is_kidx:
- xloc = index
- else:
- xloc = kdist[index]
+ xloc = index if x_is_kidx else kdist[index]
ax.axvline(x=xloc, lw=0.5, color='k', ls=':', alpha=0.8)
tick_locs.append(xloc)
tick_labels.append(clean_latex_string(label))
@@ -511,10 +507,7 @@ def plot_spectral_weights(
else:
fig, axes = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
else:
- if not isinstance(ax, list):
- axes = [ax]
- else:
- axes = ax
+ axes = [ax] if not isinstance(ax, list) else ax
fig = axes[0].figure
kweights = unfold.expansion_results['weights']
@@ -558,37 +551,37 @@ def plot_spectral_weights(
return fig
def plot_projected(
- self,
- procar: Union[str, list] = 'PROCAR',
- dos_plotter=None,
- dos_label=None,
- dos_options=None,
- zero_line=False,
- eref=None,
- gamma=False,
- npoints=2000,
- sigma=0.2,
- ncl=False,
- symm_average=True,
- figsize=(4, 3),
- ylim=(-5, 5),
- dpi=300,
- vscale=1.0,
- contour_plot=False,
- alpha=1.0,
- save=False,
- ax=None,
- cmap='PuRd',
- show=False,
- title=None,
- atoms=None,
- poscar='POSCAR',
- atoms_idx=None,
- orbitals=None,
- use_subplot=False,
- colours=('r', 'g', 'b', 'purple'),
- colorspace='lab',
- intensity=1.0,
+ self,
+ procar: Union[str, list] = 'PROCAR',
+ dos_plotter=None,
+ dos_label=None,
+ dos_options=None,
+ zero_line=False,
+ eref=None,
+ gamma=False,
+ npoints=2000,
+ sigma=0.2,
+ ncl=False,
+ symm_average=True,
+ figsize=(4, 3),
+ ylim=(-5, 5),
+ dpi=300,
+ vscale=1.0,
+ contour_plot=False,
+ alpha=1.0,
+ save=False,
+ ax=None,
+ cmap='PuRd',
+ show=False,
+ title=None,
+ atoms=None,
+ poscar='POSCAR',
+ atoms_idx=None,
+ orbitals=None,
+ use_subplot=False,
+ colours=None,
+ colorspace='lab',
+ intensity=1.0,
):
"""
Plot projected spectral function onto multiple subplots or a single plot with color mapping.
@@ -598,10 +591,13 @@ def plot_projected(
:param procar: Name of names of the `PROCAR` files.
+ :param colours: Default is pastel red, green, blue if <=3 projections, else red, green,
+ blue, purple, orange, yellow.
+
:returns: Generated plot.
"""
unfoldset = self.unfold
- unfoldset.load_procar(procar)
+ unfoldset.load_procars(procar)
nspin = unfoldset.calculated_quantities['spectral_weights_per_set'][0].shape[0]
if atoms_idx is not None:
@@ -636,12 +632,11 @@ def plot_projected(
# Collect spectral functions and scale
for this_idx, this_orbitals in zip(atoms_idx_subplots, orbitals_subplots):
- # Setup the atoms_idx and orbitals
+ # Set up the atoms_idx and orbitals
if isinstance(this_idx, str):
this_idx, this_orbitals = process_projection_options(this_idx, this_orbitals)
- else: # list of integers; pre-processed by specifying atoms
- if this_orbitals != 'all':
- this_orbitals = [token.strip() for token in this_orbitals.split(',')]
+ elif this_orbitals != 'all':
+ this_orbitals = [token.strip() for token in this_orbitals.split(',')]
eng, spectral_function = unfoldset.get_spectral_function(gamma=gamma,
npoints=npoints,
@@ -701,6 +696,18 @@ def plot_projected(
stacked_sf = np.stack(all_sf, axis=-1).reshape(np.prod(sf_size), len(all_sf))
# Construct the colour basis
+ if colours is None:
+ if len(all_sf) <= 3:
+ colours = ['#CC33A7', '#A7CC33', '#33A7CC']
+ else:
+ colours = [
+ (1, 0, 0), # red
+ (0, 1, 0), # green
+ (0, 0, 1), # blue
+ (152 / 255, 78 / 255, 163 / 255), # purple
+ (1, 127 / 255, 0), # orange
+ (1, 1, 51 / 255), # yellow
+ ]
colours = colours[:len(all_sf)]
# Compute spectral weight data with RGB reshape it back into the shape (nengs, nk, 3)
sf_rgb = interpolate_colors(colours, stacked_sf, colorspace, normalize=True).reshape(sf_size + (3,))
@@ -761,9 +768,7 @@ def plot_projected(
warnings.warn('zero_line option requires sumo to be installed!')
if atoms is not None: # add figure legend with atoms and colors
- legend_elements = []
- for i, atom in enumerate(atoms):
- legend_elements.append(Patch(facecolor=colours[i], label=atom, alpha=0.7))
+ legend_elements = [Patch(facecolor=colours[i], label=atom, alpha=0.7) for i, atom in enumerate(atoms)]
fig.axes[0].legend(handles=legend_elements, bbox_to_anchor=(1.025, 1), fontsize=9)
fig.subplots_adjust(right=0.78) # ensure legend is not cut off
@@ -814,7 +819,7 @@ def interpolate_colors(colours: Sequence, weights: list, colorspace='lab', norma
:param weights: A list of weights with the shape (n, N). Where the N values of
the last axis give the amount of N colours supplied in `colours`.
:param colorspace: The colorspace in which to perform the interpolation. The
- allowed values are rgb, hsv, lab, luvlc, lablch, and xyz.
+ allowed values are rgb, hsv, lab, luvlch, lablch, and xyz.
:returns: A list of colours, specified in the rgb format as a (n, 3) array.
"""
@@ -844,7 +849,7 @@ def interpolate_colors(colours: Sequence, weights: list, colorspace='lab', norma
# Normalise the weights if needed
if normalize:
- weights = weights / np.linalg.norm(weights, axis=1)[:, None]
+ weights = weights / np.sum(weights, axis=1)[:, None] # each row sums to 1
# perform the interpolation in the colorspace basis
interpolated_colors = colors_basis[0] * weights[:, 0][:, None]
@@ -853,10 +858,24 @@ def interpolate_colors(colours: Sequence, weights: list, colorspace='lab', norma
# convert the interpolated colors back to RGB
rgb_colors = [convert_color(colorspace(*c), sRGBColor).get_value_tuple() for c in interpolated_colors]
- rgb_colors = np.stack(rgb_colors, axis=0)
- # ensure all rgb values are less than 1 (sometimes issues in interpolation gives
- np.clip(rgb_colors, 0, 1, rgb_colors)
+ # ensure all rgb values are less than 1 (sometimes issues in interpolation)
+ normalised_rgb_colors = []
+ for rgb_color_tuple in rgb_colors:
+ if np.max(rgb_color_tuple) > 1:
+ normalised_rgb_color = np.array(rgb_color_tuple) / np.max(rgb_color_tuple)
+ else:
+ normalised_rgb_color = np.array(rgb_color_tuple)
+
+ normalised_rgb_color = np.clip(normalised_rgb_color, 0, 1) # ensure all rgb values are between 0 and 1
+ # if too white, darken:
+ if np.linalg.norm(normalised_rgb_color) > 1: # white af
+ normalised_rgb_color *= (1 / np.linalg.norm(normalised_rgb_color)**(1 / 2))
+
+ normalised_rgb_colors.append(normalised_rgb_color)
+
+ rgb_colors = np.stack(normalised_rgb_colors, axis=0)
+
return rgb_colors
diff --git a/easyunfold/procar.py b/easyunfold/procar.py
index 53218cc..192302b 100644
--- a/easyunfold/procar.py
+++ b/easyunfold/procar.py
@@ -6,15 +6,21 @@
import re
import numpy as np
+from monty.json import MSONable, MontyDecoder
-class Procar:
+from easyunfold import __version__
+
+# pylint:disable=too-many-locals,
+
+
+class Procar(MSONable):
"""Reader for PROCAR file"""
- def __init__(self, fobj_or_path=None, is_soc=False):
+ def __init__(self, fobjs_or_paths=None, is_soc=False):
"""
Read the PROCAR file from a handle or path
- :param fobj_or_path: A file-like obj or a path
+ :param fobjs_or_paths: Either a string or list of file-like objs or paths
:param is_soc: Whether the PROCAR is from a calculation with spin-orbit coupling
"""
self._is_soc = is_soc
@@ -32,85 +38,197 @@ def __init__(self, fobj_or_path=None, is_soc=False):
self.proj_xyz = None
# Read the PROCAR
- if isinstance(fobj_or_path, (str, Path)):
- with open(fobj_or_path, encoding='utf-8') as fhandle:
- self._read(fhandle)
- else:
- self._read(fobj_or_path)
+ if isinstance(fobjs_or_paths, (str, Path)):
+ fobjs_or_paths = [fobjs_or_paths]
+ self.read(fobjs_or_paths)
- def _read(self, fobj):
+ def _read(self, fobj, parsed_kpoints=None):
"""Main function for reading in the data"""
+ if parsed_kpoints is None:
+ parsed_kpoints = set()
- # First sweep - found the number of kpoints and the number of bands
+ # First sweep - find the number of kpoints and the number of bands
fobj.seek(0)
- self.header = fobj.readline()
+ _header = fobj.readline()
# Read the NK, NB and NIONS that are integers
- self.nkpts, self.nbands, self.nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()]
- # Number of projects and their names
- nproj = None
- self.proj_names = None
-
- for line in fobj:
- if re.match(r'^ion.*tot', line): # only the first "ion" line, in case of LORBIT >= 12
- nproj = len(line.strip().split()) - 2
- self.proj_names = line.strip().split()[1:-1]
- break
+ _total_nkpts, nbands, nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()]
+ if nion != self.nion:
+ raise ValueError(f'Mismatch in number of ions in PROCARs supplied: ({nion} vs {self.nion})!')
# Count the number of data lines, these lines do not have any alphabets
- proj_data = []
- energies = []
- occs = []
- kvecs = []
- kweights = []
+ proj_data, energies, kvecs, kweights, occs = [], [], [], [], []
+ tot_count = 0 # count the instances of lines starting with "tot" -> (4 + 1) * nbands * nkpts for SOC calcs
fobj.seek(0)
- for line in fobj:
- if not re.search(r'[a-zA-Z]', line) and line.strip() and len(line.strip().split()) - 2 == nproj:
- # only parse data if nproj is expected length, in case of LORBIT >= 12
+
+ line = fobj.readline()
+ while line:
+ if line.startswith(' k-point'):
+ line = re.sub(r'(\d)-', r'\1 -', line)
+ tokens = line.strip().split()
+ kvec = tuple(round(float(val), 5) for val in # tuple to make it hashable
+ tokens[-6:-3]) # round to 5 decimal places to ensure proper kpoint matching
+ if kvec not in parsed_kpoints:
+ parsed_kpoints.add(kvec)
+ kvecs.append(list(kvec))
+ kweights.append(float(tokens[-1]))
+ else:
+ # skip ahead to the next instance of two blank lines in a row
+ while line.strip() or fobj.readline().strip():
+ line = fobj.readline()
+ continue
+
+ elif not re.search(r'[a-zA-Z]', line) and line.strip() and len(line.strip().split()) - 2 == len(self.proj_names):
+ # only parse data if line is expected length, in case of LORBIT >= 12
proj_data.append([float(token) for token in line.strip().split()[1:-1]])
+
elif line.startswith('band'):
tokens = line.strip().split()
energies.append(float(tokens[4]))
occs.append(float(tokens[-1]))
- elif line.startswith(' k-point'):
- line = re.sub(r'(\d)-', r'\1 -', line)
- tokens = line.strip().split()
- kvecs.append([float(val) for val in tokens[-6:-3]])
- kweights.append(float(tokens[-1]))
- self.occs = np.array(occs)
- self.kvecs = np.array(kvecs)
- self.kweights = np.array(kweights)
- self.eigenvalues = np.array(energies)
+
+ elif line.startswith('tot'):
+ tot_count += 1
+
+ line = fobj.readline()
+
+ # dynamically determine whether PROCARs are SOC or not
+ if tot_count == 4 * len(occs):
+ self._is_soc = True
+ elif tot_count == len(occs):
+ self._is_soc = False
+ else:
+ raise ValueError(f"Number of lines starting with 'tot' ({tot_count}) in PROCAR does not match expected "
+ f'values ({4*len(occs)} or {len(occs)})!')
+
+ occs = np.array(occs)
+ kvecs = np.array(kvecs)
+ kweights = np.array(kweights)
+ eigenvalues = np.array(energies)
proj_data = np.array(proj_data, dtype=float)
- self.nspins = proj_data.shape[0] // (self.nion * self.nbands * self.nkpts)
+ # redetermine nkpts in case some were skipped due to already being parsed
+ nkpts = len(kvecs)
+
+ self.nspins = proj_data.shape[0] // (self.nion * nbands * nkpts)
self.nspins //= 4 if self._is_soc else 1
# Reshape
- self.occs.resize((self.nspins, self.nkpts, self.nbands))
- self.kvecs.resize((self.nspins, self.nkpts, 3))
- self.kweights.resize((self.nspins, self.nkpts))
- self.eigenvalues.resize((self.nspins, self.nkpts, self.nbands))
+ occs.resize((self.nspins, nkpts, nbands))
+ kvecs.resize((self.nspins, nkpts, 3))
+ kweights.resize((self.nspins, nkpts))
+ eigenvalues.resize((self.nspins, nkpts, nbands))
# Reshape the array
if self._is_soc is False:
- self.proj_data = proj_data.reshape((self.nspins, self.nkpts, self.nbands, self.nion, nproj))
+ proj_data = proj_data.reshape((self.nspins, nkpts, nbands, self.nion, len(self.proj_names)))
+ proj_xyz = None
else:
- self.proj_data = proj_data.reshape((self.nspins, self.nkpts, self.nbands, 4, self.nion, nproj))
+ proj_data = proj_data.reshape((self.nspins, nkpts, nbands, 4, self.nion, len(self.proj_names)))
# Split the data into xyz projection and total
- self.proj_xyz = self.proj_data[:, :, :, 1:, :, :]
- self.proj_data = self.proj_data[:, :, :, 0, :, :]
+ proj_xyz = proj_data[:, :, :, 1:, :, :]
+ proj_data = proj_data[:, :, :, 0, :, :]
+
+ # normalise: (for each nspin, nkpt, nband, the sum of the projections over nion and proj_names should be 1)
+ proj_sum = np.sum(proj_data, axis=(-2, -1), keepdims=True)
+ proj_sum[proj_sum == 0] = 1 # just in case, avoid division by zero
+ proj_data /= proj_sum
+
+ if proj_xyz is not None:
+ proj_sum = np.sum(proj_xyz, axis=(-3, -2, -1), keepdims=True)
+ proj_sum[proj_sum == 0] = 1
+ proj_xyz /= proj_sum
+
+ return self.nspins, occs, kvecs, kweights, eigenvalues, proj_data, proj_xyz, parsed_kpoints
+
+ def _read_header_nion_proj_names(self, fobj):
+ """Read the header, nion and proj_names from the PROCAR"""
+ fobj.seek(0)
+ self.header = fobj.readline()
+ # Read the NK, NB and NIONS that are integers
+ _nkpts, _nbands, self.nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()]
+ self.proj_names = None # projection names
+
+ for line in fobj:
+ if re.match(r'^ion.*tot', line): # only the first "ion" line, in case of LORBIT >= 12
+ self.proj_names = line.strip().split()[1:-1]
+ break
+
+ def read(self, fobjs_or_paths):
+ """Read and amalgamate the data from a list of PROCARs"""
+
+ def open_file(fobj_or_path):
+ if isinstance(fobj_or_path, (str, Path)):
+ return open(fobj_or_path, encoding='utf-8') # closed later
+ return fobj_or_path # already a file-like object, just return it
+
+ parsed_kpoints = None
+ occs_list, kvecs_list, kweights_list = [], [], []
+ eigenvalues_list, proj_data_list, proj_xyz_list = [], [], []
+ for i, fobj_or_path in enumerate(fobjs_or_paths):
+ # Note: If PROCAR parsing becomes a significant bottleneck for people (e.g. with several HSE06+SOC PROCARs),
+ # this could be parallelized (somewhat) with multiprocessing. The actual file parsing in _read() is currently
+ # serial so no easy wins there, but could at least parallelise over the list of PROCARs
+ fobj = open_file(fobj_or_path)
+ if self.header is None: # first file; read header, nion, proj_names
+ self._read_header_nion_proj_names(fobj)
+
+ current_nspins = self.nspins # check spin consistency between PROCARs
+ nspins, occs, kvecs, kweights, eigenvalues, proj_data, proj_xyz, parsed_kpoints = self._read(fobj,
+ parsed_kpoints=parsed_kpoints)
+ if current_nspins is not None and current_nspins != nspins:
+ raise ValueError(f'Mismatch in number of spins in PROCARs supplied: ({nspins} vs {current_nspins})!')
+
+ if isinstance(fobj_or_path, (str, Path)):
+ fobj.close() # if file was opened in this loop, close it
+
+ # Append to respective lists
+ occs_list.append(occs)
+ kvecs_list.append(kvecs)
+ kweights_list.append(kweights)
+ eigenvalues_list.append(eigenvalues)
+ proj_data_list.append(proj_data)
+ proj_xyz_list.append(proj_xyz)
+ if len(fobjs_or_paths) > 1: # print progress if reading multiple files
+ print(f'Finished parsing PROCAR {i + 1}/{len(fobjs_or_paths)}')
+
+ # Combine along the nkpts axis:
+ # for occs, eigenvalues, proj_data and proj_xyz, nbands (axis = 2) could differ, so set missing values to zero:
+ max_nbands = max(arr.shape[2] for arr in eigenvalues_list)
+ for array_list in [occs_list, eigenvalues_list, proj_data_list, proj_xyz_list]:
+ for i, arr in enumerate(array_list):
+ if arr is not None and arr.shape[2] < max_nbands:
+ if len(arr.shape) == 3: # occs_list, eigenvalues_list
+ array_list[i] = np.pad(arr, ((0, 0), (0, 0), (0, max_nbands - arr.shape[2])), mode='constant')
+ elif len(arr.shape) == 5: # proj_xyz_list
+ array_list[i] = np.pad(arr, ((0, 0), (0, 0), (0, max_nbands - arr.shape[2]), (0, 0), (0, 0)), mode='constant')
+ elif len(arr.shape) == 6: # proj_xyz_list
+ array_list[i] = np.pad(arr, ((0, 0), (0, 0), (0, max_nbands - arr.shape[2]), (0, 0), (0, 0), (0, 0)),
+ mode='constant')
+ else:
+ raise ValueError('Unexpected array shape encountered!')
+
+ self.nbands = max_nbands
+ self.occs = np.concatenate(occs_list, axis=1)
+ self.eigenvalues = np.concatenate(eigenvalues_list, axis=1)
+ self.kvecs = np.concatenate(kvecs_list, axis=1)
+ self.kweights = np.concatenate(kweights_list, axis=1)
+ self.proj_data = np.concatenate(proj_data_list, axis=1)
+ if all(arr is not None for arr in proj_xyz_list):
+ self.proj_xyz = np.concatenate(proj_xyz_list, axis=1)
+
+ self.nkpts = self.kvecs.shape[1]
def get_projection(self, atom_idx: List[int], proj: Union[List[str], str], weight_by_k=False):
"""
- Get project for specific atoms and specific projectors
+ Get projection for specific atoms and specific projectors
:param atom_idx: A list of index of the atoms to be selected
:param proj: A list of the projector names to be selected
:param weight_by_k: Apply k weighting or not.
- :returns: The project summed over the selected atoms and the projectors
+ :returns: The projection summed over the selected atoms and the projectors
"""
atom_mask = [iatom in atom_idx for iatom in range(self.nion)]
assert any(atom_mask)
@@ -143,3 +261,41 @@ def _replace_p_d(single_proj):
for kidx in range(self.nkpts):
out[:, kidx, :] *= self.kweights[kidx]
return out
+
+ def as_dict(self) -> dict:
+ """Convert the object into a dictionary representation (so it can be saved to json)"""
+ output = {'@module': self.__class__.__module__, '@class': self.__class__.__name__, '@version': __version__}
+ for key in [
+ '_is_soc', 'eigenvalues', 'kvecs', 'kweights', 'nbands', 'nkpts', 'nspins', 'nion', 'occs', 'proj_names', 'proj_data',
+ 'header', 'proj_xyz'
+ ]:
+ output[key] = getattr(self, key)
+ return output
+
+ @classmethod
+ def from_dict(cls, d):
+ """
+ Reconstructs Procar object from a dict representation, without calling __init__().
+
+ Args:
+ d (dict): dict representation of Procar
+
+ Returns:
+ Procar object
+ """
+
+ def decode_dict(subdict):
+ if isinstance(subdict, dict) and '@module' in subdict:
+ return MontyDecoder().process_decoded(subdict)
+ return subdict
+
+ instance = cls.__new__(cls) # create a new instance without calling __init__()
+ d_decoded = {k: decode_dict(v) for k, v in d.items()}
+
+ # set the instance variables directly from the dictionary
+ for key, value in d_decoded.items():
+ if key in ['@module', '@class', '@version']:
+ continue
+ setattr(instance, key, value)
+
+ return instance
diff --git a/easyunfold/unfold.py b/easyunfold/unfold.py
index 37adcef..fe08d1b 100644
--- a/easyunfold/unfold.py
+++ b/easyunfold/unfold.py
@@ -2,6 +2,9 @@
"""
The main module for unfolding workflow and algorithm
"""
+
+import contextlib
+import itertools
# pylint: disable=invalid-name,protected-access,too-many-locals
############################################################
@@ -317,9 +320,7 @@ def generate_sc_kpoints(self) -> None:
# Collect from form nested list containing the mappings to the reduce sc kpoints
reduced_sc_map = []
for sc_set in expended_sc:
- map_indx = []
- for _ in sc_set:
- map_indx.append(sc_kpts_map.pop(0))
+ map_indx = [sc_kpts_map.pop(0) for _ in sc_set]
reduced_sc_map.append(map_indx)
self.expansion_results['reduced_sckpts'] = reduced_sckpts
@@ -411,17 +412,14 @@ def _read_weights(self, wavefunction: Union[str, List[str]], gamma: bool, ncl: b
return averaged_weights, weights_per_set
- def load_procar(self, procar: Union[str, List[str]], force: bool = False):
+ def load_procars(self, procars: Union[str, List[str]]):
"""Read in PROCAR for band-based projection"""
- if self.procars and not force:
- pass
-
- if not isinstance(procar, (tuple, list)):
- procar = [procar]
+ if not isinstance(procars, (tuple, list)):
+ procars = [procars] # list of PROCAR files
# Load the procars
# Note that this method should be generalised for non-VASP as well.
- self.transient_quantities['procars'] = [Procar(path) for path in procar]
+ self.transient_quantities['procars'] = Procar(procars)
# Construct mapping from the primitive cell kpoints to those in the PROCAR
self.transient_quantities['procars_kmap'] = self._construct_procar_kmap()
@@ -436,20 +434,17 @@ def _construct_procar_kmap(self) -> list:
K_super, _ = find_K_from_k(kpoint, self.M)
# Search for kpoints in the procar
found = False
- for iprocar, procar in enumerate(self.procars):
- for ikpt, kprocar in enumerate(procar.kvecs[0]):
- if kpoints_equal(K_super, kprocar, time_reversal=self.time_reversal):
- kidx_procar_sets[-1].append([iprocar, ikpt])
- found = True
- break
- if found:
+ for ikpt, kprocar in enumerate(self.procar.kvecs[0]):
+ if kpoints_equal(K_super, kprocar, time_reversal=self.time_reversal):
+ kidx_procar_sets[-1].append(ikpt)
+ found = True
break
if found is False:
raise ValueError(f'Cannot found kpoint {K_super} in PROCAR files')
return kidx_procar_sets
@property
- def procars(self) -> Union[None, Procar]:
+ def procar(self) -> Union[None, Procar]:
"""Loaded PROCARS"""
return self.transient_quantities.get('procars')
@@ -481,9 +476,8 @@ def _get_spectral_weights(self,
# If wave function is given - reload from the data
if wavefunction:
self._read_weights(wavefunction, gamma=gamma, ncl=ncl, gamma_half=gamma_half)
- else:
- if not self.is_calculated:
- raise RuntimeWarning('The spectral weights need to be calculated first - please pass the wave function file(s).')
+ elif not self.is_calculated:
+ raise RuntimeWarning('The spectral weights need to be calculated first - please pass the wave function file(s).')
# Use existing results
if symm_average:
@@ -493,11 +487,11 @@ def _get_spectral_weights(self,
# No averaging - we just return the first item of each set, which is the weight of the original set
sws = [item[:, :1, :, :] for item in self.calculated_quantities['spectral_weights_per_set']]
# Recreate the full weights array
- kweight_sets = [[1.0] for i in range(len(sws))]
+ kweight_sets = [[1.0] for _ in range(len(sws))]
if also_spectral_function:
if atoms_idx is not None:
- # Read in the project weights
+ # Read in the projected weights
band_weight_sets = self.get_band_weight_sets(atoms_idx, orbitals)
else:
band_weight_sets = None
@@ -531,18 +525,18 @@ def get_band_weight_sets(self,
:returns: A list of weights for each band at each expanded kpoint
"""
if procars:
- self.load_procar(procars)
- if self.procars is None:
+ self.load_procars(procars)
+ if self.procar is None:
raise RuntimeError('PROCAR files needs to be loaded')
- projs = [procar.get_projection(atoms_idx, orbitals) for procar in self.procars]
+ proj = self.procar.get_projection(atoms_idx, orbitals)
# Construct band weighting, same structure as o
band_weight_sets = []
for kset in self.transient_quantities['procars_kmap']:
band_weight_sets.append([])
# Search
- for iprocar, kidx in kset:
- band_weight = projs[iprocar][:, kidx]
+ for kidx in kset:
+ band_weight = proj[:, kidx]
band_weight_sets[-1].append(band_weight)
return band_weight_sets
@@ -673,12 +667,11 @@ def clean_latex_string(label: str):
:returns: Cleaned tag string
"""
if label == 'G':
- label = r'$\mathrm{\mathsf{\Gamma}}$'
- elif label.startswith('\\'): ## This is a latex formatted label already
- label = f'$\\mathrm{{\\mathsf{{{label}}}}}$'
- else:
- label = r'$\mathrm{\mathsf{' + label + r'}}$'
- return label
+ return r'$\mathrm{\mathsf{\Gamma}}$'
+ if label.startswith('\\'): ## This is a latex formatted label already
+ return f'$\\mathrm{{\\mathsf{{{label}}}}}$'
+
+ return r'$\mathrm{\mathsf{' + label + r'}}$'
def spectral_function_from_weight_sets(spectral_weight_sets: np.ndarray,
@@ -710,18 +703,17 @@ def spectral_function_from_weight_sets(spectral_weight_sets: np.ndarray,
emax = spectral_weight_sets[0][:, :, :, 0].max() if emax is None else emax
e0 = np.linspace(emin - 5 * sigma, emax + 5 * sigma, nedos)
- for ispin in range(ns):
- for ii in range(nk): # Iterate through kpoint sets (of primitive cell kpoints)
- for jj in range(spectral_weight_sets[ii].shape[1]):
- kweight = kweight_sets[ii][jj]
- E_Km = spectral_weight_sets[ii][ispin, jj, :, 0]
- P_Km = spectral_weight_sets[ii][ispin, jj, :, 1]
- if band_weight_sets is not None:
- P_Km = P_Km * band_weight_sets[ii][jj][ispin, :P_Km.shape[0]]
- # Take weighted average spectral functions
- spectral_function[ispin, :,
- ii] += np.sum(LorentzSmearing(e0[:, np.newaxis], E_Km[np.newaxis, :], sigma=sigma) * P_Km[np.newaxis, :],
- axis=1) * kweight
+ for ispin, ii in itertools.product(range(ns), range(nk)): # Iterate through kpoint sets (of primitive cell kpoints)
+ for jj in range(spectral_weight_sets[ii].shape[1]):
+ kweight = kweight_sets[ii][jj]
+ E_Km = spectral_weight_sets[ii][ispin, jj, :, 0]
+ P_Km = spectral_weight_sets[ii][ispin, jj, :, 1]
+ if band_weight_sets is not None:
+ P_Km = P_Km * band_weight_sets[ii][jj][ispin, :P_Km.shape[0]]
+ # Take weighted average spectral functions
+ spectral_function[ispin, :,
+ ii] += np.sum(LorentzSmearing(e0[:, np.newaxis], E_Km[np.newaxis, :], sigma=sigma) * P_Km[np.newaxis, :],
+ axis=1) * kweight
return e0, spectral_function
@@ -1009,11 +1001,9 @@ def spectral_weight_multiple_source(kpoints: list, unfold_objs: List[Unfold], tr
assert ns == obj.wfc.nspins
# When reading from multiple wave function files (e.g. WAVECAR for VASP),
- # it is possible that each of them may have differnt number of bands.
+ # it is possible that each of them may have different number of bands.
# Ff so, we take only the first N bands, where N is the minimum values of bands
- nb = []
- for source in unfold_objs:
- nb.append(source.bands.shape[2])
+ nb = [source.bands.shape[2] for source in unfold_objs]
nbands = min(nb)
spectral_weights = []
@@ -1101,9 +1091,8 @@ def parse_atoms_idx(atoms_idx: str) -> List[int]:
out.extend(range(int(match.group(1)), int(match.group(2)) + 1))
else:
out.append(int(item))
- # Expect passing 1-based indexing
- out = [x - 1 for x in out]
- return out
+
+ return [x - 1 for x in out] # Expect passing 1-based indexing
def process_projection_options(atoms_idx: str, orbitals: str) -> Tuple[list, list]:
@@ -1153,7 +1142,7 @@ def parse_atoms(atoms_to_project: str, orbitals: str, poscar: str):
"""
atoms_to_project = re.split(', *', atoms_to_project)
ase_atoms = read_poscar_contcar_if_present(poscar)
- try: # check POTCAR if possible, to check the POSCAR-POTCARs match
+ with contextlib.suppress(FileNotFoundError):
atom_types = get_atomtypes('POTCAR')
def _check_order(smaller_list, larger_list):
@@ -1162,8 +1151,6 @@ def _check_order(smaller_list, larger_list):
if not _check_order(atom_types, ase_atoms.get_chemical_symbols()):
warnings.warn('The order of atoms in the POSCAR/CONTCAR and POTCAR do not match!')
- except FileNotFoundError:
- pass
atoms_idx = [
[i for i, atom in enumerate(ase_atoms) if projected_atom_symbol in atom.symbol] for projected_atom_symbol in atoms_to_project
]
@@ -1175,7 +1162,7 @@ def _check_order(smaller_list, larger_list):
# Special case: if only one set is passed, apply it to all atomic specifications
if len(orbitals_subplots) == 1:
- orbitals_subplots = orbitals_subplots * len(atoms_idx)
+ orbitals_subplots *= len(atoms_idx)
orbitals_list = []
for orbital_sublist in orbitals_subplots:
diff --git a/examples/MgO/unfold_project.png b/examples/MgO/unfold_project.png
index 46ed044..39073de 100644
Binary files a/examples/MgO/unfold_project.png and b/examples/MgO/unfold_project.png differ
diff --git a/examples/MgO/unfold_project_rb.png b/examples/MgO/unfold_project_rb.png
index 88287b0..64eb689 100644
Binary files a/examples/MgO/unfold_project_rb.png and b/examples/MgO/unfold_project_rb.png differ
diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj.png
index 5139b80..e4440fa 100644
Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj.png differ
diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj_dos.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj_dos.png
index 6664fd2..90550cd 100644
Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj_dos.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj_dos.png differ
diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png
index eb3a13f..f5bc4e8 100644
Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj_noS.png differ
diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj_orbital_lm_dos.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj_orbital_lm_dos.png
index a802af4..50b4d50 100644
Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj_orbital_lm_dos.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj_orbital_lm_dos.png differ
diff --git a/examples/NaBiS2/NaBiS2_unfold-plot_proj_sps_dos.png b/examples/NaBiS2/NaBiS2_unfold-plot_proj_sps_dos.png
index 7043ca4..5ee426e 100644
Binary files a/examples/NaBiS2/NaBiS2_unfold-plot_proj_sps_dos.png and b/examples/NaBiS2/NaBiS2_unfold-plot_proj_sps_dos.png differ
diff --git a/examples/Si222/unfold_tall.png b/examples/Si222/unfold_tall.png
new file mode 100644
index 0000000..e73153c
Binary files /dev/null and b/examples/Si222/unfold_tall.png differ
diff --git a/tests/test_cli.py b/tests/test_cli.py
index b76f020..16f3038 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -116,36 +116,43 @@ def test_plot_projection(mgo_project_dir):
runner = CliRunner()
output = runner.invoke(easyunfold,
['unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms-idx', '1,2|3-4', '--procar', 'PROCAR'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _plot_projection_check(output)
output = runner.invoke(
easyunfold, ['unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms-idx', '1,2|3-4', '--procar', 'PROCAR', '--combined'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _plot_projection_check(output)
output = runner.invoke(easyunfold, [
'unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms-idx', '1,2|3-4', '--procar', 'PROCAR', '--combined', '--orbitals',
'px,py|pz'
])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _plot_projection_check(output)
# test --atoms option with --poscar specification
output = runner.invoke(easyunfold,
['unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms', 'Mg,O', '--poscar', 'POSCAR.mgo'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _plot_projection_check(output)
# test parsing PROCAR from LORBIT = 14 calculation
output = runner.invoke(easyunfold, [
'unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms', 'Mg,O', '--poscar', 'POSCAR.mgo', '--procar',
'PROCAR_LORBIT_14.mgo'
])
+ _plot_projection_check(output)
+
+ # test options
+ output = runner.invoke(easyunfold, [
+ 'unfold', '--data-file', 'mgo.json', 'plot-projections', '--atoms', 'Mg,O', '--poscar', 'POSCAR.mgo', '--eref', '2',
+ '--no-symm-average', '--cmap', 'PuBu', '--orbitals', 's,p', '--colours', 'r,g', '--orbitals', 's,p', '--colours', 'r,g',
+ '--colourspace', 'luvlch', '--intensity', '0.5', '--emin', '-2', '--emax', '5', '--dpi', '500', '--vscale', '2.0', '--cmap', 'PuBu',
+ '--npoints', '200', '--sigma', '0.15', '--title', 'Test', '--no-combined', '--height', '2', '--width', '3', '-o', 'test.png'
+ ])
+ assert output.exit_code == 0
+ assert Path('test.png').is_file() # different file name this time
+ Path('test.png').unlink()
+
+
+def _plot_projection_check(output):
assert output.exit_code == 0
assert Path('unfold.png').is_file()
Path('unfold.png').unlink()
@@ -199,17 +206,13 @@ def test_dos_atom_orbital_plots(nabis2_project_dir):
'--dos-elements',
'Bi.s.p',
])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, [
'unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--orbitals', 's|px,py,pz|p', '--vscale', '0.5', '--combined', '--dos',
'vasprun.xml.gz', '--zero-line', '--dos-label', 'DOS', '--gaussian', '0.1', '--no-total', '--scale', '2'
])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(
easyunfold,
@@ -217,84 +220,64 @@ def test_dos_atom_orbital_plots(nabis2_project_dir):
'unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--orbitals', 's|px,py,pz|p', '--intensity', '2', '--combined', '--dos',
'vasprun.xml.gz', '--zero-line', '--dos-label', 'DOS', '--gaussian', '0.1', '--no-total', '--scale', '2'
])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, [
'unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--vscale', '0.5', '--combined', '--dos', 'vasprun.xml.gz', '--zero-line',
'--dos-label', 'DOS', '--gaussian', '0.1', '--no-total', '--scale', '2'
])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, [
'unfold', 'plot-projections', '--atoms', 'Na,Bi', '--vscale', '0.5', '--combined', '--dos', 'vasprun.xml.gz', '--zero-line',
'--dos-label', 'DOS', '--gaussian', '0.1', '--no-total', '--scale', '2'
])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, ['unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--dos', 'vasprun.xml.gz'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, ['unfold', 'plot-projections', '--atoms', 'Na,Bi,S', '--combined', '--dos', 'vasprun.xml.gz'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, ['unfold', 'plot', '--atoms-idx', '1-20|21-40', '--orbitals', 's|p', '--dos', 'vasprun.xml.gz'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, ['unfold', 'plot', '--atoms', 'Na,Bi', '--orbitals', 's|p', '--dos', 'vasprun.xml.gz'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold,
['unfold', 'plot-projections', '--atoms', 'Na,Bi', '--combined', '--orbitals', 's', '--dos', 'vasprun.xml.gz'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, [
'unfold', 'plot-projections', '--atoms', 'Na,Bi', '--combined', '--orbitals', 's', '--dos', 'vasprun.xml.gz', '--dos-elements',
'Bi.s'
])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, [
'unfold', 'plot-projections', '--atoms-idx', '1-20,21,22,33', '--combined', '--orbitals', 's', '--dos', 'vasprun.xml.gz',
'--dos-elements', 'Bi.s'
])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(
easyunfold,
['unfold', 'plot', '--atoms-idx', '1-20,21,22,33', '--orbitals', 's', '--dos', 'vasprun.xml.gz', '--dos-elements', 'Bi.s'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, ['unfold', 'plot', '--dos', 'vasprun.xml.gz'])
- assert output.exit_code == 0
- assert Path('unfold.png').is_file()
- Path('unfold.png').unlink()
+ _check_dos_atom_orbital_plots(output)
output = runner.invoke(easyunfold, [
'unfold', 'plot-projections', '--atoms', 'Na,Bi', '--orbitals', 's', '--combined', '--dos', 'vasprun.xml.gz', '--dos-elements',
'Bi.s'
])
+ _check_dos_atom_orbital_plots(output)
+
+
+def _check_dos_atom_orbital_plots(output):
assert output.exit_code == 0
assert Path('unfold.png').is_file()
Path('unfold.png').unlink()
diff --git a/tests/test_plotting.py b/tests/test_plotting.py
index 7b3ef4e..bcb910c 100644
--- a/tests/test_plotting.py
+++ b/tests/test_plotting.py
@@ -41,8 +41,7 @@ def silicon_unfolded(si_project_dir) -> UnfoldKSet:
@pytest.fixture(scope='module')
def unfold_obj() -> UnfoldKSet:
"""Return an unfolding object"""
- obj = loadfn(Path(__file__).parent / 'test_data/mgo.json')
- return obj
+ return loadfn(Path(__file__).parent / 'test_data/mgo.json')
def test_plotting(unfold_obj: UnfoldKSet):
@@ -77,6 +76,35 @@ def test_plotting_projection(unfold_obj: UnfoldKSet):
fig = plotter.plot_projected(procar_path, atoms_idx='0,1|2,3', npoints=200, use_subplot=True)
assert isinstance(fig, Figure)
+ # test options
+ poscar_path = Path(__file__).parent / 'test_data/POSCAR.mgo'
+ fig = plotter.plot_projected(procar_path,
+ eref=2,
+ gamma=False,
+ ncl=False,
+ npoints=200,
+ sigma=0.15,
+ symm_average=False,
+ figsize=(2, 3),
+ ylim=(-2, 5),
+ dpi=500,
+ vscale=2.0,
+ contour_plot=True,
+ alpha=0.4,
+ save='test.png',
+ ax=None,
+ cmap='PuBu',
+ show=True,
+ title='Test',
+ atoms='Mg,O',
+ poscar=poscar_path,
+ orbitals='s,p',
+ use_subplot=True,
+ colours='r,g',
+ colorspace='luvlch',
+ intensity=0.5)
+ assert isinstance(fig, Figure)
+
def test_color_interpolation():
"""Test interpolating colours"""
diff --git a/tests/test_procar.py b/tests/test_procar.py
index ed9a094..94fe6a1 100644
--- a/tests/test_procar.py
+++ b/tests/test_procar.py
@@ -4,7 +4,6 @@
from pathlib import Path
import numpy as np
-import pytest
from easyunfold.procar import Procar
@@ -17,12 +16,12 @@ def test_procar():
procar = Procar(datapath / 'PROCAR')
assert procar.nion == 2
- assert procar.eigenvalues.shape == (1, 48, 20)
- assert procar.kvecs.shape == (1, 48, 3)
- assert procar.kweights.shape == (1, 48)
+ assert procar.eigenvalues.shape == (1, 47, 20)
+ assert procar.kvecs.shape == (1, 47, 3)
+ assert procar.kweights.shape == (1, 47)
assert np.all(procar.kvecs[0][0] == 0.)
- assert procar.occs.shape == (1, 48, 20)
- assert procar.get_projection([0], 'all').sum() == 310.418
- assert procar.get_projection([0], ['s', 'px']).sum() == 59.32300000000001
- assert procar.get_projection([0], ['s']).sum() + procar.get_projection([0], 'px').sum() == 59.32300000000001
+ assert procar.occs.shape == (1, 47, 20)
+ assert procar.get_projection([0], 'all').sum() == 618.2850603903657
+ assert procar.get_projection([0], ['s', 'px']).sum() == 124.10684519359549
+ assert procar.get_projection([0], ['s']).sum() + procar.get_projection([0], 'px').sum() == 124.10684519359546
assert procar.proj_names == ['s', 'py', 'pz', 'px', 'dxy', 'dyz', 'dz2', 'dxz', 'x2-y2']
diff --git a/tests/test_unfold.py b/tests/test_unfold.py
index 37dda49..eab9af2 100644
--- a/tests/test_unfold.py
+++ b/tests/test_unfold.py
@@ -232,7 +232,7 @@ def test_unfold_projection(si_project_dir, tag, nspin, ncl, nbands_expected):
assert unfolder.is_calculated
- unfolder.load_procar(si_project_dir / f'{folder_name}/PROCAR')
+ unfolder.load_procars(si_project_dir / f'{folder_name}/PROCAR')
assert 'procars' in unfolder.transient_quantities
assert 'procars_kmap' in unfolder.transient_quantities