Browse Source

Basic module support is working, but does not look very nice.

pull/1638/head
Armin Ronacher 15 years ago
parent
commit
e0148a00c0
  1. 55
      flask.py
  2. 25
      tests/flask_tests.py

55
flask.py

@ -13,6 +13,7 @@ from __future__ import with_statement
import os import os
import sys import sys
from itertools import chain
from jinja2 import Environment, PackageLoader, FileSystemLoader from jinja2 import Environment, PackageLoader, FileSystemLoader
from werkzeug import Request as RequestBase, Response as ResponseBase, \ from werkzeug import Request as RequestBase, Response as ResponseBase, \
LocalStack, LocalProxy, create_environ, SharedDataMiddleware, \ LocalStack, LocalProxy, create_environ, SharedDataMiddleware, \
@ -58,6 +59,12 @@ class Request(RequestBase):
endpoint = view_args = routing_exception = None endpoint = view_args = routing_exception = None
@property
def module(self):
"""The name of the current module"""
if self.endpoint and '.' in self.endpoint:
return self.endpoint.rsplit('.', 1)[0]
@cached_property @cached_property
def json(self): def json(self):
"""If the mimetype is `application/json` this will contain the """If the mimetype is `application/json` this will contain the
@ -141,11 +148,13 @@ def url_for(endpoint, **values):
:param values: the variable arguments of the URL rule :param values: the variable arguments of the URL rule
""" """
ctx = _request_ctx_stack.top ctx = _request_ctx_stack.top
if '.' not in endpoint and \ if '.' not in endpoint:
ctx.request.endpoint is not None \ mod = ctx.request.module
and '.' in ctx.request.endpoint: if mod is not None:
endpoint = ctx.request.endpoint.rsplit('.', 1)[0] + '.' + endpoint endpoint = mod + '.' + endpoint
return ctx.url_adapter.build(endpoint.lstrip('.'), values) elif endpoint.startswith('.'):
endpoint = endpoint[1:]
return ctx.url_adapter.build(endpoint, values)
def get_template_attribute(template_name, attribute): def get_template_attribute(template_name, attribute):
@ -311,6 +320,16 @@ class Module(object):
def add_url_rule(self, rule, endpoint, view_func=None, **options): def add_url_rule(self, rule, endpoint, view_func=None, **options):
self._record(self._register_rule, (rule, endpoint, view_func, options)) self._record(self._register_rule, (rule, endpoint, view_func, options))
def before_request(self, f):
self._record(lambda s: s.app.before_request_funcs
.setdefault(self.name, []).append(f), ())
return f
def after_request(self, f):
self._record(lambda s: s.app.after_request_funcs
.setdefault(self.name, []).append(f), ())
return f
def _record(self, func, args): def _record(self, func, args):
self._register_events.append((func, args)) self._register_events.append((func, args))
@ -402,14 +421,14 @@ class Flask(object):
#: getting hold of the currently logged in user. #: getting hold of the currently logged in user.
#: To register a function here, use the :meth:`before_request` #: To register a function here, use the :meth:`before_request`
#: decorator. #: decorator.
self.before_request_funcs = [] self.before_request_funcs = {}
#: a list of functions that are called at the end of the #: a list of functions that are called at the end of the
#: request. The function is passed the current response #: request. The function is passed the current response
#: object and modify it in place or replace it. #: object and modify it in place or replace it.
#: To register a function here use the :meth:`after_request` #: To register a function here use the :meth:`after_request`
#: decorator. #: decorator.
self.after_request_funcs = [] self.after_request_funcs = {}
#: a list of functions that are called without arguments #: a list of functions that are called without arguments
#: to populate the template context. Each returns a dictionary #: to populate the template context. Each returns a dictionary
@ -698,12 +717,12 @@ class Flask(object):
def before_request(self, f): def before_request(self, f):
"""Registers a function to run before each request.""" """Registers a function to run before each request."""
self.before_request_funcs.append(f) self.before_request_funcs.setdefault(None, []).append(f)
return f return f
def after_request(self, f): def after_request(self, f):
"""Register a function to be run after each request.""" """Register a function to be run after each request."""
self.after_request_funcs.append(f) self.after_request_funcs.setdefault(None, []).append(f)
return f return f
def context_processor(self, f): def context_processor(self, f):
@ -768,7 +787,11 @@ class Flask(object):
if it was the return value from the view and further if it was the return value from the view and further
request handling is stopped. request handling is stopped.
""" """
for func in self.before_request_funcs: funcs = self.before_request_funcs.get(None, ())
mod = request.module
if mod and mod in self.before_request_funcs:
funcs = chain(funcs, self.before_request_funcs[mod])
for func in funcs:
rv = func() rv = func()
if rv is not None: if rv is not None:
return rv return rv
@ -782,10 +805,14 @@ class Flask(object):
:return: a new response object or the same, has to be an :return: a new response object or the same, has to be an
instance of :attr:`response_class`. instance of :attr:`response_class`.
""" """
session = _request_ctx_stack.top.session ctx = _request_ctx_stack.top
if not isinstance(session, _NullSession): mod = ctx.request.module
self.save_session(session, response) if not isinstance(ctx.session, _NullSession):
for handler in self.after_request_funcs: self.save_session(ctx.session, response)
funcs = self.after_request_funcs.get(None, ())
if mod and mod in self.after_request_funcs:
funcs = chain(funcs, self.after_request_funcs[mod])
for handler in funcs:
response = handler(response) response = handler(response)
return response return response

25
tests/flask_tests.py

@ -298,6 +298,31 @@ class TemplatingTestCase(unittest.TestCase):
assert macro('World') == 'Hello World!' assert macro('World') == 'Hello World!'
class ModuleTestCase(unittest.TestCase):
def test_basic_module(self):
app = flask.Flask(__name__)
admin = flask.Module('admin', url_prefix='/admin')
@admin.route('/')
def index():
return 'admin index'
@admin.route('/login')
def login():
return 'admin login'
@admin.route('/logout')
def logout():
return 'admin logout'
@app.route('/')
def index():
return 'the index'
app.register_module('admin', admin)
c = app.test_client()
assert c.get('/').data == 'the index'
assert c.get('/admin/').data == 'admin index'
assert c.get('/admin/login').data == 'admin login'
assert c.get('/admin/logout').data == 'admin logout'
def suite(): def suite():
from minitwit_tests import MiniTwitTestCase from minitwit_tests import MiniTwitTestCase
from flaskr_tests import FlaskrTestCase from flaskr_tests import FlaskrTestCase

Loading…
Cancel
Save