Skip to content

Instantly share code, notes, and snippets.

@Ceasar
Last active December 17, 2015 16:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ceasar/5640736 to your computer and use it in GitHub Desktop.
Save Ceasar/5640736 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
:copyright: (c) 2011 by Armin Ronacher.
:copyright: (c) 2011 by the Werkzeug Team, see AUTHORS for more details.
:license: BSD, see LICENSE for more details.
"""
import os
import sys
import unittest
class ImportStringError(ImportError):
"""Provides information about a failed :func:`import_string` attempt."""
#: String in dotted notation that failed to be imported.
import_name = None
#: Wrapped exception.
exception = None
def __init__(self, import_name, exception):
self.import_name = import_name
self.exception = exception
msg = (
'import_string() failed for %r. Possible reasons are:\n\n'
'- missing __init__.py in a package;\n'
'- package or module path not included in sys.path;\n'
'- duplicated package or module name taking precedence in '
'sys.path;\n'
'- missing module, class, function or variable;\n\n'
'Debugged import:\n\n%s\n\n'
'Original exception:\n\n%s: %s')
name = ''
tracked = []
for part in import_name.replace(':', '.').split('.'):
name += (name and '.') + part
imported = import_string(name, silent=True)
if imported:
tracked.append((name, getattr(imported, '__file__', None)))
else:
track = ['- %r found in %r.' % (n, i) for n, i in tracked]
track.append('- %r not found.' % name)
msg = msg % (import_name, '\n'.join(track),
exception.__class__.__name__, str(exception))
break
ImportError.__init__(self, msg)
def __repr__(self):
return '<%s(%r, %r)>' % (self.__class__.__name__, self.import_name,
self.exception)
def _iter_modules(path):
"""Iterate over all modules in a package."""
import os
import pkgutil
if hasattr(pkgutil, 'iter_modules'):
for importer, modname, ispkg in pkgutil.iter_modules(path):
yield modname, ispkg
return
from inspect import getmodulename
from pydoc import ispackage
found = set()
for path in path:
for filename in os.listdir(path):
p = os.path.join(path, filename)
modname = getmodulename(filename)
if modname and modname != '__init__':
if modname not in found:
found.add(modname)
yield modname, ispackage(modname)
def import_string(import_name, silent=False):
"""Imports an object based on a string. This is useful if you want to
use import paths as endpoints or something similar. An import path can
be specified either in dotted notation (``xml.sax.saxutils.escape``)
or with a colon as object delimiter (``xml.sax.saxutils:escape``).
If `silent` is True the return value will be `None` if the import fails.
:param import_name: the dotted name for the object to import.
:param silent: if set to `True` import errors are ignored and
`None` is returned instead.
:return: imported object
"""
# force the import name to automatically convert to strings
if isinstance(import_name, unicode):
import_name = str(import_name)
try:
if ':' in import_name:
module, obj = import_name.split(':', 1)
elif '.' in import_name:
module, obj = import_name.rsplit('.', 1)
else:
return __import__(import_name)
# __import__ is not able to handle unicode strings in the fromlist
# if the module is a package
if isinstance(obj, unicode):
obj = obj.encode('utf-8')
try:
return getattr(__import__(module, None, None, [obj]), obj)
except (ImportError, AttributeError):
# support importing modules not yet set up by the parent module
# (or package for that matter)
modname = module + '.' + obj
__import__(modname)
return sys.modules[modname]
except ImportError, e:
if not silent:
raise ImportStringError(import_name, e), None, sys.exc_info()[2]
def find_modules(import_path, include_packages=False, recursive=False):
"""Find all the modules below a package. This can be useful to
automatically import all views / controllers so that their metaclasses /
function decorators have a chance to register themselves on the
application.
Packages are not returned unless `include_packages` is `True`. This can
also recursively list modules but in that case it will import all the
packages to get the correct load path of that module.
:param import_name: the dotted name for the package to find child modules.
:param include_packages: set to `True` if packages should be returned, too.
:param recursive: set to `True` if recursion should happen.
:return: generator
"""
module = import_string(import_path)
path = getattr(module, '__path__', None)
if path is None:
raise ValueError('%r is not a package' % import_path)
basename = module.__name__ + '.'
for modname, ispkg in _iter_modules(path):
modname = basename + modname
if ispkg:
if include_packages:
yield modname
if recursive:
for item in find_modules(modname, include_packages, True):
yield item
else:
yield modname
def add_to_path(path):
"""Adds an entry to sys.path if it's not already there. This does
not append it but moves it to the front so that we can be sure it
is loaded.
"""
if not os.path.isdir(path):
raise RuntimeError('Tried to add nonexisting path')
def _samefile(x, y):
if x == y:
return True
try:
return os.path.samefile(x, y)
except (IOError, OSError, AttributeError):
# Windows has no samefile
return False
sys.path[:] = [x for x in sys.path if not _samefile(path, x)]
sys.path.insert(0, path)
def setup_path():
add_to_path(os.path.abspath(os.path.join(
os.path.dirname(__file__), 'test_apps')))
def iter_suites():
"""Yields all testsuites."""
for module in find_modules(__name__):
mod = import_string(module)
if hasattr(mod, 'suite'):
yield mod.suite()
def suite():
"""A testsuite that has all the Flask tests. You can use this
function to integrate the Flask tests into your own testsuite
in case you want to test that monkeypatches to Flask do not
break it.
"""
setup_path()
suite = unittest.TestSuite()
for other_suite in iter_suites():
suite.addTest(other_suite)
return suite
def find_all_tests(suite):
"""Yields all the tests and their names from a given suite."""
suites = [suite]
while suites:
s = suites.pop()
try:
suites.extend(s)
except TypeError:
yield s, '%s.%s.%s' % (
s.__class__.__module__,
s.__class__.__name__,
s._testMethodName
)
class BetterLoader(unittest.TestLoader):
"""A nicer loader that solves two problems. First of all we are setting
up tests from different sources and we're doing this programmatically
which breaks the default loading logic so this is required anyways.
Secondly this loader has a nicer interpolation for test names than the
default one so you can just do ``run-tests.py ViewTestCase`` and it
will work.
"""
def getRootSuite(self):
return suite()
def loadTestsFromName(self, name, module=None):
root = self.getRootSuite()
if name == 'suite':
return root
all_tests = []
for testcase, testname in find_all_tests(root):
if testname == name or \
testname.endswith('.' + name) or \
('.' + name + '.') in testname or \
testname.startswith(name + '.'):
all_tests.append(testcase)
if not all_tests:
raise LookupError('could not find test case for "%s"' % name)
if len(all_tests) == 1:
return all_tests[0]
rv = unittest.TestSuite()
for test in all_tests:
rv.addTest(test)
return rv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment