# coding: utf-8
"""This module contains the class describing a planewave wavefunction."""
from __future__ import annotations
#import copy
import numpy as np
from monty.termcolor import cprint
from abipy.core import Mesh3D
from abipy.core.structure import Structure
from abipy.core.kpoints import Kpoint
from abipy.iotools import Visualizer
from abipy.iotools.xsf import xsf_write_structure, xsf_write_data
from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt
from abipy.tools.typing import Figure
__all__ = [
"PWWaveFunction",
]
def latex_label_ispinor(ispinor: int, nspinor: int) -> str:
if nspinor == 1:
return ""
elif nspinor == 2:
return {k: v.replace("myuparrow", "uparrow") for k, v in
{0: r"$\sigma=\myuparrow$", 1: r"$\sigma=\downarrow$"}.items()}[ispinor]
else:
raise ValueError("Wrong value for nspinor: %s" % nspinor)
class WaveFunction:
"""
Abstract class defining base and abstract methods for wavefunction objects.
"""
def __eq__(self, other):
if other is None: return False
if self.gsphere != other.gsphere: return False
return np.allclose(self.ug, other.ug)
def __ne__(self, other):
return not (self == other)
def __iter__(self):
"""Yields G, ug[0:nspinor, G]"""
return zip(self.gvecs, self.ug.T)
def __getitem__(self, slice):
return self.gvecs[slice], self.ug[:, slice]
def __repr__(self):
return str(self)
def __str__(self):
return self.to_string()
@property
def shape(self) -> tuple[int, int]:
"""Shape of ug i.e. (nspinor, npw)"""
return self.nspinor, self.npw
@property
def gsphere(self):
""":class:`GSphere` object"""
return self._gsphere
@property
def kpoint(self):
"""|Kpoint| object"""
return self.gsphere.kpoint
@property
def gvecs(self):
"""G-vectors in reduced coordinates."""
return self.gsphere.gvecs
@property
def npw(self) -> int:
"""Number of G-vectors."""
return len(self.gsphere)
@property
def ecut(self) -> float:
"""Cutoff energy in Hartree."""
return self.gsphere.ecut
@property
def isnc(self) -> bool:
"""True if norm-conserving wavefunction."""
return isinstance(self, PWWaveFunction)
@property
def ispaw(self) -> bool:
"""True if PAW wavefunction."""
return isinstance(self, PAW_WaveFunction)
@property
def ug(self):
"""Periodic part of the wavefunctions in G-space."""
return self._ug
def set_ug(self, ug):
"""Set the value of the u(nspinor, G) array."""
if ug.shape != self.shape:
raise ValueError("Input ug shape %s differs from the one stored in self %s" % (ug.shape, self.shape))
self._ug = ug
self.delete_ur()
@property
def ur(self):
"""Periodic part of the wavefunction in real space."""
try:
return self._ur
except AttributeError:
self._ur = self.fft_ug()
return self._ur
def delete_ur(self) -> None:
"""Delete _u(r) (if it has been computed)."""
try:
del self._ur
except AttributeError:
pass
#@property
#def ur_xyz(self):
# """
# Returns a copy with ur[nspinor, nx, ny, nz]. Mainly used for post-processing.
# """
# return self.mesh.reshape(self.ur).copy()
#@property
#def ur2_xyz(self):
# """
# Returns ur2[nx, ny, nz]. Mainly used for post-processing.
# """
# return self.mesh.reshape(self.ur2)
@property
def mesh(self):
"""The mesh used for the FFT."""
return self._mesh
def set_mesh(self, mesh) -> None:
"""Change the FFT mesh. `u(r)` will be computed on this box."""
assert isinstance(mesh, Mesh3D)
self._mesh = mesh
self.delete_ur()
#def deepcopy(self):
# """Deep copy of self."""
# return copy.deepcopy(self)
def get_ug_mesh(self, mesh=None):
"""
Returns u(G) on the FFT mesh
Args:
mesh: |Mesh3d| object. If mesh is None, the internal mesh is used.
"""
mesh = self.mesh if mesh is None else mesh
return self.gsphere.tofftmesh(mesh, self.ug)
def get_ur_mesh(self, mesh, copy=True):
"""
Returns u(r) on the FFT mesh. This routine is mainly used if we need u(r) on a
differen mesh. Data on the initial mesh is already available in `self.ur`
Args:
mesh: |Mesh3d| object.
copy: By default, we return a copy of ur if mesh == self.mesh.
"""
if mesh == self.mesh:
return self.ur.copy() if copy else self.ur
else:
return self.fft_ug(mesh=mesh)
def fft_ug(self, mesh=None):
"""
Performs the FFT transform of :math:`u(g)` on mesh.
Args:
mesh: |Mesh3d| object. If mesh is None, self.mesh is used.
Returns:
:math:`u(r)` on the real space FFT box.
"""
mesh = self.mesh if mesh is None else mesh
ug_mesh = self.get_ug_mesh(mesh=mesh)
return mesh.fft_g2r(ug_mesh, fg_ishifted=False)
def to_string(self, verbose=0) -> str:
"""String representation."""
lines = []; app = lines.append
app("%s: nspinor: %d, spin: %d, band: %d " % (
self.__class__.__name__, self.nspinor, self.spin, self.band))
if hasattr(self, "gsphere"):
app(self.gsphere.to_string(verbose=verbose))
if hasattr(self, "mesh"):
app(self.mesh.to_string(verbose=verbose))
return "\n".join(lines)
# TODO: get_ur2?
@property
def ur2(self):
"""
[nx, ny, nz] array with :math:`||u(r)||^2` in real space.
"""
ur2 = (self.ur.conj() * self.ur).real.copy()
#if self.nspinor == 2: ur2 = ur2.sum(axis=3)
return ur2
[docs]
class PWWaveFunction(WaveFunction):
"""
This object describes a wavefunction expressed in a plane-wave basis set.
.. rubric:: Inheritance Diagram
.. inheritance-diagram:: PWWaveFunction
"""
def __init__(self, structure: Structure, nspinor, spin, band, gsphere, ug):
"""
Creation method.
Args:
structure: |Structure| object.
nspinor: number of spinorial components.
spin: spin index (only used if collinear-magnetism).
band: band index (>=0)
gsphere |GSphere| instance.
ug: 2D array containing u[nspinor,G] for G in gsphere.
"""
self.structure = structure
self.nspinor, self.spin, self.band = nspinor, spin, band
# Sanity check.
assert ug.ndim == 2
assert ug.shape[0] == nspinor
assert ug.shape[1] == gsphere.npw
self._gsphere = gsphere
self._ug = np.array(ug)
#def kinetic_energy(self):
# """Computes the matrix element of the kinetic operator in reciprocal space."""
# tug = -0.5 * self.gsphere.kpg2 * self.ug
# return np.vdot(self.ug, tug).sum()
[docs]
def norm2(self, space="g") -> float:
r"""
Return :math:`||\psi||^2` computed in G- or r-space.
space: Integration space. Possible values ["g", "gsphere", "r"]
if "g" or "r" the scalar product is computed in G- or R-space on the FFT box.
if "gsphere" the integration is done on the G-sphere.
"""
space = space.lower()
if space == "g":
return np.real(np.vdot(self.ug, self.ug))
elif space == "gsphere":
return np.real(np.vdot(self.ug, self.ug))
elif space == "r":
return np.vdot(self.ur, self.ur) / self.mesh.size
else:
raise ValueError("Wrong space: %s" % str(space))
[docs]
def braket(self, other, space="g") -> complex:
"""
Returns the scalar product <u1|u2> of the periodic part of two wavefunctions
computed in G-space or r-space, depending on the value of space.
Note that selection rules introduced by k-points is not taken into accout.
Args:
other: Other wave (right-hand side)
space: Integration space. Possible values ["g", "gsphere", "r"]
if "g" or "r" the scalar product is computed in G- or R-space on the FFT box.
if "gsphere" the integration is done on the G-sphere. Note that
this option assumes that self and other have the same list of G-vectors.
"""
space = space.lower()
if space == "g":
ug1_mesh = self.gsphere.tofftmesh(self.mesh, self.ug)
ug2_mesh = other.gsphere.tofftmesh(self.mesh, other.ug) if other is not self else ug1_mesh
return np.vdot(ug1_mesh, ug2_mesh)
elif space == "gsphere":
return np.vdot(self.ug, other.ug)
elif space == "r":
return np.vdot(self.ur, other.ur) / self.mesh.size
else:
raise ValueError("Wrong space: %s" % str(space))
[docs]
def get_interpolator(self):
"""
Return an interpolator object that interpolates periodic functions in real space.
"""
from abipy.tools.numtools import BlochRegularGridInterpolator
return BlochRegularGridInterpolator(self.structure, self.ur)
#def pww_translation(self, gvector, rprimd):
# """Returns the pwwave of the kpoint translated by one gvector."""
# gsph = self.gsphere.copy()
# wpww = PWWaveFunction(self.structure, self.nspinor, self.spin, self.band, gsph, self.ug.copy())
# wpww.mesh = self.mesh
# wpww.pww_translation_inplace(gvector, rprimd)
# return wpww
#def pww_translation_inplace(self, gvector, rprimd):
# """Translates the pwwave from 1 kpoint by one gvector."""
# self.gsphere.kpoint = self.gsphere.kpoint + gvector
# self.gsphere.gvecs = self.gsphere.gvecs + gvector
# fft_ndivs = (self.mesh.shape[0] + 2, self.mesh.shape[1] + 2, self.mesh.shape[2] + 2)
# newmesh = Mesh3D(fft_ndivs, rprimd, pbc=True)
# self.mesh = newmesh
#def pwwtows_inplace(self):
# """Wrap the kpoint to the interval ]-1/2,1/2] and update pwwave accordingly."""
# kpoint = Kpoint(self.gsphere.kpoint, self.gsphere.gprimd)
# wkpt = kpoint.wrap_to_ws()
# if np.allclose(wkpt.rcoord, kpoint.rcoord):
# return
# #@David FIXME this is wrong
# gvector = np.array(kpoint.rcoord - wkpt.rcoord, np.int)
# self.gsphere.gvecs = self.gsphere.gvecs + gvector
# self.gsphere.kpoint = wkpt.rcoord
#def pwwtows(self):
# """Return the pwwave of the kpoint wrapped to the interval ]-1/2,1/2]."""
# gsph = self.gsphere.copy()
# wpww = PWWaveFunction(self.structure, self.nspinor, self.spin, self.band, gsph, self.ug.copy())
# wpww.pwwtows_inplace()
# wpww.mesh = self.mesh
# return wpww
#def rotate(self, symmop, mesh=None):
# """
# Apply the symmetry operation `symmop` to the periodic part.
# Args:
# symmop: :class:`Symmetry` operation
# mesh: mesh for the FFT, if None the mesh of self is used.
# Returns:
# New wavefunction object.
# """
# if self.nspinor != 1:
# raise ValueError("Spinor rotation not available yet.")
# rot_gsphere = self.gsphere.rotate(symmop)
# #rot_istwfk = istwfk(rot_kpt)
# if not np.allclose(symmop.tau, np.zeros(3)):
# rot_ug = np.empty_like(self.ug)
# rot_gvecs = rot_gsphere.gvecs
# rot_kpt = rot_gsphere.kpoint.frac_coords
# ug = self._ug
# #phase = np.exp(-2j * np.pi * (np.dot(rot_gvecs + rot_kpt, symmop.tau)))
# for ig in range(self.npw):
# rot_ug[:, ig] = ug[:, ig] * np.exp(-2j * np.pi * (np.dot(rot_gvecs[ig] + rot_kpt, symmop.tau)))
# else:
# rot_ug = self.ug.copy()
# # Invert the collinear spin if we have an AFM operation.
# rot_spin = self.spin
# if self.nspinor == 1:
# rot_spin = self.spin if symmop.is_fm else (self.spin + 1) % 2
# # Build new wave and set the mesh.
# new = self.__class__(self.nspinor, rot_spin, self.band, rot_gsphere, rot_ug)
# new.set_mesh(mesh if mesh is not None else self.mesh)
# return new
[docs]
@add_fig_kwargs
def plot_line(self, point1, point2, num=200, with_krphase=False, cartesian=False,
ax=None, fontsize=12, **kwargs) -> Figure:
"""
Plot (interpolated) wavefunction in real space along a line defined by ``point1`` and ``point2``.
Args:
point1: First point of the line. Accepts 3d vector or integer.
The vector is in reduced coordinates unless ``cartesian`` is True.
If integer, the first point of the line is given by the i-th site of the structure
e.g. ``point1=0, point2=1`` gives the line passing through the first two atoms.
point2: Second point of the line. Same API as ``point1``.
num: Number of points sampled along the line.
with_krphase: True to include the :math:`e^{ikr}` phase-factor.
cartesian: By default, ``point1`` and ``point1`` are interpreted as points in fractional
coordinates (if not integers). Use True to pass points in cartesian coordinates.
ax: |matplotlib-Axes| or None if a new figure should be created.
fontsize: legend and title fontsize.
Return: |matplotlib-Figure|
"""
# Interpolate along line.
interpolator = self.get_interpolator()
r = interpolator.eval_line(point1, point2, num=num, cartesian=cartesian,
kpoint=None if not with_krphase else self.kpoint)
# Plot data.
ax, fig, plt = get_ax_fig_plt(ax=ax)
which = r"\psi(r)" if with_krphase else "u(r)"
for ispinor in range(self.nspinor):
spinor_label = latex_label_ispinor(ispinor, self.nspinor)
ur = r.values[ispinor]
ax.plot(r.dist, ur.real, label=r"$\Re %s$ %s" % (which, spinor_label))
ax.plot(r.dist, ur.imag, label=r"$\Im %s$ %s" % (which, spinor_label))
ax.plot(r.dist, ur.real**2 + ur.imag**2, label=r"$|\psi(r)|^2$ %s" % spinor_label)
ax.grid(True)
ax.set_xlabel("Distance from site1 [Angstrom]")
ax.legend(loc="best", fontsize=fontsize, shadow=True)
return fig
[docs]
@add_fig_kwargs
def plot_line_neighbors(self, site_index, radius, num=200, with_krphase=False,
max_nn=10, fontsize=12, **kwargs) -> Figure:
"""
Plot (interpolated) density/potential in real space along the lines connecting
an atom specified by ``site_index`` and all neighbors within a sphere of given ``radius``.
.. warning:
This routine can produce lots of plots! Be careful with the value of ``radius``.
See also ``max_nn``.
Args:
site_index: Index of the atom in the structure.
radius: Radius of the sphere in Angstrom.
num: Number of points sampled along the line.
with_krphase: True to include the :math:`e^{ikr}` phase-factor.
max_nn: By default, only the first ``max_nn`` neighbors are showed.
fontsize: legend and label fontsize.
Return: |matplotlib-Figure|
"""
site = self.structure[site_index]
nn_list = self.structure.get_neighbors_old(site, radius, include_index=True)
if not nn_list:
cprint("Zero neighbors found for radius %s Ang. Returning None." % radius, "yellow")
return None
# Sort sites by distance.
nn_list = list(sorted(nn_list, key=lambda t: t[1]))
if max_nn is not None and len(nn_list) > max_nn:
cprint("For radius %s, found %s neighbors but only max_nn %s sites are show." %
(radius, len(nn_list), max_nn), "yellow")
nn_list = nn_list[:max_nn]
# Get grid of axes (one row for neighbor)
nrows, ncols = len(nn_list), 1
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=True)
ax_list = ax_list.ravel()
interpolator = self.get_interpolator()
kpoint = None if not with_krphase else self.kpoint
which = r"\psi(r)" if with_krphase else "u(r)"
# For each neighbor, plot psi along the line connecting site to nn.
for i, (nn, ax) in enumerate(zip(nn_list, ax_list)):
#nn_site, nn_dist, nn_sc_index = nn
nn_site = nn
nn_dist = nn.nn_distance
nn_sc_index = nn.index
title = "%s, %s, dist=%.3f A" % (nn_site.species_string, str(nn_site.frac_coords), nn_dist)
r = interpolator.eval_line(site.frac_coords, nn_site.frac_coords, num=num, kpoint=kpoint)
for ispinor in range(self.nspinor):
spinor_label = latex_label_ispinor(ispinor, self.nspinor)
ur = r.values[ispinor]
ax.plot(r.dist, ur.real, label=r"$\Re %s$ %s" % (which, spinor_label))
ax.plot(r.dist, ur.imag, label=r"$\Im %s$ %s" % (which, spinor_label))
ax.plot(r.dist, ur.real**2 + ur.imag**2, label=r"$|\psi(r)|^2$ %s" % spinor_label)
ax.set_title(title, fontsize=fontsize)
ax.grid(True)
if i == nrows - 1:
ax.set_xlabel("Distance from site_index %s [Angstrom]" % site_index)
ax.legend(loc="best", fontsize=fontsize, shadow=True)
return fig
[docs]
def export_ur2(self, filename, visu=None):
"""
Export :math:`|u(r)|^2` to file ``filename``.
Args:
filename: String specifying the file path and the file format.
The format is defined by the file extension. filename="prefix.xsf", for example,
will produce a file in XSF format. An *empty* prefix, e.g. ".xsf" makes the code use a temporary file.
visu: :class:`Visualizer` subclass. By default, this method returns the first available
visualizer that supports the given file format. If visu is not None, an
instance of visu is returned. See :class:`Visualizer` for the list of
applications and formats supported.
Returns:
Instance of :class:`Visualizer`
"""
if "." not in filename:
raise ValueError("Cannot detect file extension in: %s" % filename)
tokens = filename.strip().split(".")
ext = tokens[-1]
if not tokens[0]:
# fname == ".ext" ==> Create temporary file.
# dir = os.getcwd() is needed when we invoke the method from a notebook.
from abipy.core.globals import abinb_mkstemp
_, filename = abinb_mkstemp(suffix="." + ext, text=True)
print("Creating temporary file: %s" % filename)
# Compute |u(r)|2 and write data according to ext.
ur2 = np.reshape(self.ur2, (1,) + self.ur2.shape)
with open(filename, mode="wt") as fh:
if ext == "xsf":
# xcrysden
xsf_write_structure(fh, structures=self.structure)
xsf_write_data(fh, self.structure, ur2, add_replicas=True)
else:
raise NotImplementedError("extension %s is not supported." % ext)
if visu is None:
return Visualizer.from_file(filename)
else:
return visu(filename)
[docs]
def visualize_ur2(self, appname="vesta"):
"""
Visualize :math:`|u(r)|^2|`.
See :class:`Visualizer` for the list of applications and formats supported.
"""
# Get the Visualizer subclass from the string.
visu = Visualizer.from_name(appname)
# Try to export data to one of the formats supported by the visualizer
# Use a temporary file (note "." + ext)
for ext in visu.supported_extensions():
ext = "." + ext
try:
return self.export_ur2(ext, visu=visu)
except visu.Error:
pass
else:
raise visu.Error("Don't know how to export data for %s" % str(appname))
#def mvplot_cutplanes(self, show=True):
# data = self.ur2
# from abipy.display import mvtk
# figure, mlab = mvtk.get_fig_mlab(figure=None)
# source = mlab.pipeline.scalar_field(data)
# mlab.pipeline.image_plane_widget(source, plane_orientation='x_axes', slice_index=data.shape[0]//2)
# mlab.pipeline.image_plane_widget(source, plane_orientation='y_axes', slice_index=data.shape[1]//2)
# mlab.pipeline.image_plane_widget(source, plane_orientation='z_axes', slice_index=data.shape[2]//2)
# #mlab.pipeline.iso_surface(source, contours=contours) #, opacity=0.1)
# #mlab.pipeline.iso_surface(source, contours=[data.min()+ 0.1 * data.ptp()], opacity=0.1)
# mlab.outline()
# if show: mlab.show()
# return figure
#def mvplot_volume(self, figure=None, vmin=0.65, vmax=0.9, show=True):
# from abipy.display import mvtk
# figure, mlab = mvtk.get_fig_mlab(figure=figure)
# data = self.ur2
# source = mlab.pipeline.scalar_field(data)
# data_min, data_max = data.min(), data.max()
# mlab.pipeline.volume(source,
# vmin=data_min + vmin * (data_max - data_min),
# vmax=data_min + vmax * (data_max - data_min))
# if show: mlab.show()
# return figure
class PAW_WaveFunction(WaveFunction):
"""
All the methods that are related to the all-electron representation should start with ae.
.. rubric:: Inheritance Diagram
.. inheritance-diagram:: PAW_WaveFunction
"""