# Script which modifies source code away from the deprecated "flask.ext" # format. # # Run in the terminal by typing: `python flaskext_migrate.py ` # # Author: Keyan Pishdadian 2015 from redbaron import RedBaron import sys def read_source(input_file): """Parses the input_file into a RedBaron FST.""" with open(input_file, "r") as source_code: red = RedBaron(source_code.read()) return red def write_source(red, input_file): """Overwrites the input_file once the FST has been modified.""" with open(input_file, "w") as source_code: source_code.write(red.dumps()) def fix_imports(red): """Wrapper which fixes "from" style imports and then "import" style.""" red = fix_standard_imports(red) red = fix_from_imports(red) return red def fix_from_imports(red): """ Converts "from" style imports to not use "flask.ext". Handles: Case 1: from flask.ext.foo import bam --> from flask_foo import bam Case 2: from flask.ext import foo --> import flask_foo as foo """ from_imports = red.find_all("FromImport") for x, node in enumerate(from_imports): values = node.value if (values[0].value == 'flask') and (values[1].value == 'ext'): # Case 1 if len(node.value) == 3: package = values[2].value modules = node.modules() module_string = _get_modules(modules) if len(modules) > 1: node.replace("from flask_%s import %s" % (package, module_string)) else: name = node.names()[0] node.replace("from flask_%s import %s as %s" % (package, module_string, name)) # Case 2 else: module = node.modules()[0] node.replace("import flask_%s as %s" % (module, module)) return red def fix_standard_imports(red): """ Handles import modification in the form: import flask.ext.foo" --> import flask_foo """ imports = red.find_all("ImportNode") for x, node in enumerate(imports): try: if (node.value[0].value[0].value == 'flask' and node.value[0].value[1].value == 'ext'): package = node.value[0].value[2].value name = node.names()[0].split('.')[-1] if name == package: node.replace("import flask_%s" % (package)) else: node.replace("import flask_%s as %s" % (package, name)) except IndexError: pass return red def _get_modules(module): """ Takes a list of modules and converts into a string. The module list can include parens, this function checks each element in the list, if there is a paren then it does not add a comma before the next element. Otherwise a comma and space is added. This is to preserve module imports which are multi-line and/or occur within parens. While also not affecting imports which are not enclosed. """ modules_string = [cur + ', ' if cur.isalnum() and next.isalnum() else cur for (cur, next) in zip(module, module[1:]+[''])] return ''.join(modules_string) def fix_function_calls(red): """ Modifies function calls in the source to reflect import changes. Searches the AST for AtomtrailerNodes and replaces them. """ atoms = red.find_all("Atomtrailers") for x, node in enumerate(atoms): try: if (node.value[0].value == 'flask' and node.value[1].value == 'ext'): params = _form_function_call(node) node.replace("flask_%s%s" % (node.value[2], params)) except IndexError: pass return red def _form_function_call(node): """ Reconstructs function call strings when making attribute access calls. """ node_vals = node.value output = "." for x, param in enumerate(node_vals[3::]): if param.dumps()[0] == "(": output = output[0:-1] + param.dumps() return output else: output += param.dumps() + "." def check_user_input(): """Exits and gives error message if no argument is passed in the shell.""" if len(sys.argv) < 2: sys.exit("No filename was included, please try again.") def fix_tester(ast): """Wrapper which allows for testing when not running from shell.""" ast = fix_imports(ast) ast = fix_function_calls(ast) return ast.dumps() def fix(): """Wrapper for user argument checking and import fixing.""" check_user_input() input_file = sys.argv[1] ast = read_source(input_file) ast = fix_imports(ast) ast = fix_function_calls(ast) write_source(ast, input_file) if __name__ == "__main__": fix()