Source code for abiflows.core.testing

# coding: utf-8
# flake8: noqa
"""
Common test support for abiflows test scripts.

This single module should provide all the common functionality for abiflows tests
in a single location, so that test scripts can just import it and work right away.
This module heavily depends on the abipy.testing module/
"""
import os
import shutil
import glob
#import tempfile
#import unittest
import numpy.testing.utils as nptu

from mongoengine import connect, Document
from mongoengine.connection import get_db #, get_connection
#from monty.os.path import which
#from monty.string import is_string
from abipy.core.testing import AbipyTest
from fireworks.core.launchpad import LaunchPad
from fireworks.core.fworker import FWorker
from fireworks.core.rocket_launcher import rapidfire
from abiflows.fireworks.utils.fw_utils import get_fw_by_task_index

import logging
logger = logging.getLogger(__file__)

root = os.path.dirname(__file__)

__all__ = [
    "AbiflowsTest"
]


TESTDB_NAME = "abiflows_unittest"


def has_mongodb(host='localhost', port=27017, name='mongodb_test', username=None, password=None):
    try:
        from pymongo import MongoClient
        connection = MongoClient(host, port, j=True)
        db = connection[name]
        if username:
            db.authenticate(username, password)

        return True
    except Exception:
        return False


def has_fireworks():
    """True if fireworks is installed."""
    try:
        import fireworks
        return True
    except ImportError:
        return False


[docs]class AbiflowsTest(AbipyTest): """Extends AbipyTest with methods specific for the testing of workflows"""
[docs] def assertFwSerializable(self, obj): assert '_fw_name' in obj.to_dict() self.assertDictEqual(obj.to_dict(), obj.__class__.from_dict(obj.to_dict()).to_dict())
[docs] @classmethod def setup_fireworks(cls): """ Sets up the fworker and launchpad if a connection to a local mongodb is available. cls.lp is set to None if not available """ cls.fworker = FWorker() try: cls.lp = LaunchPad(name=TESTDB_NAME, strm_lvl='ERROR') cls.lp.reset(password=None, require_password=False) except Exception: cls.lp = None
[docs] @classmethod def teardown_fireworks(cls, module_dir=None): """ Removes the fireworks test database if cls.lp is present and deletes all the launcher directories """ if cls.lp: cls.lp.connection.drop_database(TESTDB_NAME) if module_dir: for ldir in glob.glob(os.path.join(module_dir,"launcher_*")): shutil.rmtree(ldir)
[docs] @classmethod def setup_mongoengine(cls): try: cls._connection = connect(db=TESTDB_NAME) cls._connection.drop_database(TESTDB_NAME) cls.db = get_db() except Exception: cls.db = None cls._connection = None
[docs] @classmethod def teardown_mongoengine(cls): if cls._connection: cls._connection.drop_database(TESTDB_NAME)
[docs] def get_document_class_from_mixin(self, mixin_cls): """ Utility function to generate a mongoengine Document class from the mixin. Needed to save the object in the db with mongoengine """ class TestDocument(mixin_cls, Document): meta = {'collection': "test_{}".format(mixin_cls.__name__)} return TestDocument
class AbiflowsIntegrationTest(object): """ Provides utility methods and variables for integration tests, that can't subclass unittest.TestCase """ # variable to enable/disable the checks on the numerical quantities as output of the workflow check_numerical_values = True @staticmethod def assertArrayAlmostEqual(actual, desired, decimal=7, err_msg='', verbose=True): """ Tests if two arrays are almost equal to a tolerance. The CamelCase naming is so that it is consistent with standard unittest methods. """ return nptu.assert_almost_equal(actual, desired, decimal, err_msg, verbose) def check_restart_task_type(lp, fworker, tmpdir, fw_id, task_tag): # resume the task for tag wf = lp.get_wf_by_fw_id(fw_id) fw = get_fw_by_task_index(wf, task_tag, index=1) assert fw is not None assert fw.state == "PAUSED" lp.resume_fw(fw.fw_id) # run the FW rapidfire(lp, fworker, m_dir=str(tmpdir), nlaunches=1) # the job should have a detour for the restart wf = lp.get_wf_by_fw_id(fw_id) fw = get_fw_by_task_index(wf, task_tag, index=2) assert fw is not None assert fw.state == "READY" # run all the following and check that the last is correctly completed (if convergence is not achieved # the final state should be FIZZLED) rapidfire(lp, fworker, m_dir=str(tmpdir)) wf = lp.get_wf_by_fw_id(fw_id) fw = get_fw_by_task_index(wf, task_tag, index=-1) assert fw.state == "COMPLETED"