diff --git a/flask.py b/flask.py index 10feaf22..6d557b26 100644 --- a/flask.py +++ b/flask.py @@ -12,6 +12,7 @@ from __future__ import with_statement import os import sys +import types from jinja2 import Environment, PackageLoader, FileSystemLoader from werkzeug import Request as RequestBase, Response as ResponseBase, \ @@ -639,6 +640,27 @@ class Flask(object): return f return decorator + def template_filter(self, arg=None): + """A decorator that is used to register custom template filter. + You can specify a name for the filter, otherwise the function + name will be used. Example:: + + @app.template_filter + def reverse(s): + return s[::-1] + + :param name: the optional name of the filter, otherwise the + function name will be used. + """ + if type(arg) is types.FunctionType: + self.jinja_env.filters[arg.__name__] = arg + return arg + + def decorator(f): + self.jinja_env.filters[arg or f.__name__] = f + return f + return decorator + def before_request(self, f): """Registers a function to run before each request.""" self.before_request_funcs.append(f) diff --git a/tests/flask_tests.py b/tests/flask_tests.py index 917f4168..91edb9c2 100644 --- a/tests/flask_tests.py +++ b/tests/flask_tests.py @@ -311,6 +311,32 @@ class TemplatingTestCase(unittest.TestCase): macro = flask.get_template_attribute('_macro.html', 'hello') assert macro('World') == 'Hello World!' + def test_template_filter_not_called(self): + app = flask.Flask(__name__) + @app.template_filter + def my_reverse(s): + return s[::-1] + assert 'my_reverse' in app.jinja_env.filters.keys() + assert app.jinja_env.filters['my_reverse'] == my_reverse + assert app.jinja_env.filters['my_reverse']('abcd') == 'dcba' + + def test_template_filter_called(self): + app = flask.Flask(__name__) + @app.template_filter() + def my_reverse(s): + return s[::-1] + assert 'my_reverse' in app.jinja_env.filters.keys() + assert app.jinja_env.filters['my_reverse'] == my_reverse + assert app.jinja_env.filters['my_reverse']('abcd') == 'dcba' + + def test_template_filter_with_name(self): + app = flask.Flask(__name__) + @app.template_filter('strrev') + def my_reverse(s): + return s[::-1] + assert 'strrev' in app.jinja_env.filters.keys() + assert app.jinja_env.filters['strrev'] == my_reverse + assert app.jinja_env.filters['strrev']('abcd') == 'dcba' def suite(): from minitwit_tests import MiniTwitTestCase