From a082a5e0ba81af15653fe56501c8d2530d3621dc Mon Sep 17 00:00:00 2001 From: Armin Ronacher Date: Fri, 26 Aug 2011 12:07:49 +0100 Subject: [PATCH] Cleanup in the test finder --- flask/testsuite/__init__.py | 51 ++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/flask/testsuite/__init__.py b/flask/testsuite/__init__.py index ff0224bc..5ebc786e 100644 --- a/flask/testsuite/__init__.py +++ b/flask/testsuite/__init__.py @@ -20,9 +20,6 @@ from contextlib import contextmanager from werkzeug.utils import import_string, find_modules -common_prefix = __name__ + '.' - - def add_to_path(path): def _samefile(x, y): try: @@ -50,23 +47,18 @@ def iter_suites(): yield mod.suite() -def find_all_tests(): - suites = [suite()] +def find_all_tests(suite): + suites = [suite] while suites: s = suites.pop() try: suites.extend(s) except TypeError: - yield s - - -def find_all_tests_with_name(): - for testcase in find_all_tests(): - yield testcase, '%s.%s.%s' % ( - testcase.__class__.__module__, - testcase.__class__.__name__, - testcase._testMethodName - ) + yield s, '%s.%s.%s' % ( + s.__class__.__module__, + s.__class__.__name__, + s._testMethodName + ) @contextmanager @@ -111,6 +103,10 @@ def emits_module_deprecation_warning(f): class FlaskTestCase(unittest.TestCase): + """Baseclass for all the tests that Flask uses. Use these methods + for testing instead of the camelcased ones in the baseclass for + consistency. + """ def ensure_clean_request_context(self): # make sure we're not leaking a request context since we are @@ -136,20 +132,27 @@ class FlaskTestCase(unittest.TestCase): class BetterLoader(unittest.TestLoader): + """A nicer loader that solves two problems. First of all we are setting + up tests from different sources and we're doing this programmatically + which breaks the default loading logic so this is required anyways. + Secondly this loader has a nicer interpolation for test names than the + default one so you can just do ``run-tests.py ViewTestCase`` and it + will work. + """ + + def getRootSuite(self): + return suite() def loadTestsFromName(self, name, module=None): + root = self.getRootSuite() if name == 'suite': - return suite() - for testcase, testname in find_all_tests_with_name(): - if testname == name: - return testcase - if testname.startswith(common_prefix): - if testname[len(common_prefix):] == name: - return testcase + return root all_tests = [] - for testcase, testname in find_all_tests_with_name(): - if testname.endswith('.' + name) or ('.' + name + '.') in testname or \ + for testcase, testname in find_all_tests(root): + if testname == name or \ + testname.endswith('.' + name) or \ + ('.' + name + '.') in testname or \ testname.startswith(name + '.'): all_tests.append(testcase)