Skip to content

Instantly share code, notes, and snippets.

@internetimagery
Last active July 8, 2022 08:35
Show Gist options
  • Save internetimagery/e56e60bb59ba6a0aa8654919122bebee to your computer and use it in GitHub Desktop.
Save internetimagery/e56e60bb59ba6a0aa8654919122bebee to your computer and use it in GitHub Desktop.
Localize imports in batch. So modules are imported as they are needed. Useful when trying to improve import times.
# Permission to use, copy, modify, and/or distribute this software for any purpose with or without
# fee is hereby granted.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO
# THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
# AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER
# RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
# OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
# Example:
# python /path/to/local_import.py --write /path/to/source.py
import os
import ast
import token
import asttokens
from fnmatch import fnmatch
from functools import partial
from contextlib import closing
from multiprocessing import Pool
def main(files, include="", exclude="", write=False):
with closing(Pool(3)) as pool:
filter_ = partial(_filter_import, include, exclude)
generate = partial(_generate_new_source, filter_)
for path, new_source in pool.imap(generate, _expand_paths(files)):
if not path:
continue
if args.write:
print("Writing:", path)
with open(path, "w") as h:
h.write(new_source)
else:
print(">>>", path)
print(new_source)
print("<<<")
def _generate_new_source(filter_, filepath):
try:
with open(filepath, "r") as h:
source_code = h.read()
source_ast = asttokens.ASTTokens(source_code, parse=True)
except Exception as err:
print("Failed to load {}: {}".format(filepath, err))
return None, None
modifications = tuple(_get_modifications(source_ast.tree, {}, source_ast, filter_))
if not modifications:
return None, None
new_source = asttokens.util.replace(source_code, modifications)
return filepath, new_source
def _get_modifications(parent, global_unused, parser, filter_):
local_imports = {}
local_names = set()
functions = []
skip = set()
for child in _iter_scope(parent):
# Collect all imports in this scope
if isinstance(child, (ast.ImportFrom, ast.Import)):
for name in child.names:
local_imports[name.asname or name.name] = child
# Track what we will recurse into next
if isinstance(child, ast.FunctionDef):
functions.append(child)
tosearch = child.decorator_list
elif isinstance(child, ast.ClassDef):
tosearch = child.bases + child.decorator_list
else:
tosearch = [child]
# Look for used references so we can tell what is in this scope
for check in tosearch:
for n in asttokens.util.walk(check):
if n in skip:
continue
if isinstance(n, (ast.Attribute, ast.Name)):
text = parser.get_text(n).split(".")
for i in range(len(text)):
local_names.add(".".join(text[0 : i + 1]))
if isinstance(n, ast.Attribute):
# Skip diving further into attribute. We don't need the
# individual pieces of the stack of attributes
skip.add(n.value)
# Run through all imports that are out of this scope
# but in use. Bring them in!
# Track imports that are still unused
local_unused = {}
for name, node in global_unused.items():
if name not in local_names:
local_unused[name] = node
continue
if not filter_(node, name):
continue
# Remove original import
start, end = parser.get_text_range(node)
yield start, end, ""
# Add import locally
tok = parent.body[0].first_token
if tok[0] == token.STRING: # Docstring
tok = parent.body[1].first_token
start = tok.startpos
yield start, start, "\n{0}# Import automatically moved local\n{0}{1}\n{0}\n{0}".format(
parser.text[start-1] * tok[2][1], _format_import(node, name)
)
# Add imports that are declared in this scope but unused
local_unused.update(
(name, node) for name, node in local_imports.items() if name not in local_names
)
# Build a new set of unused imports
for func in functions:
for item in _get_modifications(func, local_unused, parser, filter_):
yield item
def _expand_paths(paths):
for path in paths:
if os.path.isfile(path):
yield path
elif os.path.isdir(path):
for root, _, files in os.walk(path):
for f in files:
if f.endswith(".py"):
yield os.path.join(root, f)
def _filter_import(include, exclude, node, alias):
text = [n.name for n in node.names if (n.asname or n.name) == alias][0]
if isinstance(node, ast.ImportFrom):
text = "{}.{}".format(node.module, text)
if exclude and fnmatch(text, exclude):
return False
if include and not fnmatch(text, include):
return False
return True
def _format_import(node, alias):
name = [n for n in node.names if (n.asname or n.name) == alias][0]
text = "import {}".format(_format_name(name))
if isinstance(node, ast.ImportFrom):
text = "from {} {}".format(node.module, text)
return text
def _format_name(node):
if node.asname:
return "{} as {}".format(node.name, node.asname)
return node.name
def _iter_scope(parent):
for child in parent.body:
yield child
if isinstance(child, ast.ClassDef):
for node in _iter_scope(child):
yield node
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Move imports locally in python source files."
)
parser.add_argument("files", nargs="+", help="Paths to source files")
parser.add_argument(
"--include", "-i", help="Include only specified imports. Can use * wildcard."
)
parser.add_argument(
"--exclude", "-e", help="Exclude specified imports. Can use * wildcard."
)
parser.add_argument(
"--write",
"-w",
action="store_true",
default=False,
help="Write changes to the source files instead of printing",
)
args = parser.parse_args()
main(args.files, args.include, args.exclude, args.write)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment