Browse Source

Add a test and cover edge case with parens

pull/1342/head
Keyan Pishdadian 10 years ago
parent
commit
9cbe83ef0d
  1. 35
      scripts/flaskext_migrate.py
  2. 51
      scripts/test_import_migration.py

35
scripts/flaskext_migrate.py

@ -50,14 +50,14 @@ def fix_from_imports(red):
if len(node.value) == 3: if len(node.value) == 3:
package = values[2].value package = values[2].value
modules = node.modules() modules = node.modules()
module_string = _get_modules(modules)
if len(modules) > 1: if len(modules) > 1:
r = "{}," * len(modules)
node.replace("from flask_%s import %s" node.replace("from flask_%s import %s"
% (package, r.format(*modules)[:-1])) % (package, module_string))
else: else:
name = node.names()[0] name = node.names()[0]
node.replace("from flask_%s import %s as %s" node.replace("from flask_%s import %s as %s"
% (package, modules.pop(), name)) % (package, module_string, name))
# Case 2 # Case 2
else: else:
module = node.modules()[0] module = node.modules()[0]
@ -88,13 +88,36 @@ def fix_standard_imports(red):
return red return red
def _get_modules(module):
"""
Takes a list of modules and converts into a string
The module list can include parens, this function checks each element in
the list, if there is a paren then it does not add a comma before the next
element. Otherwise a comma and space is added. This is to preserve module
imports which are multi-line and/or occur within parens. While also not
affecting imports which are not enclosed.
"""
modules_string = [cur + ', ' if cur.isalnum() and next.isalnum()
else cur
for (cur, next) in zip(module, module[1:]+[''])]
return ''.join(modules_string)
def check_user_input():
"""Exits and gives error message if no argument is passed in the shell."""
if len(sys.argv) < 2:
sys.exit("No filename was included, please try again.")
def fix(ast): def fix(ast):
"""Wrapper which allows for testing when not running from shell""" """Wrapper which allows for testing when not running from shell."""
return fix_imports(ast).dumps() return fix_imports(ast).dumps()
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) < 2: check_user_input()
sys.exit("No filename was included, please try again.")
input_file = sys.argv[1] input_file = sys.argv[1]
ast = read_source(input_file) ast = read_source(input_file)
ast = fix_imports(ast) ast = fix_imports(ast)

51
scripts/test_import_migration.py

@ -0,0 +1,51 @@
# Tester for the flaskext_migrate.py module located in flask/scripts/
#
# Author: Keyan Pishdadian
import pytest
from redbaron import RedBaron
import flaskext_migrate as migrate
def test_simple_from_import():
red = RedBaron("from flask.ext import foo")
output = migrate.fix(red)
assert output == "import flask_foo as foo"
def test_from_to_from_import():
red = RedBaron("from flask.ext.foo import bar")
output = migrate.fix(red)
assert output == "from flask_foo import bar as bar"
def test_multiple_import():
red = RedBaron("from flask.ext.foo import bar, foobar, something")
output = migrate.fix(red)
assert output == "from flask_foo import bar, foobar, something"
def test_multiline_import():
red = RedBaron("from flask.ext.foo import \
bar,\
foobar,\
something")
output = migrate.fix(red)
assert output == "from flask_foo import bar, foobar, something"
def test_module_import():
red = RedBaron("import flask.ext.foo")
output = migrate.fix(red)
assert output == "import flask_foo"
def test_module_import():
red = RedBaron("from flask.ext.foo import bar as baz")
output = migrate.fix(red)
assert output == "from flask_foo import bar as baz"
def test_parens_import():
red = RedBaron("from flask.ext.foo import (bar, foo, foobar)")
output = migrate.fix(red)
assert output == "from flask_foo import (bar, foo, foobar)"
Loading…
Cancel
Save