# -*- coding: utf-8 -*- """ flask.testsuite.signals ~~~~~~~~~~~~~~~~~~~~~~~ Signalling. :copyright: (c) 2014 by Armin Ronacher. :license: BSD, see LICENSE for more details. """ import flask import unittest from flask.testsuite import FlaskTestCase class SignalsTestCase(FlaskTestCase): def test_template_rendered(self): app = flask.Flask(__name__) @app.route('/') def index(): return flask.render_template('simple_template.html', whiskey=42) recorded = [] def record(sender, template, context): recorded.append((template, context)) flask.template_rendered.connect(record, app) try: app.test_client().get('/') self.assert_equal(len(recorded), 1) template, context = recorded[0] self.assert_equal(template.name, 'simple_template.html') self.assert_equal(context['whiskey'], 42) finally: flask.template_rendered.disconnect(record, app) def test_request_signals(self): app = flask.Flask(__name__) calls = [] def before_request_signal(sender): calls.append('before-signal') def after_request_signal(sender, response): self.assert_equal(response.data, b'stuff') calls.append('after-signal') @app.before_request def before_request_handler(): calls.append('before-handler') @app.after_request def after_request_handler(response): calls.append('after-handler') response.data = 'stuff' return response @app.route('/') def index(): calls.append('handler') return 'ignored anyway' flask.request_started.connect(before_request_signal, app) flask.request_finished.connect(after_request_signal, app) try: rv = app.test_client().get('/') self.assert_equal(rv.data, b'stuff') self.assert_equal(calls, ['before-signal', 'before-handler', 'handler', 'after-handler', 'after-signal']) finally: flask.request_started.disconnect(before_request_signal, app) flask.request_finished.disconnect(after_request_signal, app) def test_request_exception_signal(self): app = flask.Flask(__name__) recorded = [] @app.route('/') def index(): 1 // 0 def record(sender, exception): recorded.append(exception) flask.got_request_exception.connect(record, app) try: self.assert_equal(app.test_client().get('/').status_code, 500) self.assert_equal(len(recorded), 1) self.assert_true(isinstance(recorded[0], ZeroDivisionError)) finally: flask.got_request_exception.disconnect(record, app) def test_appcontext_signals(self): app = flask.Flask(__name__) recorded = [] def record_push(sender, **kwargs): recorded.append('push') def record_pop(sender, **kwargs): recorded.append('pop') @app.route('/') def index(): return 'Hello' flask.appcontext_pushed.connect(record_push, app) flask.appcontext_popped.connect(record_pop, app) try: with app.test_client() as c: rv = c.get('/') self.assert_equal(rv.data, b'Hello') self.assert_equal(recorded, ['push']) self.assert_equal(recorded, ['push', 'pop']) finally: flask.appcontext_pushed.disconnect(record_push, app) flask.appcontext_popped.disconnect(record_pop, app) def test_flash_signal(self): app = flask.Flask(__name__) app.config['SECRET_KEY'] = 'secret' @app.route('/') def index(): flask.flash('This is a flash message', category='notice') return flask.redirect('/other') recorded = [] def record(sender, message, category): recorded.append((message, category)) flask.message_flashed.connect(record, app) try: client = app.test_client() with client.session_transaction(): client.get('/') self.assert_equal(len(recorded), 1) message, category = recorded[0] self.assert_equal(message, 'This is a flash message') self.assert_equal(category, 'notice') finally: flask.message_flashed.disconnect(record, app) def suite(): suite = unittest.TestSuite() if flask.signals_available: suite.addTest(unittest.makeSuite(SignalsTestCase)) return suite