Source code for abipy.tools.plotting

# coding: utf-8
"""
Utilities for generating matplotlib plots.

.. note::

    Avoid importing matplotlib or plotly in the module namespace otherwise startup is very slow.
"""
from __future__ import annotations

import os
import time
import itertools
import functools
import numpy as np
import pandas as pd
import matplotlib.collections as mcoll

from collections import namedtuple, OrderedDict
from typing import Any, Callable, Iterator
from monty.string import list_strings
from abipy.tools import duck
from abipy.tools.iotools import dataframe_from_filepath
from abipy.tools.typing import Figure, Axes, VectorLike
from abipy.tools.numtools import data_from_cplx_mode


__all__ = [
    "set_axlims",
    "add_fig_kwargs",
    "get_ax_fig_plt",
    "get_axarray_fig_plt",
    "get_ax3d_fig_plt",
    "plot_array",
    "ArrayPlotter",
    "data_from_cplx_mode",
    "Marker",
    "plot_unit_cell",
    "GenericDataFilePlotter",
    "GenericDataFilesPlotter",
    "add_plotly_fig_kwargs",
    "get_fig_plotly",
    "get_figs_plotly",
]


# https://matplotlib.org/gallery/lines_bars_and_markers/linestyles.html
linestyles = OrderedDict(
    [('solid',               (0, ())),
     ('loosely_dotted',      (0, (1, 10))),
     ('dotted',              (0, (1, 5))),
     ('densely_dotted',      (0, (1, 1))),
     #
     ('loosely_dashed',      (0, (5, 10))),
     ('dashed',              (0, (5, 5))),
     ('densely_dashed',      (0, (5, 1))),
     #
     ('loosely_dashdotted',  (0, (3, 10, 1, 10))),
     ('dashdotted',          (0, (3, 5, 1, 5))),
     ('densely_dashdotted',  (0, (3, 1, 1, 1))),
     #
     ('loosely_dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('densely_dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]
)



[docs] def add_fig_kwargs(func): """Decorator that adds keyword arguments for functions returning matplotlib figures. The function should return either a matplotlib figure or None to signal some sort of error/unexpected event. See doc string below for the list of supported options. """ @functools.wraps(func) def wrapper(*args, **kwargs): # pop the kwds used by the decorator. title = kwargs.pop("title", None) size_kwargs = kwargs.pop("size_kwargs", None) show = kwargs.pop("show", True) savefig = kwargs.pop("savefig", None) tight_layout = kwargs.pop("tight_layout", False) ax_grid = kwargs.pop("ax_grid", None) ax_annotate = kwargs.pop("ax_annotate", None) fig_close = kwargs.pop("fig_close", False) plotly = kwargs.pop("plotly", False) # Call func and return immediately if None is returned. fig = func(*args, **kwargs) if fig is None: return fig # Operate on matplotlib figure. if title is not None: fig.suptitle(title) if size_kwargs is not None: fig.set_size_inches(size_kwargs.pop("w"), size_kwargs.pop("h"), **size_kwargs) if ax_grid is not None: for ax in fig.axes: ax.grid(bool(ax_grid)) if ax_annotate: tags = ascii_letters if len(fig.axes) > len(tags): tags = (1 + len(ascii_letters) // len(fig.axes)) * ascii_letters for ax, tag in zip(fig.axes, tags): ax.annotate(f"({tag})", xy=(0.05, 0.95), xycoords="axes fraction") if tight_layout: try: fig.tight_layout() except Exception as exc: # For some unknown reason, this problem shows up only on travis. # https://stackoverflow.com/questions/22708888/valueerror-when-using-matplotlib-tight-layout print("Ignoring Exception raised by fig.tight_layout\n", str(exc)) if savefig: fig.savefig(savefig) if plotly: try: plotly_fig = mpl_to_ply(fig, latex=False) if show: plotly_fig.show() return plotly_fig except Exception as exc: print("Exception while convertig matplotlib figure to plotly. Returning mpl figure!") print(str(exc)) pass import matplotlib.pyplot as plt if show: plt.show() if fig_close: plt.close(fig=fig) return fig # Add docstring to the decorated method. doc_str = """\n\n Keyword arguments controlling the display of the figure: ================ ==================================================== kwargs Meaning ================ ==================================================== title Title of the plot (Default: None). show True to show the figure (default: True). savefig "abc.png" or "abc.eps" to save the figure to a file. size_kwargs Dictionary with options passed to fig.set_size_inches e.g. size_kwargs=dict(w=3, h=4) tight_layout True to call fig.tight_layout (default: False) ax_grid True (False) to add (remove) grid from all axes in fig. Default: None i.e. fig is left unchanged. ax_annotate Add labels to subplots e.g. (a), (b). Default: False fig_close Close figure. Default: False. plotly Try to convert mpl figure to plotly. ================ ==================================================== """ if wrapper.__doc__ is not None: # Add s at the end of the docstring. wrapper.__doc__ += f"\n{doc_str}" else: # Use s wrapper.__doc__ = doc_str return wrapper
class FilesPlotter: """ Use matplotlib to plot multiple png files on a grid. Example: FilesPlotter(["file1.png", file2.png"]).plot() """ def __init__(self, filepaths: list[str]): self.filepaths = list_strings(filepaths) @add_fig_kwargs def plot(self, **kwargs) -> Figure: """ Loop through the PNG files and display them in subplots. """ # Build grid of plots. num_plots, ncols, nrows = len(self.filepaths), 1, 1 if num_plots > 1: ncols = 2 nrows = (num_plots // ncols) + (num_plots % ncols) ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) ax_list = ax_list.ravel() # don't show the last ax if num_plots is odd. if num_plots % ncols != 0: ax_list[-1].axis("off") for i, (filepath, ax) in enumerate(zip(self.filepaths, ax_list)): ax.axis('off') ax.imshow(plt.imread(filepath)) return fig @functools.cache def get_color_symbol(style: str="VESTA") -> dict: """ Dictionary mapping chemical symbols to RGB color. Args: style: "VESTA" or "Jmol". """ from monty.serialization import loadfn from pymatgen import vis colors = loadfn(os.path.join(os.path.dirname(vis.__file__), "ElementColorSchemes.yaml")) if style not in colors: raise KeyError(f"Invalid {style=}. Should be in {colors.keys()}") color_symbol = {el: [j / 256.001 for j in colors[style][el]] for el in colors[style]} return color_symbol ################### # Matplotlib tools ###################
[docs] def get_ax_fig_plt(ax=None, **kwargs): """ Helper function used in plot functions supporting an optional Axes argument. If ax is None, we build the `matplotlib` figure and create the Axes else we return the current active figure. Args: ax (Axes, optional): Axes object. Defaults to None. kwargs: keyword arguments are passed to plt.figure if ax is not None. Returns: ax: :class:`Axes` object figure: matplotlib figure plt: matplotlib pyplot module. """ import matplotlib.pyplot as plt if ax is None: fig = plt.figure(**kwargs) ax = fig.gca() else: fig = plt.gcf() return ax, fig, plt
[docs] def get_ax3d_fig_plt(ax=None, **kwargs): """ Helper function used in plot functions supporting an optional Axes3D argument. If ax is None, we build the `matplotlib` figure and create the Axes3D else we return the current active figure. Args: ax (Axes3D, optional): Axes3D object. Defaults to None. kwargs: keyword arguments are passed to plt.figure if ax is not None. Returns: tuple[Axes3D, Figure]: matplotlib Axes3D and corresponding figure objects """ import matplotlib.pyplot as plt if ax is None: fig = plt.figure(**kwargs) ax = fig.add_subplot(projection="3d") else: fig = plt.gcf() return ax, fig, plt
[docs] def get_axarray_fig_plt( ax_array, nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw): """ Helper function used in plot functions that accept an optional array of Axes as argument. If ax_array is None, we build the `matplotlib` figure and create the array of Axes by calling plt.subplots else we return the current active figure. Returns: ax: Array of Axes objects figure: matplotlib figure plt: matplotlib pyplot module. """ import matplotlib.pyplot as plt if ax_array is None: fig, ax_array = plt.subplots( nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, squeeze=squeeze, subplot_kw=subplot_kw, gridspec_kw=gridspec_kw, **fig_kw, ) else: fig = plt.gcf() ax_array = np.reshape(np.array(ax_array), (nrows, ncols)) if squeeze: if ax_array.size == 1: ax_array = ax_array[0] elif any(s == 1 for s in ax_array.shape): ax_array = ax_array.ravel() return ax_array, fig, plt
def is_mpl_figure(obj: Any) -> bool: """Return True if obj is a matplotlib Figure.""" from matplotlib import pyplot as plt return isinstance(obj, plt.Figure) def ax_append_title(ax, title, loc="center", fontsize=None) -> str: """Add title to previous ax.title. Return new title.""" prev_title = ax.get_title(loc=loc) new_title = prev_title + title ax.set_title(new_title, loc=loc, fontsize=fontsize) return new_title def ax_share(xy_string: str, *ax_list) -> None: """ Share x- or y-axis of two or more subplots after they are created. Args: xy_string: "x" to share x-axis, "xy" for both ax_list: List of axes to share. Example: ax_share("y", ax0, ax1) ax_share("xy", *(ax0, ax1, ax2)) """ if "x" in xy_string: for ix, ax in enumerate(ax_list): others = [a for a in ax_list if a != ax] ax.get_shared_x_axes().join(*others) if "y" in xy_string: for ix, ax in enumerate(ax_list): others = [a for a in ax_list if a != ax] ax.get_shared_y_axes().join(*others)
[docs] def set_axlims(ax, lims: tuple, axname: str) -> tuple: """ Set the data limits for the axis ax. Args: lims: tuple(2) for (left, right), tuple(1) or scalar for left only. axname: "x" for x-axis, "y" for y-axis. Return: (left, right) """ left, right = None, None if lims is None: return left, right len_lims = None try: len_lims = len(lims) except TypeError: # Assume Scalar left = float(lims) if len_lims is not None: if len(lims) == 2: left, right = lims[0], lims[1] elif len(lims) == 1: left = lims[0] set_lim = getattr(ax, {"x": "set_xlim", "y": "set_ylim"}[axname]) if left != right: set_lim(left, right) return left, right
def set_ax_xylabels(ax, xlabel: str, ylabel: str, exchange_xy: bool = False) -> None: """ Set the x- and the y-label of axis ax, exchanging x and y if exchange_xy. """ if exchange_xy: xlabel, ylabel = ylabel, xlabel ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) def set_logscale(ax_or_axlist, xy_log) -> None: """ Activate logscale Args: ax_or_axlist: Axes or list of axes. xy_log: None or empty string for linear scale. "x" for log scale on x-axis. "xy" for log scale on x- and y-axis. "x:semilog" for semilog scale on x-axis. """ if not xy_log: return # Parse xy_log string. xy, log_type = xy_log, "log" if ":" in xy_log: xy, log_type = xy_log.split(":") ax_list = [ax_or_axlist] if not duck.is_listlike(ax_or_axlist) else ax_or_axlist for ix, ax in enumerate(ax_list): if "x" in xy: ax.set_xscale(log_type) if "y" in xy: ax.set_yscale(log_type) def set_ticks_fontsize(ax_or_axlist, fontsize: int, xy_string="xy", **kwargs) -> None: """ Set tick properties for one axis or a list of axis. Args: ax_or_axlist: Axes or list of axes. xy_string: "x" to share x-axis, "xy" for both. """ ax_list = [ax_or_axlist] if not duck.is_listlike(ax_or_axlist) else ax_or_axlist for ix, ax in enumerate(ax_list): if "x" in xy_string: ax.tick_params(axis='x', labelsize=fontsize, **kwargs) if "y" in xy_string: ax.tick_params(axis='y', labelsize=fontsize, **kwargs) def set_grid_legend(ax_or_axlist, fontsize: int, xlabel=None, ylabel=None, grid=True, legend=True, direction=None, title=None, legend_loc="best") -> None: """ Activate grid and legend for one axis or a list of axis. Args: grid: True to activate the grid. legend: True to activate the legend. direction: Use "x" ("y") if to add xlabel (ylabel) only to the last ax. title: Title string """ if duck.is_listlike(ax_or_axlist): for ix, ax in enumerate(ax_or_axlist): ax.grid(grid) # Check if there are artists with labels handles, labels = ax.get_legend_handles_labels() if legend and labels: # print("There are artists with labels:", labels) ax.legend(loc=legend_loc, fontsize=fontsize, shadow=True) if xlabel: doit = direction is None or (direction == "y" and ix == len(ax_or_axlist) -1) if doit: ax.set_xlabel(xlabel) if ylabel: doit = direction is None or (direction == "x" and ix == len(ax_or_axlist) -1) if doit: ax.set_ylabel(ylabel) if title: ax.set_title(title, fontsize=fontsize) else: ax = ax_or_axlist ax.grid(grid) # Check if there are artists with labels handles, labels = ax.get_legend_handles_labels() if legend and labels: ax.legend(loc=legend_loc, fontsize=fontsize, shadow=True) if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) if title: ax.set_title(title, fontsize=fontsize) def set_visible(ax, boolean: bool, *args) -> None: """ Hide/Show the artists of axis ax listed in args. ax can be a single axis, a list or axis or a numpy arrays. """ if duck.is_listlike(ax): if isinstance(ax, np.ndarray): for _ in ax.ravel(): set_visible(_, *args) else: for _ in ax: set_visible(_, *args) return if "legend" in args and ax.legend(): ax.legend().set_visible(boolean) if "title" in args and ax.title: ax.title.set_visible(boolean) if "xlabel" in args and ax.xaxis.label: ax.xaxis.label.set_visible(boolean) if "ylabel" in args and ax.yaxis.label: ax.yaxis.label.set_visible(boolean) if "xticklabels" in args: for label in ax.get_xticklabels(): label.set_visible(boolean) if "yticklabels" in args: for label in ax.get_yticklabels(): label.set_visible(boolean) def rotate_ticklabels(ax, rotation: float, axname: str ="x") -> None: """Rotate the ticklables of axis ``ax``""" if "x" in axname: for tick in ax.get_xticklabels(): tick.set_rotation(rotation) if "y" in axname: for tick in ax.get_yticklabels(): tick.set_rotation(rotation) def hspan_ax_line(ax, line, abs_conv, hatch, alpha=0.2, with_label=True) -> None: """ Add hspan to ax showing the convergence region of width `abs_conv`. Use same color as line. Return immediately if abs_conv is None or x-values are strings. """ if abs_conv is None: return xs = line.get_xdata() ys = line.get_ydata() if duck.is_string(xs[0]): return color = line.get_color() span_style = dict(alpha=0.2, color=color, hatch=hatch) x_max = xs[-1] x_inds = np.where(xs == x_max)[0] # This to support the case in which we have multiple ys for the same x_max for i, ix in enumerate(x_inds): y_xmax = ys[ix] ax.axhspan(y_xmax - abs_conv, y_xmax + abs_conv, label=r"$|y-y(x_{max})| \leq %s$" % abs_conv if (with_label and i == 0) else None, **span_style) @add_fig_kwargs def plot_xy_with_hue(data: pd.DataFrame, x: str, y: str, hue: str, decimals=None, ax=None, xlims=None, ylims=None, fontsize=8, **kwargs) -> Figure: """ Plot y = f(x) relation for different values of `hue`. Useful for convergence tests done wrt two parameters. Args: data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`. x: Name of the column used as x-value y: Name of the column(s) used as y-value hue: Variable that define subsets of the data, which will be drawn on separate lines decimals: Number of decimal places to round `hue` columns. Ignore if None ax: |matplotlib-Axes| or None if a new figure should be created. xlims, ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)` or scalar e.g. `left`. If left (right) is None, default values are used fontsize: Legend fontsize. kwargs: Keyword arguments are passed to ax.plot method. Returns: |matplotlib-Figure| """ if isinstance(y, (list, tuple)): # Recursive call for each ax in ax_list. num_plots, ncols, nrows = len(y), 1, 1 if num_plots > 1: ncols = 2 nrows = (num_plots // ncols) + (num_plots % ncols) ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) ax_list = ax_list.ravel() if num_plots % ncols != 0: ax_list[-1].axis('off') for ykey, ax in zip(y, ax_list): plot_xy_with_hue(data, x, str(ykey), hue, decimals=decimals, ax=ax, xlims=xlims, ylims=ylims, fontsize=fontsize, show=False, **kwargs) return fig # Check here because pandas error messages are a bit criptic. miss = [k for k in (x, y, hue) if k not in data] if miss: raise ValueError("Cannot find `%s` in dataframe.\nAvailable keys are: %s" % (str(miss), str(data.keys()))) # Truncate values in hue column so that we can group. if decimals is not None: data = data.round({hue: decimals}) ax, fig, plt = get_ax_fig_plt(ax=ax) for key, grp in data.groupby(by=hue): # Sort xs and rearrange ys xy = np.array(sorted(zip(grp[x], grp[y]), key=lambda t: t[0])) xvals, yvals = xy[:, 0], xy[:, 1] label = "%s" % (str(key)) if not kwargs: ax.plot(xvals, yvals, 'o-', label=label) else: ax.plot(xvals, yvals, label=label, **kwargs) ax.grid(True) ax.set_xlabel(x) ax.set_ylabel(y) set_axlims(ax, xlims, "x") set_axlims(ax, ylims, "y") ax.legend(loc="best", fontsize=fontsize, shadow=True) return fig def linear_fit_ax(ax, xs, ys, fontsize, with_label=True, with_ideal_line=False, **kwargs) -> tuple[float]: """ Calculate a linear least-squares regression for two sets of measurements. kwargs are passed to ax.plot. """ from scipy.stats import linregress fit = linregress(xs, ys) label = r"Linear fit $\alpha={:.2f}$, $r^2$={:.2f}".format(fit.slope, fit.rvalue**2) if "color" not in kwargs: kwargs["color"] = "r" ax.plot(xs, fit.slope*xs + fit.intercept, label=label if with_label else None, **kwargs) if with_ideal_line: # Plot y = x line ax.plot([xs[0], xs[-1]], [ys[0], ys[-1]], color='k', linestyle='-', linewidth=1, label='Ideal' if with_label else None) return fit
[docs] @add_fig_kwargs def plot_array(array, color_map=None, cplx_mode="abs", **kwargs) -> Figure: """ Use imshow for plotting 2D or 1D arrays. Return: |matplotlib-Figure| Example:: plot_array(np.random.rand(10,10)) See <http://stackoverflow.com/questions/7229971/2d-grid-data-visualization-in-python> Args: array: Array-like object (1D or 2D). color_map: color map. cplx_mode: Flag defining how to handle complex arrays. Possible values in ("re", "im", "abs", "angle") "re" for the real part, "im" for the imaginary part. "abs" means that the absolute value of the complex number is shown. "angle" will display the phase of the complex number in radians. """ # Handle vectors array = np.atleast_2d(array) array = data_from_cplx_mode(cplx_mode, array) import matplotlib as mpl from matplotlib import pyplot as plt if color_map is None: # make a color map of fixed colors color_map = mpl.colors.LinearSegmentedColormap.from_list('my_colormap', ['blue', 'black', 'red'], 256) img = plt.imshow(array, interpolation='nearest', cmap=color_map, origin='lower') # Make a color bar plt.colorbar(img, cmap=color_map) # Set grid plt.grid(True, color='white') fig = plt.gcf() return fig
class ConvergenceAnalyzer: """ This object allows one to plot the convergence of an arbitrary list of quantities as a function of the same x. """ # Colors for the different convergence criteria. color_ilevel = ["red", "blue", "green"] # matplotlib option to fill convergence window. HATCH = "/" @classmethod def from_xy_label_vals(cls, xlabel, xs, ylabel, yvalues, tols) -> ConvergenceAnalyzer: """ Simplified interface to analyze a single list of values. """ yvals_dict = {ylabel: yvalues} ytols_dict = {ylabel: tols} return cls(xlabel, xs, yvals_dict, ytols_dict) @classmethod def from_file(cls, filepath: str, xkey: str, ytols_dict: dict, **kwargs) -> ConvergenceAnalyzer: """ High-level constructor to build the object from a file containing data that can be converted to pandas DataFrame. kwargs are passed to the pandas IO routines. Args: filepath: Filename. xkey: name of the x-variable. ytols_dict: dict mapping the name of the y-variable to absolute tolerance(s). """ df = dataframe_from_filepath(filepath, **kwargs) return cls.from_dataframe(df, xkey, ytols_dict) @classmethod def from_dataframe(cls, df: pd.DataFrame, xkey: str, ytols_dict: dict) -> ConvergenceAnalyzer: """ Build the object from a pandas dataframe. Args: df: DataFrame xkey: name of the x-variable. ytols_dict: dict mapping the name of the y-variable to tolerance(s). """ df = df.sort_values(xkey) xs = df[xkey].values yvals_dict = {k: df[k].values for k in ytols_dict} return cls(xkey, xs, yvals_dict, ytols_dict) def __init__(self, xkey: str, xs: VectorLike, yvals_dict: dict[str, VectorLike], ytols_dict: dict): """ Args: xkey: xs: yvals_dict: ytols_dict: dict mapping the name of the y-variable to absolute tolerance(s). Example:: plotter = ConvergencePlotter("ecut", ecut_value, yvals_dict, ytols_dict) plotter.plot() """ # Convert to numpy arrays and store data in self. self.xkey = self.xlabel = xkey self.xs = np.array(xs) if not np.all(self.xs[:-1] <= self.xs[1:]): raise ValueError("xs values should be in ascending order") self.yvals_dict = {k: np.array(v) for k, v in yvals_dict.items()} self.ykey2label = {k: k for k in yvals_dict} if len(self.yvals_dict) > len(self.color_ilevel): raise ValueError(f"Not programmed for more than {len(self.color_ilevel)} convergence levels") # Handle ytols_dict. self.ytols_dict = {} for ykey, ytols in ytols_dict.items(): if not duck.is_listlike(ytols): ytols = [ytols] if any(yt <= 0 for yt in ytols): raise ValueError(f"tolerances cannot be negative: {ytols}") # Sort input tolerances just to be on the safe side. self.ytols_dict[ykey] = np.sort(np.array(ytols))[::-1] # Compute the first index in xs that gives value within the convergence window. # -1 or None indicates that convergence has not been achieved. self.ykey_ixs = {} self.ykey_best_xs = {} for ykey, ys in self.yvals_dict.items(): if len(ys) != len(xs): raise ValueError(f"len(ys) != len(xs): {len(ys)} and {len(xs)}") tol_levels = self.ytols_dict[ykey] # Init values assuming no convergence achieved. self.ykey_ixs[ykey] = [-1] * len(tol_levels) self.ykey_best_xs[ykey] = [None] * len(tol_levels) # For each y-tolerance. num_x = len(self.xs) y_xmax = ys[-1] for il, ytol in enumerate(tol_levels): for _, xx in enumerate(self.xs[::-1]): ix = -_ + num_x - 1 if abs(y_xmax - ys[ix]) > ytol: self.ykey_ixs[ykey][il] = ix + 1 break ix = self.ykey_ixs[ykey][il] if ix != -1: # If converged, use linear interpolation to get a better estimate of the # converged xx. This is useful especially if the xs grid is too coarse. best_xx = self.xs[ix] if ix - 1 >= 0: x0, y0 = xs[ix-1], ys[ix-1] x1, y1 = xs[ix], ys[ix] alpha = (y1 - y0) / (x1 - x0) # y(x) = alpha * (x - x0) + y0 #print("best_xx 1", best_xx) if (y0 - y_xmax) >= 0: best_xx = x0 + ( ytol + y_xmax - y0) / alpha if (y0 - y_xmax) < 0: best_xx = x0 + (-ytol + y_xmax - y0) / alpha #print("best_xx 2", best_xx) self.ykey_best_xs[ykey][il] = best_xx # Here we change the x-y labels for the plots using an hard-coded mapping # in order to add additional info on units and normalization. auto_key_label = dict( ecut=r"$E_{cut}$ (Ha)", energy_per_atom=r"$E/N_{at}$ (eV)", pressure="P (GPa)", ) for key, label in auto_key_label.items(): self.set_label(key, label, ignore_exc=True) def set_label(self, key: str, label: str, ignore_exc=False) -> None: """ Set the label for `key` to be used in the plot. Dont't raise exception if `ignore_exc` is True. """ if key in self.ykey2label: self.ykey2label[key] = label elif key == self.xkey: self.xlabel = label else: if not ignore_exc: raise ValueError( f"key:`{key}` should be either in {list(self.ykey2label.keys())} or {self.xname}") def get_ylabel(self, ykey: str) -> str: """Return the ylabel to be used for `ykey` in the plot.""" return self.ykey2label[ykey] def ytol_ix_xx(self, ykey) -> Iterator[tuple]: """ Iterate over (ytols, ixs, and xs) for the given ``ykey`. """ return zip(self.ytols_dict[ykey], self.ykey_ixs[ykey] ,self.ykey_best_xs[ykey]) def get_dataframe_ykey(self, ykey: str) -> pd.DataFrame: """Return dataframe with convergence params for `ykey`.""" rows = [] for ytol, ix, xx in self.ytol_ix_xx(ykey): rows.append(dict(ytol=ytol, ix=ix, xx=xx, ykey=ykey)) #rows.append(dict(ytol=ytol, ix=ix, xx=xx), xx_best=xx_best, ykey=ykey) return pd.DataFrame(rows) def to_string(self, verbose=0) -> str: """ String representation with verbosity level `verbose`. """ lines = []; app = lines.append app(f"Number of points for x-axis: {len(self.xs)}") for ykey in self.yvals_dict: app("ykey: %s" % ykey) df = self.get_dataframe_ykey(ykey) app(str(df)) return "\n".join(lines) def __str__(self) -> str: return self.to_string() def _decorate_ax(self, ax, ykey, ys, yscale) -> None: """ Decorate axis ax by adding patches showing the convergence window and vertical lines where convergence is achieved. Args: ax: matplotlib axes. ykey: y-name ys: y-values yscale: "linear" or "log" """ # Precompute y-limits of the converge window for each tolerance. y_xmax = ys[-1] ytols = self.ytols_dict[ykey] ntols = len(ytols) ylims = np.empty((ntols, 2)) ylims_log = np.empty((ntols, 2)) for il, ytol in enumerate(ytols): # Absolute tolerance. y0, y1 = y_xmax - ytol, y_xmax + ytol y1_log = ytol ylims[il] = [y0, y1] ylims_log[il] = [0, y1_log] # Loop again as ylimits are known. for il, ytol in enumerate(ytols): label = r"$|y-y_\infty| \leq %s$" % ytol span_style = dict(alpha=0.2, color=self.color_ilevel[il], zorder=abs(ytol), hatch=self.HATCH) y0, y1 = ylims[il] y0_log, y1_log = ylims_log[il] if il == ntols - 1: if yscale == "linear": ax.axhspan(y0, y1, label=label, **span_style) elif yscale == "log": ax.axhspan(y0_log, y1_log, label=label, **span_style) else: raise ValueError(f"Invalid yscale: {yscale}") else: # Use limits of the next window to avoid overlapping patches. if yscale == "linear": ax.axhspan(y0, ylims[il+1,0], label=label, **span_style) ax.axhspan(ylims[il+1,1], y1, **span_style) elif yscale == "log": ax.axhspan(y0_log, ylims_log[il+1,0], label=label, **span_style) ax.axhspan(ylims_log[il+1,1], y1_log, **span_style) else: raise ValueError(f"Invalid yscale: {yscale}") # Add vertical line to show best_xx. best_xx = self.ykey_best_xs[ykey][il] line_style = dict(lw=1, color=self.color_ilevel[il], ls=":") if best_xx is not None: ax.axvline(best_xx, **line_style) @add_fig_kwargs def plot(self, ax_mat=None, fontsize=8, **kwargs) -> Figure: """ Plot convergence profile. A new grid is build if `ax_mat` is None: """ nrows, ncols = len(self.yvals_dict), 2 ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) # TODO #for icol in range(ncols): # ax_share("x", ax_mat[0,icol], ax_mat[1,icol]) for irow, ((ykey, ys), ax_row) in enumerate(zip(self.yvals_dict.items(), ax_mat)): # Plot y(x) ax1, ax2 = ax_row ax1.plot(self.xs, ys, marker="o", color="k") ax1.set_ylabel(self.get_ylabel(ykey)) self._decorate_ax(ax1, ykey, ys, "linear") # Plot |y(x) - y_xmax| on log scale. abs_diffs = np.abs(ys - ys[-1]) ax2.plot(self.xs, abs_diffs, marker="o", color="k") ax2.set_yscale("log") self._decorate_ax(ax2, ykey, ys, "log") ax2.set_xlim(self.xs[0] - 1, self.xs[-2] + 1) title = "" for i, (ytol, ix, xx) in enumerate(self.ytol_ix_xx(ykey)): pre_str = "" if i == 0 else ", " ytol_string = str(ytol) #print("ytol_string:", ytol_string, "pre_str:", pre_str, "ytol_string:", ytol_string, "xx:", xx) if xx is not None: s = r"x: %.1f for $\Delta$: %s" % (xx, ytol_string) else: s = r"x: ?? for $\Delta$: %s" % (ytol_string) title += pre_str + s ax2.set_title(title, fontsize=fontsize) ax2.set_ylabel(r"$|y-y(x_{max})|$", fontsize=fontsize) set_grid_legend(ax_row, fontsize, xlabel=self.xlabel if irow == (nrows - 1) else None, grid=False, legend=True) fig.tight_layout() return fig
[docs] class ArrayPlotter: def __init__(self, *labels_and_arrays): """ Args: labels_and_arrays: list [("label1", arr1), ("label2", arr2)] """ self._arr_dict = {} for label, array in labels_and_arrays: self.add_array(label, array) def __len__(self) -> int: return len(self._arr_dict) def __iter__(self): return self._arr_dict.__iter__()
[docs] def keys(self): return self._arr_dict.keys()
[docs] def items(self): return self._arr_dict.items()
[docs] def add_array(self, label: str, array) -> None: """Add array with the given name.""" if label in self._arr_dict: raise ValueError("%s is already in %s" % (label, list(self._arr_dict.keys()))) self._arr_dict[label] = array
[docs] def add_arrays(self, labels: list, arr_list: list) -> None: """ Add a list of arrays Args: labels: List of labels. arr_list: List of arrays. """ assert len(labels) == len(arr_list) for label, arr in zip(labels, arr_list): self.add_array(label, arr)
[docs] @add_fig_kwargs def plot(self, cplx_mode="abs", colormap="jet", fontsize=8, **kwargs) -> Figure: """ Args: cplx_mode: "abs" for absolute value, "re", "im", "angle" colormap: matplotlib colormap. fontsize: legend and label fontsize. """ # Build grid of plots. num_plots, ncols, nrows = len(self), 1, 1 if num_plots > 1: ncols = 2 nrows = num_plots // ncols + (num_plots % ncols) import matplotlib.pyplot as plt fig, ax_mat = plt.subplots(nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) # Don't show the last ax if num_plots is odd. if num_plots % ncols != 0: ax_mat[-1, -1].axis("off") from mpl_toolkits.axes_grid1 import make_axes_locatable from matplotlib.ticker import MultipleLocator for ax, (label, arr) in zip(ax_mat.flat, self.items()): data = data_from_cplx_mode(cplx_mode, arr) # Use origin to place the [0, 0] index of the array in the lower left corner of the axes. img = ax.matshow(data, interpolation='nearest', cmap=colormap, origin='lower', aspect="auto") ax.set_title("(%s) %s" % (cplx_mode, label), fontsize=fontsize) # Make a color bar for this ax # Create divider for existing axes instance # http://stackoverflow.com/questions/18266642/multiple-imshow-subplots-each-with-colorbar divider3 = make_axes_locatable(ax) # Append axes to the right of ax, with 10% width of ax cax3 = divider3.append_axes("right", size="10%", pad=0.05) # Create colorbar in the appended axes # Tick locations can be set with the kwarg `ticks` # and the format of the ticklabels with kwarg `format` cbar3 = plt.colorbar(img, cax=cax3, ticks=MultipleLocator(0.2), format="%.2f") # Remove xticks from ax ax.xaxis.set_visible(False) # Manually set ticklocations #ax.set_yticks([0.0, 2.5, 3.14, 4.0, 5.2, 7.0]) # Set grid ax.grid(True, color='white') fig.tight_layout() return fig
#TODO Rename it to ScatterData?
[docs] class Marker: """ Stores the position and the size of the marker. A marker is a list of tuple(x, y, s) where x, and y are the position in the graph and s is the size of the marker. Used for plotting purpose e.g. QP data, energy derivatives... Example:: x, y, s = [1, 2, 3], [4, 5, 6], [0.1, 0.2, -0.3] marker = Marker(x, y, s) """ def __init__(self, x, y, s, **scatter_kwargs): #marker: str = "o", color: str = "y", alpha: float = 1.0, label=None, self.edgecolors=None): self.x, self.y, self.s = np.array(x), np.array(y), np.array(s) if len(self.x) != len(self.y): raise ValueError("len(self.x) != len(self.y)") if len(self.y) != len(self.s): raise ValueError("len(self.y) != len(self.s)") #self.marker = marker #self.color = color #self.alpha = alpha #self.label = label #self.edgecolors = edgecolors self.scatter_kwargs = scatter_kwargs # Step 1: Normalize sizes to a suitable range for plotting #min_size = 10 # Minimum size for points #max_size = 100 # Maximum size for points #normalized_s = min_size + (max_size - min_size) * (self.s - np.min(self.s)) / (np.max(self.s) - np.min(self.s)) #self.s = normalized_s def __bool__(self): return bool(len(self.s)) __nonzero__ = __bool__
[docs] def posneg_marker(self) -> tuple[Marker, Marker]: """ Split data into two sets: the first one contains all the points with positive size. The first set contains all the points with negative size. """ pos_x, pos_y, pos_s = [], [], [] neg_x, neg_y, neg_s = [], [], [] for x, y, s in zip(self.x, self.y, self.s): if s >= 0.0: pos_x.append(x) pos_y.append(y) pos_s.append(s) else: neg_x.append(x) neg_y.append(y) neg_s.append(s) return self.__class__(pos_x, pos_y, pos_s), self.__class__(neg_x, neg_y, neg_s)
class Exposer: """ Base class for Exposer objects. Example: plot_kws = dict(show=False) with Exposer.as_exposer("panel") as e: e(obj.plot1(**plot_kws)) e(obj.plot2(**plot_kws)) """ @classmethod def as_exposer(cls, exposer, **kwargs) -> Exposer: """ Return an instance of Exposer, usually from a string with then name. Args: exposer: "mpl" for MplExposer, "panel" for PanelExposer. """ if isinstance(exposer, cls): return exposer # Assume string. exposer_cls = dict( mpl=MplExposer, panel=PanelExposer, )[exposer] return exposer_cls(**kwargs) def add_obj_with_yield_figs(self, obj: Any) -> None: """ Add an object implementing a `yield_figs` method to the Exposer. """ if not hasattr(obj, "yield_figs"): raise TypeError(f"object of type {type(obj)} does not implement `yield_figs` method") for fig in obj.yield_figs(): self.add_fig(fig) def __call__(self, obj: Any): """ Add an object to the Exposer Support mpl figure, list of figures or generator yielding figures. """ import types if isinstance(obj, (types.GeneratorType, list, tuple)): for fig in obj: self.add_fig(fig) else: self.add_fig(obj) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): """Activated at the end of the with statement. """ if exc_type is not None: return self.expose() class MplExposer(Exposer): # pragma: no cover """ Context manager used to produce several matplotlib figures and show all of them at once so that users do not have to close the window to visualize the next one. Example: plot_args = dict(show=False) with MplExposer() as e: e(obj.plot1(**plot_args)) e(obj.plot2(**plot_args)) """ def __init__(self, slide_mode=False, slide_timeout=None, verbose=1, **kwargs): """ Args: slide_mode: If True, iterate over figures. Default: Expose all figures at once. slide_timeout: Close figure after slide-timeout seconds. Block if None. verbose: verbosity level """ self.figures = [] self.slide_mode = bool(slide_mode) self.timeout_ms = slide_timeout self.verbose = verbose if self.timeout_ms is not None: self.timeout_ms = int(self.timeout_ms * 1000) assert self.timeout_ms >= 0 if self.verbose: if self.slide_mode: print("\nSliding matplotlib figures with slide timeout: %s [s]" % slide_timeout) else: print("\nLoading all matplotlib figures before showing them. It may take some time...") self.start_time = time.time() def add_fig(self, fig: Figure) -> None: """ Add a matplotlib figure. """ if fig is None: return if not self.slide_mode: self.figures.append(fig) else: import matplotlib.pyplot as plt if self.timeout_ms is not None: # Creating a timer object # timer calls plt.close after interval milliseconds to close the window. timer = fig.canvas.new_timer(interval=self.timeout_ms) timer.add_callback(plt.close, fig) timer.start() plt.show() if hasattr(fig, "clear"): fig.clear() def expose(self) -> None: """ Show all figures. Clear figures if needed. """ if not self.slide_mode: print("All figures in memory, elapsed time: %.3f s" % (time.time() - self.start_time)) import matplotlib.pyplot as plt plt.show() for fig in self.figures: if hasattr(fig, "clear"): fig.clear() class PanelExposer(Exposer): # pragma: no cover """ Context manager used to produce several matplotlib/plotly figures and show all of them inside the web browser using a panel template. Example: with PanelExposer() as e: e(obj.plot1(show=False)) e(obj.plot2(show=False)) """ def __init__(self, title=None, dpi=92, verbose=1, **kwargs): """ Args: title: String to be show in the header. verbose: verbosity level """ self.title = title self.figures = [] self.verbose = verbose self.dpi = int(dpi) if self.verbose: print("\nLoading all figures before showing them. It may take some time...") self.start_time = time.time() def add_fig(self, fig: Figure) -> None: """Add a matplotlib figure.""" if fig is None: return self.figures.append(fig) def expose(self): """Show all figures. Clear figures if needed.""" import panel as pn pn.config.sizing_mode = 'stretch_width' from abipy.panels.core import get_template_cls_from_name cls = get_template_cls_from_name("FastGridTemplate") template = cls( title=self.title if self.title is not None else self.__class__.__name__, header_background="#ff8c00 ", # Dark orange ) #pn.config.sizing_mode = 'stretch_width' from abipy.panels.core import mpl, ply for i, fig in enumerate(self.figures): row, col = divmod(i, 2) if is_plotly_figure(fig): p = ply(fig, with_divider=False) elif is_mpl_figure(fig): p = mpl(fig, with_divider=False, dpi=self.dpi) else: raise TypeError(f"Don't know how to handle type: `{type(fig)}`") if hasattr(template.main, "append"): template.main.append(p) else: # Assume .main area acts like a GridSpec row_slice = slice(3 * row, 3 * (row + 1)) if col == 0: template.main[row_slice, :6] = p if col == 1: template.main[row_slice, 6:] = p return template.show()
[docs] def plot_unit_cell(lattice, ax=None, **kwargs) -> tuple[Figure, Axes]: """ Adds the unit cell of the lattice to a matplotlib Axes3D Args: lattice: Lattice object ax: matplotlib :class:`Axes3D` or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black and linewidth to 3. Returns: matplotlib figure and ax """ ax, fig, plt = get_ax3d_fig_plt(ax) if "color" not in kwargs: kwargs["color"] = "k" if "linewidth" not in kwargs: kwargs["linewidth"] = 3 v = 8 * [None] v[0] = lattice.get_cartesian_coords([0.0, 0.0, 0.0]) v[1] = lattice.get_cartesian_coords([1.0, 0.0, 0.0]) v[2] = lattice.get_cartesian_coords([1.0, 1.0, 0.0]) v[3] = lattice.get_cartesian_coords([0.0, 1.0, 0.0]) v[4] = lattice.get_cartesian_coords([0.0, 1.0, 1.0]) v[5] = lattice.get_cartesian_coords([1.0, 1.0, 1.0]) v[6] = lattice.get_cartesian_coords([1.0, 0.0, 1.0]) v[7] = lattice.get_cartesian_coords([0.0, 0.0, 1.0]) for i, j in ((0, 1), (1, 2), (2, 3), (0, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 4), (0, 7), (1, 6), (2, 5), (3, 4)): ax.plot(*zip(v[i], v[j]), **kwargs) # Plot cartesian frame ax_add_cartesian_frame(ax) return fig, ax
def ax_add_cartesian_frame(ax, start=(0, 0, 0)) -> Axes: """ Add cartesian frame to 3d axis at point `start`. """ # https://stackoverflow.com/questions/22867620/putting-arrowheads-on-vectors-in-matplotlibs-3d-plot from matplotlib.patches import FancyArrowPatch from mpl_toolkits.mplot3d import proj3d arrow_opts = {"color": "k"} arrow_opts.update(dict(lw=1, arrowstyle="-|>",)) class Arrow3D(FancyArrowPatch): def __init__(self, xs, ys, zs, *args, **kwargs): FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) self._verts3d = xs, ys, zs def draw(self, renderer): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) FancyArrowPatch.draw(self, renderer) start = np.array(start) for end in ((1, 0, 0), (0, 1, 0), (0, 0, 1)): end = start + np.array(end) xs, ys, zs = list(zip(start, end)) p = Arrow3D(xs, ys, zs, connectionstyle='arc3', mutation_scale=20, alpha=0.8, **arrow_opts) ax.add_artist(p) return ax def plot_structure(structure, ax=None, to_unit_cell=False, alpha=0.7, style="points+labels", color_scheme="VESTA", **kwargs) -> Figure: """ Plot structure with matplotlib (minimalistic version). Args: structure: |Structure| object ax: matplotlib :class:`Axes3D` or None if a new figure should be created. alpha: The alpha blending value, between 0 (transparent) and 1 (opaque) to_unit_cell: True if sites should be wrapped into the first unit cell. style: "points+labels" to show atoms sites with labels. color_scheme: color scheme for atom types. Allowed values in ("Jmol", "VESTA") Returns: |matplotlib-Figure| """ fig, ax = plot_unit_cell(structure.lattice, ax=ax, linewidth=1) from pymatgen.analysis.molecule_structure_comparator import CovalentRadius from pymatgen.vis.structure_vtk import EL_COLORS xyzs, colors = np.empty((len(structure), 4)), [] for i, site in enumerate(structure): symbol = site.specie.symbol color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol]) radius = CovalentRadius.radius[symbol] if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell() # Use cartesian coordinates. x, y, z = site.coords xyzs[i] = (x, y, z, radius) colors.append(color) if "labels" in style: ax.text(x, y, z, symbol) # The definition of sizes is not optimal because matplotlib uses points # wherease we would like something that depends on the radius (5000 seems to give reasonable plots) # For possibile approaches, see # https://stackoverflow.com/questions/9081553/python-scatter-plot-size-and-style-of-the-marker/24567352#24567352 # https://gist.github.com/syrte/592a062c562cd2a98a83 if "points" in style: x, y, z, s = xyzs.T.copy() s = 5000 * s ** 2 ax.scatter(x, y, zs=z, s=s, c=colors, alpha=alpha) #facecolors="white", #edgecolors="blue" ax.set_title(structure.composition.formula) ax.set_axis_off() return fig def _generic_parser_fh(fh) -> dict: """ Parse file with data in tabular format. Supports multi datasets a la gnuplot. Mainly used for files without any schema, not even CSV Args: fh: File object Returns: dict title --> numpy array where title is taken from the first (non-empty) line preceding the dataset """ arr_list = [None] data = [] head_list = [] count = -1 last_header = None for l in fh: l = l.strip() if not l or l.startswith("#"): count = -1 last_header = l if arr_list[-1] is not None: arr_list.append(None) continue count += 1 if count == 0: head_list.append(last_header) if arr_list[-1] is None: arr_list[-1] = [] data = arr_list[-1] data.append(list(map(float, l.split()))) if len(head_list) != len(arr_list): raise RuntimeError("len(head_list) != len(arr_list), %d != %d" % (len(head_list), len(arr_list))) od = {} for key, data in zip(head_list, arr_list): key = " ".join(key.split()) if key in od: print("Header %s already in dictionary. Using new key %s" % (key, 2 * key)) key = 2 * key od[key] = np.array(data).T.copy() return od
[docs] class GenericDataFilePlotter: """ Extract data from a generic text file with results in tabular format and plot data with matplotlib. Multiple datasets are supported. No attempt is made to handle metadata (e.g. column name) Mainly used to handle text files written without any schema. """ def __init__(self, filepath: str): with open(filepath, "rt") as fh: self.od = _generic_parser_fh(fh) def __str__(self) -> str: return self.to_string()
[docs] def to_string(self, verbose: int = 0) -> str: """String representation with verbosity level `verbose`.""" lines = [] for key, arr in self.od.items(): lines.append("key: `%s` --> array shape: %s" % (key, str(arr.shape))) return "\n".join(lines)
[docs] @add_fig_kwargs def plot(self, use_index=False, fontsize=8, **kwargs) -> Figure: """ Plot all arrays. Use multiple axes if datasets. Args: use_index: By default, the x-values are taken from the first column. If use_index is False, the x-values are the row index. fontsize: fontsize for title. kwargs: options passed to ``ax.plot``. Return: |matplotlib-figure| """ # build grid of plots. num_plots, ncols, nrows = len(self.od), 1, 1 if num_plots > 1: ncols = 2 nrows = (num_plots // ncols) + (num_plots % ncols) ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) ax_list = ax_list.ravel() # Don't show the last ax if num_plots is odd. if num_plots % ncols != 0: ax_list[-1].axis("off") for ax, (key, arr) in zip(ax_list, self.od.items()): ax.set_title(key, fontsize=fontsize) ax.grid(True) xs = arr[0] if not use_index else list(range(len(arr[0]))) for ys in arr[1:] if not use_index else arr: ax.plot(xs, ys) return fig
[docs] class GenericDataFilesPlotter:
[docs] @classmethod def from_files(cls, filepaths: list[str]) -> GenericDataFilesPlotter: """ Build object from a list of `filenames`. """ new = cls() for filepath in filepaths: new.add_file(filepath) return new
def __init__(self): self.odlist = [] self.filepaths = [] def __str__(self) -> str: return self.to_string()
[docs] def to_string(self, verbose: int = 0) -> str: lines = [] app = lines.append for od, filepath in zip(self.odlist, self.filepaths): app("File: %s" % filepath) for key, arr in od.items(): lines.append("\tkey: `%s` --> array shape: %s" % (key, str(arr.shape))) return "\n".join(lines)
[docs] def add_file(self, filepath: str) -> None: """Add data from `filepath`""" with open(filepath, "rt") as fh: self.odlist.append(_generic_parser_fh(fh)) self.filepaths.append(filepath)
[docs] @add_fig_kwargs def plot(self, use_index=False, fontsize=8, colormap="viridis", **kwargs) -> Figure: """ Plot all arrays. Use multiple axes if datasets. Args: use_index: By default, the x-values are taken from the first column. If use_index is False, the x-values are the row index. fontsize: fontsize for title. colormap: matplotlib color map. kwargs: options passed to ``ax.plot``. Return: |matplotlib-figure| """ if not self.odlist: return None # Compute intersection of all keys. # Here we loose the initial ordering in the dict but oh well! klist = [list(d.keys()) for d in self.odlist] keys = set(klist[0]).intersection(*klist) if not keys: print("Warning: cannot find common keys in files. Check input data") return None # Build grid of plots. num_plots, ncols, nrows = len(keys), 1, 1 if num_plots > 1: ncols = 2 nrows = (num_plots // ncols) + (num_plots % ncols) ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) ax_list = ax_list.ravel() # Don't show the last ax if num_plots is odd. if num_plots % ncols != 0: ax_list[-1].axis("off") cmap = plt.get_cmap(colormap) line_cycle = itertools.cycle(["-", ":", "--", "-.",]) # One ax for key, each ax may show multiple arrays # so we need different line styles that are consistent with input data. # Figure may be crowded but it's difficult to do better without metadata # so I'm not gonna spend time to implement more complicated logic. for ax, key in zip(ax_list, keys): ax.set_title(key, fontsize=fontsize) ax.grid(True) for iod, (od, filepath) in enumerate(zip(self.odlist, self.filepaths)): if key not in od: continue arr = od[key] color = cmap(iod / len(self.odlist)) xvals = arr[0] if not use_index else list(range(len(arr[0]))) arr_list = arr[1:] if not use_index else arr for iarr, (ys, linestyle) in enumerate(zip(arr_list, line_cycle)): ax.plot(xvals, ys, color=color, linestyle=linestyle, label=os.path.relpath(filepath) if iarr == 0 else None) ax.legend(loc="best", fontsize=fontsize, shadow=True) return fig
########################## # Plotly helper functions ########################## _LATEX_GREEK_TO_UNICODE = dict( alpha="α", beta="β", gamma="ɣ", delta="δ", epsilon="ε", zeta="ζ", eta="η", theta="θ", iota="ι", kappa="κ", #lambda="λ", mu="μ", nu="ν", xi="ξ", omicron="ο", pi="π", rho="ρ", sigma="σ", tau="τ", upsilon="υ", phi="φ", chi="χ", psi="ψ", omega="ω", # Capital case: Alpha="Α", Beta="Β", Gamma="Γ", Delta="Δ", Epsilon="Ε", Zeta="Ζ", Eta="Η", Theta="Θ", Iota="Ι", Kappa="Κ", Lambda="Λ", Mu="Μ", Nu="Ν", Xi="Ξ", Omicron="Ο", Po="Π", Rho="Ρ", Sigma="Σ", Tau="Τ", Upsilon="Υ", Phi="Φ", Chi="Χ", Psi="Ψ", Omega="Ω", ) _LATEX_GREEK_TO_UNICODE["lambda"] = "λ" def latex_greek_2unicode(latex: str) -> str: """ Convert a single greek letter in latex notation into unicode """ s = latex.replace("$", "").replace("\\", "").strip() return _LATEX_GREEK_TO_UNICODE[s] def is_plotly_figure(obj: Any) -> bool: """Return True if obj is a plotly Figure.""" import plotly.graph_objs as go return isinstance(obj, go.Figure) #return isinstance(obj, (go.Figure, go.FigureWidget)) class PlotlyRowColDesc: """ This object specifies the position of a plotly subplot inside a grid. rcd: PlotlyRowColDesc object used when fig is not None to specify the (row, col) of the subplot in the grid. """ @classmethod def from_object(cls, obj: Any) -> PlotlyRowColDesc: """ Build an instance for a generic object. If oject is None, a simple descriptor corresponding to a (1,1) grid is returned. """ if obj is None: return cls(0, 0, 1, 1) if isinstance(obj, cls): return obj # Assume list with 4 integers try: return cls(*obj) except Exception as exc: raise TypeError(f"Dont know how to convert `{type(obj)}` into `{cls}`") def __init__(self, py_row: int, py_col: int, nrows: int, ncols: int): """ Args: py_row, py_col: python index of the subplot in the grid (starts from 0) nrows, ncols: Number of rows/cols in the grid. """ self.py_row, self.py_col = (py_row, py_col) self.nrows, self.ncols = (nrows, ncols) self.iax = 1 + self.py_col + self.py_row * self.ncols # Note that plotly col and row start from 1. if nrows == 1 and ncols == 1: self.ply_row, self.ply_col = (None, None) else: self.ply_row, self.ply_col = (self.py_row + 1, self.py_col + 1) def __str__(self) -> str: lines = [] app = lines.append app("py_rowcol: (%d, %d) in grid: (%d, %d)" % (self.py_row, self.py_col, self.nrows, self.ncols)) app("plotly_rowcol: (%s, %s)" % (self.ply_row, self.ply_col)) return "\n".join(lines) #@lazy_property #def rowcol_dict(self): # if self.nrows == 1 and self.ncols == 1: return {} # return dict(row=self.ply_row, col=self.ply_col)
[docs] def get_figs_plotly(nrows=1, ncols=1, subplot_titles=(), sharex=False, sharey=False, **fig_kw): """ Helper function used in plot functions that build the `plotly` figure by calling plotly.subplots. Returns: figure: plotly graph_objects figure go: plotly graph_objects module. """ from plotly.subplots import make_subplots import plotly.graph_objects as go fig = make_subplots(rows=int(nrows), cols=int(ncols), subplot_titles=subplot_titles, shared_xaxes=sharex, shared_yaxes=sharey, **fig_kw) return fig, go
[docs] def get_fig_plotly(fig=None, **fig_kw): """ Helper function used in plot functions that build the `plotly` figure by calling plotly.graph_objects.Figure if fig is None else return fig Returns: figure: plotly graph_objects figure go: plotly graph_objects module. """ import plotly.graph_objects as go if fig is None: fig = go.Figure(**fig_kw) #fig = go.FigureWidget(**fig_kw) return fig, go
def plotly_set_lims(fig, lims, axname, iax=None) -> tuple: """ Set the data limits for the axis ax. Args: fig: Plotly Figure. lims: tuple(2) for (left, right), if tuple(1) or scalar for left only, none is set. axname: "x" for x-axis, "y" for y-axis. iax: An int, use iax=n to decorate the nth axis when the fig has subplots. Return: (left, right) """ left, right = None, None if lims is None: return (left, right) # iax = kwargs.pop("iax", 1) # xaxis = 'xaxis%u' % iax #fig.layout[xaxis].title.text = "Wave Vector" axis = dict(x=fig.layout.xaxis, y=fig.layout.yaxis)[axname] len_lims = None try: len_lims = len(lims) except TypeError: # Assume Scalar left = float(lims) if len_lims is not None: if len(lims) == 2: left, right = lims[0], lims[1] elif len(lims) == 1: left = lims[0] ax_range = axis.range if ax_range is None and (left is None or right is None): return None, None #if left is not None: ax_range[0] = left #if right is not None: ax_range[1] = right # Example: fig.update_layout(yaxis_range=[-4,4]) k = dict(x="xaxis", y="yaxis")[axname] if iax: k= k + str(iax) fig.layout[k].range = [left, right] return left, right _PLOTLY_DEFAULT_SHOW = [True] def set_plotly_default_show(true_or_false: bool) -> None: """ Set the default value of show in the add_plotly_fig_kwargs decorator. Useful for instance when generating the sphinx gallery of plotly plots. """ _PLOTLY_DEFAULT_SHOW[0] = true_or_false
[docs] def add_plotly_fig_kwargs(func: Callable) -> Callable: """ Decorator that adds keyword arguments for functions returning plotly figures. The function should return either a plotly figure or None to signal some sort of error/unexpected event. See doc string below for the list of supported options. """ @functools.wraps(func) def wrapper(*args, **kwargs): # pop the kwds used by the decorator. title = kwargs.pop("title", None) show = kwargs.pop("show", _PLOTLY_DEFAULT_SHOW[0]) hovermode = kwargs.pop("hovermode", False) savefig = kwargs.pop("savefig", None) write_json = kwargs.pop("write_json", None) config = kwargs.pop("config", None) renderer = kwargs.pop("renderer", None) chart_studio = kwargs.pop("chart_studio", False) template = kwargs.pop("template", None) # Allow users to specify the renderer via shell env. if renderer is not None and os.getenv("PLOTLY_RENDERER", default=None) is not None: renderer = None # Call func and return immediately if None is returned. fig = func(*args, **kwargs) if fig is None: return fig # Operate on plotly figure. if title is not None: fig.update_layout(title_text=title, title_x=0.5) if template is not None: fig.update_layout(template=template) if savefig: # https://plotly.github.io/plotly.py-docs/generated/plotly.io.write_image.html if savefig.endswith("html"): from plotly.offline import plot as show_plotly show_plotly(fig, include_mathjax="cdn", filename=savefig, auto_open=False) else: try: import kaleido except ImportError: kaleido = False if kaleido is None: raise ValueError( "kaleido package required to save static ploty images\n" "please install it using:\npip install kaleido" ) fig.write_image(savefig, engine="kaleido", scale=5, width=750, height=750) #fig.write_image(savefig) if write_json: import plotly.io as pio pio.write_json(fig, write_json) fig.layout.hovermode = hovermode if show: # and _PLOTLY_DEFAULT_SHOW: my_config = dict( responsive=True, #showEditInChartStudio=True, showLink=True, plotlyServerURL="https://chart-studio.plotly.com", ) if config is not None: my_config.update(config) #add_template_buttons(fig) fig.show(renderer=renderer, config=my_config) if chart_studio: push_to_chart_studio(fig) return fig # Add docstring to the decorated method. doc_str = """\n\n Keyword arguments controlling the display of the figure: ================ ==================================================================== kwargs Meaning ================ ==================================================================== title Title of the plot (Default: None). show True to show the figure (default: True). hovermode True to show the hover info (default: False) savefig "abc.png" , "abc.jpeg" or "abc.webp" to save the figure to a file. write_json Write plotly figure to `write_json` JSON file. Inside jupyter-lab, one can right-click the `write_json` file from the file menu and open with "Plotly Editor". Make some changes to the figure, then use the file menu to save the customized plotly plot. Requires `jupyter labextension install jupyterlab-chart-editor`. See https://github.com/plotly/jupyterlab-chart-editor renderer (str or None (default None)) – A string containing the names of one or more registered renderers (separated by ‘+’ characters) or None. If None, then the default renderers specified in plotly.io.renderers.default are used. See https://plotly.com/python-api-reference/generated/plotly.graph_objects.Figure.html config (dict) A dict of parameters to configure the figure. The defaults are set in plotly.js. chart_studio True to push figure to chart_studio server. Requires authenticatios. Default: False. template Plotly template. See https://plotly.com/python/templates/ ["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"] Default is None that is the default template is used. ================ ==================================================================== """ if wrapper.__doc__ is not None: # Add s at the end of the docstring. wrapper.__doc__ += f"\n{doc_str}" else: # Use s wrapper.__doc__ = doc_str return wrapper
def plotlyfigs_to_browser(figs, filename=None, browser=None): """ Save a list of plotly figures in an HTML file and open it the browser. Useful to display multiple figures generated by different AbiPy methods without having to construct a plotly subplot grid. Args: figs: List of plotly figures. filename: File name to save in. Use temporary filename if filename is None. browser: Open webpage in ``browser``. Use $BROWSER if None. Example: fig1 = plotter.combiplotly(renderer="browser", title="foo", show=False) fig2 = plotter.combiplotly(renderer="browser", title="bar", show=False) from abipy.tools.plotting import plotlyfigs_to_browser plotlyfigs_to_browser([fig1, fig2]) Return: path to HTML file. """ if filename is None: import tempfile fd, filename = tempfile.mkstemp(text=True, suffix=".html") if not isinstance(figs, (list, tuple)): figs = [figs] # Based on https://stackoverflow.com/questions/46821554/multiple-plotly-plots-on-1-page-without-subplot with open(filename, "wt") as fp: for i, fig in enumerate(figs): first = True if i == 0 else False fig.write_html(fp, include_plotlyjs=first, include_mathjax="cdn" if first else False) import webbrowser print("Opening HTML file:", filename) webbrowser.get(browser).open_new_tab("file://" + filename) return filename def plotly_klabels(labels: list, allow_dupes=False) -> list: """ This helper function polish a list of k-points labels before calling plotly by: - Checking if we have two equivalent consecutive labels (only the first one is shown and the second one is set to "") - Replacing particular Latex tokens with unicode as plotly support for Latex is far from optimal. Return: New list labels, same length as input labels. """ new_labels = labels.copy() if not allow_dupes: # Don't show label if previous k-point is the same. for il in range(1, len(new_labels)): if new_labels[il] == new_labels[il - 1]: new_labels[il] = "" replace = { r"$\Gamma$": "Γ", } for il in range(len(new_labels)): if new_labels[il] in replace: new_labels[il] = replace[new_labels[il]] return new_labels def plotly_set_xylabels(fig, xlabel, ylabel, exchange_xy): """ Set the x- and the y-label of axis ax, exchanging x and y if exchange_xy """ if exchange_xy: xlabel, ylabel = ylabel, xlabel fig.layout.xaxis.title.text = xlabel fig.layout.yaxis.title.text = ylabel _PLOTLY_AUTHEHTICATED = False def plotly_chartstudio_authenticate(): """ Authenticate the user on the chart studio portal by reading `PLOTLY_USERNAME` and `PLOTLY_API_KEY` from the pymatgen configuration file located in $HOME/.pmgrc.yaml. PLOTLY_USERNAME: johndoe PLOTLY_API_KEY: XXXXXXXXXXXXXXXXXXXX """ global _PLOTLY_AUTHEHTICATED if _PLOTLY_AUTHEHTICATED: return try: from pymatgen.core import SETTINGS #from pymatgen.settings import SETTINGS except ImportError: from pymatgen import SETTINGS example = """ Add it to $HOME/.pmgrc.yaml using the follow syntax: PLOTLY_USERNAME: john_doe PLOTLY_API_KEY: secret # to get your api_key go to profile > settings > regenerate key """ username = SETTINGS.get("PLOTLY_USERNAME") if username is None: raise RuntimeError(f"Cannot find PLOTLY_USERNAME in pymatgen settings.\n{example}") api_key = SETTINGS.get("PLOTLY_API_KEY") if api_key is None: raise RuntimeError(f"Cannot find PLOTLY_API_KEY in pymatgen settings.\n{example}") import chart_studio # https://towardsdatascience.com/how-to-create-a-plotly-visualization-and-embed-it-on-websites-517c1a78568b chart_studio.tools.set_credentials_file(username=username, api_key=api_key) _PLOTLY_AUTHEHTICATED = True def push_to_chart_studio(figs) -> None: """ Push a plotly figure or a list of figures to the chart studio cloud. """ plotly_chartstudio_authenticate() import chart_studio.plotly as py if not isinstance(figs, (list, tuple)): figs = [figs] for fig in figs: py.plot(fig, auto_open=True) #################################################### # This code is shamelessy taken from Adam's package #################################################### def go_points(points, size=4, color="black", labels=None, **kwargs): #textposition = 'top right', #textfont = dict(color='#E58606'), mode = "markers" if labels is None else "markers+text" #text = labels if labels is not None: labels = plotly_klabels(labels, allow_dupes=True) import plotly.graph_objects as go return go.Scatter3d( x=[v[0] for v in points], y=[v[1] for v in points], z=[v[2] for v in points], marker=dict(size=size, color=color), mode=mode, text=labels, **kwargs ) def _add_if_not_in(d, key, value): if key not in d: d[key] = value def go_line(v1, v2, color="black", width=2, mode="lines", **kwargs): _add_if_not_in(kwargs, "line_color", "black") _add_if_not_in(kwargs, "line_width", 2) import plotly.graph_objects as go return go.Scatter3d( mode=mode, x=[v1[0], v2[0]], y=[v1[1], v2[1]], z=[v1[2], v2[2]], #line=dict(color=color, width=width), **kwargs ) def go_lines(V, name=None, color="black", width=2, **kwargs): import plotly.graph_objects as go gen = ((v1, v2) for (v1, v2) in V) v1, v2 = next(gen) out = [ go_line(v1, v2, width=width, color=color, name=name, legendgroup=name, **kwargs) ] out.extend( go_line( v1, v2, width=width, color=color, showlegend=False, legendgroup=name, **kwargs ) for (v1, v2) in gen ) return out def vectors(lattice, name=None, color="black", width=4, **kwargs): gen = zip(lattice, ["a", "b", "c"]) v, label = next(gen) out = [ go_line( [0, 0, 0], v, text=["", label], width=width, color=color, name=name, legendgroup=name, mode="lines+text", **kwargs ) ] out.extend( go_line( [0, 0, 0], v, text=["", label], width=width, color=color, showlegend=False, legendgroup=name, mode="lines+text", **kwargs ) for (v, label) in gen ) return out def get_vectors(lattice_mat, name=None, color="black", width=2, **kwargs): return go_lines([[[0, 0, 0], v] for v in lattice_mat], **kwargs) def get_box(lattice_mat, **kwargs): a, b, c = lattice_mat segments = [ [[0, 0, 0], a], [[0, 0, 0], b], [[0, 0, 0], c], [a, a + b], [a, a + c], [b, b + a], [b, b + c], [c, c + a], [c, c + b], [a + b, a + b + c], [a + c, a + b + c], [b + c, a + b + c], ] return go_lines(segments, **kwargs) def plot_fcc_conv(): fcc_conv = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) fcc_vectors = vectors( fcc_conv, name="conv lattice vectors", color="darkblue", width=6 ) fcc_box = get_box(fcc_conv, name="conv lattice") atoms = go_points( [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], size=10, color="orange", name="atoms", legendgroup="atoms", ) import plotly.graph_objects as go fig = go.Figure(data=[*fcc_box, *fcc_vectors, atoms]) return fig def plot_fcc_prim(): fcc_prim = np.array([[0.5, 0.5, 0], [0, 0.5, 0.5], [0.5, 0, 0.5]]) fcc_prim_vectors = vectors( fcc_prim, name="prim lattice vectors", color="green", width=6 ) fcc_prim_box = get_box(fcc_prim, name="prim lattice", color="green") atoms = go_points( [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], size=10, color="orange", name="atoms", legendgroup="atoms", ) fcc_conv = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) fcc_conv_box = get_box(fcc_conv, name="conv lattice") import plotly.graph_objects as go fig = go.Figure(data=[*fcc_prim_box, *fcc_prim_vectors, *fcc_conv_box, atoms]) return fig def plot_fcc_100(): # fcc_100_cell = np.array([[0, 0.5, -0.5], [0, 0.5, 0.5], [1.0, 0.0, 0]]) fcc_100_cell = np.array([[0.5, -0.5, 0], [0.5, 0.5, 0], [0.0, 0, 1.0]]) fcc_100_vectors = vectors( fcc_100_cell, name="100 lattice vectors", color="red", width=6 ) fcc_100_box = get_box(fcc_100_cell, name="100 lattice", color="red") fig = plot_fcc_conv() fig.add_traces([*fcc_100_box, *fcc_100_vectors]) return fig def plot_fcc_110(): fcc_110_cell = np.array([[0, 0.0, 1.0], [0.5, -0.5, 0], [0.5, 0.5, 0.0]]) fcc_110_vectors = vectors( fcc_110_cell, name="reduced lattice vectors", color="red", width=6 ) fcc_110_box = get_box(fcc_110_cell, name="reduced lattice", color="red") fig = plot_fcc_conv() fig.add_traces([*fcc_110_box, *fcc_110_vectors]) return fig def plot_fcc_111(): fcc_111_cell = np.array([[0.5, 0, -0.5], [0, 0.5, -0.5], [1, 1, 1]]) fcc_111_vectors = vectors( fcc_111_cell, name="reduced lattice vectors", color="red", width=6 ) fcc_111_box = get_box(fcc_111_cell, name="reduced lattice", color="red") fig = plot_fcc_conv() fig.add_traces([*fcc_111_box, *fcc_111_vectors]) return fig def plotly_structure(structure, ax=None, to_unit_cell=False, alpha=0.7, style="points+labels", color_scheme="VESTA", **kwargs): """ Plot structure with plotly (minimalistic version). Args: structure: |Structure| object ax: matplotlib :class:`Axes3D` or None if a new figure should be created. alpha: The alpha blending value, between 0 (transparent) and 1 (opaque) to_unit_cell: True if sites should be wrapped into the first unit cell. style: "points+labels" to show atoms sites with labels. color_scheme: color scheme for atom types. Allowed values in ("Jmol", "VESTA") Returns: |matplotlib-Figure| """ #fig, ax = plot_unit_cell(structure.lattice, ax=ax, linewidth=1) box = get_box(structure.lattice.matrix) #, **kwargs): from pymatgen.analysis.molecule_structure_comparator import CovalentRadius from pymatgen.vis.structure_vtk import EL_COLORS #symb2data = {} #for symbol in structure.symbol_set: # symb2data[symbol] = d = {} # d["color"] = color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol]) # d["radius"] = CovalentRadius.radius[symbol] # inds = structure.indices_from_symbol(symbol) # sites = [structure[i] for i in inds] # d["xyz"] = [] # for site in sites: # if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell() # Use cartesian coordinates. # x, y, z = site.coords # d["xyz"].append((x, y ,z) xyz, sizes, colors = np.empty((len(structure), 3)), [], [] for i, site in enumerate(structure): symbol = site.specie.symbol color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol]) radius = CovalentRadius.radius[symbol] if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell() # Use cartesian coordinates. x, y, z = site.coords xyz[i] = (x, y, z) # , radius) sizes.append(radius) colors.append(color) #if "labels" in style: # ax.text(x, y, z, symbol) atoms = go_points( #[[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], xyz, size=10, color="orange", name="atoms", legendgroup="atoms", ) #marker = [dict(size=size, color=color) for (size, color) in zip(sizes, colors)] #atoms = go.Scatter3d( # x=[v[0] for v in xyz], # y=[v[1] for v in xyz], # z=[v[2] for v in xyz], # #marker=dict(size=size, color=color), # marker=marker, # mode="markers", # #**kwargs #) # The definition of sizes is not optimal because matplotlib uses points # whereas we would like something that depends on the radius (5000 seems to give reasonable plots) # For possibile approaches, see # https://stackoverflow.com/questions/9081553/python-scatter-plot-size-and-style-of-the-marker/24567352#24567352 # https://gist.github.com/syrte/592a062c562cd2a98a83 #if "points" in style: # x, y, z, s = xyzs.T.copy() # s = 5000 * s ** 2 # ax.scatter(x, y, zs=z, s=s, c=colors, alpha=alpha) #facecolors="white", #edgecolors="blue" #ax.set_title(structure.composition.formula) #ax.set_axis_off() #fig = go.Figure(data=[*box, *vectors, atoms]) import plotly.graph_objects as go fig = go.Figure(data=[*box, atoms]) return fig # This is the matplotlib API to plot the BZ. def plotly_wigner_seitz(lattice, fig=None, **kwargs): """ Adds the skeleton of the Wigner-Seitz cell of the lattice to a plotly figure. Args: lattice: Lattice object fig: plotly figure or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black and linewidth to 1. Returns: Plotly figure """ #ax, fig, plt = get_ax3d_fig_plt(ax) fig, go = get_fig_plotly(fig=fig) #, **fig_kw) if "line_color" not in kwargs: kwargs["line_color"] = "black" if "line_width" not in kwargs: kwargs["line_width"] = 1 bz = lattice.get_wigner_seitz_cell() #ax, fig, plt = get_ax3d_fig_plt(ax) for iface in range(len(bz)): # pylint: disable=C0200 for line in itertools.combinations(bz[iface], 2): for jface in range(len(bz)): if (iface < jface and any(np.all(line[0] == x) for x in bz[jface]) and any(np.all(line[1] == x) for x in bz[jface])): #ax.plot(*zip(line[0], line[1]), **kwargs) fig.add_trace(go_line(line[0], line[1], showlegend=False, **kwargs)) return fig def plotly_lattice_vectors(lattice, fig=None, **kwargs): """ Adds the basis vectors of the lattice provided to a plotly figure. Args: lattice: Lattice object fig: plotly figure or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to green and linewidth to 3. Returns: plotly figure """ fig, go = get_fig_plotly(fig=fig) if "line_color" not in kwargs: kwargs["line_color"] = "green" if "line_width" not in kwargs: kwargs["line_width"] = 3 if "showlegend" not in kwargs: kwargs["showlegend"] = False vertex1 = lattice.get_cartesian_coords([0.0, 0.0, 0.0]) vertex2 = lattice.get_cartesian_coords([1.0, 0.0, 0.0]) fig.add_trace(go_line(vertex1, vertex2, name="a", **kwargs)) vertex2 = lattice.get_cartesian_coords([0.0, 1.0, 0.0]) fig.add_trace(go_line(vertex1, vertex2, name="b", **kwargs)) vertex2 = lattice.get_cartesian_coords([0.0, 0.0, 1.0]) fig.add_trace(go_line(vertex1, vertex2, name="c", **kwargs)) return fig def plotly_path(line, lattice=None, coords_are_cartesian=False, fig=None, **kwargs): """ Adds a line passing through the coordinates listed in 'line' to a plotly figure. Args: line: list of coordinates. lattice: Lattice object used to convert from reciprocal to cartesian coordinates coords_are_cartesian: Set to True if you are providing coordinates in cartesian coordinates. Defaults to False. Requires lattice if False. fig: plotly figure or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to red and linewidth to 3. Returns: plotly figure """ fig, go = get_fig_plotly(fig=fig) if "line_color" not in kwargs: kwargs["line_color"] = "red" if "line_width" not in kwargs: kwargs["line_width"] = 3 for k in range(1, len(line)): vertex1 = line[k - 1] vertex2 = line[k] if not coords_are_cartesian: if lattice is None: raise ValueError("coords_are_cartesian False requires the lattice") vertex1 = lattice.get_cartesian_coords(vertex1) vertex2 = lattice.get_cartesian_coords(vertex2) fig.add_trace(go_line(vertex1, vertex2, showlegend=False, **kwargs)) return fig #def plotly_labels(labels, lattice=None, coords_are_cartesian=False, ax=None, **kwargs): # """ # Adds labels to a matplotlib Axes # # Args: # labels: dict containing the label as a key and the coordinates as value. # lattice: Lattice object used to convert from reciprocal to cartesian coordinates # coords_are_cartesian: Set to True if you are providing. # coordinates in cartesian coordinates. Defaults to False. # Requires lattice if False. # ax: matplotlib :class:`Axes` or None if a new figure should be created. # kwargs: kwargs passed to the matplotlib function 'text'. Color defaults to blue # and size to 25. # # Returns: # matplotlib figure and matplotlib ax # """ # ax, fig, plt = get_ax3d_fig_plt(ax) # # if "color" not in kwargs: # kwargs["color"] = "b" # if "size" not in kwargs: # kwargs["size"] = 25 # # for k, coords in labels.items(): # label = k # if k.startswith("\\") or k.find("_") != -1: # label = "$" + k + "$" # off = 0.01 # if coords_are_cartesian: # coords = np.array(coords) # else: # if lattice is None: # raise ValueError("coords_are_cartesian False requires the lattice") # coords = lattice.get_cartesian_coords(coords) # ax.text(*(coords + off), s=label, **kwargs) # # return fig, ax def plotly_points(points, lattice=None, coords_are_cartesian=False, fold=False, labels=None, fig=None, **kwargs): """ Adds points to a plotly figure. Args: points: list of coordinates lattice: Lattice object used to convert from reciprocal to cartesian coordinates coords_are_cartesian: Set to True if you are providing coordinates in cartesian coordinates. Defaults to False. Requires lattice if False. fold: whether the points should be folded inside the first Brillouin Zone. Defaults to False. Requires lattice if True. fig: plotly figure or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'scatter'. Color defaults to blue Returns: plotly figure """ fig, go = get_fig_plotly(fig=fig) #, **fig_kw) if "marker_color" not in kwargs: kwargs["marker_color"] = "blue" if (not coords_are_cartesian or fold) and lattice is None: raise ValueError("coords_are_cartesian False or fold True require the lattice") from pymatgen.electronic_structure.plotter import fold_point vecs = [] for p in points: if fold: p = fold_point(p, lattice, coords_are_cartesian=coords_are_cartesian) elif not coords_are_cartesian: p = lattice.get_cartesian_coords(p) vecs.append(p) kws = dict(textposition="top right", showlegend=False) #, textfont=dict(color='#E58606')) kws.update(kwargs) fig.add_trace(go_points(vecs, labels=labels, **kws)) return fig @add_plotly_fig_kwargs def plotly_brillouin_zone_from_kpath(kpath, fig=None, **kwargs): """ Gives the plot (as a matplotlib object) of the symmetry line path in the Brillouin Zone. Args: kpath (HighSymmKpath): a HighSymmKPath object ax: matplotlib :class:`Axes` or None if a new figure should be created. **kwargs: provided by add_fig_kwargs decorator Returns: plotly figure. """ lines = [[kpath.kpath["kpoints"][k] for k in p] for p in kpath.kpath["path"]] return plotly_brillouin_zone( bz_lattice=kpath.prim_rec, lines=lines, fig=fig, labels=kpath.kpath["kpoints"], show=False, **kwargs, ) @add_plotly_fig_kwargs def plotly_brillouin_zone( bz_lattice, lines=None, labels=None, kpoints=None, fold=False, coords_are_cartesian=False, fig=None, **kwargs, ): """ Plots a 3D representation of the Brillouin zone of the structure. Can add to the plot paths, labels and kpoints Args: bz_lattice: Lattice object of the Brillouin zone lines: list of lists of coordinates. Each list represent a different path labels: dict containing the label as a key and the coordinates as value. kpoints: list of coordinates fold: whether the points should be folded inside the first Brillouin Zone. Defaults to False. Requires lattice if True. coords_are_cartesian: Set to True if you are providing coordinates in cartesian coordinates. Defaults to False. ax: matplotlib :class:`Axes` or None if a new figure should be created. kwargs: provided by add_fig_kwargs decorator Returns: plotly figure """ fig = plotly_lattice_vectors(bz_lattice, fig=fig) plotly_wigner_seitz(bz_lattice, fig=fig) if lines is not None: for line in lines: plotly_path(line, bz_lattice, coords_are_cartesian=coords_are_cartesian, fig=fig) if labels is not None: # TODO #plotly_labels(labels, bz_lattice, coords_are_cartesian=coords_are_cartesian, ax=ax) plotly_points( labels.values(), lattice=bz_lattice, coords_are_cartesian=coords_are_cartesian, fold=False, labels=list(labels.keys()), fig=fig, ) if kpoints is not None: plotly_points( kpoints, lattice=bz_lattice, coords_are_cartesian=coords_are_cartesian, fold=fold, fig=fig, ) return fig def add_colorscale_dropwdowns(fig): """ Add dropdown widgets to change/reverse the colorscale. Based on: https://plotly.com/python/dropdowns/#update-several-data-attributes """ button_layer_1_height = 1.30 # Create list of buttons # A single button has the form: # # dict( # args=["colorscale", "Viridis"], # label="Viridis", # method="restyle" # ), colorscales = ["Viridis", "Cividis", "Blues", "Greens"] colorscale_buttons = [] for cscale in colorscales: colorscale_buttons.append(dict( args=["colorscale", cscale], label=cscale, method="restyle", )) fig.update_layout( updatemenus=[ dict( buttons=colorscale_buttons, direction="down", pad={"r": 10, "t": 10}, showactive=True, x=0.1, xanchor="left", y=button_layer_1_height, yanchor="top" ), dict( buttons=list([ dict( args=["reversescale", False], label="False", method="restyle" ), dict( args=["reversescale", True], label="True", method="restyle" ) ]), direction="down", pad={"r": 10, "t": 10}, showactive=True, x=0.37, xanchor="left", y=button_layer_1_height, yanchor="top" ), ] ) y = button_layer_1_height - 0.02 fig.update_layout( annotations=[ dict(text="colorscale", x=0, xref="paper", y=y, yref="paper", align="left", showarrow=False), dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=y, yref="paper", showarrow=False), ]) return fig def mpl_to_ply(fig: Figure, latex: bool= False): """ Nasty workaround for plotly latex rendering in legend/breaking exception """ if is_plotly_figure(fig): return fig def parse_latex(label): # Remove latex symobols new_label = label.replace("$", "") new_label = new_label.replace("\\", "") if not latex else new_label new_label = new_label.replace("{", "") if not latex else new_label new_label = new_label.replace("}", "") if not latex else new_label # plotly latex needs an extra \ for parsing python strings # new_label = new_label.replace(" ", "\\ ") if latex else new_label # Wrap the label in dollar signs for LaTeX, if needed unless empty`` new_label = f"${new_label}$" if latex and len(new_label) > 0 else new_label return new_label for ax in fig.get_axes(): # TODO improve below logic to add new scatter plots? # Loop backwards through the collections to avoid modifying the list as we iterate for coll in ax.collections[::-1]: if isinstance(coll, mcoll.PathCollection): # Use the remove() method to remove the scatter plot collection from the axes coll.remove() # Process the axis title, x-label, and y-label for label in [ax.get_title(), ax.get_xlabel(), ax.get_ylabel()]: # Few differences in how mpl and ply parse/encode symbols new_label = parse_latex(label) # Set the new label if label == ax.get_title(): ax.set_title(new_label) elif label == ax.get_xlabel(): ax.set_xlabel(new_label) elif label == ax.get_ylabel(): ax.set_ylabel(new_label) # Check if the axis has a legend if ax.get_legend(): legend = ax.get_legend() # Get the legend's text entries for text in legend.get_texts(): label = text.get_text() # Remove any existing dollar signs new_label = parse_latex(label) # Set the new label text.set_text(new_label) # Convert to plotly figure from plotly.tools import mpl_to_plotly plotly_fig = mpl_to_plotly(fig) plotly_fig.update_layout(template = "plotly_white", title = { "xanchor": "center", "yanchor": "top", "x": 0.5, "font": { "size": 14 }, }) # Iterate over the axes in the figure to retrieve the custom line attributes for ax in fig.get_axes(): if hasattr(ax, '_custom_rc_lines'): for rc, color in ax._custom_rc_lines: # Add vertical lines to the Plotly figure plotly_fig.add_vline( x=rc, line_width=2, line_dash="dash", line_color=color ) # # Loop through each trace and update the hover labels to remove $ for trace in plotly_fig.data: # Retrieve the current label and remove any $ signs new_label = trace.name.replace("$", "") # Update the trace's name (which is used for the legend label) trace.name = new_label return plotly_fig class PolyfitPlotter: """ Fit data with polynomals of different degrees and visualize the results. """ def __init__(self, xs, ys): self.xs, self.ys = np.array(xs), np.array(ys) @add_fig_kwargs def plot(self, deg_list: list[int], num=100, ax=None, xlabel=None, ylabel=None, fontsize=8, **kwargs) -> Figure: """ Args: deg_list: List with degrees of the fitting polynomial. num: Number of samples to generate. Default is 100. Must be non-negative. ax: |matplotlib-Axes| or None if a new figure should be created. fontsize: Legend fontsize. """ xs, ys = self.xs, self.ys ax, fig, plt = get_ax_fig_plt(ax=ax) for i, deg in enumerate(deg_list): # Fit a ndeg polynomial to the data points and get the polynomial function. coefficients = np.polyfit(xs, ys, deg) polynomial = np.poly1d(coefficients) #print("Coefficients:", coefficients); print("Polynomial:", polynomial) if i == 0: # Plot the original data points ax.scatter(xs, ys, color='red', marker="o", label='Data Points') # Generate (x, y) values for plotting the fit x_fit = np.linspace(min(xs), max(xs), num) y_fit = polynomial(x_fit) ax.plot(x_fit, y_fit, label=f"{deg}-order fit") if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) ax.legend(loc="best", fontsize=fontsize, shadow=True) return fig #class PolyExtrapolator: # """ # Fit data with polynomals, extrapolate to zero and visualize the results. # """ # def __init__(self, xs, ys): # self.xs, self.ys = np.array(xs), np.array(ys) # # def extrapolate_to_zero(self, deg: int):