diff --git a/flask/__init__.py b/flask/__init__.py index 6e7883fb..52eec667 100644 --- a/flask/__init__.py +++ b/flask/__init__.py @@ -34,7 +34,7 @@ from .templating import render_template, render_template_string # the signals from .signals import signals_available, template_rendered, request_started, \ request_finished, got_request_exception, request_tearing_down, \ - appcontext_tearing_down + appcontext_tearing_down, message_flashed # We're not exposing the actual json module but a convenient wrapper around # it. diff --git a/flask/helpers.py b/flask/helpers.py index fe651004..d24dde6b 100644 --- a/flask/helpers.py +++ b/flask/helpers.py @@ -35,6 +35,7 @@ except ImportError: from jinja2 import FileSystemLoader +from .signals import message_flashed from .globals import session, _request_ctx_stack, _app_ctx_stack, \ current_app, request @@ -361,6 +362,8 @@ def flash(message, category='message'): flashes = session.get('_flashes', []) flashes.append((category, message)) session['_flashes'] = flashes + message_flashed.send(current_app._get_current_object(), + message=message, category=category) def get_flashed_messages(with_categories=False, category_filter=[]): diff --git a/flask/signals.py b/flask/signals.py index 78a77bd5..14b728c6 100644 --- a/flask/signals.py +++ b/flask/signals.py @@ -50,3 +50,4 @@ request_finished = _signals.signal('request-finished') request_tearing_down = _signals.signal('request-tearing-down') got_request_exception = _signals.signal('got-request-exception') appcontext_tearing_down = _signals.signal('appcontext-tearing-down') +message_flashed = _signals.signal('message-flashed') diff --git a/flask/testsuite/signals.py b/flask/testsuite/signals.py index da1a68ca..0e5d0cea 100644 --- a/flask/testsuite/signals.py +++ b/flask/testsuite/signals.py @@ -8,6 +8,8 @@ :copyright: (c) 2011 by Armin Ronacher. :license: BSD, see LICENSE for more details. """ +from __future__ import with_statement + import flask import unittest from flask.testsuite import FlaskTestCase @@ -95,6 +97,31 @@ class SignalsTestCase(FlaskTestCase): finally: flask.got_request_exception.disconnect(record, app) + def test_flash_signal(self): + app = flask.Flask(__name__) + app.config['SECRET_KEY'] = 'secret' + + @app.route('/') + def index(): + flask.flash('This is a flash message', category='notice') + return flask.redirect('/other') + + recorded = [] + def record(sender, message, category): + recorded.append((message, category)) + + flask.message_flashed.connect(record, app) + try: + client = app.test_client() + with client.session_transaction(): + client.get('/') + self.assert_equal(len(recorded), 1) + message, category = recorded[0] + self.assert_equal(message, 'This is a flash message') + self.assert_equal(category, 'notice') + finally: + flask.message_flashed.disconnect(record, app) + def suite(): suite = unittest.TestSuite()