diff --git a/flask/app.py b/flask/app.py index e784eea2..6fc95c2c 100644 --- a/flask/app.py +++ b/flask/app.py @@ -65,70 +65,6 @@ def setupmethod(f): return update_wrapper(wrapper_func, f) -class ExceptionHandlerDict(Mapping): - """A dict storing exception handlers or falling back to the default ones - - Designed to be app.error_handler_spec[blueprint_or_none] - And hold a Exception → handler function mapping. - Converts error codes to default HTTPException subclasses. - - Returns None if no handler is defined for blueprint or app - """ - def __init__(self, app, blueprint): - super(ExceptionHandlerDict, self).__init__() - self.app = app - self.data = {} - if blueprint: # fall back to app mapping - self.fallback = app.error_handler_spec[None] - else: - self.fallback = {} - - @staticmethod - def get_class(exc_class_or_code): - if isinstance(exc_class_or_code, integer_types): - # ensure that we register only exceptions as keys - exc_class = default_exceptions[exc_class_or_code] - else: - assert issubclass(exc_class_or_code, Exception) - exc_class = exc_class_or_code - return exc_class - - def __contains__(self, e_or_c): - clazz = self.get_class(e_or_c) - return clazz in self.data or clazz in self.fallback - - def __getitem__(self, e_or_c): - clazz = self.get_class(e_or_c) - item = self.data.get(clazz) - if item is not None: - return item - elif len(self.fallback): - return self.fallback[clazz] - else: - raise KeyError(e_or_c) - - def __setitem__(self, e_or_c, handler): - assert callable(handler) - self.data[self.get_class(e_or_c)] = handler - - def __iter__(self): - return iterkeys(self.data) - - def __len__(self): - return len(self.data) - - def find_handler(self, ex_instance): - assert isinstance(ex_instance, Exception) - - for superclass in type(ex_instance).mro(): - if superclass is BaseException: - return None - handler = self.get(superclass) - if handler is not None: - return handler - return None - - class Flask(_PackageBoundObject): """The flask object implements a WSGI application and acts as the central object. It is passed the name of the module or package of the @@ -429,7 +365,7 @@ class Flask(_PackageBoundObject): # support for the now deprecated `error_handlers` attribute. The # :attr:`error_handler_spec` shall be used now. - self._error_handlers = ExceptionHandlerDict(self, None) + self._error_handlers = {} #: A dictionary of all registered error handlers. The key is ``None`` #: for error handlers active on the application, otherwise the key is @@ -1142,6 +1078,23 @@ class Flask(_PackageBoundObject): return f return decorator + @staticmethod + def _ensure_exc_class(exc_class_or_code): + """ensure that we register only exceptions as handler keys""" + if isinstance(exc_class_or_code, integer_types): + exc_class = default_exceptions[exc_class_or_code] + elif isinstance(exc_class_or_code, type): + exc_class = exc_class_or_code + else: + exc_class = type(exc_class_or_code) + + assert issubclass(exc_class, Exception) + + if issubclass(exc_class, HTTPException): + return exc_class, exc_class.code + else: + return exc_class, None + @setupmethod def errorhandler(self, code_or_exception): """A decorator that is used to register a function give a given @@ -1211,9 +1164,10 @@ class Flask(_PackageBoundObject): 'Handlers can only be registered for exception classes or HTTP error codes.' .format(code_or_exception)) - handlers = self.error_handler_spec.setdefault(key, ExceptionHandlerDict(self, key)) + exc_class, code = self._ensure_exc_class(code_or_exception) - handlers[code_or_exception] = f + handlers = self.error_handler_spec.setdefault(key, {}).setdefault(code, {}) + handlers[exc_class] = f @setupmethod def template_filter(self, name=None): @@ -1456,10 +1410,29 @@ class Flask(_PackageBoundObject): def _find_error_handler(self, e): """Finds a registered error handler for the request’s blueprint. - If nether blueprint nor App has a suitable handler registered, returns None + If neither blueprint nor App has a suitable handler registered, returns None """ - handlers = self.error_handler_spec.get(request.blueprint, self.error_handler_spec[None]) - return handlers.find_handler(e) + exc_class, code = self._ensure_exc_class(e) + + def find_superclass(d): + if not d: + return None + for superclass in exc_class.mro(): + if superclass is BaseException: + return None + handler = d.get(superclass) + if handler is not None: + return handler + return None + + # try blueprint handlers + handler = find_superclass(self.error_handler_spec.get(request.blueprint, {}).get(code)) + + if handler is not None: + return handler + + # fall back to app handlers + return find_superclass(self.error_handler_spec[None].get(code)) def handle_http_exception(self, e): """Handles an HTTP exception. By default this will invoke the diff --git a/tests/test_user_error_handler.py b/tests/test_user_error_handler.py new file mode 100644 index 00000000..78f4de3c --- /dev/null +++ b/tests/test_user_error_handler.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +from werkzeug.exceptions import Forbidden, InternalServerError +import flask + + +def test_error_handler_subclass(): + app = flask.Flask(__name__) + + class ParentException(Exception): + pass + + class ChildExceptionUnregistered(ParentException): + pass + + class ChildExceptionRegistered(ParentException): + pass + + @app.errorhandler(ParentException) + def parent_exception_handler(e): + assert isinstance(e, ParentException) + return 'parent' + + @app.errorhandler(ChildExceptionRegistered) + def child_exception_handler(e): + assert isinstance(e, ChildExceptionRegistered) + return 'child-registered' + + @app.route('/parent') + def parent_test(): + raise ParentException() + + @app.route('/child-unregistered') + def unregistered_test(): + raise ChildExceptionUnregistered() + + @app.route('/child-registered') + def registered_test(): + raise ChildExceptionRegistered() + + + c = app.test_client() + + assert c.get('/parent').data == b'parent' + assert c.get('/child-unregistered').data == b'parent' + assert c.get('/child-registered').data == b'child-registered' + + +def test_error_handler_http_subclass(): + app = flask.Flask(__name__) + + class ForbiddenSubclassRegistered(Forbidden): + pass + + class ForbiddenSubclassUnregistered(Forbidden): + pass + + @app.errorhandler(403) + def code_exception_handler(e): + assert isinstance(e, Forbidden) + return 'forbidden' + + @app.errorhandler(ForbiddenSubclassRegistered) + def subclass_exception_handler(e): + assert isinstance(e, ForbiddenSubclassRegistered) + return 'forbidden-registered' + + @app.route('/forbidden') + def forbidden_test(): + raise Forbidden() + + @app.route('/forbidden-registered') + def registered_test(): + raise ForbiddenSubclassRegistered() + + @app.route('/forbidden-unregistered') + def unregistered_test(): + raise ForbiddenSubclassUnregistered() + + + c = app.test_client() + + assert c.get('/forbidden').data == b'forbidden' + assert c.get('/forbidden-unregistered').data == b'forbidden' + assert c.get('/forbidden-registered').data == b'forbidden-registered' + + +def test_error_handler_blueprint(): + bp = flask.Blueprint('bp', __name__) + + @bp.errorhandler(500) + def bp_exception_handler(e): + return 'bp-error' + + @bp.route('/error') + def bp_test(): + raise InternalServerError() + + app = flask.Flask(__name__) + + @app.errorhandler(500) + def app_exception_handler(e): + return 'app-error' + + @app.route('/error') + def app_test(): + raise InternalServerError() + + app.register_blueprint(bp, url_prefix='/bp') + + c = app.test_client() + + assert c.get('/error').data == b'app-error' + assert c.get('/bp/error').data == b'bp-error' \ No newline at end of file