# coding: utf-8
"""Scissors operator."""
import os
import numpy as np
import pickle
from collections import OrderedDict
from monty.collections import AttrDict
from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt
__all__ = [
"Scissors",
"ScissorsBuilder",
]
class ScissorsError(Exception):
"""Base class for the exceptions raised by :class:`Scissors`"""
[docs]
class Scissors(object):
"""
This object represents an energy-dependent scissors operator.
The operator is defined by a list of domains (energy intervals)
and a list of functions defined in these domains.
The domains should fulfill the constraints documented in the main constructor.
.. note::
eV units are assumed.
The standard way to create this object is via the methods provided by the factory class :class:`ScissorBuilder`.
Once the instance has been created, one can correct the band structure by calling the `apply` method.
"""
Error = ScissorsError
def __init__(self, func_list, domains, residues, bounds=None):
"""
Args:
func_list: List of callable objects. Each function takes an eigenvalue and returns
the corrected value.
domains: Domains of each function. List of tuples [(emin1, emax1), (emin2, emax2), ...]
bounds: Specify how to handle energies that do not fall inside one of the domains.
At present, only constant boundaries are implemented.
residues: A list of the residues of the fitting per domain
.. note::
#. Domains should not overlap, cover e0mesh, and given in increasing order.
#. Holes are permitted but the interpolation will raise an exception if the
eigenvalue falls inside the hole.
#. Errors contains a list of the fitting errors per domain
"""
# TODO Add consistency check.
self.func_list = func_list
self.domains = np.atleast_2d(domains)
self.residues = residues
assert len(self.func_list) == len(self.domains)
# Treat the out-of-boundary conditions. func_low and func_high are used to handle energies
# that are below or above the min/max energy given in domains.
blow, bhigh = "c", "c"
if bounds is not None:
blow, bhigh = bounds[0][0], bounds[0][1]
if blow.lower() == "c":
try:
self.func_low = lambda x: float(bounds[0][1])
except Exception:
x_low = self.domains[0,0]
fx_low = func_list[0](x_low)
self.func_low = lambda x: fx_low
else:
raise NotImplementedError("Only constant boundaries are implemented")
if bhigh.lower() == "c":
try:
self.func_high = lambda x: float(bounds[1][1])
except Exception:
x_high = self.domains[1, -1]
fx_high = func_list[-1](x_high)
self.func_high = lambda x: fx_high
else:
raise NotImplementedError("Only constant boundaries are implemented")
# This counter stores the number of points that are out of bounds.
self.out_bounds = np.zeros(3, dtype=int)
[docs]
def apply(self, eig):
"""Correct the eigenvalue eig (eV units)."""
# Get the list of domains.
domains = self.domains
if eig < domains[0,0]:
# Eig is below the first point of the first domain.
# Call func_low
print("left ", eig, " < ", domains[0,0])
self.out_bounds[0] += 1
return self.func_low(eig)
if eig > domains[-1,1]:
# Eig is above the last point of the last domain.
# Call func_high
print("right ", eig, " > ", domains[-1,1])
self.out_bounds[1] += 1
return self.func_high(eig)
# eig is inside the domains: find the domain
# and call the corresponding function.
for idx, dms in enumerate(domains):
if dms[1] >= eig >= dms[0]:
return self.func_list[idx](eig)
self.out_bounds[2] += 1
raise self.Error("Cannot find location of eigenvalue %s in domains:\n%s" % (eig, domains))
[docs]
class ScissorsBuilder(object):
"""
This object facilitates the creation of :class:`Scissors` instances.
Usage:
builder = ScissorsBuilder.from_file("out_SIGRES.nc")
# To plot the QP results as function of the KS energy:
builder.plot_qpe_vs_e0()
# To select the domains esplicitly (optional but highly recommended)
builder.build(domains_spin=[[-10, 6.02], [6.1, 20]])
# To compare the fitted results with the ab-initio data:
builder.plot_fit()
# To plot the corrected bands:
builder.plot_qpbands(abidata.ref_file("si_nscf_WFK.nc"))
"""
[docs]
@classmethod
def from_file(cls, filepath):
"""
Generate object from (SIGRES.nc) file. Main entry point for client code.
"""
from abipy.abilab import abiopen
with abiopen(filepath) as ncfile:
return cls(qps_spin=ncfile.qplist_spin, sigres_ebands=ncfile.ebands)
[docs]
@classmethod
def pickle_load(cls, filepath):
"""Load the object from a pickle file."""
with open(filepath, "rb") as fh:
d = AttrDict(pickle.load(fh))
# Costruct the object and compute the scissors.
new = cls(d.qps_spin, d.sigres_ebands)
new.build(d.domains_spin, d.bounds_spin)
return new
[docs]
def pickle_dump(self, filepath, protocol=-1):
"""Save the object in Pickle format"""
assert all(s1 == s2 for s1, s2 in zip(self.domains_spin.keys(), self.bounds_spin.keys()))
assert all(s1 == s2 for s1, s2 in zip(self.domains_spin.keys(), range(self.nsppol)))
bounds_spin = None
if any(v is not None for v in self.bounds_spin.values()):
bounds_spin = [a.tolist() for a in self.bounds_spin.values()]
# This trick is needed because we cannot pickle bound methods of the scissors operator.
d = dict(qps_spin=self._qps_spin,
sigres_ebands=self.sigres_ebands,
domains_spin=[a for a in self.domains_spin.values()],
bounds_spin=bounds_spin)
with open(filepath, "wb") as fh:
pickle.dump(d, fh, protocol=protocol)
def __init__(self, qps_spin, sigres_ebands):
"""
Args:
qps_spin: List of :class:`QPlist`, for each spin.
sigres_ebands: |ElectronBands| obtained from the SIGRES file
"""
# Sort quasiparticle data by e0.
self._qps_spin = tuple([qps.sort_by_e0() for qps in qps_spin])
# Compute the boundaries of the E0 mesh.
e0min, e0max = np.inf, -np.inf
for qps in self._qps_spin:
e0mesh = qps.get_e0mesh()
e0min = min(e0min, e0mesh[0])
e0max = max(e0max, e0mesh[-1])
self._e0min, self._e0max = e0min, e0max
# The KS bands stored in the sigres file (used to compute automatically the boundaries)
self.sigres_ebands = sigres_ebands
# Start with default values for domains.
self.build()
@property
def nsppol(self):
"""Number of spins."""
return len(self._qps_spin)
@property
def e0min(self):
"""Minimum KS energy in eV (takes into account spin)"""
return self._e0min
@property
def e0max(self):
"""Maximum KS energy in eV (takes into account spin)"""
return self._e0max
@property
def scissors_spin(self):
"""Returns a tuple of :class:`Scissors` indexed by the spin value."""
try:
return self._scissors_spin
except AttributeError:
raise AttributeError("Call self.build to create the scissors operator")
[docs]
def build(self, domains_spin=None, bounds_spin=None, k=3):
"""
Build the scissors operator.
Args:
domains_spin: list of domains in eV for each spin. If domains is None,
domains are computed automatically from the sigres bands
(two domains separated by the middle of the gap).
bounds_spin: Options specifying the boundary conditions (not used at present)
k: Parameter defining the order of the fit.
"""
nsppol = self.nsppol
# The parameters defining the scissors operator
self.domains_spin = OrderedDict()
self.bounds_spin = OrderedDict()
if domains_spin is None:
# Use sigres_ebands and the position of the homo, lumo to compute the domains.
domains_spin = nsppol * [None]
e_bands = self.sigres_ebands
for spin in e_bands.spins:
gap_mid = (e_bands.homos[spin].eig + e_bands.lumos[spin].eig) / 2
domains_spin[spin] = [[self.e0min - 0.2 * abs(self.e0min), gap_mid],
[gap_mid, self.e0max + 0.2 * abs(self.e0max)]]
#print("domains", domains_spin[spin])
else:
if nsppol == 1:
domains_spin = np.reshape(domains_spin, (1, -1, 2))
elif nsppol == 2:
assert len(domains_spin) == nsppol
if bounds_spin is not None: assert len(bounds_spin) == nsppol
else:
raise ValueError("Wrong number of spins %d" % nsppol)
#if len(domains_spin) != nsppol:
# raise ValueError("len(domains_spin) == %s != nsppol %s" % (len(domains_spin), nsppol))
# Construct the scissors operator for each spin.
scissors_spin = nsppol * [None]
for spin, qps in enumerate(self._qps_spin):
bounds = None if not bounds_spin else bounds_spin[spin]
scissors_spin[spin] = qps.build_scissors(domains_spin[spin], bounds=bounds, k=k, plot=False)
# Save input so that we can reconstruct Scissors.
self.domains_spin[spin] = domains_spin[spin]
self.bounds_spin[spin] = bounds
self._scissors_spin = scissors_spin
return domains_spin
[docs]
@add_fig_kwargs
def plot_qpe_vs_e0(self, with_fields="all", **kwargs):
"""Plot the quasiparticle corrections as function of the KS energy."""
ax_list = None
for spin, qps in enumerate(self._qps_spin):
kwargs["title"] = "spin %s" % spin
fig = qps.plot_qps_vs_e0(with_fields=with_fields, ax_list=ax_list, show=False, **kwargs)
ax_list = fig.axes
return fig
[docs]
@add_fig_kwargs
def plot_fit(self, ax=None, fontsize=8, **kwargs):
"""
Compare fit functions with input quasi-particle corrections.
Args:
ax: |matplotlib-Axes| or None if a new figure should be created.
fontsize: fontsize for titles and legend.
Return: |matplotlib-Figure|
"""
ax, fig, plt = get_ax_fig_plt(ax=ax)
for spin in range(self.nsppol):
qps = self._qps_spin[spin]
e0mesh, qpcorrs = qps.get_e0mesh(), qps.get_qpeme0().real
ax.scatter(e0mesh, qpcorrs, label="Input QP corrections, spin %s" % spin)
scissors = self._scissors_spin[spin]
intp_qpc = [scissors.apply(e0) for e0 in e0mesh]
ax.plot(e0mesh, intp_qpc, label="Scissors operator, spin %s" % spin)
ax.grid(True)
ax.set_xlabel('KS energy (eV)')
ax.set_ylabel('QP-KS (eV)')
ax.legend(loc="best", fontsize=fontsize, shadow=True)
return fig
[docs]
def plot_qpbands(self, bands_filepath, bands_label=None, dos_filepath=None, dos_args=None, **kwargs):
"""
Correct the energies found in the netcdf file bands_filepath and plot the band energies (both the initial
and the corrected ones) with matplotlib. The plot contains the KS and the QP DOS if dos_filepath is not None.
Args:
bands_filepath: Path to the netcdf file containing the initial KS energies to be corrected.
bands_label String used to label the KS bands in the plot.
dos_filepath: Optional path to a netcdf file with the initial KS energies on a homogeneous k-mesh
(used to compute the KS and the QP dos)
dos_args: Dictionary with the arguments passed to get_dos to compute the DOS
Used if dos_filepath is not None.
kwargs: Options passed to the plotter.
Return: |matplotlib-Figure|
"""
from abipy.abilab import abiopen, ElectronBandsPlotter
# Read the KS band energies from bands_filepath and apply the scissors operator.
with abiopen(bands_filepath) as ncfile:
ks_bands = ncfile.ebands
#structure = ncfile.structure
qp_bands = ks_bands.apply_scissors(self._scissors_spin)
# Read the band energies computed on the Monkhorst-Pack (MP) mesh and compute the DOS.
ks_dos, qp_dos = None, None
if dos_filepath is not None:
with abiopen(dos_filepath) as ncfile:
ks_mpbands = ncfile.ebands
dos_args = {} if not dos_args else dos_args
ks_dos = ks_mpbands.get_edos(**dos_args)
# Compute the DOS with the modified QPState energies.
qp_mpbands = ks_mpbands.apply_scissors(self._scissors_spin)
qp_dos = qp_mpbands.get_edos(**dos_args)
# Plot the LDA and the QPState band structure with matplotlib.
plotter = ElectronBandsPlotter()
bands_label = bands_label if bands_label is not None else os.path.basename(bands_filepath)
plotter.add_ebands(bands_label, ks_bands, edos=ks_dos)
plotter.add_ebands(bands_label + " + scissors", qp_bands, edos=qp_dos)
#qp_marker: if int > 0, markers for the ab-initio QP energies are displayed. e.g qp_marker=50
#qp_marker = 50
#if qp_marker is not None:
# # Compute correspondence between the k-points in qp_list and the k-path in qp_bands.
# # TODO
# # WARNING: strictly speaking one should check if qp_kpoint is in the star of k-point.
# # but compute_star is too slow if written in pure python.
# x, y, s = [], [], []
# for ik_path, kpoint in enumerate(qp_bands.kpoints):
# #kstar = kpoint.compute_star(structure.fm_symmops)
# for spin in range(self.nsppol):
# for ik_qp, qp in enumerate(self._qps_spin[spin]):
# #if qp.kpoint in kstar:
# if qp.kpoint == kpoint:
# x.append(ik_path)
# y.append(np.real(qp.qpe))
# s.append(qp_marker)
# plotter.set_marker("ab-initio QP", [x, y, s])
return plotter.combiplot(**kwargs)