diff --git a/flask.py b/flask.py index f7ba22e6..18e4a446 100644 --- a/flask.py +++ b/flask.py @@ -68,6 +68,22 @@ class _RequestGlobals(object): pass +class _NullSession(SecureCookie): + """Class used to generate nicer error messages if sessions are not + available. Will still allow read-only access to the empty session + but fail on setting. + """ + + def _fail(self, *args, **kwargs): + raise RuntimeError('the session is unavailable because no secret ' + 'key was set. Set the secret_key on the ' + 'application to something unique and secret') + __setitem__ = __delitem__ = clear = pop = popitem = \ + update = setdefault = _fail + del _fail + + + class _RequestContext(object): """The request context contains all request relevant information. It is created at the beginning of the request and pushed to the @@ -80,6 +96,8 @@ class _RequestContext(object): self.url_adapter = app.url_map.bind_to_environ(environ) self.request = app.request_class(environ) self.session = app.open_session(self.request) + if self.session is None: + self.session = _NullSession() self.g = _RequestGlobals() self.flashes = None @@ -384,8 +402,7 @@ class Flask(object): object) :param response: an instance of :attr:`response_class` """ - if session is not None: - session.save_cookie(response, self.session_cookie_name) + session.save_cookie(response, self.session_cookie_name) def add_url_rule(self, rule, endpoint, **options): """Connects a URL rule. Works exactly like the :meth:`route` @@ -603,7 +620,7 @@ class Flask(object): instance of :attr:`response_class`. """ session = _request_ctx_stack.top.session - if session is not None: + if not isinstance(session, _NullSession): self.save_session(session, response) for handler in self.after_request_funcs: response = handler(response) diff --git a/tests/flask_tests.py b/tests/flask_tests.py index 0d73c954..b9edd366 100644 --- a/tests/flask_tests.py +++ b/tests/flask_tests.py @@ -72,6 +72,20 @@ class BasicFunctionality(unittest.TestCase): assert c.post('/set', data={'value': '42'}).data == 'value set' assert c.get('/get').data == '42' + def test_missing_session(self): + app = flask.Flask(__name__) + def expect_exception(f, *args, **kwargs): + try: + f(*args, **kwargs) + except RuntimeError, e: + assert e.args and 'session is unavailable' in e.args[0] + else: + assert False, 'expected exception' + with app.test_request_context(): + assert flask.session.get('missing_key') is None + expect_exception(flask.session.__setitem__, 'foo', 42) + expect_exception(flask.session.pop, 'foo') + def test_request_processing(self): app = flask.Flask(__name__) evts = []