Source code for abipy.flowtk.abiinspect

# pylint: disable=no-member, chained-comparison, unnecessary-comprehension, not-callable
"""
This module provides objects to inspect the status of the Abinit tasks at run-time.
by extracting information from the main output file (text format).
"""
from __future__ import annotations

import os
from collections import OrderedDict
from collections.abc import Iterable, Iterator, Mapping
from typing import Union

import numpy as np
from monty.collections import AttrDict
from monty.functools import lazy_property
from tabulate import tabulate
from abipy.tools.iotools import yaml_safe_load
from abipy.tools.plotting import (add_fig_kwargs, get_axarray_fig_plt,
    get_figs_plotly, plotly_set_lims, add_plotly_fig_kwargs)
from abipy.tools.typing import Figure


[docs] def straceback() -> str: """Returns a string with the traceback.""" import traceback return traceback.format_exc()
def _magic_parser(stream, magic: str) -> dict: """ Parse the section with the SCF cycle Returns: dict where the key are the name of columns and the values are list of numbers. Note if no section was found. .. warning:: The parser is very fragile and should be replaced by YAML. """ # Example (SCF cycle, similar format is used for phonons): # # iter Etot(hartree) deltaE(h) residm vres2 # ETOT 1 -8.8604027880849 -8.860E+00 2.458E-02 3.748E+00 # At SCF step 5 vres2 = 3.53E-08 < tolvrs= 1.00E-06 =>converged. in_doc, fields = 0, None for line in stream: line = line.strip() if line.startswith(magic): # print("Found magic token in line:", line) keys = line.split() fields = OrderedDict((k, []) for k in keys) if fields is not None: # print(line) in_doc += 1 if in_doc == 1: continue # End of the section or empty SCF cycle if not line or line.startswith("prteigrs"): break # print("Try to parse line:", line) tokens = list(map(float, line.split()[1:])) assert len(tokens) == len(keys) for l, v in zip(fields.values(), tokens): l.append(v) # Convert values to numpy arrays. if fields: return OrderedDict([(k, np.array(v)) for k, v in fields.items()]) return None
[docs] def plottable_from_outfile(filepath: str): """ Factory function that returns a plottable object by inspecting the main output file of abinit Returns None if it is not able to detect the class to instantiate. """ # TODO # Figure out how to detect the type of calculations # without having to parse the input. Possible approach: YAML doc # with YamlTokenizer(filepath) as r: # doc = r.next_doc_with_tag("!CalculationType") # d = yaml.safe_load(doc.text_notag) # calc_type = d["calculation_type"] # ctype2class = { # "Ground State": GroundStateScfCycle, # "Phonon": PhononScfCycle, # "Relaxation": Relaxation, # } # obj = ctype2class.get(calc_type, None) obj = GroundStateScfCycle if obj is not None: return obj.from_file(filepath) return None
# Use log scale for these variables. _VARS_SUPPORTING_LOGSCALE = set(["residm", "vres2", "nres2"]) # Hard-coded y-range for selected variables. _VARS_WITH_YRANGE = { "deltaE(h)": (-1e-3, +1e-3), "deltaE(Ha)": (-1e-3, +1e-3), }
[docs] class ScfCycle(Mapping): """ It essentially consists of a dictionary mapping string to list of floats containing the data at the different iterations. .. attribute:: num_iterations: Number of iterations performed. """ MAGIC = "Must be defined by the subclass." "" def __init__(self, fields: dict) -> None: """ Args: fields: Dictionary with label --> list of numerical values. """ self.fields = fields all_lens = [len(lst) for lst in self.values()] self.num_iterations = all_lens[0] assert all(n == self.num_iterations for n in all_lens) def __getitem__(self, slice): return self.fields.__getitem__(slice) def __iter__(self): return self.fields.__iter__() def __len__(self) -> int: return len(self.fields) def __str__(self) -> str: return self.to_string()
[docs] def to_string(self, verbose=0) -> str: """String representation.""" rows = [[it + 1] + list(map(str, (v[it] for k, v in self.items()))) for it in range(self.num_iterations)] return tabulate(rows, headers=["Iter"] + list(self.keys()))
@property def last_iteration(self) -> dict: """Returns a dictionary with the values of the last iteration.""" return {k: v[-1] for k, v in self.items()}
[docs] @classmethod def from_file(cls, filepath: str) -> ScfCycle: """Read the first occurrence of ScfCycle from file.""" with open(filepath, "rt") as stream: return cls.from_stream(stream)
[docs] @classmethod def from_stream(cls, stream) -> Union[ScfCycle, None]: """ Read the first occurrence of ScfCycle from stream. Returns: None if no `ScfCycle` entry is found. """ fields = _magic_parser(stream, magic=cls.MAGIC) if fields: fields.pop("iter") return cls(fields) return None
[docs] @add_fig_kwargs def plot(self, ax_list=None, fontsize=8, **kwargs) -> Figure: """ Uses matplotlib to plot the evolution of the SCF cycle. Args: ax_list: List of axes. If None a new figure is produced. fontsize: legend fontsize. kwargs: keyword arguments are passed to ax.plot Returns: matplotlib figure """ # Build grid of plots. num_plots, ncols, nrows = len(self), 1, 1 if num_plots > 1: ncols = 2 nrows = num_plots // ncols + num_plots % ncols ax_list, fig, plot = get_axarray_fig_plt( ax_list, nrows=nrows, ncols=ncols, sharex=True, sharey=False, squeeze=False ) ax_list = np.array(ax_list).ravel() iter_num = np.array(list(range(self.num_iterations))) + 1 label = kwargs.pop("label", None) for i, ((key, values), ax) in enumerate(zip(self.items(), ax_list)): ax.grid(True) ax.set_xlabel("Iteration Step") ax.set_xticks(iter_num, minor=False) ax.set_ylabel(key) xx, yy = iter_num, values if self.num_iterations > 1: # Don't show the first iteration since it's not very useful. xx, yy = xx[1:], values[1:] if not kwargs and label is None: ax.plot(xx, yy, "-o", lw=2.0) else: ax.plot(xx, yy, label=label if i == 0 else None, **kwargs) if key in _VARS_SUPPORTING_LOGSCALE and np.all(yy > 1e-22): ax.set_yscale("log") if key in _VARS_WITH_YRANGE: ymin, ymax = _VARS_WITH_YRANGE[key] val_min, val_max = np.min(yy), np.max(yy) if abs(val_max - val_min) > abs(ymax - ymin): ax.set_ylim(ymin, ymax) if label is not None: ax.legend(loc="best", fontsize=fontsize, shadow=True) # Get around a bug in matplotlib. if num_plots % ncols != 0: ax_list[-1].plot(xx, yy, lw=0.0) ax_list[-1].axis("off") return fig
[docs] @add_plotly_fig_kwargs def plotly(self, fontsize=12, **kwargs): """ Uses plotly to plot the evolution of the SCF cycle. Args: fontsize: legend fontsize. kwargs: keyword arguments are passed to go.Scatter. Returns: plotly figure """ # Build grid of plots. num_plots, ncols, nrows = len(self), 1, 1 if num_plots > 1: ncols = 2 nrows = num_plots // ncols + num_plots % ncols # Build fig. fig, go = get_figs_plotly(nrows=nrows, ncols=ncols, sharex=True, sharey=False) iter_num = np.array(list(range(self.num_iterations))) + 1 for i, (key, values) in enumerate(self.items()): row, col = divmod(i, ncols) row += 1 col += 1 xx, yy = iter_num, values if self.num_iterations > 1: # Don't show the first iteration since it's not very useful. xx, yy = xx[1:], values[1:] trace = go.Scatter(x=xx, y=yy, mode="lines+markers", showlegend=False, **kwargs) fig.add_trace(trace, row=row, col=col) fig.update_xaxes(title_text="Iteration Step", row=row, col=col) fig.update_yaxes(title_text=key, row=row, col=col) if key in _VARS_SUPPORTING_LOGSCALE and np.all(yy > 1e-22): fig.update_yaxes(type="log", row=row, col=col) if key in _VARS_WITH_YRANGE: ymin, ymax = _VARS_WITH_YRANGE[key] val_min, val_max = np.min(yy), np.max(yy) if abs(val_max - val_min) > abs(ymax - ymin): plotly_set_lims(fig, (ymin, ymax), "y", iax=i + 1) fig.layout.legend.font.size = fontsize fig.layout.title.font.size = fontsize return fig
[docs] class GroundStateScfCycle(ScfCycle): """ Result of the Ground State self-consistent cycle. """ MAGIC = "iter Etot(hartree)" @property def last_etotal(self) -> float: """The total energy at the last iteration.""" return self["Etot(hartree)"][-1]
[docs] class D2DEScfCycle(ScfCycle): """Result of the Phonon self-consistent cycle.""" MAGIC = "iter 2DEtotal(Ha)" @property def last_etotal(self) -> float: """The 2-nd order derivative of the energy at the last iteration.""" return self["2DEtotal(Ha)"][-1]
[docs] class PhononScfCycle(D2DEScfCycle): """Iterations of the DFPT SCF cycle for phonons."""
[docs] class CyclesPlotter: """Relies on the plot method of cycle objects to build multiple subfigures.""" def __init__(self): """Initialize object.""" self.labels = [] self.cycles = []
[docs] def items(self): """To iterate over (label, cycle).""" return zip(self.labels, self.cycles)
[docs] def add_label_cycle(self, label, cycle): """Add new cycle to the plotter with label `label`.""" self.labels.append(label) self.cycles.append(cycle)
[docs] @add_fig_kwargs def combiplot(self, fontsize=8, **kwargs) -> Figure: """ Compare multiple cycels on a grid: one subplot per quantity, all cycles on the same subplot. Args: fontsize: Legend fontsize. """ ax_list = None for i, (label, cycle) in enumerate(self.items()): fig = cycle.plot( ax_list=ax_list, label=label, fontsize=fontsize, lw=2.0, marker="o", linestyle="-", show=False, ) ax_list = fig.axes return fig
[docs] def slideshow(self, **kwargs): """ Produce slides show of the different cycles. One plot per cycle. """ for label, cycle in self.items(): cycle.plot(title=label, tight_layout=True)
[docs] class Relaxation(Iterable): """ A list of :class:`GroundStateScfCycle` objects. .. attribute:: num_iterations: Number of iterations performed. .. note:: Forces, stresses and crystal structures are missing. This object is mainly used to analyze the behavior of the Scf cycles during the structural relaxation. A more powerful and detailed analysis can be obtained by using the HIST.nc file. """ def __init__(self, cycles: list[GroundStateScfCycle]) -> None: """ Args cycles: list of `GroundStateScfCycle` objects. """ self.cycles = cycles self.num_iterations = len(self.cycles) def __iter__(self) -> Iterator[GroundStateScfCycle]: return self.cycles.__iter__() def __len__(self) -> int: return self.cycles.__len__() def __getitem__(self, slice) -> Union[GroundStateScfCycle, list[GroundStateScfCycle]]: return self.cycles[slice] def __str__(self) -> str: return self.to_string()
[docs] def to_string(self, verbose: int = 0) -> str: """String representation.""" lines = [] app = lines.append for i, cycle in enumerate(self): app("") app("RELAXATION STEP: %d" % (i + 1)) app(cycle.to_string(verbose=verbose)) return "\n".join(lines)
[docs] @classmethod def from_file(cls, filepath: str) -> Union[Relaxation, None]: """Initialize the object from the Abinit main output file.""" with open(filepath, "rt") as stream: return cls.from_stream(stream)
[docs] @classmethod def from_stream(cls, stream) -> Union[Relaxation, None]: """ Extract data from stream. Returns None if some error occurred. """ cycles = [] while True: scf_cycle = GroundStateScfCycle.from_stream(stream) if scf_cycle is None: break cycles.append(scf_cycle) return cls(cycles) if cycles else None
[docs] @lazy_property def history(self) -> dict: """ dictionary of lists with the evolution of the data as function of the relaxation step. """ history = {} for cycle in self: d = cycle.last_iteration for k, v in d.items(): if k in history: history[k].append(v) else: history[k] = [v] # Convert to numpy arrays. for k, v in history.items(): history[k] = np.array(v) return history
[docs] def slideshow(self, **kwargs): """ Uses matplotlib to plot the evolution of the structural relaxation. Args: ax_list: List of axes. If None a new figure is produced. Returns: `matplotlib` figure """ for i, cycle in enumerate(self.cycles): cycle.plot( title="Relaxation step %s" % (i + 1), tight_layout=kwargs.pop("tight_layout", True), show=kwargs.pop("show", True), )
[docs] @add_fig_kwargs def plot(self, ax_list=None, fontsize=8, **kwargs) -> Figure: """ Plot relaxation history i.e. the results of the last iteration of each SCF cycle. Args: ax_list: List of axes. If None a new figure is produced. fontsize: legend fontsize. kwargs: keyword arguments are passed to ax.plot Returns: matplotlib figure """ history = self.history # Build grid of plots. num_plots, ncols, nrows = len(history), 1, 1 if num_plots > 1: ncols = 2 nrows = num_plots // ncols + num_plots % ncols ax_list, fig, plot = get_axarray_fig_plt( ax_list, nrows=nrows, ncols=ncols, sharex=True, sharey=False, squeeze=False ) ax_list = np.array(ax_list).ravel() iter_num = np.array(list(range(self.num_iterations))) + 1 label = kwargs.pop("label", None) for i, ((key, values), ax) in enumerate(zip(history.items(), ax_list)): ax.grid(True) ax.set_xlabel("Relaxation Step") ax.set_xticks(iter_num, minor=False) ax.set_ylabel(key) xx, yy = iter_num, values if not kwargs and label is None: ax.plot(xx, yy, "-o", lw=2.0) else: ax.plot(xx, yy, label=label if i == 0 else None, **kwargs) if key in _VARS_SUPPORTING_LOGSCALE and np.all(yy > 1e-22): ax.set_yscale("log") if key in _VARS_WITH_YRANGE: ymin, ymax = _VARS_WITH_YRANGE[key] val_min, val_max = np.min(yy), np.max(yy) if abs(val_max - val_min) > abs(ymax - ymin): ax.set_ylim(ymin, ymax) if label is not None: ax.legend(loc="best", fontsize=fontsize, shadow=True) # Get around a bug in matplotlib. if num_plots % ncols != 0: ax_list[-1].plot(xx, yy, lw=0.0) ax_list[-1].axis("off") return fig
# TODO # class HaydockIterations(Iterable): # """This object collects info on the different steps of the Haydock technique used in the Bethe-Salpeter code""" # @classmethod # def from_file(cls, filepath): # """Initialize the object from file.""" # with open(filepath, "rt") as stream: # return cls.from_stream(stream) # # @classmethod # def from_stream(cls, stream): # """Extract data from stream. Returns None if some error occurred.""" # cycles = [] # while True: # scf_cycle = GroundStateScfCycle.from_stream(stream) # if scf_cycle is None: break # cycles.append(scf_cycle) # # return cls(cycles) if cycles else None # # #def __init__(self): # # def plot(self, **kwargs): # """ # Uses matplotlib to plot the evolution of the structural relaxation. # ============== ============================================================== # kwargs Meaning # ============== ============================================================== # title Title of the plot (Default: None). # how True to show the figure (Default). # savefig 'abc.png' or 'abc.eps'* to save the figure to a file. # ============== ============================================================== # Returns: # `matplotlib` figure # """ # import matplotlib.pyplot as plt # title = kwargs.pop("title", None) # show = kwargs.pop("show", True) # savefig = kwargs.pop("savefig", None) # if title: fig.suptitle(title) # if savefig is not None: fig.savefig(savefig) # if show: plt.show() # return fig ################## # Yaml parsers. ##################
[docs] class YamlTokenizerError(Exception): """Exceptions raised by :class:`YamlTokenizer`."""
[docs] class YamlTokenizer(Iterator): """ Provides context-manager support so you can use it in a with statement. """ Error = YamlTokenizerError def __init__(self, filename: str): """ Args: filename: Filename """ # The position inside the file. self.linepos = 0 self.filename = filename try: self.stream = open(filename, "rt") # pylint: disable=R1732 except IOError as exc: # Look for associated error file. root, ext = os.path.splitext(self.filename) errfile = root + ".err" if os.path.exists(errfile) and errfile != self.filename: print("Found error file: %s" % errfile) with open(errfile, "rt") as fh: print(fh.read()) raise exc def __iter__(self): return self def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() def __del__(self): self.close()
[docs] def close(self) -> None: """Close the stream.""" try: self.stream.close() except Exception: print("Exception in YAMLTokenizer.close()") print("Python traceback:") print(straceback())
[docs] def seek(self, offset, whence=0): """ seek(offset[, whence]) -> None. Move to new file position. Argument offset is a byte count. Optional argument whence defaults to 0 (offset from start of file, offset should be >= 0); other values are 1 (move relative to current position, positive or negative), and 2 (move relative to end of file, usually negative, although many platforms allow seeking beyond the end of a file). If the file is opened in text mode, only offsets returned by tell() are legal. Use of other offsets causes undefined behavior. Note that not all file objects are seekable. """ assert offset == 0 self.linepos = 0 return self.stream.seek(offset, whence)
# Python 3 compatibility def __next__(self): return self.next()
[docs] def next(self): """ Returns the first YAML document in stream. .. warning:: Assume that the YAML document are closed explicitely with the sentinel '...' """ in_doc, lines, doc_tag = None, [], None for i, line in enumerate(self.stream): self.linepos += 1 # print(i, line) if line.startswith("---"): # Include only lines in the form: # "--- !tag" # "---" # Other lines are spurious. in_doc = False l = line[3:].strip().lstrip() if l.startswith("!"): # "--- !tag" doc_tag = l in_doc = True elif not l: # "---" in_doc = True doc_tag = None if in_doc: lineno = self.linepos if in_doc: lines.append(line) if in_doc and line.startswith("..."): return YamlDoc(text="".join(lines), lineno=lineno, tag=doc_tag) raise StopIteration("Cannot find next YAML document in %s" % self.filename)
[docs] def all_yaml_docs(self) -> list: """ Returns a list with all the YAML docs found in stream. Seek the stream before returning. .. warning:: Assume that all the YAML docs (with the exception of the last one) are closed explicitely with the sentinel '...' """ docs = [doc for doc in self] self.seek(0) return docs
[docs] def next_doc_with_tag(self, doc_tag: str): """ Returns the next document with the specified tag. Empty string is no doc is found. """ while True: try: doc = next(self) if doc.tag == doc_tag: return doc except StopIteration: raise
[docs] def all_docs_with_tag(self, doc_tag: str) -> list: """ Returns all the documents with the specified tag. """ docs = [] while True: try: doc = self.next_doc_with(doc_tag) docs.append(doc) except StopIteration: break self.seek(0) return docs
[docs] def yaml_read_kpoints(filename: str, doc_tag: str = "!Kpoints") -> np.ndarray: """Read the K-points from file. Return numpy array""" with YamlTokenizer(filename) as r: doc = r.next_doc_with_tag(doc_tag) d = yaml_safe_load(doc.text_notag) return np.array(d["reduced_coordinates_of_qpoints"])
[docs] def yaml_read_irred_perts(filename: str, doc_tag="!IrredPerts") -> list[AttrDict]: """Read the list of irreducible perturbations from file.""" with YamlTokenizer(filename) as r: doc = r.next_doc_with_tag(doc_tag) d = yaml_safe_load(doc.text_notag) return [AttrDict(**pert) for pert in d["irred_perts"]]
[docs] class YamlDoc: """ Handy object that stores that YAML document, its main tag and the position inside the file. """ __slots__ = [ "text", "lineno", "tag", ] def __init__(self, text: str, lineno: int, tag=None): """ Args: text: String with the YAML document. lineno: The line number where the document is located. tag: The YAML tag associate to the document. """ # Sanitize strings: use "ignore" to skip invalid characters in .encode/.decode like if isinstance(text, bytes): text = text.decode("utf-8", "ignore") text = text.rstrip().lstrip() self.text = text self.lineno = lineno if isinstance(tag, bytes): tag = tag.decode("utf-8", "ignore") self.tag = tag def __str__(self): return self.text def __eq__(self, other): if other is None: return False return self.text == other.text and self.lineno == other.lineno and self.tag == other.tag def __ne__(self, other): return not (self == other) @property def text_notag(self) -> str: """ Returns the YAML text without the tag. Useful if we don't have any constructor registered for the tag (we used the tag just to locate the document). """ if self.tag is not None: return self.text.replace(self.tag, "") return self.text
[docs] def as_dict(self) -> dict: """Use Yaml to parse the text (without the tag) and returns a dictionary.""" return yaml_safe_load(self.text_notag)