diff --git a/flask/app.py b/flask/app.py index 6baee3a6..1f3d39ce 100644 --- a/flask/app.py +++ b/flask/app.py @@ -14,7 +14,7 @@ from threading import Lock from datetime import timedelta from itertools import chain from functools import update_wrapper -from collections import Mapping +from collections import Mapping, deque from werkzeug.datastructures import ImmutableDict from werkzeug.routing import Map, Rule, RequestRedirect, BuildError @@ -1411,27 +1411,37 @@ class Flask(_PackageBoundObject): If neither blueprint nor App has a suitable handler registered, returns None """ exc_class, code = self._get_exc_class_and_code(type(e)) - - def find_superclass(handler_map): + + def find_handler(handler_map): if not handler_map: - return None - for superclass in exc_class.__mro__: - if superclass is BaseException: - return None - handler = handler_map.get(superclass) + return + queue = deque(exc_class.__mro__) + # Protect from geniuses who might create circular references in + # __mro__ + done = set() + + while True: + cls = queue.popleft() + if cls in done: + continue + done.add(cls) + handler = handler_map.get(cls) if handler is not None: - handler_map[exc_class] = handler # cache for next time exc_class is raised + # cache for next time exc_class is raised + handler_map[exc_class] = handler return handler - return None - + + queue.extend(cls.__mro__) + # try blueprint handlers - handler = find_superclass(self.error_handler_spec.get(request.blueprint, {}).get(code)) - + handler = find_handler(self.error_handler_spec + .get(request.blueprint, {}) + .get(code)) if handler is not None: return handler - + # fall back to app handlers - return find_superclass(self.error_handler_spec[None].get(code)) + return find_handler(self.error_handler_spec[None].get(code)) def handle_http_exception(self, e): """Handles an HTTP exception. By default this will invoke the diff --git a/tests/test_user_error_handler.py b/tests/test_user_error_handler.py index 78f4de3c..33131f3e 100644 --- a/tests/test_user_error_handler.py +++ b/tests/test_user_error_handler.py @@ -4,110 +4,108 @@ import flask def test_error_handler_subclass(): - app = flask.Flask(__name__) + app = flask.Flask(__name__) - class ParentException(Exception): - pass + class ParentException(Exception): + pass - class ChildExceptionUnregistered(ParentException): - pass + class ChildExceptionUnregistered(ParentException): + pass - class ChildExceptionRegistered(ParentException): - pass + class ChildExceptionRegistered(ParentException): + pass - @app.errorhandler(ParentException) - def parent_exception_handler(e): - assert isinstance(e, ParentException) - return 'parent' + @app.errorhandler(ParentException) + def parent_exception_handler(e): + assert isinstance(e, ParentException) + return 'parent' - @app.errorhandler(ChildExceptionRegistered) - def child_exception_handler(e): - assert isinstance(e, ChildExceptionRegistered) - return 'child-registered' + @app.errorhandler(ChildExceptionRegistered) + def child_exception_handler(e): + assert isinstance(e, ChildExceptionRegistered) + return 'child-registered' - @app.route('/parent') - def parent_test(): - raise ParentException() + @app.route('/parent') + def parent_test(): + raise ParentException() - @app.route('/child-unregistered') - def unregistered_test(): - raise ChildExceptionUnregistered() + @app.route('/child-unregistered') + def unregistered_test(): + raise ChildExceptionUnregistered() - @app.route('/child-registered') - def registered_test(): - raise ChildExceptionRegistered() + @app.route('/child-registered') + def registered_test(): + raise ChildExceptionRegistered() + c = app.test_client() - c = app.test_client() - - assert c.get('/parent').data == b'parent' - assert c.get('/child-unregistered').data == b'parent' - assert c.get('/child-registered').data == b'child-registered' + assert c.get('/parent').data == b'parent' + assert c.get('/child-unregistered').data == b'parent' + assert c.get('/child-registered').data == b'child-registered' def test_error_handler_http_subclass(): - app = flask.Flask(__name__) - - class ForbiddenSubclassRegistered(Forbidden): - pass + app = flask.Flask(__name__) - class ForbiddenSubclassUnregistered(Forbidden): - pass + class ForbiddenSubclassRegistered(Forbidden): + pass - @app.errorhandler(403) - def code_exception_handler(e): - assert isinstance(e, Forbidden) - return 'forbidden' + class ForbiddenSubclassUnregistered(Forbidden): + pass - @app.errorhandler(ForbiddenSubclassRegistered) - def subclass_exception_handler(e): - assert isinstance(e, ForbiddenSubclassRegistered) - return 'forbidden-registered' + @app.errorhandler(403) + def code_exception_handler(e): + assert isinstance(e, Forbidden) + return 'forbidden' - @app.route('/forbidden') - def forbidden_test(): - raise Forbidden() + @app.errorhandler(ForbiddenSubclassRegistered) + def subclass_exception_handler(e): + assert isinstance(e, ForbiddenSubclassRegistered) + return 'forbidden-registered' - @app.route('/forbidden-registered') - def registered_test(): - raise ForbiddenSubclassRegistered() + @app.route('/forbidden') + def forbidden_test(): + raise Forbidden() - @app.route('/forbidden-unregistered') - def unregistered_test(): - raise ForbiddenSubclassUnregistered() + @app.route('/forbidden-registered') + def registered_test(): + raise ForbiddenSubclassRegistered() + @app.route('/forbidden-unregistered') + def unregistered_test(): + raise ForbiddenSubclassUnregistered() - c = app.test_client() + c = app.test_client() - assert c.get('/forbidden').data == b'forbidden' - assert c.get('/forbidden-unregistered').data == b'forbidden' - assert c.get('/forbidden-registered').data == b'forbidden-registered' + assert c.get('/forbidden').data == b'forbidden' + assert c.get('/forbidden-unregistered').data == b'forbidden' + assert c.get('/forbidden-registered').data == b'forbidden-registered' def test_error_handler_blueprint(): - bp = flask.Blueprint('bp', __name__) - - @bp.errorhandler(500) - def bp_exception_handler(e): - return 'bp-error' - - @bp.route('/error') - def bp_test(): - raise InternalServerError() - - app = flask.Flask(__name__) - - @app.errorhandler(500) - def app_exception_handler(e): - return 'app-error' - - @app.route('/error') - def app_test(): - raise InternalServerError() - - app.register_blueprint(bp, url_prefix='/bp') - - c = app.test_client() - - assert c.get('/error').data == b'app-error' - assert c.get('/bp/error').data == b'bp-error' \ No newline at end of file + bp = flask.Blueprint('bp', __name__) + + @bp.errorhandler(500) + def bp_exception_handler(e): + return 'bp-error' + + @bp.route('/error') + def bp_test(): + raise InternalServerError() + + app = flask.Flask(__name__) + + @app.errorhandler(500) + def app_exception_handler(e): + return 'app-error' + + @app.route('/error') + def app_test(): + raise InternalServerError() + + app.register_blueprint(bp, url_prefix='/bp') + + c = app.test_client() + + assert c.get('/error').data == b'app-error' + assert c.get('/bp/error').data == b'bp-error'