"""
Tools to analyze MD trajectories and compute diffusion coefficients.
"""
from __future__ import annotations
import dataclasses
import json
import os
import warnings
from functools import cached_property
from pathlib import Path
import numpy as np
import pandas as pd
from matplotlib.offsetbox import AnchoredText
from monty.bisect import find_le
from monty.collections import AttrDict # , dict2namedtuple
from monty.string import list_strings, marquee
from monty.termcolor import cprint
from pymatgen.core import units
from pymatgen.core.composition import Composition
from pymatgen.core.lattice import Lattice
from pymatgen.util.string import latexify
from scipy import optimize
from scipy.stats import linregress
# from abipy.core.mixins import TextFile # , NotebookWriter
from abipy.core.structure import Structure
from abipy.dynamics.cpx import EvpFile, parse_file_with_blocks
from abipy.tools.context_managers import Timer
from abipy.tools.iotools import try_files # , file_with_ext_indir
from abipy.tools.parallel import pool_nprocs_pmode # , get_max_nprocs
from abipy.tools.plotting import (
add_fig_kwargs,
get_ax_fig_plt,
get_axarray_fig_plt,
get_color_symbol,
set_axlims,
set_logscale,
)
from abipy.tools.serialization import HasPickleIO
from abipy.tools.typing import Figure, PathLike
__author__ = "Giuliana Materzanini, Tommaso Chiarotti, Matteo Giantomassi"
Ang2PsTocm2S = 0.0001
e2s = 1.602188**2 # electron charge in Coulomb scaled by 10.d-19**2
kbs = 1.38066 # Boltzmann constant in Joule/K scaled by 10.d-23
kBoltzEv = 8.617333e-05
# nCar = 56 # FIXME: Hardcoded
[docs]
def common_oxidation_states() -> dict:
oxi2symbols = {
+1: ["Li", "Na", "K", "Rb", "Cs", "Fr"],
+2: ["Be", "Mg", "Ca", "Sr", "Ba", "Ra"],
}
# map: symbol to oxidation state.
symb2oxi = {}
for oxi, slist in oxi2symbols.items():
for s in slist:
symb2oxi[s] = oxi
return symb2oxi
[docs]
def read_structure_postac_ucmats(
traj_filepath: PathLike, step_skip: int
) -> tuple[Structure, np.ndarray, np.ndarray, int]:
"""
Read all configurations from an ASE trajectory file.
Args:
traj_filepath: File path.
step_skip: Sampling frequency.
time_step should be multiplied by this number to get the real time between measurements.
Returns:
tuple with: initial Structure, (nsteps, natom, 3) array with the Cartesian coords,
(nsteps,3,3) array with cell vectors.
"""
from ase.io import read
traj = read(traj_filepath, index=":")
traj_len = len(traj)
structure = Structure.as_structure(traj[0])
pos_tac, ucmats = [], []
for it in range(0, traj_len, step_skip):
atoms = traj[it]
pos_tac.append(atoms.positions)
ucmats.append(atoms.cell.array)
del traj
pos_tac = np.array(pos_tac, dtype=float)
ucmats = np.array(ucmats, dtype=float)
return structure, pos_tac, ucmats, traj_len
[docs]
class MdAnalyzer(HasPickleIO):
"""
High-level interface to read MD trajectories and metadata from external files,
compute the MSQD and plot the results.
"""
[docs]
@classmethod
def from_abiml_dir(cls, directory: PathLike, step_skip: int = 1) -> MdAnalyzer:
"""
Build an instance from a directory containing an ASE trajectory file and
a JSON file with the MD parameters as produced by the `abiml.py md` script.
"""
directory = Path(str(directory))
structure, pos_tac, ucmats, traj_len = read_structure_postac_ucmats(directory / "md.traj", step_skip)
# Read metadata from the JSON file.
with open(directory / "md.json") as fh:
meta = json.load(fh)
temperature = meta["temperature"]
# Convert from ASE fs to ps
timestep = meta["timestep"] * 1e-3
loginterval = meta["loginterval"]
engine = meta["nn_name"]
times = (np.arange(0, traj_len) * timestep * loginterval)[::step_skip].copy()
log_path = try_files([directory / "md.aselog", directory / "md.log"])
from abipy.ml.aseml import AseMdLog
with AseMdLog(log_path) as log:
evp_df = log.df.copy()
return cls(structure, temperature, times, pos_tac, ucmats, engine, evp_df=evp_df)
[docs]
@classmethod
def from_hist_file(cls, hist_filepath: PathLike, step_skip: int = 1) -> MdAnalyzer:
"""
Build an instance from an ABINIT HIST.nc file.
"""
from abipy.dynamics.hist import HistFile
with HistFile(hist_filepath) as hist:
structure = hist.structure.copy()
# hist.r.read_dimvalue("time")
pos_tac = hist.r.read_value("xcart") * units.bohr_to_ang
ucmats = hist.r.read_value("rprimd") * units.bohr_to_ang
# temperature = None
# evp_df = None
# times = (np.arange(0, traj_len) * r.timestep * r.loginterval)[::step_skip].copy()
if step_skip != 1:
ucmats = ucmats[::step_skip].copy()
pos_tac = pos_tac[::step_skip].copy()
raise NotImplementedError
# return cls(structure, temperature, times, pos_tac, ucmats, "abinit", evp_df=evp_df)
[docs]
@classmethod
def from_vaspruns(cls, filepaths: list) -> MdAnalyzer:
"""
Build an instance from a list of Vasprun files (must be ordered in sequence of MD simulation).
"""
def get_structures(vaspruns):
# This piece of code is shamelessy taken from
# https://github.com/materialsvirtuallab/pymatgen-analysis-diffusion/blob/master/pymatgen/analysis/diffusion/analyzer.py
for i, vr in enumerate(vaspruns):
if i == 0:
step_skip = vr.ionic_step_skip or 1
final_structure = vr.initial_structure
temperature = vr.parameters["TEEND"]
timestep = vr.parameters["POTIM"] # fs
yield step_skip, temperature, timestep
# check that the runs are continuous
from pymatgen.util.coord import pbc_diff
fdist = pbc_diff(vr.initial_structure.frac_coords, final_structure.frac_coords)
if np.any(fdist > 0.001):
raise ValueError("initial and final structures do not match.")
final_structure = vr.final_structure
assert (vr.ionic_step_skip or 1) == step_skip
for s in vr.ionic_steps:
yield s["structure"]
from pymatgen.io.vasp.outputs import Vasprun
with warnings.catch_warnings():
warnings.simplefilter("ignore")
vaspruns = [Vasprun(path) for path in list_strings(filepaths)]
s = get_structures(vaspruns)
step_skip, temperature, timestep = next(s)
# Extract Cartesian positions.
pos_tac = []
for i, strc in enumerate(s):
if i == 0:
structure = strc
pos_tac.append(strc.coords)
nsteps, natom = i + 1, len(structure)
pos_tac = np.reshape(pos_tac, (nsteps, natom, 3))
times = np.arange(0, nsteps) * timestep * step_skip
evp_df = None
return cls(structure, temperature, times, pos_tac, "vasp", evp_df=evp_df)
# @classmethod
# def from_qe_dir(cls, directory: PathLike, step_skip: int=1):
# traj_filepath = file_with_ext_indir(".lammpstrj", directory)
# return cls.from_qe_input(filepath: PathLike, step_skip=step_skip):
[docs]
@classmethod
def from_lammps_dir(cls, directory: PathLike, step_skip: int = 1, basename="in.lammps") -> MdAnalyzer:
"""
Build an instance from a directory containing a LAMMPS input file.
"""
return cls.from_lammp_input(Path(str(directory)) / basename, step_skip=step_skip)
[docs]
@classmethod
def from_lammpstrj(cls, traj_filepath: PathLike, input_filepath: PathLike, step_skip: int = 1) -> MdAnalyzer:
"""
Build an instance from a LAMMPS trajectory file and a log file.
Args:
traj_filepath:
input_filepath:
"""
structure, pos_tac, ucmats, traj_len = read_structure_postac_ucmats(traj_filepath, step_skip)
if input_filepath.endswith(".evp"):
# Extract times from CP EVP file (Giuliana's way)
temperature = None
with EvpFile(input_filepath) as evp:
evp_df = evp.df.copy()
timestep = evp.times[1] - evp.times[0]
loginterval = 1
else:
raise NotImplementedError
times = (np.arange(0, traj_len) * timestep * loginterval)[::step_skip].copy()
return cls(structure, temperature, times, pos_tac, ucmats, "lammps", evp_df=evp_df)
def __init__(
self,
structure: Structure,
temperature: float,
times: np.ndarray,
cart_positions: np.ndarray,
ucmats: np.ndarray,
engine: str,
pos_order: str = "tac",
evp_df=None | pd.DataFrame,
):
"""
Args:
structure: Structure object (first geometry of the MD run).
temperature: Temperature in Kelvin.
times: Array with times in ps units.
cart_positions: Cartesian positions in Ang. Default shape: (nt, natom, 3).
ucmats: Array of lattice matrix of every step. Used for NPT.
For NVT-AIMD, the lattice at each time step is set to the lattice in the "structure" argument.
engine: String defining the engine used to produce the MD trajectory.
pos_order: "tac" if cart_positions has shape (nt, natom, 3).
"atc" if cart_positions has shape (natom, nt, 3).
evp_df:
"""
self.structure = structure
self.times = times
self.engine = engine
if pos_order == "tac":
self.pos_atc = cart_positions.transpose(1, 0, 2).copy()
elif pos_order == "atc":
self.pos_atc = cart_positions
else:
raise ValueError(f"Invalid {pos_order=}")
self.evp_df = evp_df
if np.all(ucmats[it] == ucmats[0] for it in range(len(ucmats))):
self.lattices = None
else:
self.lattices = np.array([Lattice(mat) for mat in ucmats])
self.latex_formula = self.structure.latex_formula
self.temperature = temperature
self.set_color_symbol("VESTA")
self.verbose = 0
self.consistency_check()
[docs]
def consistency_check(self) -> None:
"""
Perform internal consistency check.
"""
if self.pos_atc.shape != (self.natom, self.nt, 3):
raise ValueError(f"Invalid shape {self.pos_atc.shape=}, expecting: {(self.natom, self.nt, 3)}")
if len(self.times) != self.nt:
raise ValueError(f"{len(self.times)=} != {self.nt=}")
# Check times mesh.
ierr = 0
for it in range(self.nt - 1):
dt = self.times[it + 1] - self.times[it]
if abs(dt - self.timestep) > 1e-3:
ierr += 1
if ierr < 10:
print(f"{dt=} != {self.timestep=}")
if ierr:
raise ValueError(f"Time-mesh is not linear. There are {ierr} points with wrong timestep")
if self.lattices is not None and len(self.lattices) != self.nt:
raise ValueError(f"{len(self.lattices)=} != {self.nt=}")
[docs]
def get_params_dict(self) -> dict:
"""Dictionary with the most important parameters."""
attr_names = [
"latex_formula",
"temperature",
"timestep",
"nt",
"max_time",
"natom",
"avg_volume",
"engine",
]
d = {aname: getattr(self, aname) for aname in attr_names}
return d
[docs]
def deepcopy(self) -> MdAnalyzer:
"""Deep copy of the object."""
import copy
return copy.deepcopy(self)
[docs]
def iter_structures(self):
"""Generate pymatgen structures."""
species = self.structure.species
pos_tac = self.pos_atc.transpose(1, 0, 2).copy()
if self.lattices is None:
# Same lattice.
const_lattice = self.structure.lattice
for coords in pos_tac:
yield Structure(const_lattice, species, coords, coords_are_cartesian=True)
else:
for coords, lattice in zip(pos_tac, self.lattices, strict=False):
yield Structure(lattice.copy(), species, coords, coords_are_cartesian=True)
# def iter_atoms(self):
# """Generate ASE atoms"""
# Atoms(symbols=None,
# positions=None, numbers=None,
# tags=None, momenta=None, masses=None,
# magmoms=None, charges=None,
# scaled_positions=None,
# cell=None, pbc=None, celldisp=None,
# constraint=None,
# calculator=None,
# info=None,
# velocities=None)
[docs]
def resample_step(self, start_at_step: int, take_every: int) -> MdAnalyzer:
"""
Resample the trajectory. Start at iteration start_at_step and increase
the timestep by taking every `take_every` iteration.
"""
return self.resample_time(start_time=start_at_step * self.timestep, new_timestep=take_every * self.timestep)
[docs]
def resample_time(self, start_time: float, new_timestep: float) -> MdAnalyzer:
"""
Resample the trajectory. Start at time `start_time` and use new timestep `new_timestep`.
"""
# TODO: This is not true anymore!
# NB: Cannot change the object in place as SigmaBerend and DiffusionData keep a reference to self.
new = self.deepcopy()
old_timestep = new.times[1] - new.times[0]
if not (new.times[-1] > start_time > new.times[0]):
raise ValueError(f"Invalid start_time should be between {new.times[0]} and {new.times[-1]})")
it0 = int(start_time / old_timestep)
if it0 != 0:
new.pos_atc = new.pos_atc[:, it0:, :]
new.times = new.times[it0:] - new.times[it0]
if new.lattices is not None:
new.lattices = new.lattices[it0:]
if self.evp_df is not None:
self.evp_df = self.evp_df.iloc[it0:]
if new_timestep < old_timestep:
raise ValueError(f"Invalid {new_timestep=} should be >= {old_timestep}")
istep = int(new_timestep / old_timestep)
if istep != 1:
new.pos_atc = new.pos_atc[:, ::istep, :].copy()
new.times = new.times[::istep] - new.times[0]
if new.lattices is not None:
new.lattices = new.lattices[::istep].copy()
if self.evp_df is not None:
self.evp_df = self.evp_df.iloc[::istep]
new.consistency_check()
return new
@property
def timestep(self) -> float:
"""Timestep in ps."""
return self.times[1] - self.times[0]
@property
def max_time(self) -> float:
"""Maximum simulation time in ps."""
return self.times[-1]
@property
def nt(self) -> int:
"""Number of points in the MD trajectory."""
return self.pos_atc.shape[1]
[docs]
@cached_property
def natom(self) -> int:
"""Number of atoms."""
return len(self.structure)
@property
def temperature(self) -> float:
"""Temperature in Kelvin."""
return self._temperature
@temperature.setter
def temperature(self, value):
"""Set temperature in Kelvin."""
self._temperature = value
@property
def verbose(self) -> int:
"""Verbosity level."""
return self._verbose
@verbose.setter
def verbose(self, value: int):
"""Set temperature in Kelvin."""
self._verbose = value
@property
def engine(self) -> str:
"""The engine used to produce the MD trajectory."""
return self._engine
@engine.setter
def engine(self, value):
"""Set engine string."""
self._engine = value
@property
def formula(self) -> str:
"""Returns the formula as a string."""
return self.structure.formula
@property
def latex_formula(self) -> str:
"""LaTeX formatted formula."""
return self._latex_formula
@latex_formula.setter
def latex_formula(self, value):
"""LaTeX formatted formula. E.g., Fe2O3 is transformed to Fe$_{2}$O$_{3}$."""
self._latex_formula = latexify(value)
@property
def latex_formula_n_temp(self) -> str:
"""LaTeX formatted formula and temperature."""
return f"{self.latex_formula}\nT = {self.temperature} K"
@property
def latex_avg_volume(self) -> str:
"""LaTeX formatted average volume."""
return r"V$_{\mathrm{ave}}$ = " + f"{self.avg_volume:.2f}" + r"$\mathrm{{\AA}^3}$"
@property
def avg_volume(self) -> float:
"""Average unit cell volume in Ang^3."""
if self.lattices is None:
return self.structure.lattice.volume
return np.mean([lat.volume for lat in self.lattices])
[docs]
def set_color_symbol(self, dict_or_string: dict | str) -> None:
"""
Set the dictionary mapping chemical_symbol --> color
used in the matplotlib plots.
Args:
dict_or_string: "VESTA", "Jmol"
"""
if isinstance(dict_or_string, dict):
self.color_symbol = dict_or_string
else:
self.color_symbol = get_color_symbol(style=dict_or_string)
for symbol in self.structure.symbol_set:
if symbol not in self.color_symbol:
raise KeyError(f"Cannot find {symbol=} in color_symbol dictionary!")
[docs]
def get_it_ts(self, t0: float) -> tuple[int, np.ndarray]:
"""
Return the index of time t0 in self.times and the array with the time values.
"""
if t0 < self.times[0] or t0 > self.times[-1]:
raise ValueError(f"Invalid {t0=}. It should be between {self.times[0]} and {self.times[-1]}")
it0 = find_le(self.times, t0)
return it0, self.times[it0:] - self.times[it0]
def __str__(self) -> str:
return self.to_string()
[docs]
def to_string(self, verbose: int = 0) -> str:
"""String representation with verbosity level verbose."""
lines = []
app = lines.append
app(marquee("MD PARAMETERS", mark="="))
app(pd.Series(self.get_params_dict()).to_string())
if verbose:
app(self.structure.spget_summary(verbose=verbose))
app("\n")
return "\n".join(lines)
[docs]
def iatoms_with_symbol(self, symbol: str, atom_inds=None) -> np.ndarray:
"""
Array with the index of the atoms with the given chemical symbol.
If atom_inds is not None, filter sites accordingly.
"""
iatoms = [iat for iat in range(len(self.structure)) if self.structure[iat].specie.symbol == symbol]
if atom_inds is not None:
iatoms = [iat for iat in iatoms if iat in atom_inds]
if not iatoms:
raise ValueError(f"Empty list of iatoms indices for {symbol=} and {atom_inds=}")
return np.array(iatoms)
def _select_symbols(self, symbols) -> list[str]:
if symbols == "all":
return sorted(self.structure.symbol_set)
return list_strings(symbols)
[docs]
def get_sqdt_iatom(self, iatom: int, it0: int = 0) -> np.array:
"""
Compute the square displacement vs time for a given atomic index
starting from time index it0.
"""
return ((self.pos_atc[iatom, it0:] - self.pos_atc[iatom, it0]) ** 2).sum(axis=1)
[docs]
def get_sqdt_symbol(self, symbol: str, it0: int = 0, atom_inds=None) -> np.array:
"""
Compute the square displacement vs time averaged over atoms with the same chemical symbol
starting from time index it0. atoms_inds adds an additional filter on the site index.
"""
for count, iatom in enumerate(self.iatoms_with_symbol(symbol, atom_inds=atom_inds)):
if count == 0:
sqdt = self.get_sqdt_iatom(iatom, it0=it0)
else:
sqdt += self.get_sqdt_iatom(iatom, it0=it0)
sqdt /= count + 1
return sqdt
[docs]
def get_dw_symbol(self, symbol, t0: float = 0.0, tmax=None, atom_inds=None):
"""
Compute diffusion coefficient by performing a naive linear regression of the raw MQST.
"""
it0, ts = self.get_it_ts(t0)
it1 = -1
if tmax is not None:
it1, _ = self.get_it_ts(tmax)
ts = ts[:it1]
sqdt = self.get_sqdt_symbol(symbol, it0=it0, atom_inds=atom_inds)
fit = linregress(ts[:it1], sqdt[:it1])
naive_d = fit.slope * Ang2PsTocm2S / 6
label = rf"Linear fit: D={naive_d:.2E} cm$^2$/s, $r^2$={fit.rvalue**2:.2f}"
return AttrDict(naive_d=naive_d, ts=ts, fit=fit, label=label)
[docs]
def get_msdtt0_symbol_tmax(self, symbol: str, tmax: float, atom_inds=None, nprocs=None) -> Msdtt0:
r"""
Calculates the MSD for every possible pair of time points using the formula:
$$MSD(t,t_0) = \frac{1}{N} \sum_{i=1}^{N} (\vec{r}_i(t+t_0) - \vec{r}_i(t_0))^2$$
where $N$ is the number of particles with the given symbol, and $\vec{r}_i(t)$ is the position vector.
Args:
symbols:
tmax:
atoms_ins
"""
index_tmax, _ = self.get_it_ts(tmax)
iatoms = self.iatoms_with_symbol(symbol, atom_inds=atom_inds)
tac = self.pos_atc[iatoms].transpose(1, 0, 2).copy()
arr_tt0 = msd_tt0_from_tac(tac, index_tmax, nprocs=nprocs)
return Msdtt0(arr_tt0=arr_tt0, mda=self, index_tmax=index_tmax, symbol=symbol)
[docs]
@add_fig_kwargs
def plot_sqdt_atoms(
self, symbols="all", t0: float = 0.0, atom_inds=None, ax=None, xy_log=None, fontsize=8, xlims=None, **kwargs
) -> Figure:
"""
Plot the square displacement of atoms vs time.
Args:
symbols: List of chemical symbols to consider.
t0: Initial time in ps.
atom_inds: List of atom indices to include. None to disable filtering.
ax: |matplotlib-Axes| or None if a new figure should be created.
xy_log: None or empty string for linear scale. "x" for log scale on x-axis.
"xy" for log scale on x- and y-axis. "x:semilog" for semilog scale on x-axis.
fontsize: fontsize for legends and titles.
xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
"""
it0, ts = self.get_it_ts(t0)
ax, fig, plt = get_ax_fig_plt(ax=ax)
for symbol in self._select_symbols(symbols):
for iatom in self.iatoms_with_symbol(symbol, atom_inds=atom_inds):
sd_t = self.get_sqdt_iatom(iatom, it0=it0)
ax.plot(ts, sd_t, label=f"${symbol}_{{{iatom}}}$")
ax.legend(fontsize=fontsize, loc="upper left")
ax.set_xlabel("t (ps)", fontsize=fontsize)
ax.set_ylabel(r"square displacement ($\mathrm{{\AA}^2}$)", fontsize=fontsize)
set_axlims(ax, xlims, "x")
# set_ticks_fontsize(ax, fontsize)
set_logscale(ax, xy_log)
ax.add_artist(
AnchoredText(
f"{self.latex_formula_n_temp}\n{self.latex_avg_volume}\n"
"sd(t, $t_0$ =" + str(int(self.times[it0])) + " ps)",
loc="upper right",
prop=dict(size=fontsize),
)
)
return fig
[docs]
@add_fig_kwargs
def plot_sqdt_symbols(
self,
symbols,
t0: float = 0.0,
atom_inds=None,
with_dw=0,
ax=None,
xy_log=None,
fontsize=8,
xlims=None,
**kwargs,
) -> Figure:
"""
Plot the square displacement averaged over all atoms of the same specie vs time.
Args:
symbols: List of chemical symbols to consider. "all" for all symbols in structure.
t0: Initial time in ps.
atom_inds: List of atom indices to include. None to disable filtering.
with_dw: If != 0 compute diffusion coefficient via least-squares fit in the time-interval [t0, with_dw].
If with_dw < 0 e.g. -1 the time-interval is set to [t0, tmax].
ax: |matplotlib-Axes| or None if a new figure should be created.
xy_log: None or empty string for linear scale. "x" for log scale on x-axis.
"xy" for log scale on x- and y-axis. "x:semilog" for semilog scale on x-axis.
fontsize: fontsize for legends and titles.
xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
"""
it0, ts = self.get_it_ts(t0)
ax, fig, plt = get_ax_fig_plt(ax=ax)
for symbol in self._select_symbols(symbols):
ax.plot(
ts,
self.get_sqdt_symbol(symbol, it0=it0, atom_inds=atom_inds),
label=symbol + " msd($t, t_0$ = " + str(t0) + " ps)",
color=self.color_symbol[symbol],
)
if with_dw != 0:
tmax = self.times[-1] if with_dw < 0 else with_dw
dw = self.get_dw_symbol(symbol, t0=t0, tmax=with_dw, atom_inds=atom_inds)
ax.plot(
dw.ts,
dw.fit.slope * dw.ts + dw.fit.intercept,
label=symbol + " " + dw.label,
color=self.color_symbol[symbol],
)
ax.legend(fontsize=fontsize, loc="upper left")
ax.set_xlabel("t (ps)", fontsize=fontsize)
ax.set_ylabel(r"mean square displacement ($\mathrm{{\AA}^2}$)", fontsize=fontsize)
set_axlims(ax, xlims, "x")
# set_ticks_fontsize(ax, fontsize)
set_logscale(ax, xy_log)
ax.add_artist(
AnchoredText(
f"{self.latex_formula_n_temp}\n{self.latex_avg_volume}", loc="upper right", prop=dict(size=fontsize)
)
)
return fig
[docs]
@add_fig_kwargs
def plot_sqdt_symbols_tmax(
self, symbols, tmax: float, atom_inds=None, nprocs=None, ax=None, xy_log=None, fontsize=8, xlims=None, **kwargs
) -> Figure:
"""
Plot the square displacement averaged over all atoms of the same specie vs time.
Args:
symbols: List of chemical symbols to consider. "all" for all symbols in structure.
tmax: Max time in ps.
atom_inds: List of atom indices to include. None to disable filtering.
nprocs: Number of procs to use.
ax: |matplotlib-Axes| or None if a new figure should be created.
xy_log: None or empty string for linear scale. "x" for log scale on x-axis.
"xy" for log scale on x- and y-axis. "x:semilog" for semilog scale on x-axis.
fontsize: fontsize for legends and titles
xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
"""
index_tmax, _ = self.get_it_ts(tmax)
ax, fig, plt = get_ax_fig_plt(ax=ax)
for symbol in self._select_symbols(symbols):
iatoms = self.iatoms_with_symbol(symbol, atom_inds=atom_inds)
tac = self.pos_atc[iatoms].transpose(1, 0, 2).copy()
msd_tt0 = msd_tt0_from_tac(tac, index_tmax, nprocs=nprocs)
msd_t = np.mean(msd_tt0, axis=1)
t_start = self.nt - index_tmax
ts = self.times[t_start:] - self.times[t_start]
ax.plot(
ts,
msd_t,
label=symbol + r" <msd($t$, $t_0$)>$\{$t_0$\}$, $t$ = [0, " + str(int(self.times[index_tmax])) + " ps]",
color=self.color_symbol[symbol],
)
set_axlims(ax, xlims, "x")
ax.legend(fontsize=fontsize, loc="upper left")
ax.set_xlabel("t (ps)", fontsize=fontsize)
ax.set_ylabel(r"average mean square displacement ($\mathrm{{\AA}^2}$)", fontsize=fontsize)
# set_ticks_fontsize(ax, fontsize)
set_logscale(ax, xy_log)
ax.add_artist(
AnchoredText(
f"{self.latex_formula_n_temp}\n{self.latex_avg_volume}", loc="upper right", prop=dict(size=fontsize)
)
)
return fig
[docs]
@add_fig_kwargs
def plot_lattices(
self, what_list=("abc", "angles", "volume"), ax_list=None, xy_log=None, fontsize=8, xlims=None, **kwargs
) -> Figure:
"""
Plot lattice lengths/angles/volume as a function of time.
Args:
what_list: List of strings specifying the quantities to plot. Default all
ax_list: List of axis or None if a new figure should be created.
xy_log: None or empty string for linear scale. "x" for log scale on x-axis.
"xy" for log scale on x- and y-axis. "x:semilog" for semilog scale on x-axis.
fontsize: fontsize for legends and titles
xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
"""
if self.lattices is None:
cprint("MD simulation has been performed with fixed lattice!", color="red")
return None
what_list = list_strings(what_list)
ax_list, fig, plt = get_axarray_fig_plt(
ax_list, nrows=1, ncols=len(what_list), sharex=True, sharey=False, squeeze=False
)
markers = ["o", "^", "v"]
cnt = -1
if "abc" in what_list:
# plot lattice parameters.
for i, label in enumerate(["a", "b", "c"]):
cnt += 1
ax = ax_list[cnt]
ax.plot(self.times, [lattice.abc[i] for lattice in self.lattices], label=label, marker=markers[i])
ax.set_ylabel("abc (A)")
if "angles" in what_list:
# plot lattice angles.
for i, label in enumerate(["alpha", "beta", "gamma"]):
cnt += 1
ax = ax_list[cnt]
ax.plot(self.times, [lattice.angles[i] for lattice in self.lattices], label=label, marker=markers[i])
ax.set_ylabel(r"$\alpha\beta\gamma$ (degree)")
if "volume" in what_list:
# plot lattice volume.
marker = "o"
cnt += 1
ax = ax_list[cnt]
ax.plot(self.times, [lattice.volume for lattice in self.lattices], label="Volume", marker=marker)
ax.set_ylabel(r"$V\, (A^3)$")
for ix, ax in enumerate(ax_list):
set_logscale(ax, xy_log)
set_axlims(ax, xlims, "x")
if ix == len(ax_list) - 1:
ax.set_xlabel("t (ps)", fontsize=fontsize)
ax.legend(loc="best", shadow=True, fontsize=fontsize)
return fig
[docs]
@dataclasses.dataclass(kw_only=True)
class Msdtt0:
r"""
This object stores:
$$MSD(t,t_0) = \frac{1}{N} \sum_{i=1}^{N} (\vec{r}_i(t+t_0) - \vec{r}_i(t_0))^2$$
where $N$ is the number of particles of a particular chemical symbol and $\vec{r}_i(t)$ is the position vector.
"""
index_tmax: int
symbol: str
arr_tt0: np.ndarray
mda: MdAnalyzer
@property
def times(self) -> np.ndarray:
"""Time mesh."""
return self.mda.times
@property
def temperature(self) -> float:
"""Temperature in Kelvin."""
return self.mda.temperature
[docs]
@cached_property
def msd_t(self) -> np.ndarray:
"""Average of MSD(t,t_0) over t0."""
return np.mean(self.arr_tt0, axis=1)
def __str__(self) -> str:
return self.to_string()
[docs]
def to_string(self, verbose=0) -> str:
"""String representation with verbosity level verbose."""
lines = []
app = lines.append
r = self.get_linfit_results()
app(r.label)
return "\n".join(lines)
[docs]
def get_linfit_results(self):
"""
Perform linear fit.
"""
t_start = self.mda.nt - self.index_tmax
ts = self.times[t_start:] - self.times[t_start]
fit = linregress(ts, self.msd_t)
naive_d = fit.slope * Ang2PsTocm2S / 6
label = rf"Linear fit: D={naive_d:.2E} cm$^2$/s, $r^2$={fit.rvalue**2:.2f}"
return AttrDict(naive_d=naive_d, fit=fit, label=label)
[docs]
@add_fig_kwargs
def plot(self, ax=None, xy_log=None, fontsize=8, xlims=None, **kwargs) -> Figure:
"""
Plot <msd($t, t_0$)>$ averaged over the initial time t0.
Args:
ax: |matplotlib-Axes| or None if a new figure should be created.
xy_log: None or empty string for linear scale. "x" for log scale on x-axis.
"xy" for log scale on x- and y-axis. "x:semilog" for semilog scale on x-axis.
fontsize: fontsize for legends and titles
xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
"""
index_tmax = self.index_tmax
t_start = self.mda.nt - index_tmax
ts = self.times[t_start:] - self.times[t_start]
ax, fig, plt = get_ax_fig_plt(ax=ax)
ax.plot(
ts,
self.msd_t,
label=self.symbol + r" <msd($t, t_0$)>$\{$t_0$\}$, t = [0, " + str(int(self.times[index_tmax])) + " ps]",
color=self.mda.color_symbol[self.symbol],
)
# Linear fit.
r = self.get_linfit_results()
ax.plot(ts, r.fit.slope * ts + r.fit.intercept, label=r.label)
set_axlims(ax, xlims, "x")
ax.legend(fontsize=fontsize, loc="upper left")
ax.set_xlabel("t (ps)", fontsize=fontsize)
ax.set_ylabel(r"Average mean square displacement ($\mathrm{{\AA}^2}$)", fontsize=fontsize)
# set_ticks_fontsize(ax, fontsize)
set_logscale(ax, xy_log)
ax.add_artist(
AnchoredText(
f"{self.mda.latex_formula_n_temp}\n{self.mda.latex_avg_volume}",
loc="upper right",
prop=dict(size=fontsize),
)
)
return fig
[docs]
def get_sigma_berend(self, t1: float, t2: float, nblock_step: int = 1, tot_block: int = 1000) -> SigmaBerend:
"""
Args:
t1:
t2:
nblock_step
tot_block
"""
# choose the time elapsed
it1, _ = sfind_ind_val(self.times, t1)
it2, _ = sfind_ind_val(self.times, t2)
size_t = self.arr_tt0.shape[0]
if it1 >= it2:
raise ValueError(f"For input {t1=} and {t2=}, got {it1=} >= {it2=}")
if it1 >= size_t:
raise ValueError(f"For input {t1=}, got {it1=} >= {size_t=}")
if it2 >= size_t:
raise ValueError(f"For input {t2=}, got {it2=} >= {size_t=}")
block_sizes1, sigmas1, delta_sigmas1 = sigma_berend(nblock_step, tot_block, self.arr_tt0[it1, :])
block_sizes2, sigmas2, delta_sigmas2 = sigma_berend(nblock_step, tot_block, self.arr_tt0[it2, :])
# Build instance from locals dict.
mda = self.mda
latex_formula = mda.latex_formula
temperature = mda.temperature
time1, time2 = mda.times[it1], mda.times[it2]
data = locals()
return SigmaBerend(**{k: data[k] for k in [field.name for field in dataclasses.fields(SigmaBerend)]})
[docs]
def get_diffusion_with_sigma(
self, block_size1: int, block_size2: int, fit_time_start: float, fit_time_stop: float, sigma_berend: SigmaBerend
) -> DiffusionData:
"""
Compute diffusion coefficient with uncertainty.
"""
sigmas1, block_sizes1 = sigma_berend.sigmas1, sigma_berend.block_sizes1
sigmas2, block_sizes2 = sigma_berend.sigmas2, sigma_berend.block_sizes2
it1, it2 = sigma_berend.it1, sigma_berend.it2
mda = self.mda
symbol = self.symbol
temperature = mda.temperature
latex_formula = mda.latex_formula
engine = mda.engine
avg_volume = mda.avg_volume
ncarriers = len(mda.structure.indices_from_symbol(symbol))
times = self.times
ib1, size1 = sfind_ind_val(block_sizes1, block_size1)
sig1 = sigmas1[ib1]
print(f"{sig1=}")
ib2, size2 = sfind_ind_val(block_sizes2, block_size2)
sig2 = sigmas2[ib2]
print(f"{sig2=}")
# fit to a linear behaviour the errors
mSigma = (sig2 - sig1) / (it2 - it1)
qSigma = sig1 - mSigma * it1
mDataInBlock = (size2 - size1) / (it2 - it1)
qDataInBlock = size1 - mDataInBlock * it1
# and find error for anytime.
msd_tt0 = self.arr_tt0
err_msd = np.zeros(msd_tt0.shape[0], dtype=float)
for t in range(msd_tt0.shape[0]):
err_msd[t] = abs(mSigma * t + qSigma)
# and find error for anytime
dataScorrelated = np.zeros(msd_tt0.shape[0], dtype=int)
for t in range(msd_tt0.shape[0]):
dataScorrelated[t] = int(mDataInBlock * t + qDataInBlock)
# average over the initial times.
msd_t = np.mean(msd_tt0, axis=1)
fit_istart, _ = sfind_ind_val(times, fit_time_start)
fit_istop, _ = sfind_ind_val(times, fit_time_stop)
msdSScorrelated, timeArrScorrelated, errMSDScorrelated = [], [], []
counter, condition = fit_istart, True
while condition:
if counter >= fit_istop:
condition = False
else:
index = dataScorrelated[counter]
msdSScorrelated.append(msd_t[counter])
timeArrScorrelated.append(times[counter])
errMSDScorrelated.append(err_msd[counter])
counter = counter + index
msdSScorrelated = np.array(msdSScorrelated, dtype=float)
timeArrScorrelated = np.array(timeArrScorrelated, dtype=float)
errMSDScorrelated = np.array(errMSDScorrelated, dtype=float)
best_angcoeff, quote, var_angcoeff, var_quote = linear_lsq_linefit(
msdSScorrelated, timeArrScorrelated, 1 / (errMSDScorrelated) ** 2
)
min_angcoeff = best_angcoeff - np.sqrt(var_angcoeff)
max_angcoeff = best_angcoeff + np.sqrt(var_angcoeff)
diffusion_coeff = best_angcoeff * Ang2PsTocm2S / 6
err_diffusion_coeff = np.sqrt(var_angcoeff) * Ang2PsTocm2S / 6
# TODO: Charge from oxidation state?
conductivity = e2s / kbs * ncarriers * diffusion_coeff / avg_volume / temperature * 1.0e09
err_conductivity = e2s / kbs * ncarriers / avg_volume / temperature * err_diffusion_coeff * 1.0e09
print(f"{ncarriers=} for {symbol=}")
print(f"{best_angcoeff * Ang2PsTocm2S / 6:.2E}")
print(best_angcoeff)
print(max_angcoeff)
print(min_angcoeff)
# Build instance from locals dict.
data = locals()
return DiffusionData(**{k: data[k] for k in [field.name for field in dataclasses.fields(DiffusionData)]})
"""
('timeStepJump = ' + str(timeStepJump) + '\n')
('time, cel and pos were cut before ' + str(timeArray[0]+tInitial)+ 'ps' + '\n')
('t elapsed max is ' + str(estimatedTMax)+ 'ps' + '\n')
('trajectory length is ' + str(timeArrayTmp[timeArrayTmp.shape[0]-1])+ 'ps' + '\n')
('error on msd(t) evaluated at ' + str(estimatedFirstTElapsed) + 'ps' + ' and ' + str(estimatedSecondTElapsed) + 'ps' + '\n')
('evaluated n. of blocks at ' + str(estimatedFirstTElapsed) + 'ps' + ' is ' + str(block_size1) + '\n')
('evaluated n. of blocks at ' + str(estimatedSecondTElapsed) + 'ps' + ' is ' + str(block_size2) + '\n')
('number of decorrelated msd(t) data that we fit: ' + str(timeArrScorrelated.shape[0]) + '\n')
"""
[docs]
@add_fig_kwargs
def plot_mat(self, cmap="jet", fontsize=8, ax=None, **kwargs) -> Figure:
"""Plot the MSD(t, t0) matrix."""
ax, fig, plt = get_ax_fig_plt(ax=ax)
im = ax.matshow(self.arr_tt0, cmap=cmap)
fig.colorbar(im, ax=ax)
ax.set_xlabel("t0 (ps)", fontsize=fontsize)
ax.set_ylabel("t (ps)", fontsize=fontsize)
ax.set_title(self.symbol + ": msd($t, t_0$)", fontsize=fontsize)
return fig
[docs]
class Msdtt0List(list):
"""
A list of Msdtt0 objects.
"""
def __str__(self) -> str:
return self.to_string()
[docs]
def to_string(self, verbose=0) -> str:
"""String representation with verbosity level verbose."""
lines = []
app = lines.append
for i, msdtt0 in enumerate(self):
app(msdtt0.to_string(verbose=verbose))
return "\n".join(lines)
[docs]
@add_fig_kwargs
def plot(self, sharex=True, sharey=True, fontsize=8, **kwargs) -> Figure:
"""
Plot all Msdtt0 objects on a grid.
"""
nrows, ncols = len(self), 1
ax_list, fig, plt = get_axarray_fig_plt(
None, nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, squeeze=False
)
ax_list = ax_list.ravel()
if len(self) % ncols != 0:
ax_list[-1].axis("off")
for ix, (msdtt0, ax) in enumerate(zip(self, ax_list, strict=False)):
msdtt0.plot(ax=ax, fontsize=fontsize, show=False, **kwargs)
return fig
[docs]
@dataclasses.dataclass(kw_only=True)
class SigmaBerend:
"""
Stores the variance of correlated data as function of block number.
"""
temperature: float
latex_formula: str
it1: int
time1: float
block_sizes1: np.ndarray
sigmas1: np.ndarray
delta_sigmas1: np.ndarray
it2: int
time2: float
block_sizes2: np.ndarray
sigmas2: np.ndarray
delta_sigmas2: np.ndarray
[docs]
@add_fig_kwargs
def plot(self, fontsize=8, ax_list=None, **kwargs) -> Figure:
"""
Plot variance of correlated data as function of block number.
"""
ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=1, ncols=2, sharex=True, sharey=True, squeeze=False)
for ix, ax in enumerate(ax_list.ravel()):
xs, ys, yerr, it, time = self.block_sizes1, self.sigmas1, self.delta_sigmas1, self.it1, self.time1
if ix == 1:
xs, ys, yerr, it, time = self.block_sizes2, self.sigmas2, self.delta_sigmas2, self.it2, self.time2
ax.errorbar(
xs,
ys,
yerr=yerr,
linestyle="-", # linewidth=0.5,
label=r"$\sigma(\mathrm{MSD}($"
+ "%2.1f" % time
+ " ps$))$ "
+ "\n"
+ self.latex_formula
+ ", "
+ "T = %4.0f" % self.temperature
+ "K",
)
ax.legend(fontsize=fontsize, loc="lower right")
ax.set_xlabel("N. of data in block", fontsize=fontsize)
ax.set_ylabel(r"$\sigma$ ($\AA^2$)", fontsize=fontsize)
ax.grid(True)
fig.suptitle("Variance of correlated data as function of block number")
return fig
[docs]
@dataclasses.dataclass(kw_only=True)
class DiffusionData(HasPickleIO):
"""
Diffusion results for a given temperature.
"""
diffusion_coeff: float
err_diffusion_coeff: float
conductivity: float
err_conductivity: float
temperature: float
symbol: str
latex_formula: str
avg_volume: float
ncarriers: int
block_size1: int
block_size2: int
fit_time_start: float # msd(t) fit starts at this time.
fit_time_stop: float # msd(t) fit ends at this time.
min_angcoeff: float # min angular coefficient of msd(t)
max_angcoeff: float # max angular coefficient of msd(t)
best_angcoeff: float # best angular coefficient of msd(t)
engine: str
msd_t: np.ndarray
err_msd: np.ndarray
sigma_berend: SigmaBerend
times: np.ndarray
# TODO: Find better names
quote: float
timeArrScorrelated: np.ndarray
msdSScorrelated: np.ndarray
errMSDScorrelated: np.ndarray
[docs]
@add_fig_kwargs
def plot(self, ax=None, fontsize=8, **kwargs) -> Figure:
"""
Plot MDS(t) with errors.
"""
ax, fig, plt = get_ax_fig_plt(ax=ax)
ts = self.times[: self.msd_t.shape[0]]
ax.errorbar(ts, self.msd_t, yerr=self.err_msd, color="mediumblue", label=self.symbol)
ax.errorbar(self.timeArrScorrelated, self.msdSScorrelated, yerr=self.errMSDScorrelated, linestyle="-")
ax.errorbar(ts, self.best_angcoeff * ts + self.quote, linestyle="--")
ax.errorbar(ts, self.min_angcoeff * ts + self.quote, linestyle="--")
ax.errorbar(ts, self.max_angcoeff * ts + self.quote, linestyle="--")
ax.set_xlabel("t (ps)", fontsize=fontsize)
ax.set_ylabel(r"$\mathrm{MSD}_\mathrm{tr}$ $\mathrm{(\AA}^2\mathrm{)}$", fontsize=fontsize)
ax.add_artist(
AnchoredText(
"D$_{tr}$ = ("
+ str(f"{self.diffusion_coeff:.2E}")
+ "\u00b1"
+ str(f"{self.err_diffusion_coeff:.2E}")
+ ") cm$^2$/s",
loc="upper left",
prop=dict(size=fontsize),
)
)
ax.legend(fontsize=fontsize, loc="lower right")
ax.add_artist(
AnchoredText(f"{self.latex_formula}\nT = {self.temperature} K", loc="upper right", prop=dict(size=fontsize))
)
return fig
[docs]
class DiffusionDataList(list):
"""
A list of DiffusionData objects.
"""
# @classmethod
# def from_topdir(cls, topdir):
# return cls.from_files(filepaths)
# @classmethod
# def from_files(cls, filepaths):
# new = cls()
# for path in filepaths:
# new.append(DiffusionData.from_file(path))
# return new
# def _nrows_ncols_nplots(self, size=None):
# size = size or len(self)
# nrows, ncols, nplots = 1, 1, size
# if nplots > 1:
# ncols = 2; nrows = nplots // ncols + nplots % ncols
# return nrows, ncols, nplots
# def filter(self, filter_dict: dict) -> DiffusionDataList:
# """
# filter_dict = [{symbol: "Li"}
# """
# new = DiffusionDataList()
# for df_data in self:
# if all(getattr(df_data, a_name) == a_value for (a_name, a_value) in filter_dict.items()):
# new.append(df_data)
# return new
# def write_csv(self, filename: PathLike) -> None:
# """
# Writes data to a file that can be easily plotted in other software.
# Args:
# filename: Supported formats are csv and dat.
# If the extension is csv, a csv file is written. Otherwise, a dat format is assumed.
# """
# delimiter = ", "
# with open(filename, "wt") as f:
# f.write("T,diffusion,err_diffusion,volume,symbol,composition\n")
# #for dt, msd, msdc, mscd in zip(
# # self.dt, self.msd, self.msd_components, self.mscd
# #):
# # f.write(delimiter.join([str(v) for v in [dt, msd, *list(msdc), mscd]]))
# # f.write("\n")
[docs]
def get_dataframe(self, add_keys=None) -> pd.DataFrame:
"""
Dataframe with diffusion results.
Args:
add_keys: optional list of attributes to add.
"""
keys = [
"temperature",
"latex_formula",
"symbol",
"diffusion_coeff",
"err_diffusion_coeff",
"conductivity",
"err_conductivity",
"engine",
]
if add_keys is not None:
keys.extend(list_strings(add_keys))
d = {k: np.array([getattr(data, k) for data in self]) for k in keys}
df = pd.DataFrame(d).sort_values(["latex_formula", "engine", "temperature"])
return df
[docs]
@add_fig_kwargs
def plot_sigma_berend(self, **kwargs) -> Figure:
"""
Plot variance of correlated data as function of block number
for all objects stored in DiffusionDataList.
"""
nrows, ncols = len(self), 2
ax_mat, fig, plt = get_axarray_fig_plt(
None, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False
)
for ix, (data, ax_list) in enumerate(zip(self, ax_mat, strict=False)):
data.sigma_berend.plot(ax_list=ax_list, show=False)
return fig
[docs]
@add_fig_kwargs
def plot(self, **kwargs) -> Figure:
"""
Plot ...
for all objects stored DiffusionDataList.
"""
nrows, ncols = len(self), 1
ax_list, fig, plt = get_axarray_fig_plt(
None, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False
)
ax_list = ax_list.ravel()
if len(self) % ncols != 0:
ax_list[-1].axis("off")
for i, (data, ax) in enumerate(zip(self, ax_list, strict=False)):
data.plot(ax=ax, show=False)
return fig
[docs]
def get_arrhenius_nmtuples(self, df: pd.Dataframe, hue: str | None) -> list:
"""
Return list of namedtuple objects.
Args:
df: |pandas-DataFrame|.
hue: Variable defining how to group data. If None, no grouping is performed.
"""
df = df.sort_values(["temperature"])
if hue is None:
if not df["temperature"].is_unique:
raise ValueError("Found duplicated values in temperature column. Please specify hue")
return [
AttrDict(
temps=df["temperature"].values,
d_coeffs=df["diffusion_coeff"].values,
d_errs=df["err_diffusion_coeff"].values,
)
]
nt_list = []
for key, grp in df.groupby(hue):
grp = grp.sort_values(["temperature"])
if not grp["temperature"].is_unique:
raise ValueError(f"Found duplicated values in temperature column for {hue=}")
nt = AttrDict(
temps=grp["temperature"].values,
d_coeffs=grp["diffusion_coeff"].values,
d_errs=grp["err_diffusion_coeff"].values,
)
nt_list.append(nt)
return nt_list
[docs]
def yield_figs(self, **kwargs): # pragma: no cover
"""
This function *generates* a predefined list of matplotlib figures with minimal input from the user.
"""
yield self.plot_sigma_berend(show=False)
yield self.plot(show=False)
[docs]
def expose(self, exposer="mpl", **kwargs):
"""Expose the results to an exposer."""
from abipy.tools.plotting import Exposer
with Exposer.as_exposer(exposer) as e:
e(self.yield_figs(**kwargs))
[docs]
class MultiMdAnalyzer(HasPickleIO):
"""
High-level interface to analyze multiple MD trajectories.
"""
[docs]
@classmethod
def from_abiml_dirs(cls, directories: list, step_skip=1, pmode="processes") -> MultiMdAnalyzer:
"""
Build an instance from a list of directories produced by abiml.py md
"""
p = pool_nprocs_pmode(len(directories), pmode=pmode)
using_msg = f"Reading {len(directories)} abiml directories {p.using_msg}"
args = [(dirpath, step_skip) for dirpath in directories]
with p.pool_cls(p.nprocs) as pool, Timer(header=using_msg, footer="") as timer:
return cls(pool.starmap(MdAnalyzer.from_abiml_dir, args))
[docs]
@classmethod
def from_lammps_dirs(cls, directories: list, step_skip=1, basename="in.lammps", pmode="processes") -> MdAnalyzer:
"""
Build an instance from a list of directories containing LAMMPS results.
"""
p = pool_nprocs_pmode(len(directories), pmode=pmode)
using_msg = f"Reading {len(directories)} LAMMPS directories {p.using_msg}"
args = [(dirpath, step_skip, basename) for dirpath in directories]
with p.pool_cls(p.nprocs) as pool, Timer(header=using_msg, footer="") as timer:
return cls(pool.starmap(MdAnalyzer.from_lammps_dir, args))
[docs]
@classmethod
def from_vaspruns(cls, vasprun_filepaths: list, step_skip=1, pmode="processes") -> MultiMdAnalyzer:
"""
Build an instance from a list of vasprun files.
"""
p = pool_nprocs_pmode(len(vasprun_filepaths), pmode=pmode)
using_msg = f"Reading {len(vasprun_filepaths)} vasprun files {p.using_msg}..."
args = [(vrun, step_skip) for vrun in vasprun_filepaths]
with p.pool_cls(p.nprocs) as pool, Timer(header=using_msg, footer="") as timer:
return cls(pool.starmap(MdAnalyzer.from_vaspruns, args))
[docs]
@classmethod
def from_hist_files(cls, hist_filepaths: list, step_skip=1, pmode="processes") -> MultiMdAnalyzer:
"""
Build an instance from a list of ABINIT HIST.nc files.
"""
p = pool_nprocs_pmode(len(hist_filepaths), pmode=pmode)
using_msg = f"Reading {len(hist_filepaths)} HIST.nc files {p.using_msg}..."
args = [(ncpath, step_skip) for ncpath in hist_filepaths]
with p.pool_cls(p.nprocs) as pool, Timer(header=using_msg, footer="") as timer:
return cls(pool.starmap(MdAnalyzer.from_hist_file, args))
# @classmethod
# def from_qe_inputs(cls, filepaths: list, step_skip=1, pmode="processes") -> MultiMdAnalyzer:
# """
# Build an instance from a list of QE input files.
# """
# p = pool_nprocs_pmode(len(filepaths), pmode=pmode)
# using_msg = f"Reading {len(filespaths)} QE input files {p.using_msg}..."
# args = [(path, step_skip) for path in filepaths]
# with p.pool_cls(p.nprocs) as pool, Timer(header=using_msg, footer="") as timer:
# return cls(pool.starmap(MdAnalyzer.from_qe_input, args))
def __init__(self, mdas: list[MdAnalyzer], temp_colormap="jet"):
"""
Args:
mdas: List of MdAnalyzer
colormap: matplotlib colormap for temperatures.
"""
self.mdas = mdas
# Sort analyzers according to temperature.
if self.has_same_system():
self.mdas = sorted(mdas, key=lambda x: x.temperature)
self.set_temp_colormap(temp_colormap)
def __iter__(self):
return self.mdas.__iter__()
def __len__(self) -> int:
return len(self.mdas)
def __getitem__(self, items):
return self.mdas.__getitem__(items)
[docs]
def has_same_system(self) -> bool:
"""True if all analyzers have the same chemical system."""
return all(mda.latex_formula == self[0].latex_formula for mda in self)
[docs]
def set_temp_colormap(self, colormap) -> None:
"""Set the colormap for the list of temperatures."""
import matplotlib.pyplot as plt
self.temp_cmap = plt.get_cmap(colormap)
[docs]
def get_params_dataframe(self) -> pd.DataFrame:
"""Dataframe with the parameters of the different MdAnalyzers."""
return pd.DataFrame([mda.get_params_dict() for mda in self])
def __str__(self) -> str:
return self.to_string()
[docs]
def to_string(self, verbose: int = 0) -> str:
"""String representation with verbosity level verbose."""
lines = []
app = lines.append
app(marquee("MD PARAMETERS", mark="="))
app(self.get_params_dataframe().to_string())
app("\n")
return "\n".join(lines)
[docs]
def iter_mdat(self):
"""Iterate over (MdAnalyzer, temperature)."""
for itemp, mda in enumerate(self):
yield mda, mda.temperature
[docs]
def iter_mdatc(self):
"""Iterate over (MdAnalyzer, temperature, color)."""
for itemp, mda in enumerate(self):
yield mda, mda.temperature, self.temp_cmap(float(itemp) / len(self))
[docs]
def get_msdtt0_symbol_tmax(self, symbol: str, tmax: float, atom_inds=None, nprocs=None) -> Msdtt0List:
"""Get the Msdtt0 objects for the given symbol and max time."""
msdtt0_list = Msdtt0List()
for mda in self:
obj = mda.get_msdtt0_symbol_tmax(symbol, tmax, atom_inds=atom_inds, nprocs=nprocs)
msdtt0_list.append(obj)
return msdtt0_list
# def color_itemp(self, itemp: int):
# return self.temp_cmap(float(itemp) / len(self))
# def temps_colors(self) -> tuple[list, list]:
# return ([mda.temperature for mda in self],
# [self.color_itemp(itemp) for itemp in range(len(self))])
def _nrows_ncols_nplots(self, size=None):
size = size or len(self)
nrows, ncols, nplots = 1, 1, size
if nplots > 1:
ncols = 2
nrows = nplots // ncols + nplots % ncols
return nrows, ncols, nplots
[docs]
@add_fig_kwargs
def plot_sqdt_symbols(self, symbols, t0: float = 0.0, xy_log=None, fontsize=8, xlims=None, **kwargs) -> Figure:
"""
Plot the square displacement averaged over all atoms of the same specie vs time
for the different temperatures.
Args:
symbols: List of chemical symbols to consider. "all" for all symbols in structure.
t0: Initial time in ps.
ax: |matplotlib-Axes| or None if a new figure should be created.
xy_log: None or empty string for linear scale. "x" for log scale on x-axis.
"xy" for log scale on x- and y-axis. "x:semilog" for semilog scale on x-axis.
fontsize: fontsize for legends and titles.
xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
"""
symbols = self[0]._select_symbols(symbols)
nrows, ncols, nplots = self._nrows_ncols_nplots(size=len(symbols))
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, sharex=True, sharey=True, squeeze=False)
ax_list = ax_list.ravel()
if nplots % ncols != 0:
ax_list[-1].axis("off")
# Plot data.
for itemp, (mda, temp, color) in enumerate(self.iter_mdatc()):
it0, ts = mda.get_it_ts(t0)
for ix, (ax, symbol) in enumerate(zip(ax_list, symbols, strict=False)):
ax.plot(ts, mda.get_sqdt_symbol(symbol, it0=it0), label=f"T = {temp} K", color=color)
# Decorate axes.
for ix, (ax, symbol) in enumerate(zip(ax_list, symbols, strict=False)):
ax.set_title(symbol, fontsize=fontsize)
set_axlims(ax, xlims, "x")
ax.legend(fontsize=fontsize, loc="upper left")
ax.set_xlabel("t (ps)", fontsize=fontsize)
ax.set_ylabel(r"average mean square displacement ($\mathrm{{\AA}^2}$)", fontsize=fontsize)
# set_ticks_fontsize(ax, fontsize)
set_logscale(ax, xy_log)
return fig
[docs]
def sfind_ind_val(array, value, arr_is_sorted=False) -> tuple:
if arr_is_sorted:
# Use Log(N) bisection.
ind = find_le(array, value)
return ind, array[ind]
array = np.asarray(array)
ind = (np.abs(array - value)).argmin()
return ind, array[ind]
[docs]
def linear_lsq_linefit(x, z, weights):
S00 = np.sum(weights)
S10 = np.sum(z * weights)
S01 = np.sum(x * weights)
S20 = np.sum(z**2 * weights)
S11 = np.sum((x * z) * weights)
D = S00 * S20 - S10**2
m = (S00 * S11 - S10 * S01) / D
q = (S01 * S20 - S11 * S10) / D
varM = S00 / D
varQ = S20 / D
return m, q, varM, varQ
def _func(size_t0: int, it: int, pos_tac: np.ndarray, msd_tt0: np.ndarray):
msd_tt0[it, :] = np.mean(np.sum((pos_tac[it : it + size_t0, :, :] - pos_tac[:size_t0, :, :]) ** 2, axis=2), axis=1)
[docs]
def msd_tt0_from_tac(pos_tac: np.ndarray, size_t: int, nprocs=None) -> np.ndarray:
r"""
Calculates the MSD for every possible pair of time points in the input array, using the formula:
$$MSD(t,t_0) = \frac{1}{N} \sum_{i=1}^{N} (\vec{r}_i(t+t_0) - \vec{r}_i(t_0))^2$$
where $N$ is the number of particles, and $\vec{r}_i(t)$ is the position vector.
Args:
pos_tac
size_t
nprocs:
"""
# Check if size_t is valid.
n_time_points = pos_tac.shape[0]
if size_t >= n_time_points:
raise ValueError(f"{size_t=} must be less than {n_time_points}")
size_t0 = pos_tac.shape[0] - size_t
msd_tt0 = np.empty((size_t, size_t0), dtype=float)
if nprocs == 1:
for it in range(size_t):
# for it0 in range(0, size_t0):
# msd_tt0[it,it0] = np.mean(np.sum((pos_tac[it+it0,:,:] - pos_tac[it0,:,:])**2, axis=1))
msd_tt0[it, :] = np.mean(
np.sum((pos_tac[it : it + size_t0, :, :] - pos_tac[:size_t0, :, :]) ** 2, axis=2), axis=1
)
else:
p = pool_nprocs_pmode(nprocs, pmode="threads")
using_msg = f"Computing MSD(t,t_0) matrix of shape {(size_t, size_t0)} {p.using_msg}"
args = [(size_t0, it, pos_tac, msd_tt0) for it in range(size_t)]
with p.pool_cls(p.nprocs) as pool, Timer(header=using_msg, footer="") as timer:
pool.starmap(_func, args)
return msd_tt0
[docs]
def block_mean_var(data, data_mean, n_block) -> tuple[float, float]:
"""
Perform the block mean and the block variance of data.
"""
N = data.shape[0]
n_inblock = int(N / n_block)
sigma2 = 0
for iblock in range(n_block):
mean_inblock = 0
for datavalue in data[n_inblock * iblock : (iblock + 1) * n_inblock]:
mean_inblock = mean_inblock + datavalue / n_inblock
sigma2 = sigma2 + (mean_inblock - data_mean) ** 2 / (n_block)
sigma2 = sigma2 / (n_block - 1)
delta_sigma2 = np.sqrt(2.0 / (n_block - 1) ** 3) * sigma2
return sigma2, delta_sigma2
[docs]
def sigma_berend(nblock_step: int, tot_block: int, data: np.ndarray) -> tuple[float, float, float]:
"""
Args:
nblock_step:
tot_block:
data:
Return: (data_in_block, sigma, delta_sigma)
"""
Ndata = data.shape[0]
mean = np.mean(data)
# tot_block=int(data.shape[0]/nblock_step)
sigma2 = np.zeros(tot_block, dtype=float)
delta_sigma2 = np.zeros(tot_block, dtype=float)
arr_nblock = np.zeros(tot_block, dtype=float)
data_in_block = np.zeros(tot_block, dtype=float)
counter = 1
sigma2[0] = float("inf")
delta_sigma2[0] = 0
for nblock in range(1, nblock_step * tot_block, nblock_step):
sigma2[counter], delta_sigma2[counter] = block_mean_var(data, mean, nblock + 1)
arr_nblock[counter] = nblock + 1
data_in_block[counter] = int(Ndata / (nblock + 1))
counter = counter + 1
sigma = np.sqrt(sigma2)
delta_sigma = 0.5 * delta_sigma2 / sigma
return data_in_block, sigma, delta_sigma
def _lin_fit(th_invt, log10d0, e_act):
return log10d0 - e_act * th_invt
[docs]
@dataclasses.dataclass(kw_only=True)
class ArrheniusEntry:
""" """
key: str
symbol: str
composition: str
temps: np.ndarray
diffusions: np.ndarray
err_diffusions: np.ndarray
volumes: np.ndarray
mpl_style: dict
[docs]
@classmethod
def from_file(cls, filepath: PathLike, key, mpl_style) -> ArrheniusEntry:
"""
Read data in CSV format. Assuming header with at least the following entries:
temperature,diffusion,err_diffusion,volume,symbol,composition
"""
try:
df = pd.read_csv(filepath, skipinitialspace=True)
def get_unique(col):
v0 = df[col].values[0]
if np.any(v0 != df[col].values):
raise ValueError(f"All values for column: {col} should be unique while found:\n{df[col].values}")
return v0
symbol = get_unique("symbol")
composition = Composition(get_unique("composition"))
temps = df["temperature"].values
volumes = df["volume"].values
diffusions = df["diffusion"].values
err_diffusions = np.zeros(len(temps))
if "err_diffusion" in df.keys():
err_diffusions = df["err_diffusion"].values
except Exception as exc:
raise RuntimeError(f"Exception while reading {filepath=}") from exc
return cls(
key=key,
symbol=symbol,
composition=composition,
temps=temps,
diffusions=diffusions,
err_diffusions=err_diffusions,
volumes=volumes,
mpl_style=mpl_style,
)
# def __post_init__(self):
# self.latex_formula = latexify(self.formula)
# @property
# def latex_formula(self) -> str:
# return self._latex_formula
# @latex_formula.setter
# def latex_formula(self, value):
# """LaTeX formatted formula. E.g., Fe2O3 is transformed to Fe$_{2}$O$_{3}$."""
# self._latex_formula = latexify(value)
[docs]
def get_diffusion_data(self, fit_thinvt=None) -> AttrDict:
"""
Fit diffusion(T) taking into account uncertainties.
Return dict with activation energy in eV and fit parameters.
"""
# NB: log10(x) = ln(x) log10(e) --> d_x log_10(x) = log10(e) / x
th_invt = 1000 / self.temps
log10 = np.log10(self.diffusions)
err_log10 = np.log10(np.e) * self.err_diffusions / self.diffusions
popt, pcov = optimize.curve_fit(_lin_fit, th_invt, log10, sigma=err_log10)
e_act = popt[1] * 1000 * kBoltzEv * np.log(10)
fit_log10 = None
if fit_thinvt is not None:
fit_log10 = _lin_fit(fit_thinvt, popt[0], popt[1])
return AttrDict(
th_invt=th_invt,
log10=log10,
err_log10=err_log10,
fit_thinvt=fit_thinvt,
fit_log10=fit_log10,
e_act=e_act,
popt=popt,
pcov=pcov,
)
[docs]
def get_conductivity_data(self, ncar: float, fit_thinvt=None) -> tuple[AttrDict, AttrDict]:
"""
Compute conductivity sigma and T x sigma(T) assuming ncar carriers.
Fit values taking into account uncertainties.
sigma(T) = Nq^2 D(T) / (VkT)
"""
temps = self.temps
diffusions = self.diffusions
err_diffusions = self.err_diffusions
volumes = self.volumes
th_invt = 1000 / self.temps
# if charge is None:
# ncar = symb2oxi[symbol] * self.composition[self.symbol]
# ncar = charge * self.composition[self.symbol]
# NB: log10(x) = ln(x) log10(e) --> d_x log_10(x) = log10(e) / x
conds = e2s / kbs * ncar * diffusions / volumes / temps * 1.0e09
err_conds = e2s / kbs * ncar * err_diffusions / volumes / temps * 1.0e09
log10_conds = np.log10(conds)
err_log10_conds = np.log10(np.e) * err_conds / conds
cond_popt, cond_pcov = optimize.curve_fit(_lin_fit, th_invt, log10_conds, sigma=err_log10_conds)
fit_log10_cond = None
if fit_thinvt is not None:
fit_log10_cond = _lin_fit(fit_thinvt, cond_popt[0], cond_popt[1])
# temp(K) * sigma(S/cm)
t_conds = e2s / kbs * ncar * diffusions / volumes * 1.0e09
err_tconds = e2s / kbs * ncar * err_diffusions / volumes * 1.0e09
log10_tconds = np.log10(t_conds)
err_log10_tconds = np.log10(np.e) * err_tconds / t_conds
tcond_popt, tcond_pcov = optimize.curve_fit(_lin_fit, th_invt, log10_tconds, sigma=err_log10_tconds)
fit_log10_tcond = None
if fit_thinvt is not None:
fit_log10_tcond = _lin_fit(fit_thinvt, tcond_popt[0], tcond_popt[1])
return (
AttrDict(
th_invt=th_invt,
log10=log10_conds,
err_log10=err_log10_conds,
fit_thinvt=fit_thinvt,
fit_log10=fit_log10_cond,
popt=cond_popt,
cov=cond_pcov,
),
AttrDict(
th_invt=th_invt,
log10=log10_tconds,
err_log10=err_log10_tconds,
fit_thinvt=fit_thinvt,
fit_log10=fit_log10_tcond,
popt=tcond_popt,
cov=tcond_pcov,
),
)
[docs]
class ArrheniusPlotter:
"""
This object stores conductivities D(T) computed for different structures and/or
different ML-potentials and allows one to produce Arrnehnius plots on the same figure.
Internally, the results are indexed by a unique key that is be used as label in the matplotlib plot.
The style for each key can be customized by setting the mpl_style dict
with the options that will be passed to ax.plot.
In the simplest case, one reads the data from external files in CSV format.
Example:
from abipy.dynamics.analyzer import ArrheniusPlotter
key_path = {
"matgl-MD": "diffusion_cLLZO-matgl.csv",
"m3gnet-MD": "diffusion_cLLZO-m3gnet.csv",
}
mpl_style_key = {
"matgl-MD" : dict(c='blue'),
"m3gnet-MD": dict(c='purple'),
}
plotter = ArrheniusPlotter()
for key, path in key_path.items():
plotter.add_entry_from_file(path, key, mpl_style=mpl_style_key[key])
# temperature grid refined
thinvt_arange = (0.6, 2.5, 0.01)
xlims, ylims = (0.5, 2.0), (-7, -4)
plotter.plot(thinvt_arange=thinvt_arange,
xlims=xlims, ylims=ylims, text='LLZO cubic', savefig=None)
"""
def __init__(self, entries=None):
"""
Args:
entries: List of ArrheniusEntry.
"""
self.entries = entries or []
self.symb2oxi = common_oxidation_states()
def __iter__(self):
return self.entries.__iter__()
def __len__(self) -> int:
return len(self.entries)
def __getitem__(self, items):
return self.entries.__getitem__(items)
[docs]
def copy(self):
"""Deep copy of the object."""
import copy
return copy.deepcopy(self)
[docs]
def keys(self) -> list[str]:
"""List of keys. Must be unique"""
return [entry.key for entry in self]
[docs]
def symbols(self) -> set[str]:
"""Return set with chemical symbols."""
return set([entry.symbol for entry in self])
[docs]
def index_key(self, key: str) -> int:
"""Find the index of key. Raise KeyError if not found."""
for i, entry in enumerate(self):
if entry.key == key:
return i
raise KeyError(f"Cannot find {key=} in {self.keys=}")
[docs]
def pop_key(self, key: str) -> ArrheniusEntry:
"""Pop entry by key"""
i = self.index_key(key)
return self.entries.pop(i)
[docs]
def append(self, entry: ArrheniusEntry) -> None:
"""Append new entry."""
if entry.key in self.keys():
raise KeyError(f"{entry.key=} is already present.")
self.entries.append(entry)
[docs]
def set_style(self, key: str, mpl_style: dict) -> None:
"""Set matplotlib style for key."""
i = self.index_key(key)
self.entries[i].mpl_style = mpl_style
[docs]
def get_min_max_temp(self) -> tuple[float, float]:
"""Compute the min and max temperature for all entries."""
min_temp = min([e.temperatures.min() for e in self])
max_temp = max([e.temperatures.max() for e in self])
return min_temp, max_temp
[docs]
def add_entry_from_file(self, filepath: PathLike, key: str, mpl_style=None) -> None:
""" """
mpl_style = mpl_style or {}
self.append(ArrheniusEntry.from_file(filepath, key, mpl_style))
[docs]
@add_fig_kwargs
def plot(
self,
thinvt_arange=None,
what="diffusion",
ncar=None,
colormap="jet",
with_t=True,
text=None,
ax=None,
fontsize=8,
xlims=None,
ylims=None,
**kwargs,
) -> Figure:
"""
Arrhenius plot.
Args:
thinvt_arange: start, stop, step for 1000/T mesh. If None, the mesh is automatically computed.
what: Selects the quantity to plot. Possible values: "diffusion", "sigma", "tsigma".
ncar: Number of carriers. Required if what is "sigma" or "tsigma".
colormap: Colormap used to select the color if entry.mpl_style does not provide it.
with_t: True to add a twin axes with the value of T
text:
ax: |matplotlib-Axes| or None if a new figure should be created.
fontsize: fontsize for legends and titles.
xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
ylims: Similar to xlims but for the y-axis.
"""
ax, fig, plt = get_ax_fig_plt(ax=ax)
cmap = plt.get_cmap(colormap)
# Build temperature grid for fit.
if thinvt_arange is None:
t_min, t_max = self.get_min_max_temp()
xstart, xstop = 1000 / t_max, 1000 / t_min
xstart, xstop = 0.8 * xstart, 1.2 * xstop
thinvt_arange = (xstart, xstop, 0.01)
fit_thinvt = np.arange(thinvt_arange[0], thinvt_arange[1], step=thinvt_arange[2])
for ie, entry in enumerate(self):
mpl_style = entry.mpl_style.copy()
if "marker" not in mpl_style:
mpl_style["marker"] = "."
if "linestyle" not in mpl_style and "ls" not in mpl_style:
mpl_style["linestyle"] = ""
if "color" not in mpl_style and "c" not in mpl_style:
mpl_style["color"] = cmap(ie / len(self))
symbol, composition = entry.symbol, entry.composition
diff = entry.get_diffusion_data(fit_thinvt=fit_thinvt)
label = r"D$_{\mathrm{%s}}$ (%s): " % (symbol, entry.key)
label += r"E$_\mathrm{a}$=" + str(f"{diff.e_act:.2F}") + " eV"
if what == "diffusion":
data = diff
else:
ncar_ = ncar
if ncar_ is None:
if symbol not in self.symb2oxi:
raise ValueError(f"No entry for {symbol=} found in symb2oxi! Please add the oxistate manually!")
charge = self.symb2oxi[symbol]
ncar_ = charge * composition[symbol]
label += f", q={charge}"
sigma_data, tsigma_data = entry.get_conductivity_data(ncar_, fit_thinvt=fit_thinvt)
data = dict(sigma=sigma_data, tsigma=tsigma_data)[what]
if data.err_log10 is not None:
# Plot data with errors.
lines = ax.errorbar(
data.th_invt, data.log10, yerr=data.err_log10, label=label, capsize=5.0, **mpl_style
)
else:
# Plot data without errors.
lines = ax.plot(data.th_invt, data.log10, label=label, **mpl_style)
ax.plot(fit_thinvt, data.fit_log10, linestyle="--", color=lines[0].get_color())
ax.set_xlabel("1000 / T(K)", fontsize=18)
ylabel = dict(
diffusion=r"log$_{10}$ (D(cm$^2$/s))",
sigma=r"log$_{10}$ [$\sigma$(S/cm)]",
tsigma=r"log$_{10}$ [$\sigma$T(SK/cm)]",
)[what]
ax.set_ylabel(ylabel, fontsize=18)
set_axlims(ax, xlims, "x")
set_axlims(ax, ylims, "y")
ax.legend(loc="lower left", fontsize=12)
# set_ticks_fontsize(ax, fontsize=14)
# from matplotlib.ticker import MultipleLocator
# ax.yaxis.set_major_locator(MultipleLocator(1))
# ax.yaxis.set_minor_locator(MultipleLocator(0.2))
if text:
ax.text(
0.96,
0.85,
text,
verticalalignment="bottom",
horizontalalignment="right",
transform=ax.transAxes,
color="black",
fontsize=18,
bbox=dict(facecolor="none", edgecolor="grey", pad=10),
)
if with_t:
# Add a twin axes and set its limits so it matches the first.
ax_t = ax.twiny()
ax_t.set_xlabel("T (K)", fontsize=18)
ax_t.set_xlim(ax.get_xlim())
# apply a function formatter
import matplotlib.ticker as mticker
formatter = mticker.FuncFormatter(lambda x, pos: f"{1000 / x:.0f}")
ax_t.xaxis.set_major_formatter(formatter)
return fig