diff --git a/flask/app.py b/flask/app.py index 1f7df2e8..6baee3a6 100644 --- a/flask/app.py +++ b/flask/app.py @@ -8,18 +8,18 @@ :copyright: (c) 2015 by Armin Ronacher. :license: BSD, see LICENSE for more details. """ - import os import sys from threading import Lock from datetime import timedelta from itertools import chain from functools import update_wrapper +from collections import Mapping from werkzeug.datastructures import ImmutableDict from werkzeug.routing import Map, Rule, RequestRedirect, BuildError from werkzeug.exceptions import HTTPException, InternalServerError, \ - MethodNotAllowed, BadRequest + MethodNotAllowed, BadRequest, default_exceptions from .helpers import _PackageBoundObject, url_for, get_flashed_messages, \ locked_cached_property, _endpoint_from_view_func, find_package @@ -33,7 +33,7 @@ from .templating import DispatchingJinjaLoader, Environment, \ _default_template_ctx_processor from .signals import request_started, request_finished, got_request_exception, \ request_tearing_down, appcontext_tearing_down -from ._compat import reraise, string_types, text_type, integer_types +from ._compat import reraise, string_types, text_type, integer_types, iterkeys # a lock used for logger initialization _logger_lock = Lock() @@ -1078,6 +1078,21 @@ class Flask(_PackageBoundObject): return f return decorator + @staticmethod + def _get_exc_class_and_code(exc_class_or_code): + """Ensure that we register only exceptions as handler keys""" + if isinstance(exc_class_or_code, integer_types): + exc_class = default_exceptions[exc_class_or_code] + else: + exc_class = exc_class_or_code + + assert issubclass(exc_class, Exception) + + if issubclass(exc_class, HTTPException): + return exc_class, exc_class.code + else: + return exc_class, None + @setupmethod def errorhandler(self, code_or_exception): """A decorator that is used to register a function give a given @@ -1136,16 +1151,21 @@ class Flask(_PackageBoundObject): @setupmethod def _register_error_handler(self, key, code_or_exception, f): - if isinstance(code_or_exception, HTTPException): - code_or_exception = code_or_exception.code - if isinstance(code_or_exception, integer_types): - assert code_or_exception != 500 or key is None, \ - 'It is currently not possible to register a 500 internal ' \ - 'server error on a per-blueprint level.' - self.error_handler_spec.setdefault(key, {})[code_or_exception] = f - else: - self.error_handler_spec.setdefault(key, {}).setdefault(None, []) \ - .append((code_or_exception, f)) + """ + :type key: None|str + :type code_or_exception: int|T<=Exception + :type f: callable + """ + if isinstance(code_or_exception, HTTPException): # old broken behavior + raise ValueError( + 'Tried to register a handler for an exception instance {0!r}. ' + 'Handlers can only be registered for exception classes or HTTP error codes.' + .format(code_or_exception)) + + exc_class, code = self._get_exc_class_and_code(code_or_exception) + + handlers = self.error_handler_spec.setdefault(key, {}).setdefault(code, {}) + handlers[exc_class] = f @setupmethod def template_filter(self, name=None): @@ -1386,6 +1406,33 @@ class Flask(_PackageBoundObject): self.url_default_functions.setdefault(None, []).append(f) return f + def _find_error_handler(self, e): + """Finds a registered error handler for the request’s blueprint. + If neither blueprint nor App has a suitable handler registered, returns None + """ + exc_class, code = self._get_exc_class_and_code(type(e)) + + def find_superclass(handler_map): + if not handler_map: + return None + for superclass in exc_class.__mro__: + if superclass is BaseException: + return None + handler = handler_map.get(superclass) + if handler is not None: + handler_map[exc_class] = handler # cache for next time exc_class is raised + return handler + return None + + # try blueprint handlers + handler = find_superclass(self.error_handler_spec.get(request.blueprint, {}).get(code)) + + if handler is not None: + return handler + + # fall back to app handlers + return find_superclass(self.error_handler_spec[None].get(code)) + def handle_http_exception(self, e): """Handles an HTTP exception. By default this will invoke the registered error handlers and fall back to returning the @@ -1393,15 +1440,12 @@ class Flask(_PackageBoundObject): .. versionadded:: 0.3 """ - handlers = self.error_handler_spec.get(request.blueprint) # Proxy exceptions don't have error codes. We want to always return # those unchanged as errors if e.code is None: return e - if handlers and e.code in handlers: - handler = handlers[e.code] - else: - handler = self.error_handler_spec[None].get(e.code) + + handler = self._find_error_handler(e) if handler is None: return e return handler(e) @@ -1443,20 +1487,15 @@ class Flask(_PackageBoundObject): # wants the traceback preserved in handle_http_exception. Of course # we cannot prevent users from trashing it themselves in a custom # trap_http_exception method so that's their fault then. - - blueprint_handlers = () - handlers = self.error_handler_spec.get(request.blueprint) - if handlers is not None: - blueprint_handlers = handlers.get(None, ()) - app_handlers = self.error_handler_spec[None].get(None, ()) - for typecheck, handler in chain(blueprint_handlers, app_handlers): - if isinstance(e, typecheck): - return handler(e) - + if isinstance(e, HTTPException) and not self.trap_http_exception(e): return self.handle_http_exception(e) - reraise(exc_type, exc_value, tb) + handler = self._find_error_handler(e) + + if handler is None: + reraise(exc_type, exc_value, tb) + return handler(e) def handle_exception(self, e): """Default exception handling that kicks in when an exception @@ -1470,7 +1509,7 @@ class Flask(_PackageBoundObject): exc_type, exc_value, tb = sys.exc_info() got_request_exception.send(self, exception=e) - handler = self.error_handler_spec[None].get(500) + handler = self._find_error_handler(InternalServerError()) if self.propagate_exceptions: # if we want to repropagate the exception, we can attempt to diff --git a/tests/test_user_error_handler.py b/tests/test_user_error_handler.py new file mode 100644 index 00000000..78f4de3c --- /dev/null +++ b/tests/test_user_error_handler.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +from werkzeug.exceptions import Forbidden, InternalServerError +import flask + + +def test_error_handler_subclass(): + app = flask.Flask(__name__) + + class ParentException(Exception): + pass + + class ChildExceptionUnregistered(ParentException): + pass + + class ChildExceptionRegistered(ParentException): + pass + + @app.errorhandler(ParentException) + def parent_exception_handler(e): + assert isinstance(e, ParentException) + return 'parent' + + @app.errorhandler(ChildExceptionRegistered) + def child_exception_handler(e): + assert isinstance(e, ChildExceptionRegistered) + return 'child-registered' + + @app.route('/parent') + def parent_test(): + raise ParentException() + + @app.route('/child-unregistered') + def unregistered_test(): + raise ChildExceptionUnregistered() + + @app.route('/child-registered') + def registered_test(): + raise ChildExceptionRegistered() + + + c = app.test_client() + + assert c.get('/parent').data == b'parent' + assert c.get('/child-unregistered').data == b'parent' + assert c.get('/child-registered').data == b'child-registered' + + +def test_error_handler_http_subclass(): + app = flask.Flask(__name__) + + class ForbiddenSubclassRegistered(Forbidden): + pass + + class ForbiddenSubclassUnregistered(Forbidden): + pass + + @app.errorhandler(403) + def code_exception_handler(e): + assert isinstance(e, Forbidden) + return 'forbidden' + + @app.errorhandler(ForbiddenSubclassRegistered) + def subclass_exception_handler(e): + assert isinstance(e, ForbiddenSubclassRegistered) + return 'forbidden-registered' + + @app.route('/forbidden') + def forbidden_test(): + raise Forbidden() + + @app.route('/forbidden-registered') + def registered_test(): + raise ForbiddenSubclassRegistered() + + @app.route('/forbidden-unregistered') + def unregistered_test(): + raise ForbiddenSubclassUnregistered() + + + c = app.test_client() + + assert c.get('/forbidden').data == b'forbidden' + assert c.get('/forbidden-unregistered').data == b'forbidden' + assert c.get('/forbidden-registered').data == b'forbidden-registered' + + +def test_error_handler_blueprint(): + bp = flask.Blueprint('bp', __name__) + + @bp.errorhandler(500) + def bp_exception_handler(e): + return 'bp-error' + + @bp.route('/error') + def bp_test(): + raise InternalServerError() + + app = flask.Flask(__name__) + + @app.errorhandler(500) + def app_exception_handler(e): + return 'app-error' + + @app.route('/error') + def app_test(): + raise InternalServerError() + + app.register_blueprint(bp, url_prefix='/bp') + + c = app.test_client() + + assert c.get('/error').data == b'app-error' + assert c.get('/bp/error').data == b'bp-error' \ No newline at end of file