diff --git a/flask.py b/flask.py index e569aaaa..94072ce9 100644 --- a/flask.py +++ b/flask.py @@ -13,6 +13,7 @@ from __future__ import with_statement import os import sys +from itertools import chain from jinja2 import Environment, PackageLoader, FileSystemLoader from werkzeug import Request as RequestBase, Response as ResponseBase, \ LocalStack, LocalProxy, create_environ, SharedDataMiddleware, \ @@ -58,6 +59,12 @@ class Request(RequestBase): endpoint = view_args = routing_exception = None + @property + def module(self): + """The name of the current module""" + if self.endpoint and '.' in self.endpoint: + return self.endpoint.rsplit('.', 1)[0] + @cached_property def json(self): """If the mimetype is `application/json` this will contain the @@ -141,11 +148,13 @@ def url_for(endpoint, **values): :param values: the variable arguments of the URL rule """ ctx = _request_ctx_stack.top - if '.' not in endpoint and \ - ctx.request.endpoint is not None \ - and '.' in ctx.request.endpoint: - endpoint = ctx.request.endpoint.rsplit('.', 1)[0] + '.' + endpoint - return ctx.url_adapter.build(endpoint.lstrip('.'), values) + if '.' not in endpoint: + mod = ctx.request.module + if mod is not None: + endpoint = mod + '.' + endpoint + elif endpoint.startswith('.'): + endpoint = endpoint[1:] + return ctx.url_adapter.build(endpoint, values) def get_template_attribute(template_name, attribute): @@ -311,6 +320,16 @@ class Module(object): def add_url_rule(self, rule, endpoint, view_func=None, **options): self._record(self._register_rule, (rule, endpoint, view_func, options)) + def before_request(self, f): + self._record(lambda s: s.app.before_request_funcs + .setdefault(self.name, []).append(f), ()) + return f + + def after_request(self, f): + self._record(lambda s: s.app.after_request_funcs + .setdefault(self.name, []).append(f), ()) + return f + def _record(self, func, args): self._register_events.append((func, args)) @@ -402,14 +421,14 @@ class Flask(object): #: getting hold of the currently logged in user. #: To register a function here, use the :meth:`before_request` #: decorator. - self.before_request_funcs = [] + self.before_request_funcs = {} #: a list of functions that are called at the end of the #: request. The function is passed the current response #: object and modify it in place or replace it. #: To register a function here use the :meth:`after_request` #: decorator. - self.after_request_funcs = [] + self.after_request_funcs = {} #: a list of functions that are called without arguments #: to populate the template context. Each returns a dictionary @@ -698,12 +717,12 @@ class Flask(object): def before_request(self, f): """Registers a function to run before each request.""" - self.before_request_funcs.append(f) + self.before_request_funcs.setdefault(None, []).append(f) return f def after_request(self, f): """Register a function to be run after each request.""" - self.after_request_funcs.append(f) + self.after_request_funcs.setdefault(None, []).append(f) return f def context_processor(self, f): @@ -768,7 +787,11 @@ class Flask(object): if it was the return value from the view and further request handling is stopped. """ - for func in self.before_request_funcs: + funcs = self.before_request_funcs.get(None, ()) + mod = request.module + if mod and mod in self.before_request_funcs: + funcs = chain(funcs, self.before_request_funcs[mod]) + for func in funcs: rv = func() if rv is not None: return rv @@ -782,10 +805,14 @@ class Flask(object): :return: a new response object or the same, has to be an instance of :attr:`response_class`. """ - session = _request_ctx_stack.top.session - if not isinstance(session, _NullSession): - self.save_session(session, response) - for handler in self.after_request_funcs: + ctx = _request_ctx_stack.top + mod = ctx.request.module + if not isinstance(ctx.session, _NullSession): + self.save_session(ctx.session, response) + funcs = self.after_request_funcs.get(None, ()) + if mod and mod in self.after_request_funcs: + funcs = chain(funcs, self.after_request_funcs[mod]) + for handler in funcs: response = handler(response) return response diff --git a/tests/flask_tests.py b/tests/flask_tests.py index 5f07fbe8..2f372514 100644 --- a/tests/flask_tests.py +++ b/tests/flask_tests.py @@ -298,6 +298,31 @@ class TemplatingTestCase(unittest.TestCase): assert macro('World') == 'Hello World!' +class ModuleTestCase(unittest.TestCase): + + def test_basic_module(self): + app = flask.Flask(__name__) + admin = flask.Module('admin', url_prefix='/admin') + @admin.route('/') + def index(): + return 'admin index' + @admin.route('/login') + def login(): + return 'admin login' + @admin.route('/logout') + def logout(): + return 'admin logout' + @app.route('/') + def index(): + return 'the index' + app.register_module('admin', admin) + c = app.test_client() + assert c.get('/').data == 'the index' + assert c.get('/admin/').data == 'admin index' + assert c.get('/admin/login').data == 'admin login' + assert c.get('/admin/logout').data == 'admin logout' + + def suite(): from minitwit_tests import MiniTwitTestCase from flaskr_tests import FlaskrTestCase