diff --git a/flask/testsuite/__init__.py b/flask/testsuite/__init__.py index 8df7a7fd..ff0224bc 100644 --- a/flask/testsuite/__init__.py +++ b/flask/testsuite/__init__.py @@ -20,6 +20,9 @@ from contextlib import contextmanager from werkzeug.utils import import_string, find_modules +common_prefix = __name__ + '.' + + def add_to_path(path): def _samefile(x, y): try: @@ -47,6 +50,25 @@ def iter_suites(): yield mod.suite() +def find_all_tests(): + suites = [suite()] + while suites: + s = suites.pop() + try: + suites.extend(s) + except TypeError: + yield s + + +def find_all_tests_with_name(): + for testcase in find_all_tests(): + yield testcase, '%s.%s.%s' % ( + testcase.__class__.__module__, + testcase.__class__.__name__, + testcase._testMethodName + ) + + @contextmanager def catch_warnings(): """Catch warnings in a with block in a list""" @@ -113,6 +135,36 @@ class FlaskTestCase(unittest.TestCase): return self.assertEqual(x, y) +class BetterLoader(unittest.TestLoader): + + def loadTestsFromName(self, name, module=None): + if name == 'suite': + return suite() + for testcase, testname in find_all_tests_with_name(): + if testname == name: + return testcase + if testname.startswith(common_prefix): + if testname[len(common_prefix):] == name: + return testcase + + all_tests = [] + for testcase, testname in find_all_tests_with_name(): + if testname.endswith('.' + name) or ('.' + name + '.') in testname or \ + testname.startswith(name + '.'): + all_tests.append(testcase) + + if not all_tests: + print >> sys.stderr, 'Error: could not find test case for "%s"' % name + sys.exit(1) + + if len(all_tests) == 1: + return all_tests[0] + rv = unittest.TestSuite() + for test in all_tests: + rv.addTest(test) + return rv + + def suite(): setup_paths() suite = unittest.TestSuite() diff --git a/run-tests.py b/run-tests.py index b74e7f71..7d44febc 100644 --- a/run-tests.py +++ b/run-tests.py @@ -1,58 +1,3 @@ -import sys import unittest -from unittest.loader import TestLoader -from flask.testsuite import suite - -common_prefix = suite.__module__ + '.' - - -def find_all_tests(): - suites = [suite()] - while suites: - s = suites.pop() - try: - suites.extend(s) - except TypeError: - yield s - - -def find_all_tests_with_name(): - for testcase in find_all_tests(): - yield testcase, '%s.%s.%s' % ( - testcase.__class__.__module__, - testcase.__class__.__name__, - testcase._testMethodName - ) - - -class BetterLoader(TestLoader): - - def loadTestsFromName(self, name, module=None): - if name == 'suite': - return suite() - for testcase, testname in find_all_tests_with_name(): - if testname == name: - return testcase - if testname.startswith(common_prefix): - if testname[len(common_prefix):] == name: - return testcase - - all_tests = [] - for testcase, testname in find_all_tests_with_name(): - if testname.endswith('.' + name) or ('.' + name + '.') in testname or \ - testname.startswith(name + '.'): - all_tests.append(testcase) - - if not all_tests: - print >> sys.stderr, 'Error: could not find test case for "%s"' % name - sys.exit(1) - - if len(all_tests) == 1: - return all_tests[0] - rv = unittest.TestSuite() - for test in all_tests: - rv.addTest(test) - return rv - - +from flask.testsuite import BetterLoader unittest.main(testLoader=BetterLoader(), defaultTest='suite')