Browse Source

Merge pull request #1291 from flying-sheep/errorhandler-rework

Fixed and intuitivized exception handling
pull/1432/merge
Markus Unterwaditzer 10 years ago
parent
commit
aed464b92b
  1. 95
      flask/app.py
  2. 113
      tests/test_user_error_handler.py

95
flask/app.py

@ -8,18 +8,18 @@
:copyright: (c) 2015 by Armin Ronacher. :copyright: (c) 2015 by Armin Ronacher.
:license: BSD, see LICENSE for more details. :license: BSD, see LICENSE for more details.
""" """
import os import os
import sys import sys
from threading import Lock from threading import Lock
from datetime import timedelta from datetime import timedelta
from itertools import chain from itertools import chain
from functools import update_wrapper from functools import update_wrapper
from collections import Mapping
from werkzeug.datastructures import ImmutableDict from werkzeug.datastructures import ImmutableDict
from werkzeug.routing import Map, Rule, RequestRedirect, BuildError from werkzeug.routing import Map, Rule, RequestRedirect, BuildError
from werkzeug.exceptions import HTTPException, InternalServerError, \ from werkzeug.exceptions import HTTPException, InternalServerError, \
MethodNotAllowed, BadRequest MethodNotAllowed, BadRequest, default_exceptions
from .helpers import _PackageBoundObject, url_for, get_flashed_messages, \ from .helpers import _PackageBoundObject, url_for, get_flashed_messages, \
locked_cached_property, _endpoint_from_view_func, find_package locked_cached_property, _endpoint_from_view_func, find_package
@ -33,7 +33,7 @@ from .templating import DispatchingJinjaLoader, Environment, \
_default_template_ctx_processor _default_template_ctx_processor
from .signals import request_started, request_finished, got_request_exception, \ from .signals import request_started, request_finished, got_request_exception, \
request_tearing_down, appcontext_tearing_down request_tearing_down, appcontext_tearing_down
from ._compat import reraise, string_types, text_type, integer_types from ._compat import reraise, string_types, text_type, integer_types, iterkeys
# a lock used for logger initialization # a lock used for logger initialization
_logger_lock = Lock() _logger_lock = Lock()
@ -1078,6 +1078,21 @@ class Flask(_PackageBoundObject):
return f return f
return decorator return decorator
@staticmethod
def _get_exc_class_and_code(exc_class_or_code):
"""Ensure that we register only exceptions as handler keys"""
if isinstance(exc_class_or_code, integer_types):
exc_class = default_exceptions[exc_class_or_code]
else:
exc_class = exc_class_or_code
assert issubclass(exc_class, Exception)
if issubclass(exc_class, HTTPException):
return exc_class, exc_class.code
else:
return exc_class, None
@setupmethod @setupmethod
def errorhandler(self, code_or_exception): def errorhandler(self, code_or_exception):
"""A decorator that is used to register a function give a given """A decorator that is used to register a function give a given
@ -1136,16 +1151,21 @@ class Flask(_PackageBoundObject):
@setupmethod @setupmethod
def _register_error_handler(self, key, code_or_exception, f): def _register_error_handler(self, key, code_or_exception, f):
if isinstance(code_or_exception, HTTPException): """
code_or_exception = code_or_exception.code :type key: None|str
if isinstance(code_or_exception, integer_types): :type code_or_exception: int|T<=Exception
assert code_or_exception != 500 or key is None, \ :type f: callable
'It is currently not possible to register a 500 internal ' \ """
'server error on a per-blueprint level.' if isinstance(code_or_exception, HTTPException): # old broken behavior
self.error_handler_spec.setdefault(key, {})[code_or_exception] = f raise ValueError(
else: 'Tried to register a handler for an exception instance {0!r}. '
self.error_handler_spec.setdefault(key, {}).setdefault(None, []) \ 'Handlers can only be registered for exception classes or HTTP error codes.'
.append((code_or_exception, f)) .format(code_or_exception))
exc_class, code = self._get_exc_class_and_code(code_or_exception)
handlers = self.error_handler_spec.setdefault(key, {}).setdefault(code, {})
handlers[exc_class] = f
@setupmethod @setupmethod
def template_filter(self, name=None): def template_filter(self, name=None):
@ -1386,6 +1406,33 @@ class Flask(_PackageBoundObject):
self.url_default_functions.setdefault(None, []).append(f) self.url_default_functions.setdefault(None, []).append(f)
return f return f
def _find_error_handler(self, e):
"""Finds a registered error handler for the request’s blueprint.
If neither blueprint nor App has a suitable handler registered, returns None
"""
exc_class, code = self._get_exc_class_and_code(type(e))
def find_superclass(handler_map):
if not handler_map:
return None
for superclass in exc_class.__mro__:
if superclass is BaseException:
return None
handler = handler_map.get(superclass)
if handler is not None:
handler_map[exc_class] = handler # cache for next time exc_class is raised
return handler
return None
# try blueprint handlers
handler = find_superclass(self.error_handler_spec.get(request.blueprint, {}).get(code))
if handler is not None:
return handler
# fall back to app handlers
return find_superclass(self.error_handler_spec[None].get(code))
def handle_http_exception(self, e): def handle_http_exception(self, e):
"""Handles an HTTP exception. By default this will invoke the """Handles an HTTP exception. By default this will invoke the
registered error handlers and fall back to returning the registered error handlers and fall back to returning the
@ -1393,15 +1440,12 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
handlers = self.error_handler_spec.get(request.blueprint)
# Proxy exceptions don't have error codes. We want to always return # Proxy exceptions don't have error codes. We want to always return
# those unchanged as errors # those unchanged as errors
if e.code is None: if e.code is None:
return e return e
if handlers and e.code in handlers:
handler = handlers[e.code] handler = self._find_error_handler(e)
else:
handler = self.error_handler_spec[None].get(e.code)
if handler is None: if handler is None:
return e return e
return handler(e) return handler(e)
@ -1444,19 +1488,14 @@ class Flask(_PackageBoundObject):
# we cannot prevent users from trashing it themselves in a custom # we cannot prevent users from trashing it themselves in a custom
# trap_http_exception method so that's their fault then. # trap_http_exception method so that's their fault then.
blueprint_handlers = ()
handlers = self.error_handler_spec.get(request.blueprint)
if handlers is not None:
blueprint_handlers = handlers.get(None, ())
app_handlers = self.error_handler_spec[None].get(None, ())
for typecheck, handler in chain(blueprint_handlers, app_handlers):
if isinstance(e, typecheck):
return handler(e)
if isinstance(e, HTTPException) and not self.trap_http_exception(e): if isinstance(e, HTTPException) and not self.trap_http_exception(e):
return self.handle_http_exception(e) return self.handle_http_exception(e)
handler = self._find_error_handler(e)
if handler is None:
reraise(exc_type, exc_value, tb) reraise(exc_type, exc_value, tb)
return handler(e)
def handle_exception(self, e): def handle_exception(self, e):
"""Default exception handling that kicks in when an exception """Default exception handling that kicks in when an exception
@ -1470,7 +1509,7 @@ class Flask(_PackageBoundObject):
exc_type, exc_value, tb = sys.exc_info() exc_type, exc_value, tb = sys.exc_info()
got_request_exception.send(self, exception=e) got_request_exception.send(self, exception=e)
handler = self.error_handler_spec[None].get(500) handler = self._find_error_handler(InternalServerError())
if self.propagate_exceptions: if self.propagate_exceptions:
# if we want to repropagate the exception, we can attempt to # if we want to repropagate the exception, we can attempt to

113
tests/test_user_error_handler.py

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
from werkzeug.exceptions import Forbidden, InternalServerError
import flask
def test_error_handler_subclass():
app = flask.Flask(__name__)
class ParentException(Exception):
pass
class ChildExceptionUnregistered(ParentException):
pass
class ChildExceptionRegistered(ParentException):
pass
@app.errorhandler(ParentException)
def parent_exception_handler(e):
assert isinstance(e, ParentException)
return 'parent'
@app.errorhandler(ChildExceptionRegistered)
def child_exception_handler(e):
assert isinstance(e, ChildExceptionRegistered)
return 'child-registered'
@app.route('/parent')
def parent_test():
raise ParentException()
@app.route('/child-unregistered')
def unregistered_test():
raise ChildExceptionUnregistered()
@app.route('/child-registered')
def registered_test():
raise ChildExceptionRegistered()
c = app.test_client()
assert c.get('/parent').data == b'parent'
assert c.get('/child-unregistered').data == b'parent'
assert c.get('/child-registered').data == b'child-registered'
def test_error_handler_http_subclass():
app = flask.Flask(__name__)
class ForbiddenSubclassRegistered(Forbidden):
pass
class ForbiddenSubclassUnregistered(Forbidden):
pass
@app.errorhandler(403)
def code_exception_handler(e):
assert isinstance(e, Forbidden)
return 'forbidden'
@app.errorhandler(ForbiddenSubclassRegistered)
def subclass_exception_handler(e):
assert isinstance(e, ForbiddenSubclassRegistered)
return 'forbidden-registered'
@app.route('/forbidden')
def forbidden_test():
raise Forbidden()
@app.route('/forbidden-registered')
def registered_test():
raise ForbiddenSubclassRegistered()
@app.route('/forbidden-unregistered')
def unregistered_test():
raise ForbiddenSubclassUnregistered()
c = app.test_client()
assert c.get('/forbidden').data == b'forbidden'
assert c.get('/forbidden-unregistered').data == b'forbidden'
assert c.get('/forbidden-registered').data == b'forbidden-registered'
def test_error_handler_blueprint():
bp = flask.Blueprint('bp', __name__)
@bp.errorhandler(500)
def bp_exception_handler(e):
return 'bp-error'
@bp.route('/error')
def bp_test():
raise InternalServerError()
app = flask.Flask(__name__)
@app.errorhandler(500)
def app_exception_handler(e):
return 'app-error'
@app.route('/error')
def app_test():
raise InternalServerError()
app.register_blueprint(bp, url_prefix='/bp')
c = app.test_client()
assert c.get('/error').data == b'app-error'
assert c.get('/bp/error').data == b'bp-error'
Loading…
Cancel
Save