diff --git a/flask/ext/__init__.py b/flask/ext/__init__.py index 86ad3da5..1eda6df5 100644 --- a/flask/ext/__init__.py +++ b/flask/ext/__init__.py @@ -57,6 +57,17 @@ class _ExtensionImporter(object): __import__(realname) except ImportError: exc_type, exc_value, tb = exc_info() + # since we only establish the entry in sys.modules at the + # very this seems to be redundant, but if recursive imports + # happen we will call into the move import a second time. + # On the second invocation we still don't have an entry for + # fullname in sys.modules, but we will end up with the same + # fake module name and that import will succeed since this + # one already has a temporary entry in the modules dict. + # Since this one "succeeded" temporarily that second + # invocation now will have created a fullname entry in + # sys.modules which we have to kill. + modules.pop(fullname, None) if self.is_important_traceback(realname, tb): raise exc_type, exc_value, tb continue diff --git a/flask/testsuite/__init__.py b/flask/testsuite/__init__.py index 49b85b23..9f9e0a60 100644 --- a/flask/testsuite/__init__.py +++ b/flask/testsuite/__init__.py @@ -134,6 +134,32 @@ class FlaskTestCase(unittest.TestCase): def assert_equal(self, x, y): return self.assertEqual(x, y) + def assert_raises(self, exc_type, callable=None, *args, **kwargs): + catcher = _ExceptionCatcher(self, exc_type) + if callable is None: + return catcher + with catcher: + callable(*args, **kwargs) + + +class _ExceptionCatcher(object): + + def __init__(self, test_case, exc_type): + self.test_case = test_case + self.exc_type = exc_type + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + exception_name = self.exc_type.__name__ + if exc_type is None: + self.test_case.fail('Expected exception of type %r' % + exception_name) + elif not issubclass(exc_type, self.exc_type): + raise exc_type, exc_value, tb + return True + class BetterLoader(unittest.TestLoader): """A nicer loader that solves two problems. First of all we are setting diff --git a/flask/testsuite/ext.py b/flask/testsuite/ext.py index dc90952f..966db439 100644 --- a/flask/testsuite/ext.py +++ b/flask/testsuite/ext.py @@ -8,6 +8,8 @@ :copyright: (c) 2011 by Armin Ronacher. :license: BSD, see LICENSE for more details. """ +from __future__ import with_statement + import sys import unittest from flask.testsuite import FlaskTestCase @@ -92,6 +94,11 @@ class ExtImportHookTestCase(FlaskTestCase): from flask.ext.oldext_package.submodule import test_function self.assert_equal(test_function(), 42) + def test_flaskext_broken_package_no_module_caching(self): + for x in xrange(2): + with self.assert_raises(ImportError): + import flask.ext.broken + def suite(): suite = unittest.TestSuite() diff --git a/flask/testsuite/test_apps/flask_broken/__init__.py b/flask/testsuite/test_apps/flask_broken/__init__.py new file mode 100644 index 00000000..c194c04f --- /dev/null +++ b/flask/testsuite/test_apps/flask_broken/__init__.py @@ -0,0 +1,2 @@ +import flask.ext.broken.b +import missing_module diff --git a/flask/testsuite/test_apps/flask_broken/b.py b/flask/testsuite/test_apps/flask_broken/b.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/flaskext_compat.py b/scripts/flaskext_compat.py index f0b1739d..bb3ada03 100644 --- a/scripts/flaskext_compat.py +++ b/scripts/flaskext_compat.py @@ -25,7 +25,8 @@ ext_module.__package__ = ext_module.__name__ class _ExtensionImporter(object): """This importer redirects imports from the flask.ext module to other - locations. + locations. For implementation details see the code in Flask 0.8 + that does the same. """ _module_choices = ['flask_%s', 'flaskext.%s'] prefix = ext_module.__name__ + '.' @@ -45,6 +46,7 @@ class _ExtensionImporter(object): __import__(realname) except ImportError: exc_type, exc_value, tb = sys.exc_info() + sys.modules.pop(fullname, None) if self.is_important_traceback(realname, tb): raise exc_type, exc_value, tb continue @@ -55,12 +57,6 @@ class _ExtensionImporter(object): raise ImportError('No module named %s' % fullname) def is_important_traceback(self, important_module, tb): - """Walks a traceback's frames and checks if any of the frames - originated in the given important module. If that is the case - then we were able to import the module itself but apparently - something went wrong when the module was imported. (Eg: import - of an import failed). - """ while tb is not None: if tb.tb_frame.f_globals.get('__name__') == important_module: return True