Browse Source

Tests pass now.

pull/1165/head
Markus Unterwaditzer 10 years ago
parent
commit
8fa5e32d9a
  1. 36
      examples/blueprintexample/blueprintexample_test.py
  2. 33
      examples/blueprintexample/test_blueprintexample.py
  3. 76
      examples/flaskr/flaskr_tests.py
  4. 77
      examples/flaskr/test_flaskr.py
  5. 150
      examples/minitwit/minitwit_tests.py
  6. 151
      examples/minitwit/test_minitwit.py
  7. 3
      setup.cfg
  8. 169
      tests/__init__.py
  9. 8
      tests/test_appctx.py
  10. 2
      tests/test_apps/importerror.py
  11. 47
      tests/test_basic.py
  12. 22
      tests/test_blueprints.py
  13. 16
      tests/test_config.py
  14. 6
      tests/test_deprecations.py
  15. 8
      tests/test_examples.py
  16. 6
      tests/test_ext.py
  17. 34
      tests/test_helpers.py
  18. 27
      tests/test_regression.py
  19. 6
      tests/test_reqctx.py
  20. 6
      tests/test_signals.py
  21. 6
      tests/test_subclassing.py
  22. 6
      tests/test_templating.py
  23. 50
      tests/test_testing.py
  24. 6
      tests/test_views.py

36
examples/blueprintexample/blueprintexample_test.py

@ -1,36 +0,0 @@
# -*- coding: utf-8 -*-
"""
Blueprint Example Tests
~~~~~~~~~~~~~~
Tests the Blueprint example app
"""
import blueprintexample
import unittest
class BlueprintExampleTestCase(unittest.TestCase):
def setUp(self):
self.app = blueprintexample.app.test_client()
def test_urls(self):
r = self.app.get('/')
self.assertEquals(r.status_code, 200)
r = self.app.get('/hello')
self.assertEquals(r.status_code, 200)
r = self.app.get('/world')
self.assertEquals(r.status_code, 200)
#second blueprint instance
r = self.app.get('/pages/hello')
self.assertEquals(r.status_code, 200)
r = self.app.get('/pages/world')
self.assertEquals(r.status_code, 200)
if __name__ == '__main__':
unittest.main()

33
examples/blueprintexample/test_blueprintexample.py

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
"""
Blueprint Example Tests
~~~~~~~~~~~~~~
Tests the Blueprint example app
"""
import pytest
import blueprintexample
@pytest.fixture
def client():
return blueprintexample.app.test_client()
def test_urls(client):
r = client.get('/')
assert r.status_code == 200
r = client.get('/hello')
assert r.status_code == 200
r = client.get('/world')
assert r.status_code == 200
# second blueprint instance
r = client.get('/pages/hello')
assert r.status_code == 200
r = client.get('/pages/world')
assert r.status_code == 200

76
examples/flaskr/flaskr_tests.py

@ -1,76 +0,0 @@
# -*- coding: utf-8 -*-
"""
Flaskr Tests
~~~~~~~~~~~~
Tests the Flaskr application.
:copyright: (c) 2014 by Armin Ronacher.
:license: BSD, see LICENSE for more details.
"""
import os
import flaskr
import unittest
import tempfile
class FlaskrTestCase(unittest.TestCase):
def setUp(self):
"""Before each test, set up a blank database"""
self.db_fd, flaskr.app.config['DATABASE'] = tempfile.mkstemp()
flaskr.app.config['TESTING'] = True
self.app = flaskr.app.test_client()
with flaskr.app.app_context():
flaskr.init_db()
def tearDown(self):
"""Get rid of the database again after each test."""
os.close(self.db_fd)
os.unlink(flaskr.app.config['DATABASE'])
def login(self, username, password):
return self.app.post('/login', data=dict(
username=username,
password=password
), follow_redirects=True)
def logout(self):
return self.app.get('/logout', follow_redirects=True)
# testing functions
def test_empty_db(self):
"""Start with a blank database."""
rv = self.app.get('/')
assert b'No entries here so far' in rv.data
def test_login_logout(self):
"""Make sure login and logout works"""
rv = self.login(flaskr.app.config['USERNAME'],
flaskr.app.config['PASSWORD'])
assert b'You were logged in' in rv.data
rv = self.logout()
assert b'You were logged out' in rv.data
rv = self.login(flaskr.app.config['USERNAME'] + 'x',
flaskr.app.config['PASSWORD'])
assert b'Invalid username' in rv.data
rv = self.login(flaskr.app.config['USERNAME'],
flaskr.app.config['PASSWORD'] + 'x')
assert b'Invalid password' in rv.data
def test_messages(self):
"""Test that messages work"""
self.login(flaskr.app.config['USERNAME'],
flaskr.app.config['PASSWORD'])
rv = self.app.post('/add', data=dict(
title='<Hello>',
text='<strong>HTML</strong> allowed here'
), follow_redirects=True)
assert b'No entries here so far' not in rv.data
assert b'&lt;Hello&gt;' in rv.data
assert b'<strong>HTML</strong> allowed here' in rv.data
if __name__ == '__main__':
unittest.main()

77
examples/flaskr/test_flaskr.py

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
"""
Flaskr Tests
~~~~~~~~~~~~
Tests the Flaskr application.
:copyright: (c) 2014 by Armin Ronacher.
:license: BSD, see LICENSE for more details.
"""
import pytest
import os
import flaskr
import tempfile
@pytest.fixture
def client(request):
db_fd, flaskr.app.config['DATABASE'] = tempfile.mkstemp()
flaskr.app.config['TESTING'] = True
client = flaskr.app.test_client()
with flaskr.app.app_context():
flaskr.init_db()
def teardown():
os.close(db_fd)
os.unlink(flaskr.app.config['DATABASE'])
request.addfinalizer(teardown)
return client
def login(client, username, password):
return client.post('/login', data=dict(
username=username,
password=password
), follow_redirects=True)
def logout(client):
return client.get('/logout', follow_redirects=True)
def test_empty_db(client):
"""Start with a blank database."""
rv = client.get('/')
assert b'No entries here so far' in rv.data
def test_login_logout(client):
"""Make sure login and logout works"""
rv = login(client, flaskr.app.config['USERNAME'],
flaskr.app.config['PASSWORD'])
assert b'You were logged in' in rv.data
rv = logout(client)
assert b'You were logged out' in rv.data
rv = login(client, flaskr.app.config['USERNAME'] + 'x',
flaskr.app.config['PASSWORD'])
assert b'Invalid username' in rv.data
rv = login(client, flaskr.app.config['USERNAME'],
flaskr.app.config['PASSWORD'] + 'x')
assert b'Invalid password' in rv.data
def test_messages(client):
"""Test that messages work"""
login(client, flaskr.app.config['USERNAME'],
flaskr.app.config['PASSWORD'])
rv = client.post('/add', data=dict(
title='<Hello>',
text='<strong>HTML</strong> allowed here'
), follow_redirects=True)
assert b'No entries here so far' not in rv.data
assert b'&lt;Hello&gt;' in rv.data
assert b'<strong>HTML</strong> allowed here' in rv.data

150
examples/minitwit/minitwit_tests.py

@ -1,150 +0,0 @@
# -*- coding: utf-8 -*-
"""
MiniTwit Tests
~~~~~~~~~~~~~~
Tests the MiniTwit application.
:copyright: (c) 2014 by Armin Ronacher.
:license: BSD, see LICENSE for more details.
"""
import os
import minitwit
import unittest
import tempfile
class MiniTwitTestCase(unittest.TestCase):
def setUp(self):
"""Before each test, set up a blank database"""
self.db_fd, minitwit.app.config['DATABASE'] = tempfile.mkstemp()
self.app = minitwit.app.test_client()
with minitwit.app.app_context():
minitwit.init_db()
def tearDown(self):
"""Get rid of the database again after each test."""
os.close(self.db_fd)
os.unlink(minitwit.app.config['DATABASE'])
# helper functions
def register(self, username, password, password2=None, email=None):
"""Helper function to register a user"""
if password2 is None:
password2 = password
if email is None:
email = username + '@example.com'
return self.app.post('/register', data={
'username': username,
'password': password,
'password2': password2,
'email': email,
}, follow_redirects=True)
def login(self, username, password):
"""Helper function to login"""
return self.app.post('/login', data={
'username': username,
'password': password
}, follow_redirects=True)
def register_and_login(self, username, password):
"""Registers and logs in in one go"""
self.register(username, password)
return self.login(username, password)
def logout(self):
"""Helper function to logout"""
return self.app.get('/logout', follow_redirects=True)
def add_message(self, text):
"""Records a message"""
rv = self.app.post('/add_message', data={'text': text},
follow_redirects=True)
if text:
assert b'Your message was recorded' in rv.data
return rv
# testing functions
def test_register(self):
"""Make sure registering works"""
rv = self.register('user1', 'default')
assert b'You were successfully registered ' \
b'and can login now' in rv.data
rv = self.register('user1', 'default')
assert b'The username is already taken' in rv.data
rv = self.register('', 'default')
assert b'You have to enter a username' in rv.data
rv = self.register('meh', '')
assert b'You have to enter a password' in rv.data
rv = self.register('meh', 'x', 'y')
assert b'The two passwords do not match' in rv.data
rv = self.register('meh', 'foo', email='broken')
assert b'You have to enter a valid email address' in rv.data
def test_login_logout(self):
"""Make sure logging in and logging out works"""
rv = self.register_and_login('user1', 'default')
assert b'You were logged in' in rv.data
rv = self.logout()
assert b'You were logged out' in rv.data
rv = self.login('user1', 'wrongpassword')
assert b'Invalid password' in rv.data
rv = self.login('user2', 'wrongpassword')
assert b'Invalid username' in rv.data
def test_message_recording(self):
"""Check if adding messages works"""
self.register_and_login('foo', 'default')
self.add_message('test message 1')
self.add_message('<test message 2>')
rv = self.app.get('/')
assert b'test message 1' in rv.data
assert b'&lt;test message 2&gt;' in rv.data
def test_timelines(self):
"""Make sure that timelines work"""
self.register_and_login('foo', 'default')
self.add_message('the message by foo')
self.logout()
self.register_and_login('bar', 'default')
self.add_message('the message by bar')
rv = self.app.get('/public')
assert b'the message by foo' in rv.data
assert b'the message by bar' in rv.data
# bar's timeline should just show bar's message
rv = self.app.get('/')
assert b'the message by foo' not in rv.data
assert b'the message by bar' in rv.data
# now let's follow foo
rv = self.app.get('/foo/follow', follow_redirects=True)
assert b'You are now following &#34;foo&#34;' in rv.data
# we should now see foo's message
rv = self.app.get('/')
assert b'the message by foo' in rv.data
assert b'the message by bar' in rv.data
# but on the user's page we only want the user's message
rv = self.app.get('/bar')
assert b'the message by foo' not in rv.data
assert b'the message by bar' in rv.data
rv = self.app.get('/foo')
assert b'the message by foo' in rv.data
assert b'the message by bar' not in rv.data
# now unfollow and check if that worked
rv = self.app.get('/foo/unfollow', follow_redirects=True)
assert b'You are no longer following &#34;foo&#34;' in rv.data
rv = self.app.get('/')
assert b'the message by foo' not in rv.data
assert b'the message by bar' in rv.data
if __name__ == '__main__':
unittest.main()

151
examples/minitwit/test_minitwit.py

@ -0,0 +1,151 @@
# -*- coding: utf-8 -*-
"""
MiniTwit Tests
~~~~~~~~~~~~~~
Tests the MiniTwit application.
:copyright: (c) 2014 by Armin Ronacher.
:license: BSD, see LICENSE for more details.
"""
import os
import minitwit
import tempfile
import pytest
@pytest.fixture
def client(request):
db_fd, minitwit.app.config['DATABASE'] = tempfile.mkstemp()
client = minitwit.app.test_client()
with minitwit.app.app_context():
minitwit.init_db()
def teardown():
"""Get rid of the database again after each test."""
os.close(db_fd)
os.unlink(minitwit.app.config['DATABASE'])
request.addfinalizer(teardown)
return client
def register(client, username, password, password2=None, email=None):
"""Helper function to register a user"""
if password2 is None:
password2 = password
if email is None:
email = username + '@example.com'
return client.post('/register', data={
'username': username,
'password': password,
'password2': password2,
'email': email,
}, follow_redirects=True)
def login(client, username, password):
"""Helper function to login"""
return client.post('/login', data={
'username': username,
'password': password
}, follow_redirects=True)
def register_and_login(client, username, password):
"""Registers and logs in in one go"""
register(client, username, password)
return login(client, username, password)
def logout(client):
"""Helper function to logout"""
return client.get('/logout', follow_redirects=True)
def add_message(client, text):
"""Records a message"""
rv = client.post('/add_message', data={'text': text},
follow_redirects=True)
if text:
assert b'Your message was recorded' in rv.data
return rv
def test_register(client):
"""Make sure registering works"""
rv = register(client, 'user1', 'default')
assert b'You were successfully registered ' \
b'and can login now' in rv.data
rv = register(client, 'user1', 'default')
assert b'The username is already taken' in rv.data
rv = register(client, '', 'default')
assert b'You have to enter a username' in rv.data
rv = register(client, 'meh', '')
assert b'You have to enter a password' in rv.data
rv = register(client, 'meh', 'x', 'y')
assert b'The two passwords do not match' in rv.data
rv = register(client, 'meh', 'foo', email='broken')
assert b'You have to enter a valid email address' in rv.data
def test_login_logout(client):
"""Make sure logging in and logging out works"""
rv = register_and_login(client, 'user1', 'default')
assert b'You were logged in' in rv.data
rv = logout(client)
assert b'You were logged out' in rv.data
rv = login(client, 'user1', 'wrongpassword')
assert b'Invalid password' in rv.data
rv = login(client, 'user2', 'wrongpassword')
assert b'Invalid username' in rv.data
def test_message_recording(client):
"""Check if adding messages works"""
register_and_login(client, 'foo', 'default')
add_message(client, 'test message 1')
add_message(client, '<test message 2>')
rv = client.get('/')
assert b'test message 1' in rv.data
assert b'&lt;test message 2&gt;' in rv.data
def test_timelines(client):
"""Make sure that timelines work"""
register_and_login(client, 'foo', 'default')
add_message(client, 'the message by foo')
logout(client)
register_and_login(client, 'bar', 'default')
add_message(client, 'the message by bar')
rv = client.get('/public')
assert b'the message by foo' in rv.data
assert b'the message by bar' in rv.data
# bar's timeline should just show bar's message
rv = client.get('/')
assert b'the message by foo' not in rv.data
assert b'the message by bar' in rv.data
# now let's follow foo
rv = client.get('/foo/follow', follow_redirects=True)
assert b'You are now following &#34;foo&#34;' in rv.data
# we should now see foo's message
rv = client.get('/')
assert b'the message by foo' in rv.data
assert b'the message by bar' in rv.data
# but on the user's page we only want the user's message
rv = client.get('/bar')
assert b'the message by foo' not in rv.data
assert b'the message by bar' in rv.data
rv = client.get('/foo')
assert b'the message by foo' in rv.data
assert b'the message by bar' not in rv.data
# now unfollow and check if that worked
rv = client.get('/foo/unfollow', follow_redirects=True)
assert b'You are no longer following &#34;foo&#34;' in rv.data
rv = client.get('/')
assert b'the message by foo' not in rv.data
assert b'the message by bar' in rv.data

3
setup.cfg

@ -1,3 +1,6 @@
[pytest]
norecursedirs= scripts docs
[aliases] [aliases]
release = egg_info -RDb '' release = egg_info -RDb ''

169
tests/__init__.py

@ -11,16 +11,15 @@
""" """
from __future__ import print_function from __future__ import print_function
import pytest
import os import os
import sys import sys
import flask import flask
import warnings import warnings
import unittest
from functools import update_wrapper from functools import update_wrapper
from contextlib import contextmanager from contextlib import contextmanager
from werkzeug.utils import import_string, find_modules from flask._compat import StringIO
from flask._compat import reraise, StringIO
def add_to_path(path): def add_to_path(path):
@ -43,29 +42,6 @@ def add_to_path(path):
sys.path.insert(0, path) sys.path.insert(0, path)
def iter_suites():
"""Yields all testsuites."""
for module in find_modules(__name__):
mod = import_string(module)
if hasattr(mod, 'suite'):
yield mod.suite()
def find_all_tests(suite):
"""Yields all the tests and their names from a given suite."""
suites = [suite]
while suites:
s = suites.pop()
try:
suites.extend(s)
except TypeError:
yield s, '%s.%s.%s' % (
s.__class__.__module__,
s.__class__.__name__,
s._testMethodName
)
@contextmanager @contextmanager
def catch_warnings(): def catch_warnings():
"""Catch warnings in a with block in a list""" """Catch warnings in a with block in a list"""
@ -76,6 +52,7 @@ def catch_warnings():
warnings.filters = filters[:] warnings.filters = filters[:]
old_showwarning = warnings.showwarning old_showwarning = warnings.showwarning
log = [] log = []
def showwarning(message, category, filename, lineno, file=None, line=None): def showwarning(message, category, filename, lineno, file=None, line=None):
log.append(locals()) log.append(locals())
try: try:
@ -107,12 +84,23 @@ def emits_module_deprecation_warning(f):
return update_wrapper(new_f, f) return update_wrapper(new_f, f)
class FlaskTestCase(unittest.TestCase): class TestFlask(object):
"""Baseclass for all the tests that Flask uses. Use these methods """Baseclass for all the tests that Flask uses. Use these methods
for testing instead of the camelcased ones in the baseclass for for testing instead of the camelcased ones in the baseclass for
consistency. consistency.
""" """
@pytest.fixture(autouse=True)
def setup_path(self, monkeypatch):
monkeypatch.syspath_prepend(
os.path.abspath(os.path.join(
os.path.dirname(__file__), 'test_apps'))
)
@pytest.fixture(autouse=True)
def leak_detector(self, request):
request.addfinalizer(self.ensure_clean_request_context)
def ensure_clean_request_context(self): def ensure_clean_request_context(self):
# make sure we're not leaking a request context since we are # make sure we're not leaking a request context since we are
# testing flask internally in debug mode in a few cases # testing flask internally in debug mode in a few cases
@ -121,133 +109,42 @@ class FlaskTestCase(unittest.TestCase):
leaks.append(flask._request_ctx_stack.pop()) leaks.append(flask._request_ctx_stack.pop())
self.assert_equal(leaks, []) self.assert_equal(leaks, [])
def setup_method(self, method):
self.setup()
def teardown_method(self, method):
self.teardown()
def setup(self): def setup(self):
pass pass
def teardown(self): def teardown(self):
pass pass
def setUp(self):
self.setup()
def tearDown(self):
unittest.TestCase.tearDown(self)
self.ensure_clean_request_context()
self.teardown()
def assert_equal(self, x, y): def assert_equal(self, x, y):
return self.assertEqual(x, y) assert x == y
def assert_raises(self, exc_type, callable=None, *args, **kwargs): def assert_raises(self, exc_type, callable=None, *args, **kwargs):
catcher = _ExceptionCatcher(self, exc_type) if callable:
if callable is None: return pytest.raises(exc_type, callable, *args, **kwargs)
return catcher else:
with catcher: return pytest.raises(exc_type)
callable(*args, **kwargs)
def assert_true(self, x, msg=None): def assert_true(self, x, msg=None):
self.assertTrue(x, msg) assert x
assert_ = assert_true assert_ = assert_true
def assert_false(self, x, msg=None): def assert_false(self, x, msg=None):
self.assertFalse(x, msg) assert not x
def assert_in(self, x, y): def assert_in(self, x, y):
self.assertIn(x, y) assert x in y
def assert_not_in(self, x, y): def assert_not_in(self, x, y):
self.assertNotIn(x, y) assert x not in y
def assert_isinstance(self, obj, cls): def assert_isinstance(self, obj, cls):
self.assertIsInstance(obj, cls) assert isinstance(obj, cls)
if sys.version_info[:2] == (2, 6):
def assertIn(self, x, y):
assert x in y, "%r unexpectedly not in %r" % (x, y)
def assertNotIn(self, x, y):
assert x not in y, "%r unexpectedly in %r" % (x, y)
def assertIsInstance(self, x, y):
assert isinstance(x, y), "not isinstance(%r, %r)" % (x, y)
class _ExceptionCatcher(object): def fail(self, msg):
raise AssertionError(msg)
def __init__(self, test_case, exc_type):
self.test_case = test_case
self.exc_type = exc_type
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
exception_name = self.exc_type.__name__
if exc_type is None:
self.test_case.fail('Expected exception of type %r' %
exception_name)
elif not issubclass(exc_type, self.exc_type):
reraise(exc_type, exc_value, tb)
return True
class BetterLoader(unittest.TestLoader):
"""A nicer loader that solves two problems. First of all we are setting
up tests from different sources and we're doing this programmatically
which breaks the default loading logic so this is required anyways.
Secondly this loader has a nicer interpolation for test names than the
default one so you can just do ``run-tests.py ViewTestCase`` and it
will work.
"""
def getRootSuite(self):
return suite()
def loadTestsFromName(self, name, module=None):
root = self.getRootSuite()
if name == 'suite':
return root
all_tests = []
for testcase, testname in find_all_tests(root):
if testname == name or \
testname.endswith('.' + name) or \
('.' + name + '.') in testname or \
testname.startswith(name + '.'):
all_tests.append(testcase)
if not all_tests:
raise LookupError('could not find test case for "%s"' % name)
if len(all_tests) == 1:
return all_tests[0]
rv = unittest.TestSuite()
for test in all_tests:
rv.addTest(test)
return rv
def setup_path():
add_to_path(os.path.abspath(os.path.join(
os.path.dirname(__file__), 'test_apps')))
def suite():
"""A testsuite that has all the Flask tests. You can use this
function to integrate the Flask tests into your own testsuite
in case you want to test that monkeypatches to Flask do not
break it.
"""
setup_path()
suite = unittest.TestSuite()
for other_suite in iter_suites():
suite.addTest(other_suite)
return suite
def main():
"""Runs the testsuite as command line application."""
try:
unittest.main(testLoader=BetterLoader(), defaultTest='suite')
except Exception as e:
print('Error: %s' % e)

8
tests/test_appctx.py

@ -11,10 +11,10 @@
import flask import flask
import unittest import unittest
from tests import FlaskTestCase from tests import TestFlask
class AppContextTestCase(FlaskTestCase): class TestAppContext(TestFlask):
def test_basic_url_generation(self): def test_basic_url_generation(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -109,10 +109,10 @@ class AppContextTestCase(FlaskTestCase):
return u'' return u''
c = app.test_client() c = app.test_client()
c.get('/') c.get('/')
self.assertEqual(called, ['request', 'app']) self.assert_equal(called, ['request', 'app'])
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(AppContextTestCase)) suite.addTest(unittest.makeSuite(TestAppContext))
return suite return suite

2
tests/test_apps/importerror.py

@ -1,2 +1,2 @@
# NoImportsTestCase # TestNoImports
raise NotImplementedError raise NotImplementedError

47
tests/test_basic.py

@ -17,14 +17,14 @@ import pickle
import unittest import unittest
from datetime import datetime from datetime import datetime
from threading import Thread from threading import Thread
from tests import FlaskTestCase, emits_module_deprecation_warning from tests import TestFlask, emits_module_deprecation_warning
from flask._compat import text_type from flask._compat import text_type
from werkzeug.exceptions import BadRequest, NotFound, Forbidden from werkzeug.exceptions import BadRequest, NotFound, Forbidden
from werkzeug.http import parse_date from werkzeug.http import parse_date
from werkzeug.routing import BuildError from werkzeug.routing import BuildError
class BasicFunctionalityTestCase(FlaskTestCase): class TestBasicFunctionality(TestFlask):
def test_options_work(self): def test_options_work(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -522,8 +522,8 @@ class BasicFunctionalityTestCase(FlaskTestCase):
return 'Test' return 'Test'
c = app.test_client() c = app.test_client()
resp = c.get('/') resp = c.get('/')
self.assertEqual(resp.status_code, 200) self.assert_equal(resp.status_code, 200)
self.assertEqual(resp.headers['X-Foo'], 'a header') self.assert_equal(resp.headers['X-Foo'], 'a header')
def test_teardown_request_handler(self): def test_teardown_request_handler(self):
called = [] called = []
@ -840,22 +840,22 @@ class BasicFunctionalityTestCase(FlaskTestCase):
with app.test_request_context(): with app.test_request_context():
rv = flask.make_response( rv = flask.make_response(
flask.jsonify({'msg': 'W00t'}), 400) flask.jsonify({'msg': 'W00t'}), 400)
self.assertEqual(rv.status_code, 400) self.assert_equal(rv.status_code, 400)
self.assertEqual(rv.data, b'{\n "msg": "W00t"\n}') self.assert_equal(rv.data, b'{\n "msg": "W00t"\n}')
self.assertEqual(rv.mimetype, 'application/json') self.assert_equal(rv.mimetype, 'application/json')
rv = flask.make_response( rv = flask.make_response(
flask.Response(''), 400) flask.Response(''), 400)
self.assertEqual(rv.status_code, 400) self.assert_equal(rv.status_code, 400)
self.assertEqual(rv.data, b'') self.assert_equal(rv.data, b'')
self.assertEqual(rv.mimetype, 'text/html') self.assert_equal(rv.mimetype, 'text/html')
rv = flask.make_response( rv = flask.make_response(
flask.Response('', headers={'Content-Type': 'text/html'}), flask.Response('', headers={'Content-Type': 'text/html'}),
400, [('X-Foo', 'bar')]) 400, [('X-Foo', 'bar')])
self.assertEqual(rv.status_code, 400) self.assert_equal(rv.status_code, 400)
self.assertEqual(rv.headers['Content-Type'], 'text/html') self.assert_equal(rv.headers['Content-Type'], 'text/html')
self.assertEqual(rv.headers['X-Foo'], 'bar') self.assert_equal(rv.headers['X-Foo'], 'bar')
def test_url_generation(self): def test_url_generation(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -872,7 +872,7 @@ class BasicFunctionalityTestCase(FlaskTestCase):
# Test base case, a URL which results in a BuildError. # Test base case, a URL which results in a BuildError.
with app.test_request_context(): with app.test_request_context():
self.assertRaises(BuildError, flask.url_for, 'spam') self.assert_raises(BuildError, flask.url_for, 'spam')
# Verify the error is re-raised if not the current exception. # Verify the error is re-raised if not the current exception.
try: try:
@ -883,7 +883,7 @@ class BasicFunctionalityTestCase(FlaskTestCase):
try: try:
raise RuntimeError('Test case where BuildError is not current.') raise RuntimeError('Test case where BuildError is not current.')
except RuntimeError: except RuntimeError:
self.assertRaises(BuildError, app.handle_url_build_error, error, 'spam', {}) self.assert_raises(BuildError, app.handle_url_build_error, error, 'spam', {})
# Test a custom handler. # Test a custom handler.
def handler(error, endpoint, values): def handler(error, endpoint, values):
@ -936,7 +936,7 @@ class BasicFunctionalityTestCase(FlaskTestCase):
def test_request_locals(self): def test_request_locals(self):
self.assert_equal(repr(flask.g), '<LocalProxy unbound>') self.assert_equal(repr(flask.g), '<LocalProxy unbound>')
self.assertFalse(flask.g) self.assert_false(flask.g)
def test_test_app_proper_environ(self): def test_test_app_proper_environ(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -1205,9 +1205,9 @@ class BasicFunctionalityTestCase(FlaskTestCase):
assert flask.url_for('123') == '/bar/123' assert flask.url_for('123') == '/bar/123'
c = app.test_client() c = app.test_client()
self.assertEqual(c.get('/foo/').data, b'foo') self.assert_equal(c.get('/foo/').data, b'foo')
self.assertEqual(c.get('/bar/').data, b'bar') self.assert_equal(c.get('/bar/').data, b'bar')
self.assertEqual(c.get('/bar/123').data, b'123') self.assert_equal(c.get('/bar/123').data, b'123')
def test_preserve_only_once(self): def test_preserve_only_once(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -1286,7 +1286,7 @@ class BasicFunctionalityTestCase(FlaskTestCase):
self.assert_equal(sorted(flask.g), ['bar', 'foo']) self.assert_equal(sorted(flask.g), ['bar', 'foo'])
class SubdomainTestCase(FlaskTestCase): class TestSubdomain(TestFlask):
def test_basic_support(self): def test_basic_support(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -1355,10 +1355,3 @@ class SubdomainTestCase(FlaskTestCase):
self.assert_equal(rv.data, b'a') self.assert_equal(rv.data, b'a')
rv = app.test_client().open('/b/') rv = app.test_client().open('/b/')
self.assert_equal(rv.data, b'b') self.assert_equal(rv.data, b'b')
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(BasicFunctionalityTestCase))
suite.addTest(unittest.makeSuite(SubdomainTestCase))
return suite

22
tests/test_blueprints.py

@ -11,13 +11,13 @@
import flask import flask
import unittest import unittest
from tests import FlaskTestCase from tests import TestFlask
from flask._compat import text_type from flask._compat import text_type
from werkzeug.http import parse_cache_control_header from werkzeug.http import parse_cache_control_header
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
class BlueprintTestCase(FlaskTestCase): class TestBlueprint(TestFlask):
def test_blueprint_specific_error_handling(self): def test_blueprint_specific_error_handling(self):
frontend = flask.Blueprint('frontend', __name__) frontend = flask.Blueprint('frontend', __name__)
@ -303,11 +303,11 @@ class BlueprintTestCase(FlaskTestCase):
return flask.request.endpoint return flask.request.endpoint
c = app.test_client() c = app.test_client()
self.assertEqual(c.get('/').data, b'index') self.assert_equal(c.get('/').data, b'index')
self.assertEqual(c.get('/py/foo').data, b'bp.foo') self.assert_equal(c.get('/py/foo').data, b'bp.foo')
self.assertEqual(c.get('/py/bar').data, b'bp.bar') self.assert_equal(c.get('/py/bar').data, b'bp.bar')
self.assertEqual(c.get('/py/bar/123').data, b'bp.123') self.assert_equal(c.get('/py/bar/123').data, b'bp.123')
self.assertEqual(c.get('/py/bar/foo').data, b'bp.bar_foo') self.assert_equal(c.get('/py/bar/foo').data, b'bp.bar_foo')
def test_route_decorator_custom_endpoint_with_dots(self): def test_route_decorator_custom_endpoint_with_dots(self):
bp = flask.Blueprint('bp', __name__) bp = flask.Blueprint('bp', __name__)
@ -337,14 +337,14 @@ class BlueprintTestCase(FlaskTestCase):
def foo_foo_foo(): def foo_foo_foo():
pass pass
self.assertRaises( self.assert_raises(
AssertionError, AssertionError,
lambda: bp.add_url_rule( lambda: bp.add_url_rule(
'/bar/123', endpoint='bar.123', view_func=foo_foo_foo '/bar/123', endpoint='bar.123', view_func=foo_foo_foo
) )
) )
self.assertRaises( self.assert_raises(
AssertionError, AssertionError,
bp.route('/bar/123', endpoint='bar.123'), bp.route('/bar/123', endpoint='bar.123'),
lambda: None lambda: None
@ -354,7 +354,7 @@ class BlueprintTestCase(FlaskTestCase):
app.register_blueprint(bp, url_prefix='/py') app.register_blueprint(bp, url_prefix='/py')
c = app.test_client() c = app.test_client()
self.assertEqual(c.get('/py/foo').data, b'bp.foo') self.assert_equal(c.get('/py/foo').data, b'bp.foo')
# The rule's didn't actually made it through # The rule's didn't actually made it through
rv = c.get('/py/bar') rv = c.get('/py/bar')
assert rv.status_code == 404 assert rv.status_code == 404
@ -581,5 +581,5 @@ class BlueprintTestCase(FlaskTestCase):
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(BlueprintTestCase)) suite.addTest(unittest.makeSuite(TestBlueprint))
return suite return suite

16
tests/test_config.py

@ -15,21 +15,21 @@ import flask
import pkgutil import pkgutil
import unittest import unittest
from contextlib import contextmanager from contextlib import contextmanager
from tests import FlaskTestCase from tests import TestFlask
from flask._compat import PY2 from flask._compat import PY2
# config keys used for the ConfigTestCase # config keys used for the TestConfig
TEST_KEY = 'foo' TEST_KEY = 'foo'
SECRET_KEY = 'devkey' SECRET_KEY = 'devkey'
class ConfigTestCase(FlaskTestCase): class TestConfig(TestFlask):
def common_object_test(self, app): def common_object_test(self, app):
self.assert_equal(app.secret_key, 'devkey') self.assert_equal(app.secret_key, 'devkey')
self.assert_equal(app.config['TEST_KEY'], 'foo') self.assert_equal(app.config['TEST_KEY'], 'foo')
self.assert_not_in('ConfigTestCase', app.config) self.assert_not_in('TestConfig', app.config)
def test_config_from_file(self): def test_config_from_file(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -117,7 +117,7 @@ class ConfigTestCase(FlaskTestCase):
self.assert_true(msg.endswith("missing.cfg'")) self.assert_true(msg.endswith("missing.cfg'"))
else: else:
self.fail('expected IOError') self.fail('expected IOError')
self.assertFalse(app.config.from_envvar('FOO_SETTINGS', silent=True)) self.assert_false(app.config.from_envvar('FOO_SETTINGS', silent=True))
finally: finally:
os.environ = env os.environ = env
@ -207,7 +207,7 @@ def patch_pkgutil_get_loader(wrapper_class=LimitedLoaderMockWrapper):
pkgutil.get_loader = old_get_loader pkgutil.get_loader = old_get_loader
class InstanceTestCase(FlaskTestCase): class TestInstance(TestFlask):
def test_explicit_instance_paths(self): def test_explicit_instance_paths(self):
here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))
@ -379,6 +379,6 @@ class InstanceTestCase(FlaskTestCase):
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ConfigTestCase)) suite.addTest(unittest.makeSuite(TestConfig))
suite.addTest(unittest.makeSuite(InstanceTestCase)) suite.addTest(unittest.makeSuite(TestInstance))
return suite return suite

6
tests/test_deprecations.py

@ -11,14 +11,14 @@
import flask import flask
import unittest import unittest
from tests import FlaskTestCase, catch_warnings from tests import TestFlask, catch_warnings
class DeprecationsTestCase(FlaskTestCase): class TestDeprecations(TestFlask):
"""not used currently""" """not used currently"""
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(DeprecationsTestCase)) suite.addTest(unittest.makeSuite(TestDeprecations))
return suite return suite

8
tests/test_examples.py

@ -24,15 +24,15 @@ def suite():
setup_path() setup_path()
suite = unittest.TestSuite() suite = unittest.TestSuite()
try: try:
from minitwit_tests import MiniTwitTestCase from minitwit_tests import TestMiniTwit
except ImportError: except ImportError:
pass pass
else: else:
suite.addTest(unittest.makeSuite(MiniTwitTestCase)) suite.addTest(unittest.makeSuite(TestMiniTwit))
try: try:
from flaskr_tests import FlaskrTestCase from flaskr_tests import TestFlaskr
except ImportError: except ImportError:
pass pass
else: else:
suite.addTest(unittest.makeSuite(FlaskrTestCase)) suite.addTest(unittest.makeSuite(TestFlaskr))
return suite return suite

6
tests/test_ext.py

@ -15,10 +15,10 @@ try:
from imp import reload as reload_module from imp import reload as reload_module
except ImportError: except ImportError:
reload_module = reload reload_module = reload
from tests import FlaskTestCase from tests import TestFlask
from flask._compat import PY2 from flask._compat import PY2
class ExtImportHookTestCase(FlaskTestCase): class TestExtImportHook(TestFlask):
def setup(self): def setup(self):
# we clear this out for various reasons. The most important one is # we clear this out for various reasons. The most important one is
@ -132,5 +132,5 @@ class ExtImportHookTestCase(FlaskTestCase):
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ExtImportHookTestCase)) suite.addTest(unittest.makeSuite(TestExtImportHook))
return suite return suite

34
tests/test_helpers.py

@ -13,7 +13,7 @@ import os
import flask import flask
import unittest import unittest
from logging import StreamHandler from logging import StreamHandler
from tests import FlaskTestCase, catch_warnings, catch_stderr from tests import TestFlask, catch_warnings, catch_stderr
from werkzeug.http import parse_cache_control_header, parse_options_header from werkzeug.http import parse_cache_control_header, parse_options_header
from flask._compat import StringIO, text_type from flask._compat import StringIO, text_type
@ -27,7 +27,7 @@ def has_encoding(name):
return False return False
class JSONTestCase(FlaskTestCase): class TestJSON(TestFlask):
def test_json_bad_requests(self): def test_json_bad_requests(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -148,7 +148,7 @@ class JSONTestCase(FlaskTestCase):
rv = c.post('/', data=flask.json.dumps({ rv = c.post('/', data=flask.json.dumps({
'x': {'_foo': 42} 'x': {'_foo': 42}
}), content_type='application/json') }), content_type='application/json')
self.assertEqual(rv.data, b'"<42>"') self.assert_equal(rv.data, b'"<42>"')
def test_modified_url_encoding(self): def test_modified_url_encoding(self):
class ModifiedRequest(flask.Request): class ModifiedRequest(flask.Request):
@ -240,7 +240,7 @@ class JSONTestCase(FlaskTestCase):
except AssertionError: except AssertionError:
self.assert_equal(lines, sorted_by_str) self.assert_equal(lines, sorted_by_str)
class SendfileTestCase(FlaskTestCase): class TestSendfile(TestFlask):
def test_send_file_regular(self): def test_send_file_regular(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -422,7 +422,7 @@ class SendfileTestCase(FlaskTestCase):
rv.close() rv.close()
class LoggingTestCase(FlaskTestCase): class TestLogging(TestFlask):
def test_logger_cache(self): def test_logger_cache(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -450,7 +450,7 @@ class LoggingTestCase(FlaskTestCase):
with catch_stderr() as err: with catch_stderr() as err:
c.get('/') c.get('/')
out = err.getvalue() out = err.getvalue()
self.assert_in('WARNING in helpers [', out) self.assert_in('WARNING in test_helpers [', out)
self.assert_in(os.path.basename(__file__.rsplit('.', 1)[0] + '.py'), out) self.assert_in(os.path.basename(__file__.rsplit('.', 1)[0] + '.py'), out)
self.assert_in('the standard library is dead', out) self.assert_in('the standard library is dead', out)
self.assert_in('this is a debug statement', out) self.assert_in('this is a debug statement', out)
@ -572,7 +572,7 @@ class LoggingTestCase(FlaskTestCase):
'/myview/create') '/myview/create')
class NoImportsTestCase(FlaskTestCase): class TestNoImports(TestFlask):
"""Test Flasks are created without import. """Test Flasks are created without import.
Avoiding ``__import__`` helps create Flask instances where there are errors Avoiding ``__import__`` helps create Flask instances where there are errors
@ -590,7 +590,7 @@ class NoImportsTestCase(FlaskTestCase):
self.fail('Flask(import_name) is importing import_name.') self.fail('Flask(import_name) is importing import_name.')
class StreamingTestCase(FlaskTestCase): class TestStreaming(TestFlask):
def test_streaming_with_context(self): def test_streaming_with_context(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -604,7 +604,7 @@ class StreamingTestCase(FlaskTestCase):
return flask.Response(flask.stream_with_context(generate())) return flask.Response(flask.stream_with_context(generate()))
c = app.test_client() c = app.test_client()
rv = c.get('/?name=World') rv = c.get('/?name=World')
self.assertEqual(rv.data, b'Hello World!') self.assert_equal(rv.data, b'Hello World!')
def test_streaming_with_context_as_decorator(self): def test_streaming_with_context_as_decorator(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -619,7 +619,7 @@ class StreamingTestCase(FlaskTestCase):
return flask.Response(generate()) return flask.Response(generate())
c = app.test_client() c = app.test_client()
rv = c.get('/?name=World') rv = c.get('/?name=World')
self.assertEqual(rv.data, b'Hello World!') self.assert_equal(rv.data, b'Hello World!')
def test_streaming_with_context_and_custom_close(self): def test_streaming_with_context_and_custom_close(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -645,16 +645,16 @@ class StreamingTestCase(FlaskTestCase):
Wrapper(generate()))) Wrapper(generate())))
c = app.test_client() c = app.test_client()
rv = c.get('/?name=World') rv = c.get('/?name=World')
self.assertEqual(rv.data, b'Hello World!') self.assert_equal(rv.data, b'Hello World!')
self.assertEqual(called, [42]) self.assert_equal(called, [42])
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
if flask.json_available: if flask.json_available:
suite.addTest(unittest.makeSuite(JSONTestCase)) suite.addTest(unittest.makeSuite(TestJSON))
suite.addTest(unittest.makeSuite(SendfileTestCase)) suite.addTest(unittest.makeSuite(TestSendfile))
suite.addTest(unittest.makeSuite(LoggingTestCase)) suite.addTest(unittest.makeSuite(TestLogging))
suite.addTest(unittest.makeSuite(NoImportsTestCase)) suite.addTest(unittest.makeSuite(TestNoImports))
suite.addTest(unittest.makeSuite(StreamingTestCase)) suite.addTest(unittest.makeSuite(TestStreaming))
return suite return suite

27
tests/test_regression.py

@ -9,14 +9,15 @@
:license: BSD, see LICENSE for more details. :license: BSD, see LICENSE for more details.
""" """
import pytest
import os import os
import gc import gc
import sys import sys
import flask import flask
import threading import threading
import unittest
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from tests import FlaskTestCase from tests import TestFlask
_gc_lock = threading.Lock() _gc_lock = threading.Lock()
@ -51,13 +52,16 @@ class _NoLeakAsserter(object):
gc.enable() gc.enable()
class MemoryTestCase(FlaskTestCase): @pytest.mark.skipif(os.environ.get('RUN_FLASK_MEMORY_TESTS') != '1',
reason='Turned off due to envvar.')
class TestMemory(TestFlask):
def assert_no_leak(self): def assert_no_leak(self):
return _NoLeakAsserter(self) return _NoLeakAsserter(self)
def test_memory_consumption(self): def test_memory_consumption(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@app.route('/') @app.route('/')
def index(): def index():
return flask.render_template('simple_template.html', whiskey=42) return flask.render_template('simple_template.html', whiskey=42)
@ -84,33 +88,28 @@ class MemoryTestCase(FlaskTestCase):
safe_join('/foo', '..') safe_join('/foo', '..')
class ExceptionTestCase(FlaskTestCase): class TestException(TestFlask):
def test_aborting(self): def test_aborting(self):
class Foo(Exception): class Foo(Exception):
whatever = 42 whatever = 42
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.testing = True app.testing = True
@app.errorhandler(Foo) @app.errorhandler(Foo)
def handle_foo(e): def handle_foo(e):
return str(e.whatever) return str(e.whatever)
@app.route('/') @app.route('/')
def index(): def index():
raise flask.abort(flask.redirect(flask.url_for('test'))) raise flask.abort(flask.redirect(flask.url_for('test')))
@app.route('/test') @app.route('/test')
def test(): def test():
raise Foo() raise Foo()
with app.test_client() as c: with app.test_client() as c:
rv = c.get('/') rv = c.get('/')
self.assertEqual(rv.headers['Location'], 'http://localhost/test') self.assert_equal(rv.headers['Location'], 'http://localhost/test')
rv = c.get('/test') rv = c.get('/test')
self.assertEqual(rv.data, b'42') self.assert_equal(rv.data, b'42')
def suite():
suite = unittest.TestSuite()
if os.environ.get('RUN_FLASK_MEMORY_TESTS') == '1':
suite.addTest(unittest.makeSuite(MemoryTestCase))
suite.addTest(unittest.makeSuite(ExceptionTestCase))
return suite

6
tests/test_reqctx.py

@ -15,10 +15,10 @@ try:
from greenlet import greenlet from greenlet import greenlet
except ImportError: except ImportError:
greenlet = None greenlet = None
from tests import FlaskTestCase from tests import TestFlask
class RequestContextTestCase(FlaskTestCase): class TestRequestContext(TestFlask):
def test_teardown_on_pop(self): def test_teardown_on_pop(self):
buffer = [] buffer = []
@ -197,5 +197,5 @@ class RequestContextTestCase(FlaskTestCase):
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(RequestContextTestCase)) suite.addTest(unittest.makeSuite(TestRequestContext))
return suite return suite

6
tests/test_signals.py

@ -11,10 +11,10 @@
import flask import flask
import unittest import unittest
from tests import FlaskTestCase from tests import TestFlask
class SignalsTestCase(FlaskTestCase): class TestSignals(TestFlask):
def test_template_rendered(self): def test_template_rendered(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -149,5 +149,5 @@ class SignalsTestCase(FlaskTestCase):
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
if flask.signals_available: if flask.signals_available:
suite.addTest(unittest.makeSuite(SignalsTestCase)) suite.addTest(unittest.makeSuite(TestSignals))
return suite return suite

6
tests/test_subclassing.py

@ -12,11 +12,11 @@
import flask import flask
import unittest import unittest
from logging import StreamHandler from logging import StreamHandler
from tests import FlaskTestCase from tests import TestFlask
from flask._compat import StringIO from flask._compat import StringIO
class FlaskSubclassingTestCase(FlaskTestCase): class TestFlaskSubclassing(TestFlask):
def test_suppressed_exception_logging(self): def test_suppressed_exception_logging(self):
class SuppressedFlask(flask.Flask): class SuppressedFlask(flask.Flask):
@ -42,5 +42,5 @@ class FlaskSubclassingTestCase(FlaskTestCase):
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(FlaskSubclassingTestCase)) suite.addTest(unittest.makeSuite(TestFlaskSubclassing))
return suite return suite

6
tests/test_templating.py

@ -14,10 +14,10 @@ import unittest
import logging import logging
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from tests import FlaskTestCase from tests import TestFlask
class TemplatingTestCase(FlaskTestCase): class TestTemplating(TestFlask):
def test_context_processing(self): def test_context_processing(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -348,5 +348,5 @@ class TemplatingTestCase(FlaskTestCase):
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TemplatingTestCase)) suite.addTest(unittest.makeSuite(TestTemplating))
return suite return suite

50
tests/test_testing.py

@ -8,14 +8,15 @@
:copyright: (c) 2014 by Armin Ronacher. :copyright: (c) 2014 by Armin Ronacher.
:license: BSD, see LICENSE for more details. :license: BSD, see LICENSE for more details.
""" """
import pytest
import flask import flask
import unittest import unittest
from tests import FlaskTestCase from tests import TestFlask
from flask._compat import text_type from flask._compat import text_type
class TestToolsTestCase(FlaskTestCase): class TestTestTools(TestFlask):
def test_environ_defaults_from_config(self): def test_environ_defaults_from_config(self):
app = flask.Flask(__name__) app = flask.Flask(__name__)
@ -212,46 +213,45 @@ class TestToolsTestCase(FlaskTestCase):
self.assert_true('vodka' in flask.request.args) self.assert_true('vodka' in flask.request.args)
class SubdomainTestCase(FlaskTestCase): class TestSubdomain(TestFlask):
def setUp(self): @pytest.fixture
self.app = flask.Flask(__name__) def app(self, request):
self.app.config['SERVER_NAME'] = 'example.com' app = flask.Flask(__name__)
self.client = self.app.test_client() app.config['SERVER_NAME'] = 'example.com'
ctx = app.test_request_context()
ctx.push()
self._ctx = self.app.test_request_context() def teardown():
self._ctx.push() if ctx is not None:
ctx.pop()
request.addfinalizer(teardown)
return app
def tearDown(self): @pytest.fixture
if self._ctx is not None: def client(self, app):
self._ctx.pop() return app.test_client()
def test_subdomain(self): def test_subdomain(self, app, client):
@self.app.route('/', subdomain='<company_id>') @app.route('/', subdomain='<company_id>')
def view(company_id): def view(company_id):
return company_id return company_id
url = flask.url_for('view', company_id='xxx') url = flask.url_for('view', company_id='xxx')
response = self.client.get(url) response = client.get(url)
self.assert_equal(200, response.status_code) self.assert_equal(200, response.status_code)
self.assert_equal(b'xxx', response.data) self.assert_equal(b'xxx', response.data)
def test_nosubdomain(self): def test_nosubdomain(self, app, client):
@self.app.route('/<company_id>') @app.route('/<company_id>')
def view(company_id): def view(company_id):
return company_id return company_id
url = flask.url_for('view', company_id='xxx') url = flask.url_for('view', company_id='xxx')
response = self.client.get(url) response = client.get(url)
self.assert_equal(200, response.status_code) self.assert_equal(200, response.status_code)
self.assert_equal(b'xxx', response.data) self.assert_equal(b'xxx', response.data)
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestToolsTestCase))
suite.addTest(unittest.makeSuite(SubdomainTestCase))
return suite

6
tests/test_views.py

@ -12,10 +12,10 @@
import flask import flask
import flask.views import flask.views
import unittest import unittest
from tests import FlaskTestCase from tests import TestFlask
from werkzeug.http import parse_set_header from werkzeug.http import parse_set_header
class ViewTestCase(FlaskTestCase): class TestView(TestFlask):
def common_test(self, app): def common_test(self, app):
c = app.test_client() c = app.test_client()
@ -165,5 +165,5 @@ class ViewTestCase(FlaskTestCase):
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ViewTestCase)) suite.addTest(unittest.makeSuite(TestView))
return suite return suite

Loading…
Cancel
Save