"""
Python interface to fftprof. Provides objects to benchmark
the FFT libraries used by ABINIT and plot the results with matplotlib.
"""
from __future__ import annotations
import os
import tempfile
import numpy as np
from subprocess import Popen, PIPE
from shutil import which
from monty.fnmatch import WildCard
from abipy.tools.plotting import add_fig_kwargs
from abipy.tools.typing import Figure
__all__ = [
"FFTBenchmark",
]
_color_fftalg = {
112: "r",
#311: "y",
312: "b",
412: "y",
401: "c",
#511: "c",
512: "g",
}
_linestyle_nt = {
1: "-",
2: "--",
3: "-.",
4: ":",
5: "-",
6: "--",
7: "-.",
8: ":",
}
_markers_nt = {
1: "1",
2: ">",
3: "2",
4: "<",
5: "3",
6: "*",
7: "4",
8: "+",
}
class FFT_Test:
"""
This object stores the wall-time of the FFT
as a function of the size of the problem.
"""
def __init__(self, ecut, ngfft, wall_time, info):
"""
Args:
ecut: List of cutoff energies in Hartree.
ngfft: List with FFT divisions.
wall_time: List of wall_time for the different ecut.
info: Dictionary with extra information.
"""
self.ecut = np.asarray(ecut)
self.necut = len(ecut)
self.ngfft = np.asarray(ngfft, int)
self.wall_time = np.asarray(wall_time)
self._info = info
self.__dict__.update(info)
self.available = int(self.available)
self.fftalg = int(self.fftalg)
self.nthreads = int(self.nthreads)
def __str__(self) -> str:
"""String representation."""
return "fftalg: %s, nt: %d" % (self.fftalg, self.nthreads)
def speedup_wrt(self, other: FFT_Test):
"""Compare self with other, return a new instance of FFT_Test containing the speedup."""
info = self._info.copy()
info["speedup"] = other.wall_time / self.wall_time
return FFT_Test(self.ecut, self.ngfft, self.wall_time, info)
def plot_ax(self, ax, ydata=None, fact=1.0):
"""
Plot ydata on the axis ax. If data is None, the wall_time is plotted.
"""
color = _color_fftalg[self.fftalg]
linestyle = _linestyle_nt[self.nthreads]
marker = _markers_nt[self.nthreads]
xx = self.ecut
if ydata is None:
yy = self.wall_time
else:
yy = self.__dict__[ydata]
yy = yy * fact
line, = ax.plot(xx, yy,
color=color, linestyle=linestyle, label=str(self), linewidth=3.0,
marker=marker, markersize=10)
return line
[docs]
class FFTBenchmark:
"""
Container class storing the results of the FFT benchmark.
Use the class method ``from_file`` to generate a new instance.
"""
[docs]
@classmethod
def from_file(cls, fileobj):
return parse_prof_file(fileobj)
def __init__(self, title, FFT_tests):
self.title = title
self._fftalgs = []
for test in FFT_tests:
alg = test.fftalg
if alg not in self._fftalgs:
self._fftalgs.append(alg)
self._fftalgs.sort()
self.tests = [t for t in FFT_tests]
[docs]
def iter_fftalgs(self):
"""Iterator over the FFT algorithms."""
return self._fftalgs.__iter__()
[docs]
def tests_with_fftalg(self, fftalg):
"""Return the list of FFT_tests with a given fftalg."""
lst = []
for t in self.tests:
if t.fftalg == fftalg: lst.append(t)
return lst
[docs]
@add_fig_kwargs
def plot(self, exclude_algs=None, exclude_threads=None, **kwargs) -> Figure:
"""
Plot the wall-time and the speed-up.
"""
import matplotlib.pyplot as plt
fig = plt.figure()
ax1 = fig.add_subplot(2, 1, 1)
exc_algs = []
if exclude_algs is not None:
exc_algs = [int(alg) for alg in exclude_algs]
exc_nths = []
if exclude_threads is not None:
exc_nths = [int(nth) for nth in exclude_threads]
for test in self.tests:
if test.fftalg in exc_algs: continue
if test.available and test.nthreads not in exc_nths:
#print("test", test.nthreads)
test.plot_ax(ax1)
ax1.legend(loc="upper left")
ax1.grid(True)
ax1.set_title(self.title)
ax1.set_xlabel('Ecut [Hartree]')
ax1.set_ylabel('WALL time (s)')
ax2 = fig.add_subplot(2, 1, 2)
ax2.grid(True)
ax2.set_xlabel('FFT divisions')
ax2.set_ylabel('Efficiency')
# Use FFT divs as labels.
xticks = ax1.get_xticks()
#print(xticks)
t0 = self.tests[0]
labels = [str(ndiv) for ndiv in t0.ngfft]
#labels = []
#for t ndiv in zip(t0.ecut, t0.ngfft):
#for #labels = [ str(ndiv) for ndiv in t0.ngfft ]
#print labels
# Rotate labels.
ax2.set_xticklabels(labels, fontdict=None, minor=False, rotation=35)
for fftalg in self.iter_fftalgs():
if fftalg in exc_algs: continue
tests = self.tests_with_fftalg(fftalg)
for t in tests:
if t.nthreads == 1:
ref_test = t
break
else:
raise ValueError("Ref file not found")
for t in tests:
tspeed = t.speedup_wrt(ref_test)
fact = 1.0 / t.nthreads
if t.nthreads != 1 and t.available and t.nthreads not in exc_nths:
tspeed.plot_ax(ax2, ydata="speedup", fact=fact)
# Use FFT divs as labels.
xticks = ax1.get_xticks()
#print(xticks)
t0 = self.tests[0]
labels = []
for xtick in xticks:
xecut = float(xtick)
for ecut, ndiv in zip(t0.ecut, t0.ngfft):
if abs(ecut - xecut) < 0.1:
#print(ecut, xecut, ndiv)
labels.append(str(ndiv))
break
else:
msg = "xecut" + str(xecut) + " not found"
labels.append("")
#raise RuntimeError(msg)
#print("labels:", labels)
# Set and rotate labels.
ax2.set_xticklabels(labels, fontdict=None, minor=False, rotation=35)
ideal = [1.0 for i in range(t0.necut)]
ax2.plot(t0.ecut, ideal, "b-", linewidth=3.0)
return fig
def parse_prof_file(fileobj):
"""
Parse the PROF file generated by fftprof.F90.
Args: fileobj: String or file-like object.
Returns: Instance of FFTBenchmark.
"""
# The file contains
# 1) An initial header with info on the routine.
# 2) The list of fftalgs that have been analyzed.
# Results:
# ecut ngfft(1:3) wall_times
#
# Example
# Benchmark: routine = fourwf, cplex = 1, option= 2, istwfk= 1
# fftalg = 112, fftcache = 16, ndat = 1, nthreads = 1, available = 1
# fftalg = 401, fftcache = 16, ndat = 1, nthreads = 1, available = 1
# fftalg = 412, fftcache = 16, ndat = 1, nthreads = 1, available = 1
# fftalg = 312, fftcache = 16, ndat = 1, nthreads = 1, available = 1
# 20.0 91 91 90 0.0307 0.0330 0.0275 0.0283
# 30.0 101 101 100 0.0395 0.0494 0.0404 0.0333
if not hasattr(fileobj, "readlines"):
with open(fileobj, "r") as fh:
lines = fh.readlines()
else:
lines = fileobj.readlines()
# Parse the header
title = lines.pop(0)[1:]
line = lines.pop(0)
info_of_test = []
while line[0] == "#":
line = line.replace("#", "")
line = line.replace("\n", "")
tokens = line.split(",")
tokens = [tk.replace(" ", "") for tk in tokens]
info = {}
for tk in tokens:
kw = tk.split("=")
info[kw[0]] = kw[1]
info_of_test.append(info)
line = lines.pop(0)
# Parse the columns: ecut ngfft(4:6) wall_time[1] wall_time[2] ...
ntests = len(info_of_test)
data = [[] for i in range(ntests)]
ecut, ngfft = [], []
while line:
vals = [float(v) for v in line.split()]
ecut.append(vals[0])
ngfft.append(vals[1:4])
for idx, v in enumerate(vals[4:]):
data[idx].append(v)
try:
line = lines.pop(0)
except IndexError:
line = None
# Instantiate FFTBenchmark.
fft_tests = []
for idx, wall_time in enumerate(data):
info = info_of_test[idx]
Test = FFT_Test(ecut, ngfft, wall_time, info)
fft_tests.append(Test)
return FFTBenchmark(title, fft_tests)
class FFTProfError(Exception):
"""Exceptions raised by FFTprof."""
class FFTProf:
"""Wrapper around fftprof Fortran executable."""
Error = FFTProfError
def __init__(self, fft_input, executable="fftprof"):
self.verbose = 1
self.fft_input = fft_input
self.executable = which(executable)
if self.executable is None:
raise self.Error("Cannot find executable %s in $PATH" % executable)
@classmethod
def from_file(cls, filename, executable="fftprof") -> FFTProf:
with open(filename, "r") as fh:
fft_input = fh.read()
return cls(fft_input, executable=executable)
def run(self) -> int:
"""Execute fftprof in a subprocess."""
self.workdir = tempfile.mkdtemp()
#print(self.workdir)
self.stdin_fname = os.path.join(self.workdir, "fftprof.in")
self.stdout_fname = os.path.join(self.workdir, "fftprof.out")
self.stderr_fname = os.path.join(self.workdir, "fftprof.err")
with open(self.stdin_fname, "w") as fh:
fh.write(self.fft_input)
args = [self.executable, "<", self.stdin_fname, ">", self.stdout_fname, "2>", self.stderr_fname]
cmd_str = " ".join(args)
p = Popen(cmd_str, shell=True, stdout=PIPE, stderr=PIPE, cwd=self.workdir)
self.stdout_data, self.stderr_data = p.communicate()
self.returncode = p.returncode
if self.returncode != 0:
with open(self.stdout_fname, "r") as out, open(self.stderr_fname, "r") as err:
self.stdout_data = out.read()
self.stderr_data = err.read()
if self.verbose:
print("*** stdout: ***\n", self.stdout_data)
print("*** stderr: ***\n", self.stderr_data)
raise self.Error("%s returned %s\n cmd_str: %s" % (self, self.returncode, cmd_str))
return self.returncode
def plot(self):
filepaths = WildCard("PROF_*").filter(os.listdir(self.workdir))
filepaths = filter(os.path.isfile, [os.path.join(self.workdir, f) for f in filepaths])
for prof_file in filepaths:
if self.verbose:
print("About to plot prof_file: ", prof_file)
bench = FFTBenchmark.from_file(prof_file)
bench.plot()