|
|
|
@ -43,20 +43,25 @@ def fix_from_imports(red):
|
|
|
|
|
Case 2: from flask.ext import foo --> import flask_foo as foo |
|
|
|
|
""" |
|
|
|
|
from_imports = red.find_all("FromImport") |
|
|
|
|
for x in range(len(from_imports)): |
|
|
|
|
values = from_imports[x].value |
|
|
|
|
for x, node in enumerate(from_imports): |
|
|
|
|
values = node.value |
|
|
|
|
if (values[0].value == 'flask') and (values[1].value == 'ext'): |
|
|
|
|
# Case 1 |
|
|
|
|
if len(from_imports[x].value) == 3: |
|
|
|
|
if len(node.value) == 3: |
|
|
|
|
package = values[2].value |
|
|
|
|
modules = from_imports[x].modules() |
|
|
|
|
modules = node.modules() |
|
|
|
|
if len(modules) > 1: |
|
|
|
|
r = "{}," * len(modules) |
|
|
|
|
from_imports[x].replace("from flask_%s import %s" |
|
|
|
|
node.replace("from flask_%s import %s" |
|
|
|
|
% (package, r.format(*modules)[:-1])) |
|
|
|
|
else: |
|
|
|
|
name = node.names()[0] |
|
|
|
|
node.replace("from flask_%s import %s as %s" |
|
|
|
|
% (package, modules.pop(), name)) |
|
|
|
|
# Case 2 |
|
|
|
|
else: |
|
|
|
|
module = from_imports[x].modules()[0] |
|
|
|
|
from_imports[x].replace("import flask_%s as %s" |
|
|
|
|
module = node.modules()[0] |
|
|
|
|
node.replace("import flask_%s as %s" |
|
|
|
|
% (module, module)) |
|
|
|
|
return red |
|
|
|
|
|
|
|
|
@ -70,13 +75,13 @@ def fix_standard_imports(red):
|
|
|
|
|
original import statement. |
|
|
|
|
""" |
|
|
|
|
imports = red.find_all("ImportNode") |
|
|
|
|
for x in range(len(imports)): |
|
|
|
|
values = imports[x].value |
|
|
|
|
for x, node in enumerate(imports): |
|
|
|
|
try: |
|
|
|
|
if (values[x].value[0].value == 'flask' and |
|
|
|
|
values[x].value[1].value == 'ext'): |
|
|
|
|
package = values[x].value[2].value |
|
|
|
|
imports[x].replace("import flask_%s" % package) |
|
|
|
|
if (node.value[0].value == 'flask' and |
|
|
|
|
node.value[1].value == 'ext'): |
|
|
|
|
package = node.value[2].value |
|
|
|
|
name = node.names()[0] |
|
|
|
|
imports[x].replace("import flask_%s as %s" % (package, name)) |
|
|
|
|
except IndexError: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
@ -88,6 +93,8 @@ def fix(ast):
|
|
|
|
|
return fix_imports(ast).dumps() |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
if len(sys.argv) < 2: |
|
|
|
|
sys.exit("No filename was included, please try again.") |
|
|
|
|
input_file = sys.argv[1] |
|
|
|
|
ast = read_source(input_file) |
|
|
|
|
ast = fix_imports(ast) |
|
|
|
|