Source code for abipy.tools.plotting

# coding: utf-8
"""
Utilities for generating matplotlib plots.

.. note::

    Avoid importing matplotlib or plotly in the module namespace otherwise startup is very slow.
"""
import os
import time
import itertools
import numpy as np

from collections import OrderedDict, namedtuple
from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig_plt, get_ax3d_fig_plt, get_axarray_fig_plt
from .numtools import data_from_cplx_mode


__all__ = [
    "set_axlims",
    "get_ax_fig_plt",
    "get_ax3d_fig_plt",
    "plot_array",
    "ArrayPlotter",
    "data_from_cplx_mode",
    "Marker",
    "plot_unit_cell",
    "GenericDataFilePlotter",
    "GenericDataFilesPlotter",
]


# https://matplotlib.org/gallery/lines_bars_and_markers/linestyles.html
linestyles = OrderedDict(
    [('solid',               (0, ())),
     ('loosely_dotted',      (0, (1, 10))),
     ('dotted',              (0, (1, 5))),
     ('densely_dotted',      (0, (1, 1))),

     ('loosely_dashed',      (0, (5, 10))),
     ('dashed',              (0, (5, 5))),
     ('densely_dashed',      (0, (5, 1))),

     ('loosely_dashdotted',  (0, (3, 10, 1, 10))),
     ('dashdotted',          (0, (3, 5, 1, 5))),
     ('densely_dashdotted',  (0, (3, 1, 1, 1))),

     ('loosely_dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('densely_dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]
)


###################
# Matplotlib tools
###################

def is_mpl_figure(obj):
    """Return True if obj is a matplotlib Figure."""
    from matplotlib import pyplot as plt
    return isinstance(obj, plt.Figure)


def ax_append_title(ax, title, loc="center", fontsize=None):
    """Add title to previous ax.title. Return new title."""
    prev_title = ax.get_title(loc=loc)
    new_title = prev_title + title
    ax.set_title(new_title, loc=loc, fontsize=fontsize)
    return new_title


def ax_share(xy_string, *ax_list):
    """
    Share x- or y-axis of two or more subplots after they are created

    Args:
        xy_string: "x" to share x-axis, "xy" for both
        ax_list: List of axes to share.

    Example:

        ax_share("y", ax0, ax1)
        ax_share("xy", *(ax0, ax1, ax2))
    """
    if "x" in xy_string:
        for ix, ax in enumerate(ax_list):
            others = [a for a in ax_list if a != ax]
            ax.get_shared_x_axes().join(*others)

    if "y" in xy_string:
        for ix, ax in enumerate(ax_list):
            others = [a for a in ax_list if a != ax]
            ax.get_shared_y_axes().join(*others)


#def set_grid(fig, boolean):
#    if hasattr(fig, "axes"):
#        for ax in fig.axes:
#            if ax.grid: ax.grid.set_visible(boolean)
#    else:
#            if ax.grid: ax.grid.set_visible(boolean)


[docs]def set_axlims(ax, lims, axname): """ Set the data limits for the axis ax. Args: lims: tuple(2) for (left, right), tuple(1) or scalar for left only. axname: "x" for x-axis, "y" for y-axis. Return: (left, right) """ left, right = None, None if lims is None: return (left, right) len_lims = None try: len_lims = len(lims) except TypeError: # Assume Scalar left = float(lims) if len_lims is not None: if len(lims) == 2: left, right = lims[0], lims[1] elif len(lims) == 1: left = lims[0] set_lim = getattr(ax, {"x": "set_xlim", "y": "set_ylim"}[axname]) if left != right: set_lim(left, right) return left, right
def set_ax_xylabels(ax, xlabel, ylabel, exchange_xy): """ Set the x- and the y-label of axis ax, exchanging x and y if exchange_xy """ if exchange_xy: xlabel, ylabel = ylabel, xlabel ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) def set_visible(ax, boolean, *args): """ Hide/Show the artists of axis ax listed in args. """ if "legend" in args and ax.legend(): ax.legend().set_visible(boolean) if "title" in args and ax.title: ax.title.set_visible(boolean) if "xlabel" in args and ax.xaxis.label: ax.xaxis.label.set_visible(boolean) if "ylabel" in args and ax.yaxis.label: ax.yaxis.label.set_visible(boolean) if "xticklabels" in args: for label in ax.get_xticklabels(): label.set_visible(boolean) if "yticklabels" in args: for label in ax.get_yticklabels(): label.set_visible(boolean) def rotate_ticklabels(ax, rotation, axname="x"): """Rotate the ticklables of axis ``ax``""" if "x" in axname: for tick in ax.get_xticklabels(): tick.set_rotation(rotation) if "y" in axname: for tick in ax.get_yticklabels(): tick.set_rotation(rotation) @add_fig_kwargs def plot_xy_with_hue(data, x, y, hue, decimals=None, ax=None, xlims=None, ylims=None, fontsize=12, **kwargs): """ Plot y = f(x) relation for different values of `hue`. Useful for convergence tests done wrt to two parameters. Args: data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`. x: Name of the column used as x-value y: Name of the column(s) used as y-value hue: Variable that define subsets of the data, which will be drawn on separate lines decimals: Number of decimal places to round `hue` columns. Ignore if None ax: |matplotlib-Axes| or None if a new figure should be created. xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)` or scalar e.g. `left`. If left (right) is None, default values are used fontsize: Legend fontsize. kwargs: Keywork arguments are passed to ax.plot method. Returns: |matplotlib-Figure| """ if isinstance(y, (list, tuple)): # Recursive call for each ax in ax_list. num_plots, ncols, nrows = len(y), 1, 1 if num_plots > 1: ncols = 2 nrows = (num_plots // ncols) + (num_plots % ncols) ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) ax_list = ax_list.ravel() if num_plots % ncols != 0: ax_list[-1].axis('off') for yname, ax in zip(y, ax_list): plot_xy_with_hue(data, x, str(yname), hue, decimals=decimals, ax=ax, xlims=xlims, ylims=ylims, fontsize=fontsize, show=False, **kwargs) return fig # Check here because pandas error messages are a bit criptic. miss = [k for k in (x, y, hue) if k not in data] if miss: raise ValueError("Cannot find `%s` in dataframe.\nAvailable keys are: %s" % (str(miss), str(data.keys()))) # Truncate values in hue column so that we can group. if decimals is not None: data = data.round({hue: decimals}) ax, fig, plt = get_ax_fig_plt(ax=ax) for key, grp in data.groupby(hue): # Sort xs and rearrange ys xy = np.array(sorted(zip(grp[x], grp[y]), key=lambda t: t[0])) xvals, yvals = xy[:, 0], xy[:, 1] #label = "{} = {}".format(hue, key) label = "%s" % (str(key)) if not kwargs: ax.plot(xvals, yvals, 'o-', label=label) else: ax.plot(xvals, yvals, label=label, **kwargs) ax.grid(True) ax.set_xlabel(x) ax.set_ylabel(y) set_axlims(ax, xlims, "x") set_axlims(ax, ylims, "y") ax.legend(loc="best", fontsize=fontsize, shadow=True) return fig
[docs]@add_fig_kwargs def plot_array(array, color_map=None, cplx_mode="abs", **kwargs): """ Use imshow for plotting 2D or 1D arrays. Example:: plot_array(np.random.rand(10,10)) See <http://stackoverflow.com/questions/7229971/2d-grid-data-visualization-in-python> Args: array: Array-like object (1D or 2D). color_map: color map. cplx_mode: Flag defining how to handle complex arrays. Possible values in ("re", "im", "abs", "angle") "re" for the real part, "im" for the imaginary part. "abs" means that the absolute value of the complex number is shown. "angle" will display the phase of the complex number in radians. Returns: |matplotlib-Figure| """ # Handle vectors array = np.atleast_2d(array) array = data_from_cplx_mode(cplx_mode, array) import matplotlib as mpl from matplotlib import pyplot as plt if color_map is None: # make a color map of fixed colors color_map = mpl.colors.LinearSegmentedColormap.from_list('my_colormap', ['blue', 'black', 'red'], 256) img = plt.imshow(array, interpolation='nearest', cmap=color_map, origin='lower') # Make a color bar plt.colorbar(img, cmap=color_map) # Set grid plt.grid(True, color='white') fig = plt.gcf() return fig
[docs]class ArrayPlotter(object): def __init__(self, *labels_and_arrays): """ Args: labels_and_arrays: List [("label1", arr1), ("label2", arr2")] """ self._arr_dict = OrderedDict() for label, array in labels_and_arrays: self.add_array(label, array) def __len__(self): return len(self._arr_dict) def __iter__(self): return self._arr_dict.__iter__()
[docs] def keys(self): return self._arr_dict.keys()
[docs] def items(self): return self._arr_dict.items()
[docs] def add_array(self, label, array): """Add array with the given name.""" if label in self._arr_dict: raise ValueError("%s is already in %s" % (label, list(self._arr_dict.keys()))) self._arr_dict[label] = array
[docs] def add_arrays(self, labels, arr_list): """ Add a list of arrays Args: labels: List of labels. arr_list: List of arrays. """ assert len(labels) == len(arr_list) for label, arr in zip(labels, arr_list): self.add_array(label, arr)
[docs] @add_fig_kwargs def plot(self, cplx_mode="abs", colormap="jet", fontsize=8, **kwargs): """ Args: cplx_mode: "abs" for absolute value, "re", "im", "angle" colormap: matplotlib colormap. fontsize: legend and label fontsize. Returns: |matplotlib-Figure| """ # Build grid of plots. num_plots, ncols, nrows = len(self), 1, 1 if num_plots > 1: ncols = 2 nrows = num_plots // ncols + (num_plots % ncols) import matplotlib.pyplot as plt fig, ax_mat = plt.subplots(nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) # Don't show the last ax if num_plots is odd. if num_plots % ncols != 0: ax_mat[-1, -1].axis("off") from mpl_toolkits.axes_grid1 import make_axes_locatable from matplotlib.ticker import MultipleLocator for ax, (label, arr) in zip(ax_mat.flat, self.items()): data = data_from_cplx_mode(cplx_mode, arr) # Use origin to place the [0, 0] index of the array in the lower left corner of the axes. img = ax.matshow(data, interpolation='nearest', cmap=colormap, origin='lower', aspect="auto") ax.set_title("(%s) %s" % (cplx_mode, label), fontsize=fontsize) # Make a color bar for this ax # Create divider for existing axes instance # http://stackoverflow.com/questions/18266642/multiple-imshow-subplots-each-with-colorbar divider3 = make_axes_locatable(ax) # Append axes to the right of ax, with 10% width of ax cax3 = divider3.append_axes("right", size="10%", pad=0.05) # Create colorbar in the appended axes # Tick locations can be set with the kwarg `ticks` # and the format of the ticklabels with kwarg `format` cbar3 = plt.colorbar(img, cax=cax3, ticks=MultipleLocator(0.2), format="%.2f") # Remove xticks from ax ax.xaxis.set_visible(False) # Manually set ticklocations #ax.set_yticks([0.0, 2.5, 3.14, 4.0, 5.2, 7.0]) # Set grid ax.grid(True, color='white') fig.tight_layout() return fig
#TODO use object and introduce c for color, client code should be able to customize it. # Rename it to ScatterData
[docs]class Marker(namedtuple("Marker", "x y s")): """ Stores the position and the size of the marker. A marker is a list of tuple(x, y, s) where x, and y are the position in the graph and s is the size of the marker. Used for plotting purpose e.g. QP data, energy derivatives... Example:: x, y, s = [1, 2, 3], [4, 5, 6], [0.1, 0.2, -0.3] marker = Marker(x, y, s) marker.extend((x, y, s)) """ def __new__(cls, *xys): """Extends the base class adding consistency check.""" if not xys: xys = ([], [], []) return super().__new__(cls, *xys) if len(xys) != 3: raise TypeError("Expecting 3 entries in xys got %d" % len(xys)) x = np.asarray(xys[0]) y = np.asarray(xys[1]) s = np.asarray(xys[2]) xys = (x, y, s) for s in xys[-1]: if np.iscomplex(s): raise ValueError("Found ambiguous complex entry %s" % str(s)) return super().__new__(cls, *xys) def __bool__(self): return bool(len(self.s)) __nonzero__ = __bool__
[docs] def extend(self, xys): """ Extend the marker values. """ if len(xys) != 3: raise TypeError("Expecting 3 entries in xys got %d" % len(xys)) self.x.extend(xys[0]) self.y.extend(xys[1]) self.s.extend(xys[2]) lens = np.array((len(self.x), len(self.y), len(self.s))) if np.any(lens != lens[0]): raise TypeError("x, y, s vectors should have same lengths but got %s" % str(lens))
[docs] def posneg_marker(self): """ Split data into two sets: the first one contains all the points with positive size. The first set contains all the points with negative size. """ pos_x, pos_y, pos_s = [], [], [] neg_x, neg_y, neg_s = [], [], [] for x, y, s in zip(self.x, self.y, self.s): if s >= 0.0: pos_x.append(x) pos_y.append(y) pos_s.append(s) else: neg_x.append(x) neg_y.append(y) neg_s.append(s) return self.__class__(pos_x, pos_y, pos_s), Marker(neg_x, neg_y, neg_s)
class MplExpose(object): # pragma: no cover """ Context manager used to produce several matplotlib figures and then show all them at the end so that the user does not need to close the window to visualize to the next one. Example: with MplExpose() as e: e(obj.plot1(show=False)) e(obj.plot2(show=False)) """ def __init__(self, slide_mode=False, slide_timeout=None, verbose=1): """ Args: slide_mode: If Rrue, iterate over figures. Default: Expose all figures at once. slide_timeout: Close figure after slide-timeout seconds. Block if None. verbose: verbosity level """ self.figures = [] self.slide_mode = bool(slide_mode) self.timeout_ms = slide_timeout self.verbose = verbose if self.timeout_ms is not None: self.timeout_ms = int(self.timeout_ms * 1000) assert self.timeout_ms >= 0 if self.verbose: if self.slide_mode: print("\nSliding matplotlib figures with slide timeout: %s [s]" % slide_timeout) else: print("\nLoading all matplotlib figures before showing them. It may take some time...") self.start_time = time.time() def __call__(self, obj): """ Add an object to MplExpose. Support mpl figure, list of figures or generator yielding figures. """ import types if isinstance(obj, (types.GeneratorType, list, tuple)): for fig in obj: self.add_fig(fig) else: self.add_fig(obj) def add_fig(self, fig): """Add a matplotlib figure.""" if fig is None: return if not self.slide_mode: self.figures.append(fig) else: #print("Printing and closing", fig) import matplotlib.pyplot as plt if self.timeout_ms is not None: # Creating a timer object # timer calls plt.close after interval milliseconds to close the window. timer = fig.canvas.new_timer(interval=self.timeout_ms) timer.add_callback(plt.close, fig) timer.start() plt.show() fig.clear() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): """Activated at the end of the with statement. """ if exc_type is not None: return self.expose() def expose(self): """Show all figures. Clear figures if needed.""" if not self.slide_mode: print("All figures in memory, elapsed time: %.3f s" % (time.time() - self.start_time)) import matplotlib.pyplot as plt plt.show() for fig in self.figures: fig.clear() class PanelExpose(object): # pragma: no cover """ Context manager used to produce several matplotlib/plotly figures and then show all them inside the Browser using a panel template. Example: with PanelExpose() as e: e(obj.plot1(show=False)) e(obj.plot2(show=False)) """ def __init__(self, title=None, verbose=1): """ Args: title: String to be show in the header. verbose: verbosity level """ self.title = title self.figures = [] self.verbose = verbose if self.verbose: print("\nLoading all figures before showing them. It may take some time...") self.start_time = time.time() def __call__(self, obj): """ Add an object to MplPanelExpose. Support mpl figure, list of figures or generator yielding figures. """ import types if isinstance(obj, (types.GeneratorType, list, tuple)): for fig in obj: self.add_fig(fig) else: self.add_fig(obj) def add_fig(self, fig): """Add a matplotlib figure.""" if fig is None: return self.figures.append(fig) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): """Activated at the end of the with statement. """ if exc_type is not None: return self.expose() def expose(self): """Show all figures. Clear figures if needed.""" import panel as pn pn.config.sizing_mode = 'stretch_width' from abipy.panels.core import get_template_cls_from_name cls = get_template_cls_from_name("FastGridTemplate") template = cls( title=self.title if self.title is not None else self.__class__.__name__, header_background="#ff8c00 ", # Dark orange ) #pn.config.sizing_mode = 'stretch_width' from abipy.panels.core import mpl, ply for i, fig in enumerate(self.figures): row, col = divmod(i, 2) if is_plotly_figure(fig): p = ply(fig, with_divider=False) elif is_mpl_figure(fig): p = mpl(fig, with_divider=False) else: raise TypeError(f"Don't know how to handle type: `{type(fig)}`") if hasattr(template.main, "append"): template.main.append(p) else: # Assume .main area acts like a GridSpec row_slice = slice(3 * row, 3 * (row + 1)) if col == 0: template.main[row_slice, :6] = p if col == 1: template.main[row_slice, 6:] = p return template.show()
[docs]def plot_unit_cell(lattice, ax=None, **kwargs): """ Adds the unit cell of the lattice to a matplotlib Axes3D Args: lattice: Lattice object ax: matplotlib :class:`Axes3D` or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black and linewidth to 3. Returns: matplotlib figure and ax """ ax, fig, plt = get_ax3d_fig_plt(ax) if "color" not in kwargs: kwargs["color"] = "k" if "linewidth" not in kwargs: kwargs["linewidth"] = 3 v = 8 * [None] v[0] = lattice.get_cartesian_coords([0.0, 0.0, 0.0]) v[1] = lattice.get_cartesian_coords([1.0, 0.0, 0.0]) v[2] = lattice.get_cartesian_coords([1.0, 1.0, 0.0]) v[3] = lattice.get_cartesian_coords([0.0, 1.0, 0.0]) v[4] = lattice.get_cartesian_coords([0.0, 1.0, 1.0]) v[5] = lattice.get_cartesian_coords([1.0, 1.0, 1.0]) v[6] = lattice.get_cartesian_coords([1.0, 0.0, 1.0]) v[7] = lattice.get_cartesian_coords([0.0, 0.0, 1.0]) for i, j in ((0, 1), (1, 2), (2, 3), (0, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 4), (0, 7), (1, 6), (2, 5), (3, 4)): ax.plot(*zip(v[i], v[j]), **kwargs) # Plot cartesian frame ax_add_cartesian_frame(ax) return fig, ax
def ax_add_cartesian_frame(ax, start=(0, 0, 0)): """ Add cartesian frame to 3d axis at point `start`. """ # https://stackoverflow.com/questions/22867620/putting-arrowheads-on-vectors-in-matplotlibs-3d-plot from matplotlib.patches import FancyArrowPatch from mpl_toolkits.mplot3d import proj3d arrow_opts = {"color": "k"} arrow_opts.update(dict(lw=1, arrowstyle="-|>",)) class Arrow3D(FancyArrowPatch): def __init__(self, xs, ys, zs, *args, **kwargs): FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) self._verts3d = xs, ys, zs def draw(self, renderer): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) FancyArrowPatch.draw(self, renderer) start = np.array(start) for end in ((1, 0, 0), (0, 1, 0), (0, 0, 1)): end = start + np.array(end) xs, ys, zs = list(zip(start, end)) p = Arrow3D(xs, ys, zs, connectionstyle='arc3', mutation_scale=20, alpha=0.8, **arrow_opts) ax.add_artist(p) return ax def plot_structure(structure, ax=None, to_unit_cell=False, alpha=0.7, style="points+labels", color_scheme="VESTA", **kwargs): """ Plot structure with matplotlib (minimalistic version). Args: structure: |Structure| object ax: matplotlib :class:`Axes3D` or None if a new figure should be created. alpha: The alpha blending value, between 0 (transparent) and 1 (opaque) to_unit_cell: True if sites should be wrapped into the first unit cell. style: "points+labels" to show atoms sites with labels. color_scheme: color scheme for atom types. Allowed values in ("Jmol", "VESTA") Returns: |matplotlib-Figure| """ fig, ax = plot_unit_cell(structure.lattice, ax=ax, linewidth=1) from pymatgen.analysis.molecule_structure_comparator import CovalentRadius from pymatgen.vis.structure_vtk import EL_COLORS xyzs, colors = np.empty((len(structure), 4)), [] for i, site in enumerate(structure): symbol = site.specie.symbol color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol]) radius = CovalentRadius.radius[symbol] if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell() # Use cartesian coordinates. x, y, z = site.coords xyzs[i] = (x, y, z, radius) colors.append(color) if "labels" in style: ax.text(x, y, z, symbol) # The definition of sizes is not optimal because matplotlib uses points # wherease we would like something that depends on the radius (5000 seems to give reasonable plots) # For possibile approaches, see # https://stackoverflow.com/questions/9081553/python-scatter-plot-size-and-style-of-the-marker/24567352#24567352 # https://gist.github.com/syrte/592a062c562cd2a98a83 if "points" in style: x, y, z, s = xyzs.T.copy() s = 5000 * s ** 2 ax.scatter(x, y, zs=z, s=s, c=colors, alpha=alpha) #facecolors="white", #edgecolors="blue" ax.set_title(structure.composition.formula) ax.set_axis_off() return fig def _generic_parser_fh(fh): """ Parse file with data in tabular format. Supports multi datasets a la gnuplot. Mainly used for files without any schema, not even CSV Args: fh: File object Returns: OrderedDict title --> numpy array where title is taken from the first (non-empty) line preceding the dataset """ arr_list = [None] data = [] head_list = [] count = -1 last_header = None for l in fh: l = l.strip() if not l or l.startswith("#"): count = -1 last_header = l if arr_list[-1] is not None: arr_list.append(None) continue count += 1 if count == 0: head_list.append(last_header) if arr_list[-1] is None: arr_list[-1] = [] data = arr_list[-1] data.append(list(map(float, l.split()))) if len(head_list) != len(arr_list): raise RuntimeError("len(head_list) != len(arr_list), %d != %d" % (len(head_list), len(arr_list))) od = OrderedDict() for key, data in zip(head_list, arr_list): key = " ".join(key.split()) if key in od: print("Header %s already in dictionary. Using new key %s" % (key, 2 * key)) key = 2 * key od[key] = np.array(data).T.copy() return od
[docs]class GenericDataFilePlotter(object): """ Extract data from a generic text file with results in tabular format and plot data with matplotlib. Multiple datasets are supported. No attempt is made to handle metadata (e.g. column name) Mainly used to handle text files written without any schema. """ def __init__(self, filepath): with open(filepath, "rt") as fh: self.od = _generic_parser_fh(fh) def __str__(self): return self.to_string()
[docs] def to_string(self, verbose=0): """String representation with verbosity level `verbose`.""" lines = [] for key, arr in self.od.items(): lines.append("key: `%s` --> array shape: %s" % (key, str(arr.shape))) return "\n".join(lines)
[docs] @add_fig_kwargs def plot(self, use_index=False, fontsize=8, **kwargs): """ Plot all arrays. Use multiple axes if datasets. Args: use_index: By default, the x-values are taken from the first column. If use_index is False, the x-values are the row index. fontsize: fontsize for title. kwargs: options passed to ``ax.plot``. Return: |matplotlib-figure| """ # build grid of plots. num_plots, ncols, nrows = len(self.od), 1, 1 if num_plots > 1: ncols = 2 nrows = (num_plots // ncols) + (num_plots % ncols) ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) ax_list = ax_list.ravel() # Don't show the last ax if num_plots is odd. if num_plots % ncols != 0: ax_list[-1].axis("off") for ax, (key, arr) in zip(ax_list, self.od.items()): ax.set_title(key, fontsize=fontsize) ax.grid(True) xs = arr[0] if not use_index else list(range(len(arr[0]))) for ys in arr[1:] if not use_index else arr: ax.plot(xs, ys) return fig
[docs]class GenericDataFilesPlotter(object):
[docs] @classmethod def from_files(cls, filepaths): """ Build object from a list of `filenames`. """ new = cls() for filepath in filepaths: new.add_file(filepath) return new
def __init__(self): self.odlist = [] self.filepaths = [] def __str__(self): return self.to_string()
[docs] def to_string(self, verbose=0): lines = [] app = lines.append for od, filepath in zip(self.odlist, self.filepaths): app("File: %s" % filepath) for key, arr in od.items(): lines.append("\tkey: `%s` --> array shape: %s" % (key, str(arr.shape))) return "\n".join(lines)
[docs] def add_file(self, filepath): """Add data from `filepath`""" with open(filepath, "rt") as fh: self.odlist.append(_generic_parser_fh(fh)) self.filepaths.append(filepath)
[docs] @add_fig_kwargs def plot(self, use_index=False, fontsize=8, colormap="viridis", **kwargs): """ Plot all arrays. Use multiple axes if datasets. Args: use_index: By default, the x-values are taken from the first column. If use_index is False, the x-values are the row index. fontsize: fontsize for title. colormap: matplotlib color map. kwargs: options passed to ``ax.plot``. Return: |matplotlib-figure| """ if not self.odlist: return None # Compute intersection of all keys. # Here we loose the initial ordering in the dict but oh well! klist = [list(d.keys()) for d in self.odlist] keys = set(klist[0]).intersection(*klist) if not keys: print("Warning: cannot find common keys in files. Check input data") return None # Build grid of plots. num_plots, ncols, nrows = len(keys), 1, 1 if num_plots > 1: ncols = 2 nrows = (num_plots // ncols) + (num_plots % ncols) ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) ax_list = ax_list.ravel() # Don't show the last ax if num_plots is odd. if num_plots % ncols != 0: ax_list[-1].axis("off") cmap = plt.get_cmap(colormap) line_cycle = itertools.cycle(["-", ":", "--", "-.",]) # One ax for key, each ax may show multiple arrays # so we need different line styles that are consistent with input data. # Figure may be crowded but it's difficult to do better without metadata # so I'm not gonna spend time to implement more complicated logic. for ax, key in zip(ax_list, keys): ax.set_title(key, fontsize=fontsize) ax.grid(True) for iod, (od, filepath) in enumerate(zip(self.odlist, self.filepaths)): if key not in od: continue arr = od[key] color = cmap(iod / len(self.odlist)) xvals = arr[0] if not use_index else list(range(len(arr[0]))) arr_list = arr[1:] if not use_index else arr for iarr, (ys, linestyle) in enumerate(zip(arr_list, line_cycle)): ax.plot(xvals, ys, color=color, linestyle=linestyle, label=os.path.relpath(filepath) if iarr == 0 else None) ax.legend(loc="best", fontsize=fontsize, shadow=True) return fig
########################## # Plotly helper functions ########################## _LATEX_GREEK_TO_UNICODE = dict( alpha="α", beta="β", gamma="ɣ", delta="δ", epsilon="ε", zeta="ζ", eta="η", theta="θ", iota="ι", kappa="κ", #lambda="λ", mu="μ", nu="ν", xi="ξ", omicron="ο", pi="π", rho="ρ", sigma="σ", tau="τ", upsilon="υ", phi="φ", chi="χ", psi="ψ", omega="ω", # Capital case: Alpha="Α", Beta="Β", Gamma="Γ", Delta="Δ", Epsilon="Ε", Zeta="Ζ", Eta="Η", Theta="Θ", Iota="Ι", Kappa="Κ", Lambda="Λ", Mu="Μ", Nu="Ν", Xi="Ξ", Omicron="Ο", Po="Π", Rho="Ρ", Sigma="Σ", Tau="Τ", Upsilon="Υ", Phi="Φ", Chi="Χ", Psi="Ψ", Omega="Ω", ) _LATEX_GREEK_TO_UNICODE["lambda"] = "λ" def latex_greek_2unicode(latex): """ Convert a single greek letter in latex notation into unicode """ s = latex.replace("$", "").replace("\\", "").strip() return _LATEX_GREEK_TO_UNICODE[s] def is_plotly_figure(obj): """Return True if obj is a plotly Figure.""" import plotly.graph_objs as go return isinstance(obj, go.Figure) #return isinstance(obj, (go.Figure, go.FigureWidget)) class PlotlyRowColDesc(object): """ This object specifies the position of a plotly subplot inside a grid. rcd: PlotlyRowColDesc object used when fig is not None to specify the (row, col) of the subplot in the grid. """ @classmethod def from_object(cls, obj): """ Build an instance for a generic object. If oject is None, a simple descriptor corresponding to a (1,1) grid is returned. """ if obj is None: return cls(0, 0, 1, 1) if isinstance(obj, cls): return obj # Assume list with 4 integers try: return cls(*obj) except Exception as exc: raise TypeError(f"Dont know how to convert `{type(obj)}` into `{cls}`") def __init__(self, py_row, py_col, nrows, ncols): """ Args: py_row, py_col: python index of the subplot in the grid (starts from 0) nrows, ncols: Number of rows/cols in the grid. """ self.py_row, self.py_col = (py_row, py_col) self.nrows, self.ncols = (nrows, ncols) self.iax = 1 + self.py_col + self.py_row * self.ncols # Note that plotly col and row start from 1. if nrows == 1 and ncols == 1: self.ply_row, self.ply_col = (None, None) else: self.ply_row, self.ply_col = (self.py_row + 1, self.py_col + 1) def __str__(self): lines = [] app = lines.append app("py_rowcol: (%d, %d) in grid: (%d, %d)" % (self.py_row, self.py_col, self.nrows, self.ncols)) app("plotly_rowcol: (%s, %s)" % (self.ply_row, self.ply_col)) return "\n".join(lines) #@lazy_property #def rowcol_dict(self): # if self.nrows == 1 and self.ncols == 1: return {} # return dict(row=self.ply_row, col=self.ply_col) def get_figs_plotly(nrows=1, ncols=1, subplot_titles=(), sharex=False, sharey=False, **fig_kw): """ Helper function used in plot functions that build the `plotly` figure by calling plotly.subplots. Returns: figure: plotly graph_objects figure go: plotly graph_objects module. """ from plotly.subplots import make_subplots import plotly.graph_objects as go fig = make_subplots(rows=int(nrows), cols=int(ncols), subplot_titles=subplot_titles, shared_xaxes=sharex, shared_yaxes=sharey, **fig_kw) return fig, go def get_fig_plotly(fig=None, **fig_kw): """ Helper function used in plot functions that build the `plotly` figure by calling plotly.graph_objects.Figure if fig is None else return fig Returns: figure: plotly graph_objects figure go: plotly graph_objects module. """ import plotly.graph_objects as go if fig is None: fig = go.Figure(**fig_kw) #fig = go.FigureWidget(**fig_kw) return fig, go def plotly_set_lims(fig, lims, axname, iax=None): """ Set the data limits for the axis ax. Args: lims: tuple(2) for (left, right), if tuple(1) or scalar for left only, none is set. axname: "x" for x-axis, "y" for y-axis. iax: An int, use iax=n to decorate the nth axis when the fig has subplots. Return: (left, right) """ left, right = None, None if lims is None: return (left, right) # iax = kwargs.pop("iax", 1) # xaxis = 'xaxis%u' % iax #fig.layout[xaxis].title.text = "Wave Vector" axis = dict(x=fig.layout.xaxis, y=fig.layout.yaxis)[axname] len_lims = None try: len_lims = len(lims) except TypeError: # Assume Scalar left = float(lims) if len_lims is not None: if len(lims) == 2: left, right = lims[0], lims[1] elif len(lims) == 1: left = lims[0] ax_range = axis.range if ax_range is None and (left is None or right is None): return None, None #if left is not None: ax_range[0] = left #if right is not None: ax_range[1] = right # Example: fig.update_layout(yaxis_range=[-4,4]) k = dict(x="xaxis", y="yaxis")[axname] if iax: k= k + str(iax) fig.layout[k].range = [left, right] return left, right _PLOTLY_DEFAULT_SHOW = [True] def set_plotly_default_show(true_or_false): """ Set the default value of show in the add_plotly_fig_kwargs decorator. Usefule for instance when generating the sphinx gallery of plotly plots. """ _PLOTLY_DEFAULT_SHOW[0] = true_or_false def add_plotly_fig_kwargs(func): """ Decorator that adds keyword arguments for functions returning plotly figures. The function should return either a plotly figure or None to signal some sort of error/unexpected event. See doc string below for the list of supported options. """ from functools import wraps @wraps(func) def wrapper(*args, **kwargs): # pop the kwds used by the decorator. title = kwargs.pop("title", None) show = kwargs.pop("show", _PLOTLY_DEFAULT_SHOW[0]) hovermode = kwargs.pop("hovermode", False) savefig = kwargs.pop("savefig", None) write_json = kwargs.pop("write_json", None) config = kwargs.pop("config", None) renderer = kwargs.pop("renderer", None) chart_studio = kwargs.pop("chart_studio", False) # Allow users to specify the renderer via shell env. if renderer is not None and os.getenv("PLOTLY_RENDERER", default=None) is not None: renderer = None # Call func and return immediately if None is returned. fig = func(*args, **kwargs) if fig is None: return fig # Operate on plotly figure. if title is not None: fig.update_layout(title_text=title, title_x=0.5) if savefig: fig.write_image(savefig) if write_json: import plotly.io as pio pio.write_json(fig, write_json) fig.layout.hovermode = hovermode if show: # and _PLOTLY_DEFAULT_SHOW: fig.show(renderer=renderer, config=config) if chart_studio: push_to_chart_studio(fig) return fig # Add docstring to the decorated method. s = ( "\n\n" + """\ Keyword arguments controlling the display of the figure: ================ ==================================================================== kwargs Meaning ================ ==================================================================== title Title of the plot (Default: None). show True to show the figure (default: True). hovormode True to show the hover info (default: False) savefig "abc.png" , "abc.jpeg" or "abc.webp" to save the figure to a file. write_json Write plotly figure to `write_json` JSON file. Inside jupyter-lab, one can right-click the `write_json` file from the file menu and open with "Plotly Editor". Make some changes to the figure, then use the file menu to save the customized plotly plot. Requires `jupyter labextension install jupyterlab-chart-editor`. See https://github.com/plotly/jupyterlab-chart-editor renderer (str or None (default None)) – A string containing the names of one or more registered renderers (separated by ‘+’ characters) or None. If None, then the default renderers specified in plotly.io.renderers.default are used. See https://plotly.com/python-api-reference/generated/plotly.graph_objects.Figure.html config (dict) A dict of parameters to configure the figure. The defaults are set in plotly.js. chart_studio True to push figure to chart_studio server. Requires authenticatios. Default: False. ================ ==================================================================== """ ) if wrapper.__doc__ is not None: # Add s at the end of the docstring. wrapper.__doc__ += "\n" + s else: # Use s wrapper.__doc__ = s return wrapper def plotlyfigs_to_browser(figs, filename=None, browser=None): """ Save a list of plotly figures in an HTML file and open it the browser. Useful to display multiple figures generated by different AbiPy methods without having to construct a plotly subplot grid. Args: figs: List of plotly figures. filename: File name to save in. Use temporary filename if filename is None. browser: Open webpage in ``browser``. Use $BROWSER if None. Example: fig1 = plotter.combiplotly(renderer="browser", title="foo", show=False) fig2 = plotter.combiplotly(renderer="browser", title="bar", show=False) from abipy.tools.plotting import plotlyfigs_to_browser plotlyfigs_to_browser([fig1, fig2]) Return: path to HTML file. """ if filename is None: import tempfile fd, filename = tempfile.mkstemp(text=True, suffix=".html") if not isinstance(figs, (list, tuple)): figs = [figs] # Based on https://stackoverflow.com/questions/46821554/multiple-plotly-plots-on-1-page-without-subplot with open(filename, "wt") as fp: for i, fig in enumerate(figs): first = True if i == 0 else False fig.write_html(fp, include_plotlyjs=first, include_mathjax="cdn" if first else False) import webbrowser print("Opening HTML file:", filename) webbrowser.get(browser).open_new_tab("file://" + filename) return filename def plotly_klabels(labels, allow_dupes=False): """ This helper function polish a list of k-points labels before calling plotly by: - Checking if we have two equivalent consecutive labels (only the first one is shown and the second one is set to "") - Replacing particular Latex tokens with unicode as plotly support for Latex is far from optimal. Return: New list labels, same length as input labels. """ new_labels = labels.copy() if not allow_dupes: # Don't show label if previous k-point is the same. for il in range(1, len(new_labels)): if new_labels[il] == new_labels[il - 1]: new_labels[il] = "" replace = { r"$\Gamma$": "Γ", } for il in range(len(new_labels)): if new_labels[il] in replace: new_labels[il] = replace[new_labels[il]] return new_labels def plotly_set_xylabels(fig, xlabel, ylabel, exchange_xy): """ Set the x- and the y-label of axis ax, exchanging x and y if exchange_xy """ if exchange_xy: xlabel, ylabel = ylabel, xlabel fig.layout.xaxis.title.text = xlabel fig.layout.yaxis.title.text = ylabel _PLOTLY_AUTHEHTICATED = False def plotly_chartstudio_authenticate(): """ Authenticate the user on the chart studio portal by reading `PLOTLY_USERNAME` and `PLOTLY_API_KEY` from the pymatgen configuration file located in $HOME/.pmgrc.yaml. PLOTLY_USERNAME: johndoe PLOTLY_API_KEY: XXXXXXXXXXXXXXXXXXXX """ global _PLOTLY_AUTHEHTICATED if _PLOTLY_AUTHEHTICATED: return try: from pymatgen.core import SETTINGS #from pymatgen.settings import SETTINGS except ImportError: from pymatgen import SETTINGS example = """ Add it to $HOME/.pmgrc.yaml using the follow syntax: PLOTLY_USERNAME: john_doe PLOTLY_API_KEY: secret # to get your api_key go to profile > settings > regenerate key """ username = SETTINGS.get("PLOTLY_USERNAME") if username is None: raise RuntimeError(f"Cannot find PLOTLY_USERNAME in pymatgen settings.\n{example}") api_key = SETTINGS.get("PLOTLY_API_KEY") if api_key is None: raise RuntimeError(f"Cannot find PLOTLY_API_KEY in pymatgen settings.\n{example}") import chart_studio # https://towardsdatascience.com/how-to-create-a-plotly-visualization-and-embed-it-on-websites-517c1a78568b chart_studio.tools.set_credentials_file(username=username, api_key=api_key) _PLOTLY_AUTHEHTICATED = True def push_to_chart_studio(figs): """ Push a plotly figure or a list of figures to the chart studio cloud. """ plotly_chartstudio_authenticate() import chart_studio.plotly as py if not isinstance(figs, (list, tuple)): figs = [figs] for fig in figs: py.plot(fig, auto_open=True) #################################################### # This code is shamelessy taken from Adam's package #################################################### import plotly.graph_objects as go def go_points(points, size=4, color="black", labels=None, **kwargs): #textposition = 'top right', #textfont = dict(color='#E58606'), mode = "markers" if labels is None else "markers+text" #text = labels if labels is not None: labels = plotly_klabels(labels, allow_dupes=True) return go.Scatter3d( x=[v[0] for v in points], y=[v[1] for v in points], z=[v[2] for v in points], marker=dict(size=size, color=color), mode=mode, text=labels, **kwargs ) def _add_if_not_in(d, key, value): if key not in d: d[key] = value def go_line(v1, v2, color="black", width=2, mode="lines", **kwargs): _add_if_not_in(kwargs, "line_color", "black") _add_if_not_in(kwargs, "line_width", 2) return go.Scatter3d( mode=mode, x=[v1[0], v2[0]], y=[v1[1], v2[1]], z=[v1[2], v2[2]], #line=dict(color=color, width=width), **kwargs ) def go_lines(V, name=None, color="black", width=2, **kwargs): gen = ((v1, v2) for (v1, v2) in V) v1, v2 = next(gen) out = [ go_line(v1, v2, width=width, color=color, name=name, legendgroup=name, **kwargs) ] out.extend( go_line( v1, v2, width=width, color=color, showlegend=False, legendgroup=name, **kwargs ) for (v1, v2) in gen ) return out def vectors(lattice, name=None, color="black", width=4, **kwargs): gen = zip(lattice, ["a", "b", "c"]) v, label = next(gen) out = [ go_line( [0, 0, 0], v, text=["", label], width=width, color=color, name=name, legendgroup=name, mode="lines+text", **kwargs ) ] out.extend( go_line( [0, 0, 0], v, text=["", label], width=width, color=color, showlegend=False, legendgroup=name, mode="lines+text", **kwargs ) for (v, label) in gen ) return out def get_vectors(lattice_mat, name=None, color="black", width=2, **kwargs): return go_lines([[[0, 0, 0], v] for v in lattice_mat], **kwargs) def get_box(lattice_mat, **kwargs): a, b, c = lattice_mat segments = [ [[0, 0, 0], a], [[0, 0, 0], b], [[0, 0, 0], c], [a, a + b], [a, a + c], [b, b + a], [b, b + c], [c, c + a], [c, c + b], [a + b, a + b + c], [a + c, a + b + c], [b + c, a + b + c], ] return go_lines(segments, **kwargs) def plot_fcc_conv(): fcc_conv = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) fcc_vectors = vectors( fcc_conv, name="conv lattice vectors", color="darkblue", width=6 ) fcc_box = get_box(fcc_conv, name="conv lattice") atoms = go_points( [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], size=10, color="orange", name="atoms", legendgroup="atoms", ) fig = go.Figure(data=[*fcc_box, *fcc_vectors, atoms]) return fig def plot_fcc_prim(): fcc_prim = np.array([[0.5, 0.5, 0], [0, 0.5, 0.5], [0.5, 0, 0.5]]) fcc_prim_vectors = vectors( fcc_prim, name="prim lattice vectors", color="green", width=6 ) fcc_prim_box = get_box(fcc_prim, name="prim lattice", color="green") atoms = go_points( [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], size=10, color="orange", name="atoms", legendgroup="atoms", ) fcc_conv = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) fcc_conv_box = get_box(fcc_conv, name="conv lattice") fig = go.Figure(data=[*fcc_prim_box, *fcc_prim_vectors, *fcc_conv_box, atoms]) return fig def plot_fcc_100(): # fcc_100_cell = np.array([[0, 0.5, -0.5], [0, 0.5, 0.5], [1.0, 0.0, 0]]) fcc_100_cell = np.array([[0.5, -0.5, 0], [0.5, 0.5, 0], [0.0, 0, 1.0]]) fcc_100_vectors = vectors( fcc_100_cell, name="100 lattice vectors", color="red", width=6 ) fcc_100_box = get_box(fcc_100_cell, name="100 lattice", color="red") fig = plot_fcc_conv() fig.add_traces([*fcc_100_box, *fcc_100_vectors]) return fig def plot_fcc_110(): fcc_110_cell = np.array([[0, 0.0, 1.0], [0.5, -0.5, 0], [0.5, 0.5, 0.0]]) fcc_110_vectors = vectors( fcc_110_cell, name="reduced lattice vectors", color="red", width=6 ) fcc_110_box = get_box(fcc_110_cell, name="reduced lattice", color="red") fig = plot_fcc_conv() fig.add_traces([*fcc_110_box, *fcc_110_vectors]) return fig def plot_fcc_111(): fcc_111_cell = np.array([[0.5, 0, -0.5], [0, 0.5, -0.5], [1, 1, 1]]) fcc_111_vectors = vectors( fcc_111_cell, name="reduced lattice vectors", color="red", width=6 ) fcc_111_box = get_box(fcc_111_cell, name="reduced lattice", color="red") fig = plot_fcc_conv() fig.add_traces([*fcc_111_box, *fcc_111_vectors]) return fig def plotly_structure(structure, ax=None, to_unit_cell=False, alpha=0.7, style="points+labels", color_scheme="VESTA", **kwargs): """ Plot structure with plotly (minimalistic version). Args: structure: |Structure| object ax: matplotlib :class:`Axes3D` or None if a new figure should be created. alpha: The alpha blending value, between 0 (transparent) and 1 (opaque) to_unit_cell: True if sites should be wrapped into the first unit cell. style: "points+labels" to show atoms sites with labels. color_scheme: color scheme for atom types. Allowed values in ("Jmol", "VESTA") Returns: |matplotlib-Figure| """ #fig, ax = plot_unit_cell(structure.lattice, ax=ax, linewidth=1) box = get_box(structure.lattice.matrix) #, **kwargs): from pymatgen.analysis.molecule_structure_comparator import CovalentRadius from pymatgen.vis.structure_vtk import EL_COLORS #symb2data = {} #for symbol in structure.symbol_set: # symb2data[symbol] = d = {} # d["color"] = color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol]) # d["radius"] = CovalentRadius.radius[symbol] # inds = structure.indices_from_symbol(symbol) # sites = [structure[i] for i in inds] # d["xyz"] = [] # for site in sites: # if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell() # Use cartesian coordinates. # x, y, z = site.coords # d["xyz"].append((x, y ,z) xyz, sizes, colors = np.empty((len(structure), 3)), [], [] for i, site in enumerate(structure): symbol = site.specie.symbol color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol]) radius = CovalentRadius.radius[symbol] if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell() # Use cartesian coordinates. x, y, z = site.coords xyz[i] = (x, y, z) # , radius) sizes.append(radius) colors.append(color) #if "labels" in style: # ax.text(x, y, z, symbol) atoms = go_points( #[[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], xyz, size=10, color="orange", name="atoms", legendgroup="atoms", ) #marker = [dict(size=size, color=color) for (size, color) in zip(sizes, colors)] #atoms = go.Scatter3d( # x=[v[0] for v in xyz], # y=[v[1] for v in xyz], # z=[v[2] for v in xyz], # #marker=dict(size=size, color=color), # marker=marker, # mode="markers", # #**kwargs #) # The definition of sizes is not optimal because matplotlib uses points # wherease we would like something that depends on the radius (5000 seems to give reasonable plots) # For possibile approaches, see # https://stackoverflow.com/questions/9081553/python-scatter-plot-size-and-style-of-the-marker/24567352#24567352 # https://gist.github.com/syrte/592a062c562cd2a98a83 #if "points" in style: # x, y, z, s = xyzs.T.copy() # s = 5000 * s ** 2 # ax.scatter(x, y, zs=z, s=s, c=colors, alpha=alpha) #facecolors="white", #edgecolors="blue" #ax.set_title(structure.composition.formula) #ax.set_axis_off() #fig = go.Figure(data=[*box, *vectors, atoms]) fig = go.Figure(data=[*box, atoms]) return fig # This is the matplotlib API to plot the BZ. def plotly_wigner_seitz(lattice, fig=None, **kwargs): """ Adds the skeleton of the Wigner-Seitz cell of the lattice to a plotly figure. Args: lattice: Lattice object fig: plotly figure or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black and linewidth to 1. Returns: Plotly figure """ #ax, fig, plt = get_ax3d_fig_plt(ax) fig, go = get_fig_plotly(fig=fig) #, **fig_kw) if "line_color" not in kwargs: kwargs["line_color"] = "black" if "line_width" not in kwargs: kwargs["line_width"] = 1 bz = lattice.get_wigner_seitz_cell() #ax, fig, plt = get_ax3d_fig_plt(ax) for iface in range(len(bz)): # pylint: disable=C0200 for line in itertools.combinations(bz[iface], 2): for jface in range(len(bz)): if ( iface < jface and any(np.all(line[0] == x) for x in bz[jface]) and any(np.all(line[1] == x) for x in bz[jface]) ): #ax.plot(*zip(line[0], line[1]), **kwargs) fig.add_trace(go_line(line[0], line[1], showlegend=False, **kwargs)) return fig def plotly_lattice_vectors(lattice, fig=None, **kwargs): """ Adds the basis vectors of the lattice provided to a matplotlib Axes Args: lattice: Lattice object fig: plotly figure or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to green and linewidth to 3. Returns: matplotlib figure and matplotlib ax """ #ax, fig, plt = get_ax3d_fig_plt(ax) fig, go = get_fig_plotly(fig=fig) #, **fig_kw) if "line_color" not in kwargs: kwargs["line_color"] = "green" if "line_width" not in kwargs: kwargs["line_width"] = 3 if "showlegend" not in kwargs: kwargs["showlegend"] = False vertex1 = lattice.get_cartesian_coords([0.0, 0.0, 0.0]) vertex2 = lattice.get_cartesian_coords([1.0, 0.0, 0.0]) #ax.plot(*zip(vertex1, vertex2), **kwargs) fig.add_trace(go_line(vertex1, vertex2, name="a", **kwargs)) vertex2 = lattice.get_cartesian_coords([0.0, 1.0, 0.0]) #ax.plot(*zip(vertex1, vertex2), **kwargs) fig.add_trace(go_line(vertex1, vertex2, name="b", **kwargs)) vertex2 = lattice.get_cartesian_coords([0.0, 0.0, 1.0]) #ax.plot(*zip(vertex1, vertex2), **kwargs) fig.add_trace(go_line(vertex1, vertex2, name="c", **kwargs)) return fig def plotly_path(line, lattice=None, coords_are_cartesian=False, fig=None, **kwargs): """ Adds a line passing through the coordinates listed in 'line' to a matplotlib Axes Args: line: list of coordinates. lattice: Lattice object used to convert from reciprocal to cartesian coordinates coords_are_cartesian: Set to True if you are providing coordinates in cartesian coordinates. Defaults to False. Requires lattice if False. fig: plotly figure or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to red and linewidth to 3. Returns: matplotlib figure and matplotlib ax """ #ax, fig, plt = get_ax3d_fig_plt(ax) fig, go = get_fig_plotly(fig=fig) #, **fig_kw) if "line_color" not in kwargs: kwargs["line_color"] = "red" if "line_width" not in kwargs: kwargs["line_width"] = 3 for k in range(1, len(line)): vertex1 = line[k - 1] vertex2 = line[k] if not coords_are_cartesian: if lattice is None: raise ValueError("coords_are_cartesian False requires the lattice") vertex1 = lattice.get_cartesian_coords(vertex1) vertex2 = lattice.get_cartesian_coords(vertex2) #ax.plot(*zip(vertex1, vertex2), **kwargs) fig.add_trace(go_line(vertex1, vertex2, showlegend=False, **kwargs)) return fig def plotly_labels(labels, lattice=None, coords_are_cartesian=False, ax=None, **kwargs): """ Adds labels to a matplotlib Axes Args: labels: dict containing the label as a key and the coordinates as value. lattice: Lattice object used to convert from reciprocal to cartesian coordinates coords_are_cartesian: Set to True if you are providing. coordinates in cartesian coordinates. Defaults to False. Requires lattice if False. ax: matplotlib :class:`Axes` or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'text'. Color defaults to blue and size to 25. Returns: matplotlib figure and matplotlib ax """ ax, fig, plt = get_ax3d_fig_plt(ax) if "color" not in kwargs: kwargs["color"] = "b" if "size" not in kwargs: kwargs["size"] = 25 for k, coords in labels.items(): label = k if k.startswith("\\") or k.find("_") != -1: label = "$" + k + "$" off = 0.01 if coords_are_cartesian: coords = np.array(coords) else: if lattice is None: raise ValueError("coords_are_cartesian False requires the lattice") coords = lattice.get_cartesian_coords(coords) ax.text(*(coords + off), s=label, **kwargs) return fig, ax def plotly_points(points, lattice=None, coords_are_cartesian=False, fold=False, labels=None, fig=None, **kwargs): """ Adds points to a matplotlib Axes Args: points: list of coordinates lattice: Lattice object used to convert from reciprocal to cartesian coordinates coords_are_cartesian: Set to True if you are providing coordinates in cartesian coordinates. Defaults to False. Requires lattice if False. fold: whether the points should be folded inside the first Brillouin Zone. Defaults to False. Requires lattice if True. fig: plotly figure or None if a new figure should be created. kwargs: kwargs passed to the matplotlib function 'scatter'. Color defaults to blue Returns: matplotlib figure and matplotlib ax """ #ax, fig, plt = get_ax3d_fig_plt(ax) fig, go = get_fig_plotly(fig=fig) #, **fig_kw) if "marker_color" not in kwargs: kwargs["marker_color"] = "blue" if (not coords_are_cartesian or fold) and lattice is None: raise ValueError("coords_are_cartesian False or fold True require the lattice") from pymatgen.electronic_structure.plotter import fold_point vecs = [] for p in points: if fold: p = fold_point(p, lattice, coords_are_cartesian=coords_are_cartesian) elif not coords_are_cartesian: p = lattice.get_cartesian_coords(p) vecs.append(p) #ax.scatter(*p, **kwargs) kws = dict(textposition="top right", showlegend=False) #, textfont=dict(color='#E58606')) kws.update(kwargs) fig.add_trace(go_points(vecs, labels=labels, **kws)) return fig @add_plotly_fig_kwargs def plotly_brillouin_zone_from_kpath(kpath, fig=None, **kwargs): """ Gives the plot (as a matplotlib object) of the symmetry line path in the Brillouin Zone. Args: kpath (HighSymmKpath): a HighSymmKPath object ax: matplotlib :class:`Axes` or None if a new figure should be created. **kwargs: provided by add_fig_kwargs decorator Returns: plotly figure. """ lines = [[kpath.kpath["kpoints"][k] for k in p] for p in kpath.kpath["path"]] return plotly_brillouin_zone( bz_lattice=kpath.prim_rec, lines=lines, fig=fig, labels=kpath.kpath["kpoints"], **kwargs, ) @add_plotly_fig_kwargs def plotly_brillouin_zone( bz_lattice, lines=None, labels=None, kpoints=None, fold=False, coords_are_cartesian=False, fig=None, **kwargs, ): """ Plots a 3D representation of the Brillouin zone of the structure. Can add to the plot paths, labels and kpoints Args: bz_lattice: Lattice object of the Brillouin zone lines: list of lists of coordinates. Each list represent a different path labels: dict containing the label as a key and the coordinates as value. kpoints: list of coordinates fold: whether the points should be folded inside the first Brillouin Zone. Defaults to False. Requires lattice if True. coords_are_cartesian: Set to True if you are providing coordinates in cartesian coordinates. Defaults to False. ax: matplotlib :class:`Axes` or None if a new figure should be created. kwargs: provided by add_fig_kwargs decorator Returns: plotly figure """ fig = plotly_lattice_vectors(bz_lattice, fig=fig) plotly_wigner_seitz(bz_lattice, fig=fig) if lines is not None: for line in lines: plotly_path(line, bz_lattice, coords_are_cartesian=coords_are_cartesian, fig=fig) if labels is not None: # TODO #plotly_labels(labels, bz_lattice, coords_are_cartesian=coords_are_cartesian, ax=ax) plotly_points( labels.values(), lattice=bz_lattice, coords_are_cartesian=coords_are_cartesian, fold=False, labels=list(labels.keys()), fig=fig, ) if kpoints is not None: plotly_points( kpoints, lattice=bz_lattice, coords_are_cartesian=coords_are_cartesian, fold=fold, fig=fig, ) #ax.set_xlim3d(-1, 1) #ax.set_ylim3d(-1, 1) #ax.set_zlim3d(-1, 1) # ax.set_aspect('equal') #ax.axis("off") return fig