# coding: utf-8
"""Decorators for AbinitInput or MultiDataset objects."""
from __future__ import annotations
import abc
import pymatgen.io.abinit.abiobjects as aobj
from monty.json import MSONable
from abipy.tools.serialization import pmg_serialize
from abipy.flowtk.abiobjects import LdauParams, LexxParams
from .inputs import AbinitInput, MultiDataset
[docs]
class SpinDecorator(AbinitInputDecorator):
    """This decorator changes the spin polarization."""
    def __init__(self, spinmode, kptopt_ifspinor=4):
        self.spinmode = aobj.SpinMode.as_spinmode(spinmode)
        self.kptopt_ifspinor = kptopt_ifspinor
[docs]
    @pmg_serialize
    def as_dict(self) -> dict:
        return dict(spinmode=self.spinmode.as_dict(), kptopt_ifspinor=self.kptopt_ifspinor) 
[docs]
    @classmethod
    def from_dict(cls, d: dict) -> SpinDecorator:
        return cls(aobj.SpinMode.from_dict(d["spinmode"]), kptopt_ifspinor=d["kptopt_ifspinor"]) 
    def _decorate(self, inp, deepcopy=True):
        if deepcopy: inp = inp.deepcopy()
        inp.set_vars(self.spinmode.to_abivars())
        # in version 7.11.5
        # When non-collinear magnetism is activated (nspden=4),
        # time-reversal symmetry cannot be used in the present
        # state of the code (to be checked and validated).
        # Action: choose kptopt different from 1 or 2.
        # Here we set kptopt to 4 (spatial symmetries, no time-reversal)
        # unless we already have a dataset with kptopt == 3 (no tr, no spatial)
        # This case is needed for DFPT at q != 0.
        if self.spinmode.nspinor == 2 and inp.get("kptopt") != 3:
            inp.set_vars(kptopt=self.kptopt_ifspinor)
        return inp 
[docs]
class SmearingDecorator(AbinitInputDecorator):
    """This decorator changes the electronic smearing."""
    def __init__(self, smearing):
        self.smearing = aobj.Smearing.as_smearing(smearing)
[docs]
    @pmg_serialize
    def as_dict(self) -> dict:
        return {"smearing": self.smearing.as_dict()} 
[docs]
    @classmethod
    def from_dict(cls, d: dict) -> SmearingDecorator:
        return cls(aobj.Smearing.from_dict(d["smearing"])) 
    def _decorate(self, inp, deepcopy=True):
        if deepcopy: inp = inp.deepcopy()
        inp.set_vars(self.smearing.to_abivars())
        return inp 
[docs]
class XcDecorator(AbinitInputDecorator):
    """Change the exchange-correlation functional."""
    def __init__(self, ixc: int):
        """
        Args:
            ixc: Abinit input variable
        """
        self.ixc = ixc
[docs]
    @pmg_serialize
    def as_dict(self):
        return {"ixc": self.ixc} 
[docs]
    @classmethod
    def from_dict(cls, d: dict) -> XcDecorator:
        return cls(d["ixc"]) 
    def _decorate(self, inp, deepcopy=True):
        if deepcopy: inp = inp.deepcopy()
        # TODO: Don't understand why abinit does not enable usekden if MGGA!
        usekden = None
        #usekden = 1 if ixc.ismgga() else None
        inp.set_vars(ixc=self.ixc, usekden=usekden)
        return inp 
[docs]
class LdaUDecorator(AbinitInputDecorator):
    """This decorator adds LDA+U parameters to an :class:`AbinitInput` object."""
    def __init__(self, symbols_luj, usepawu=1, unit="eV"):
        """
        Args:
            symbols_luj: dictionary mapping chemical symbols to another dict with (l, u, j) values
            usepawu: Abinit input variable.
            unit: Energy unit for U and J
        """
        self.symbols_luj, self.usepawu, self.unit = symbols_luj, usepawu, unit
[docs]
    @pmg_serialize
    def as_dict(self) -> dict:
        return dict(symbols_luj=self.symbols_luj, usepawu=self.usepawu, unit=self.unit) 
[docs]
    @classmethod
    def from_dict(cls, d: dict) -> LdaUDecorator:
        return cls(**{k: v for k, v in d.items() if not k.startswith("@")}) 
    def _decorate(self, inp, deepcopy=True):
        if not inp.ispaw: raise self.Error("LDA+U requires PAW!")
        if deepcopy: inp = inp.deepcopy()
        luj_params = LdauParams(usepawu=self.usepawu, structure=inp.structure)
        # Apply UJ on all the symbols present in symbols_lui.
        for symbol in inp.structure.symbol_set:
            if symbol not in self.symbols_luj: continue
            args = self.symbols_luj[symbol]
            luj_params.luj_for_symbol(symbol, l=args["l"], u=args["u"], j=args["j"], unit=self.unit)
            #luj_params.luj_for_symbol("Ni", l=2, u=u, j=0.1*u, unit=self.unit)
        inp.set_vars(luj_params.to_abivars())
        return inp 
[docs]
class LexxDecorator(AbinitInputDecorator):
    """This decorator add local exact exchange to a :class:`AbinitInput` object."""
    def __init__(self, symbols_lexx, exchmix=None):
        """
        Args:
            symbols_lexx: dictionary mapping chemical symbols to the angular momentum l on which lexx is applied.
            exchmix: ratio of exact exchange when useexexch is used. The default value of 0.25 corresponds to PBE0.
            Example. To perform a LEXX calculation for NiO in which the LEXX is computed only for the l=2
            channel of the nickel atoms:
                {"Ni": 2}
        """
        self.symbols_lexx, self.exchmix = symbols_lexx, exchmix
[docs]
    @classmethod
    def from_dict(cls, d: dict) -> LexxDecorator:
        return cls(**{k:v for k, v in d.items() if not k.startswith("@")}) 
[docs]
    @pmg_serialize
    def as_dict(self) -> dict:
        return {"symbols_lexx": self.symbols_lexx, "exchmix": self.exchmix} 
    def _decorate(self, inp, deepcopy=True):
        if not inp.ispaw: raise self.Error("LEXX requires PAW!")
        if deepcopy: inp = inp.deepcopy()
        lexx_params = LexxParams(inp.structure)
        for symbol in inp.structure.symbol_set:
            if symbol not in self.symbols_lexx: continue
            lexx_params.lexx_for_symbol(symbol, l=self.symbols_lexx[symbol])
        # Context: the value of the variable useexexch is   1.
        # The value of the input variable ixc is    7, while it must be
        # equal to one of the following:  11  23
        # Action : you should change the input variables ixc or useexexch.
        inp.set_vars(lexx_params.to_abivars())
        dt_ixc = inp.get("ixc")
        if dt_ixc is None or dt_ixc not in [11, 23]: inp.set_vars(ixc=11)
        if self.exchmix is not None: inp.set_vars(exchmix=self.exchmix)
        return inp 
# Stubs
#class ScfMixingDecorator(AbinitInputDecorator):
#class MagneticMomentDecorator(AbinitInputDecorator):
#    """Add reasoanble guesses for the initial magnetic moments."""
#class SpinOrbitDecorator(AbinitInputDecorator):
#    """Enable spin-orbit in the input."""
#     def __init__(self, no_spatial_symmetries=True, no_time_reversal=False, spnorbscl=None):
#        self.use_spatial_symmetries = use_spati
#        self.use_spatial_symmetries
#
#    def _decorate(self, inp, deepcopy=True)
#        if deepcopy: inp = inp.deepcopy()
#        kptopt =
#        if inp.ispaw:
#            for dt in inp.datasets:
#               dt.set_vars(pawspnorb=1, kptopt=kptopt)
#        return inp
#class PerformanceDecorator(AbinitInputDecorator):
#    """Change the variables in order to speedup the calculation."""
#    fftgw
#    boxcutmin
#    fft
#    def __init__(self, accuracy):
#        self.accuracy = accuracy
#
#    def _decorate(self, inp, deepcopy=True)
#        if deepcopy: inp = inp.deepcopy()
#         for dt in inp[1:]:
#            runlevel = dt.runlevel
#        return inp
#class DmftDecorator(AbinitInputDecorator):
#    """Add DMFT variables."""