Browse Source

Cleanup in the test finder

pull/309/head
Armin Ronacher 14 years ago
parent
commit
a082a5e0ba
  1. 51
      flask/testsuite/__init__.py

51
flask/testsuite/__init__.py

@ -20,9 +20,6 @@ from contextlib import contextmanager
from werkzeug.utils import import_string, find_modules from werkzeug.utils import import_string, find_modules
common_prefix = __name__ + '.'
def add_to_path(path): def add_to_path(path):
def _samefile(x, y): def _samefile(x, y):
try: try:
@ -50,23 +47,18 @@ def iter_suites():
yield mod.suite() yield mod.suite()
def find_all_tests(): def find_all_tests(suite):
suites = [suite()] suites = [suite]
while suites: while suites:
s = suites.pop() s = suites.pop()
try: try:
suites.extend(s) suites.extend(s)
except TypeError: except TypeError:
yield s yield s, '%s.%s.%s' % (
s.__class__.__module__,
s.__class__.__name__,
def find_all_tests_with_name(): s._testMethodName
for testcase in find_all_tests(): )
yield testcase, '%s.%s.%s' % (
testcase.__class__.__module__,
testcase.__class__.__name__,
testcase._testMethodName
)
@contextmanager @contextmanager
@ -111,6 +103,10 @@ def emits_module_deprecation_warning(f):
class FlaskTestCase(unittest.TestCase): class FlaskTestCase(unittest.TestCase):
"""Baseclass for all the tests that Flask uses. Use these methods
for testing instead of the camelcased ones in the baseclass for
consistency.
"""
def ensure_clean_request_context(self): def ensure_clean_request_context(self):
# make sure we're not leaking a request context since we are # make sure we're not leaking a request context since we are
@ -136,20 +132,27 @@ class FlaskTestCase(unittest.TestCase):
class BetterLoader(unittest.TestLoader): class BetterLoader(unittest.TestLoader):
"""A nicer loader that solves two problems. First of all we are setting
up tests from different sources and we're doing this programmatically
which breaks the default loading logic so this is required anyways.
Secondly this loader has a nicer interpolation for test names than the
default one so you can just do ``run-tests.py ViewTestCase`` and it
will work.
"""
def getRootSuite(self):
return suite()
def loadTestsFromName(self, name, module=None): def loadTestsFromName(self, name, module=None):
root = self.getRootSuite()
if name == 'suite': if name == 'suite':
return suite() return root
for testcase, testname in find_all_tests_with_name():
if testname == name:
return testcase
if testname.startswith(common_prefix):
if testname[len(common_prefix):] == name:
return testcase
all_tests = [] all_tests = []
for testcase, testname in find_all_tests_with_name(): for testcase, testname in find_all_tests(root):
if testname.endswith('.' + name) or ('.' + name + '.') in testname or \ if testname == name or \
testname.endswith('.' + name) or \
('.' + name + '.') in testname or \
testname.startswith(name + '.'): testname.startswith(name + '.'):
all_tests.append(testcase) all_tests.append(testcase)

Loading…
Cancel
Save