Browse Source

Add test for naming module and fix logic to cover

pull/1342/head
Keyan Pishdadian 10 years ago
parent
commit
b759aa2b95
  1. 39
      scripts/flaskext_migrate.py
  2. 11
      tests/test_import_migration.py

39
scripts/flaskext_migrate.py

@ -43,21 +43,26 @@ def fix_from_imports(red):
Case 2: from flask.ext import foo --> import flask_foo as foo Case 2: from flask.ext import foo --> import flask_foo as foo
""" """
from_imports = red.find_all("FromImport") from_imports = red.find_all("FromImport")
for x in range(len(from_imports)): for x, node in enumerate(from_imports):
values = from_imports[x].value values = node.value
if (values[0].value == 'flask') and (values[1].value == 'ext'): if (values[0].value == 'flask') and (values[1].value == 'ext'):
# Case 1 # Case 1
if len(from_imports[x].value) == 3: if len(node.value) == 3:
package = values[2].value package = values[2].value
modules = from_imports[x].modules() modules = node.modules()
r = "{}," * len(modules) if len(modules) > 1:
from_imports[x].replace("from flask_%s import %s" r = "{}," * len(modules)
% (package, r.format(*modules)[:-1])) node.replace("from flask_%s import %s"
% (package, r.format(*modules)[:-1]))
else:
name = node.names()[0]
node.replace("from flask_%s import %s as %s"
% (package, modules.pop(), name))
# Case 2 # Case 2
else: else:
module = from_imports[x].modules()[0] module = node.modules()[0]
from_imports[x].replace("import flask_%s as %s" node.replace("import flask_%s as %s"
% (module, module)) % (module, module))
return red return red
@ -70,13 +75,13 @@ def fix_standard_imports(red):
original import statement. original import statement.
""" """
imports = red.find_all("ImportNode") imports = red.find_all("ImportNode")
for x in range(len(imports)): for x, node in enumerate(imports):
values = imports[x].value
try: try:
if (values[x].value[0].value == 'flask' and if (node.value[0].value == 'flask' and
values[x].value[1].value == 'ext'): node.value[1].value == 'ext'):
package = values[x].value[2].value package = node.value[2].value
imports[x].replace("import flask_%s" % package) name = node.names()[0]
imports[x].replace("import flask_%s as %s" % (package, name))
except IndexError: except IndexError:
pass pass
@ -88,6 +93,8 @@ def fix(ast):
return fix_imports(ast).dumps() return fix_imports(ast).dumps()
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) < 2:
sys.exit("No filename was included, please try again.")
input_file = sys.argv[1] input_file = sys.argv[1]
ast = read_source(input_file) ast = read_source(input_file)
ast = fix_imports(ast) ast = fix_imports(ast)

11
tests/test_import_migration.py

@ -1,7 +1,8 @@
# Tester for the flaskext_migrate.py module located in flask/scripts/ # Tester for the flaskext_migrate.py module located in flask/scripts/
# #
# Author: Keyan Pishdadian # Author: Keyan Pishdadian
import sys
sys.path.append('scripts')
import pytest import pytest
from redbaron import RedBaron from redbaron import RedBaron
import flaskext_migrate as migrate import flaskext_migrate as migrate
@ -16,7 +17,7 @@ def test_simple_from_import():
def test_from_to_from_import(): def test_from_to_from_import():
red = RedBaron("from flask.ext.foo import bar") red = RedBaron("from flask.ext.foo import bar")
output = migrate.fix(red) output = migrate.fix(red)
assert output == "from flask_foo import bar" assert output == "from flask_foo import bar as bar"
def test_multiple_import(): def test_multiple_import():
@ -38,3 +39,9 @@ def test_module_import():
red = RedBaron("import flask.ext.foo") red = RedBaron("import flask.ext.foo")
output = migrate.fix(red) output = migrate.fix(red)
assert output == "import flask_foo" assert output == "import flask_foo"
def test_module_import():
red = RedBaron("from flask.ext.foo import bar as baz")
output = migrate.fix(red)
assert output == "from flask_foo import bar as baz"

Loading…
Cancel
Save