Browse Source

Add test and logic for attribute access calls

pull/1342/head
Keyan Pishdadian 10 years ago
parent
commit
cd6ec40947
  1. 24
      scripts/flaskext_migrate.py
  2. 10
      scripts/test_import_migration.py

24
scripts/flaskext_migrate.py

@ -76,9 +76,12 @@ def fix_standard_imports(red):
try:
if (node.value[0].value[0].value == 'flask' and
node.value[0].value[1].value == 'ext'):
package = node.value[0].value[2]
package = node.value[0].value[2].value
name = node.names()[0].split('.')[-1]
node.replace("import flask_%s as %s" % (package, name))
if name == package:
node.replace("import flask_%s" % (package))
else:
node.replace("import flask_%s as %s" % (package, name))
except IndexError:
pass
@ -113,13 +116,28 @@ def fix_function_calls(red):
try:
if (node.value[0].value == 'flask' and
node.value[1].value == 'ext'):
node.replace("flask_%s%s" % (node.value[3], node.value[3]))
params = _form_function_call(node)
node.replace("flask_%s%s" % (node.value[2], params))
except IndexError:
pass
return red
def _form_function_call(node):
"""
Reconstructs function call strings when making attribute access calls.
"""
node_vals = node.value
output = "."
for x, param in enumerate(node_vals[3::]):
if param.dumps()[0] == "(":
output = output[0:-1] + param.dumps()
return output
else:
output += param.dumps() + "."
def check_user_input():
"""Exits and gives error message if no argument is passed in the shell."""
if len(sys.argv) < 2:

10
scripts/test_import_migration.py

@ -36,7 +36,7 @@ def test_multiline_import():
def test_module_import():
red = RedBaron("import flask.ext.foo")
output = migrate.fix_tester(red)
assert output == "import flask_foo as foo"
assert output == "import flask_foo"
def test_named_module_import():
@ -61,3 +61,11 @@ def test_function_call_migration():
red = RedBaron("flask.ext.foo(var)")
output = migrate.fix_tester(red)
assert output == "flask_foo(var)"
def test_nested_function_call_migration():
red = RedBaron("import flask.ext.foo\n\n"
"flask.ext.foo.bar(var)")
output = migrate.fix_tester(red)
assert output == ("import flask_foo\n\n"
"flask_foo.bar(var)")

Loading…
Cancel
Save