From 3d67736e090f55fcdc476452ac89a79fdd14cc23 Mon Sep 17 00:00:00 2001 From: Daniel Richman Date: Sat, 17 Aug 2013 22:40:06 +0000 Subject: [PATCH] Check error handlers for specific classes first This allows adding error handlers like this: @app.errorhandler(werkzeug.exceptions.Forbidden) And subclassing HTTPExceptions: class ForbiddenBecauseReason(Forbidden): pass @app.errorhandler(ForbiddenBecauseReason) def error1(): return "Forbidden because reason", 403 @app.errorhandler(403) def error2(): return "Forbidden", 403 ... the idea being, that a flask extension might want to raise an exception, with the default behaviour of creating a HTTP error page, but still allowing the user to add a view/handler specific to that exception (e.g., "Forbidden because you are not in the right group"). --- CHANGES | 6 ++++++ flask/app.py | 5 +++-- flask/testsuite/basic.py | 41 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/CHANGES b/CHANGES index 8699718c..dcedf67c 100644 --- a/CHANGES +++ b/CHANGES @@ -14,6 +14,12 @@ Version 1.0 `False` it will only be modified if the session actually modifies. Non permanent sessions are not affected by this and will always expire if the browser window closes. +- Error handlers that match specific classes are now checked first, + thereby allowing catching exceptions that are subclasses of HTTP + exceptions (in ``werkzeug.execptions``). This makes it possible + for an extension author to create exceptions that will by default + result in the HTTP error of their choosing, but may be caught with + a custom error handler if desired. Version 0.10.2 -------------- diff --git a/flask/app.py b/flask/app.py index 805dc166..c97e8b3c 100644 --- a/flask/app.py +++ b/flask/app.py @@ -1365,8 +1365,6 @@ class Flask(_PackageBoundObject): # wants the traceback preserved in handle_http_exception. Of course # we cannot prevent users from trashing it themselves in a custom # trap_http_exception method so that's their fault then. - if isinstance(e, HTTPException) and not self.trap_http_exception(e): - return self.handle_http_exception(e) blueprint_handlers = () handlers = self.error_handler_spec.get(request.blueprint) @@ -1377,6 +1375,9 @@ class Flask(_PackageBoundObject): if isinstance(e, typecheck): return handler(e) + if isinstance(e, HTTPException) and not self.trap_http_exception(e): + return self.handle_http_exception(e) + reraise(exc_type, exc_value, tb) def handle_exception(self, e): diff --git a/flask/testsuite/basic.py b/flask/testsuite/basic.py index 71a1f832..8c13bda7 100644 --- a/flask/testsuite/basic.py +++ b/flask/testsuite/basic.py @@ -18,7 +18,7 @@ from datetime import datetime from threading import Thread from flask.testsuite import FlaskTestCase, emits_module_deprecation_warning from flask._compat import text_type -from werkzeug.exceptions import BadRequest, NotFound +from werkzeug.exceptions import BadRequest, NotFound, Forbidden from werkzeug.http import parse_date from werkzeug.routing import BuildError @@ -626,12 +626,18 @@ class BasicFunctionalityTestCase(FlaskTestCase): @app.errorhandler(500) def internal_server_error(e): return 'internal server error', 500 + @app.errorhandler(Forbidden) + def forbidden(e): + return 'forbidden', 403 @app.route('/') def index(): flask.abort(404) @app.route('/error') def error(): 1 // 0 + @app.route('/forbidden') + def error2(): + flask.abort(403) c = app.test_client() rv = c.get('/') self.assert_equal(rv.status_code, 404) @@ -639,6 +645,9 @@ class BasicFunctionalityTestCase(FlaskTestCase): rv = c.get('/error') self.assert_equal(rv.status_code, 500) self.assert_equal(b'internal server error', rv.data) + rv = c.get('/forbidden') + self.assert_equal(rv.status_code, 403) + self.assert_equal(b'forbidden', rv.data) def test_before_request_and_routing_errors(self): app = flask.Flask(__name__) @@ -668,6 +677,36 @@ class BasicFunctionalityTestCase(FlaskTestCase): c = app.test_client() self.assert_equal(c.get('/').data, b'42') + def test_http_error_subclass_handling(self): + class ForbiddenSubclass(Forbidden): + pass + + app = flask.Flask(__name__) + @app.errorhandler(ForbiddenSubclass) + def handle_forbidden_subclass(e): + self.assert_true(isinstance(e, ForbiddenSubclass)) + return 'banana' + @app.errorhandler(403) + def handle_forbidden_subclass(e): + self.assert_false(isinstance(e, ForbiddenSubclass)) + self.assert_true(isinstance(e, Forbidden)) + return 'apple' + + @app.route('/1') + def index1(): + raise ForbiddenSubclass() + @app.route('/2') + def index2(): + flask.abort(403) + @app.route('/3') + def index3(): + raise Forbidden() + + c = app.test_client() + self.assert_equal(c.get('/1').data, b'banana') + self.assert_equal(c.get('/2').data, b'apple') + self.assert_equal(c.get('/3').data, b'apple') + def test_trapping_of_bad_request_key_errors(self): app = flask.Flask(__name__) app.testing = True