diff --git a/flasky.py b/flasky.py index f5fb0d2..91b0a78 100644 --- a/flasky.py +++ b/flasky.py @@ -1,10 +1,56 @@ -from flask import Flask, make_response, request +from flask import Flask, make_response, request, current_app from simplejson import dumps from pymongo import MongoClient import datetime import dateutil.parser import bson +from datetime import timedelta +from functools import update_wrapper + + +def crossdomain(origin=None, methods=None, headers=None, + max_age=21600, attach_to_all=True, + automatic_options=True): + if methods is not None: + methods = ', '.join(sorted(x.upper() for x in methods)) + if headers is not None and not isinstance(headers, basestring): + headers = ', '.join(x.upper() for x in headers) + if not isinstance(origin, basestring): + origin = ', '.join(origin) + if isinstance(max_age, timedelta): + max_age = max_age.total_seconds() + + def get_methods(): + if methods is not None: + return methods + + options_resp = current_app.make_default_options_response() + return options_resp.headers['allow'] + + def decorator(f): + def wrapped_function(*args, **kwargs): + if automatic_options and request.method == 'OPTIONS': + resp = current_app.make_default_options_response() + else: + resp = make_response(f(*args, **kwargs)) + if not attach_to_all and request.method != 'OPTIONS': + return resp + + h = resp.headers + + h['Access-Control-Allow-Origin'] = origin + h['Access-Control-Allow-Methods'] = get_methods() + h['Access-Control-Max-Age'] = str(max_age) + if headers is not None: + h['Access-Control-Allow-Headers'] = headers + return resp + + f.provide_automatic_options = False + return update_wrapper(wrapped_function, f) + return decorator + + app = Flask(__name__) client = MongoClient(**{'host': 'localhost', 'port': 27017}) db = client.showtimes @@ -16,11 +62,13 @@ miscObjHandler = lambda obj: ( @app.route('/flask/') +@crossdomain(origin='*') def hello_world(): return 'This comes from Flask ^_^' @app.route('/groups/', methods=['GET']) +@crossdomain(origin='*') def groups(): known_groups = ['sf', 'major'] r = make_response(dumps(known_groups)) @@ -30,6 +78,7 @@ def groups(): @app.route('/theaters/', methods=['GET']) @app.route('/theaters//', methods=['GET']) +@crossdomain(origin='*') def list_theaters(group=None): if not group: result = db.theater.find() @@ -43,6 +92,7 @@ def list_theaters(group=None): @app.route('/showtimes//', methods=['GET']) @app.route('/showtimes///', methods=['GET']) +@crossdomain(origin='*') def list_showtimes(group=None, code=None): day = request.args.get('d', '') q = {}