# Filename: testing.py
"""
Common unit testing support for km3pipe.
"""
from __future__ import absolute_import, print_function, division
import sys
from functools import wraps
from unittest import TestCase # noqa
from mock import MagicMock # noqa
from mock import Mock # noqa
from mock import patch # noqa
from numpy.testing import assert_allclose # noqa
import pytest # noqa
[docs]skipif = pytest.mark.skipif
__author__ = "Tamas Gal"
__copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration."
__credits__ = []
__license__ = "MIT"
__maintainer__ = "Tamas Gal"
__email__ = "tgal@km3net.de"
__status__ = "Development"
[docs]class surrogate(object):
"""
Add empty module stub that can be imported
for every subpath in path.
Those stubs can later be patched by mock's
patch decorator.
This class was written by Kostia Balytskyi (ikostia @github)
Example
-------
@surrogate('sys.my.cool.module1')
@surrogate('sys.my.cool.module2')
@mock.patch('sys.my.cool.module1', mock1)
@mock.patch('sys.my.cool.module2', mock2)
def function():
from sys.my import cool
from sys.my.cool import module1
from sys.my.cool import module2
"""
def __init__(self, path):
self.path = path
self.elements = self.path.split('.')
def __enter__(self):
self.prepared = self.prepare()
def __exit__(self, *args):
if self.prepared:
self.restore()
def __call__(self, func):
@wraps(func)
def _wrapper(*args, **kwargs):
prepared = self.prepare()
result = func(*args, **kwargs)
if prepared:
self.restore()
return result
return _wrapper
@property
[docs] def nothing_to_stub(self):
"""Check if there are no modules to stub"""
return len(self.elements) == 0
[docs] def prepare(self):
"""Preparations before actual function call"""
self._determine_existing_modules()
if self.nothing_to_stub:
return False
self._create_module_stubs()
self._save_base_module()
self._add_module_stubs()
return True
[docs] def restore(self):
"""Post-actions to restore initial state of the system"""
self._remove_module_stubs()
self._restore_base_module()
def _get_importing_path(self, elements):
"""Return importing path for a module that is last in elements list"""
ip = '.'.join(elements)
if self.known_path:
ip = self.known_path + '.' + ip
return ip
def _create_module_stubs(self):
"""Create stubs for all not-existing modules"""
# last module in our sequence
# it should be loaded
last_module = type(
self.elements[-1], (object, ), {
'__all__': [],
'_importing_path': self._get_importing_path(self.elements)
}
)
modules = [last_module]
# now we create a module stub for each
# element in a path.
# each module stub contains `__all__`
# list and a member that
# points to the next module stub in
# sequence
for element in reversed(self.elements[:-1]):
next_module = modules[-1]
module = type(
element, (object, ), {
next_module.__name__: next_module,
'__all__': [next_module.__name__]
}
)
modules.append(module)
self.modules = list(reversed(modules))
self.modules[0].__path__ = []
def _determine_existing_modules(self):
"""Find out which of the modules from specified path are already
imported (e.g. present in sys.modules) those modules should not be
replaced by stubs.
"""
known = 0
while known < len(self.elements) and\
'.'.join(self.elements[:known + 1]) in sys.modules:
known += 1
self.known_path = '.'.join(self.elements[:known])
self.elements = self.elements[known:]
def _save_base_module(self):
"""Remember state of the last of existing modules
The last of the sequence of existing modules
is the only one we will change. So we must
remember it's state in order to restore it
afterwards.
"""
try:
# save last of the existing modules
self.base_module = sys.modules[self.known_path]
except KeyError:
self.base_module = None
# save `__all__` attribute of the base_module
self.base_all = []
if hasattr(self.base_module, '__all__'):
self.base_all = list(self.base_module.__all__)
if self.base_module:
# change base_module's `__all__` attribute
# to include the first module of the sequence
self.base_module.__all__ = self.base_all + [self.elements[0]]
setattr(self.base_module, self.elements[0], self.modules[0])
def _add_module_stubs(self):
"""Push created module stubs into sys.modules"""
for i, module in enumerate(self.modules):
module._importing_path =\
self._get_importing_path(self.elements[:i + 1])
sys.modules[module._importing_path] = module
def _remove_module_stubs(self):
"""Remove fake modules from sys.modules"""
for module in reversed(self.modules):
if module._importing_path in sys.modules:
del sys.modules[module._importing_path]
def _restore_base_module(self):
"""Restore the state of the last existing module"""
if self.base_module:
self.base_module.__all__ = self.base_all
if not self.base_all:
del self.base_module.__all__
if hasattr(self.base_module, self.elements[0]):
delattr(self.base_module, self.elements[0])