Skip to content

Instantly share code, notes, and snippets.

@vjeranc
Created January 15, 2026 19:57
Show Gist options
  • Select an option

  • Save vjeranc/ef99c2c5faeaf9d492ec68f633746789 to your computer and use it in GitHub Desktop.

Select an option

Save vjeranc/ef99c2c5faeaf9d492ec68f633746789 to your computer and use it in GitHub Desktop.
if helper called only inside 1 fn, inline def
import ast
import os
import re
from collections import defaultdict
from typing import Any
def get_project_files(root: str = ".") -> list[str]:
files: list[str] = []
for dirpath, dirnames, filenames in os.walk(root):
# Skip hidden/venv dirs
dirnames[:] = [
d
for d in dirnames
if not d.startswith(".")
and d != "__pycache__"
and d != "venv"
and d != "env"
]
for f in filenames:
if f.endswith(".py"):
files.append(os.path.join(dirpath, f))
return files
class UsageVisitor(ast.NodeVisitor):
def __init__(self) -> None:
self.usages: dict[str, set[str]] = defaultdict(set)
self.header_usages: dict[str, set[str]] = defaultdict(set)
self.current_function: str | None = None
self.top_level_function: str | None = None
self.in_header: bool = False
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
prev_current = self.current_function
prev_top = self.top_level_function
self.current_function = node.name
if self.top_level_function is None:
self.top_level_function = node.name
# Visit header parts with in_header=True
old_in_header = self.in_header
self.in_header = True
for decorator in node.decorator_list:
self.visit(decorator)
self.visit(node.args)
if node.returns:
self.visit(node.returns)
# Visit body with in_header=False
self.in_header = False
for b in node.body:
self.visit(b)
self.in_header = old_in_header
self.current_function = prev_current
self.top_level_function = prev_top
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
prev_current = self.current_function
prev_top = self.top_level_function
self.current_function = node.name
if self.top_level_function is None:
self.top_level_function = node.name
old_in_header = self.in_header
self.in_header = True
for decorator in node.decorator_list:
self.visit(decorator)
self.visit(node.args)
if node.returns:
self.visit(node.returns)
self.in_header = False
for b in node.body:
self.visit(b)
self.in_header = old_in_header
self.current_function = prev_current
self.top_level_function = prev_top
def visit_Name(self, node: ast.Name) -> None:
if isinstance(node.ctx, ast.Load) and self.top_level_function is not None:
self.usages[node.id].add(self.top_level_function)
if self.in_header:
self.header_usages[self.top_level_function].add(node.id)
def process_file(file_path: str, all_files_content: dict[str, str]) -> bool:
def get_source_range(node: Any, lines: list[str]) -> tuple[int, int]:
# Start line
start = node.lineno - 1
# Check for decorators if function
if hasattr(node, "decorator_list") and node.decorator_list:
start = node.decorator_list[0].lineno - 1
# End line
end = node.end_lineno # inclusive
return start, end
def indent_lines(text_lines: list[str], indent_str: str) -> list[str]:
return [indent_str + line for line in text_lines]
def rename_in_lines(lines: list[str], old_name: str, new_name: str) -> list[str]:
# Very simple renaming using regex on lines.
# Safe because we only rename inside the function scope we just constructed.
# But it might hit strings or comments.
# Better: use AST to find offsets? Too hard with text lines.
# Let's use word-boundary regex.
pattern = re.compile(r"\b" + re.escape(old_name) + r"\b")
new_lines: list[str] = []
for line in lines:
new_lines.append(pattern.sub(new_name, line))
return new_lines
with open(file_path, "r", encoding="utf-8") as f:
original_source = f.read()
lines: list[str] = original_source.splitlines(keepends=True)
if not lines:
return False
try:
tree = ast.parse(original_source)
except SyntaxError:
return False
# Collect top-level nodes of interest (Functions and Assignments)
# We maintain order to allow correct dependency ordering in moves.
top_level_nodes: list[tuple[str, Any, str]] = [] # (name, node, type)
top_level_names: set[str] = set()
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
top_level_nodes.append((node.name, node, "func"))
top_level_names.add(node.name)
elif isinstance(node, ast.Assign):
# Check for simple assignments: name = ...
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
name = node.targets[0].id
top_level_nodes.append((name, node, "var"))
top_level_names.add(name)
elif isinstance(node, ast.AnnAssign):
# Check for simple annotated assignments: name: type = ...
if isinstance(node.target, ast.Name):
name = node.target.id
top_level_nodes.append((name, node, "var"))
top_level_names.add(name)
if not top_level_nodes:
return False
visitor = UsageVisitor()
visitor.visit(tree)
moves: dict[str, str] = {} # helper_name -> target_func_name
# Determine candidates for moving
for name, node, _ in top_level_nodes:
if name.startswith("__"):
continue
callers = visitor.usages.get(name, set())
# Filter self-recursion (funcs) or self-ref (vars?)
callers.discard(name)
if len(callers) == 1:
target_name = list(callers)[0]
if target_name in top_level_names:
# Check if used in target's header (e.g. default args, decorators)
# If so, cannot move inside body.
if name in visitor.header_usages.get(target_name, set()):
continue
# Check module-level usage (global scope usage)
module_level_usage = False
for b_node in tree.body:
if not isinstance(b_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Use a temporary visitor or manual walk to check for Name loads
for n in ast.walk(b_node):
# Ensure it's not the definition itself
if n is node:
continue
# For Assign nodes, definition is in targets.
# For AnnAssign, definition is in target.
# We just check if usage is Load.
if (
isinstance(n, ast.Name)
and n.id == name
and isinstance(n.ctx, ast.Load)
):
module_level_usage = True
break
if module_level_usage:
break
if module_level_usage:
continue
# Check usage in other files
is_used_elsewhere = False
for other_path, content in all_files_content.items():
if other_path == file_path:
continue
if re.search(r"\b" + re.escape(name) + r"\b", content):
is_used_elsewhere = True
break
if not is_used_elsewhere:
moves[name] = target_name
if not moves:
return False
# Resolve dependencies (children)
children_of: dict[str, list[str]] = defaultdict(list[str])
# We iterate top_level_nodes to preserve definition order.
# If A is helper for B, and C is helper for B. And A comes before C in file.
# children_of[B] will have [A, C].
for name, node, _ in top_level_nodes:
if name in moves:
target = moves[name]
children_of[target].append(name)
# Reconstruct file
processed_sources: dict[str, list[str]] = {} # name -> list of strings (lines)
# Map name back to node for easy lookup
name_to_node = {name: node for name, node, _ in top_level_nodes}
def get_node_source(name: str) -> list[str]:
if name in processed_sources:
return processed_sources[name]
node = name_to_node[name]
s, e = get_source_range(node, lines)
node_lines = lines[s:e]
# Clean name if it starts with underscore (internal helper noise)
clean_name = name
if name.startswith("_") and not name.startswith("__"):
clean_name = name.lstrip("_")
node_lines = rename_in_lines(node_lines, name, clean_name)
# Check if this node (must be a function) has children to inline
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
my_children = children_of.get(name, [])
if my_children:
# Insert children at the start of the function body
insert_idx = 0
# Determine body start indent
if node.body:
body_start_lineno = node.body[0].lineno - 1
first_body_line = lines[body_start_lineno]
current_indent = len(first_body_line) - len(
first_body_line.lstrip()
)
indent_str = " " * current_indent
insert_idx = body_start_lineno - s
else:
indent_str = " "
for i in range(s, e):
if lines[i].strip().endswith(":"):
insert_idx = i + 1 - s
break
else:
insert_idx = len(node_lines)
children_text: list[str] = []
for child in my_children:
c_lines = get_node_source(child)
c_lines_indented = indent_lines(c_lines, indent_str)
children_text.extend(c_lines_indented)
if c_lines_indented and not c_lines_indented[-1].endswith("\n"):
children_text.append("\n")
node_lines[insert_idx:insert_idx] = children_text
# If we renamed anything, we must also rename call sites in THIS function
if clean_name != name:
# rename_in_lines handles the definition.
# But what about callers of OTHER children that were renamed?
pass
# Wait, the logic is:
# When processing target T, it pulls in children C1, C2.
# If C1 was renamed to c1, then usages of C1 in T must be renamed to c1.
# get_node_source(C1) returns lines with C1 renamed to c1 in its definition.
# But T's lines still use C1.
# So after pulling in all children, T's lines must be updated.
for child in my_children:
if child.startswith("_") and not child.startswith("__"):
clean_child = child.lstrip("_")
node_lines = rename_in_lines(node_lines, child, clean_child)
processed_sources[name] = node_lines
return node_lines
edits: list[tuple[int, int, list[str]]] = [] # (start, end, content)
for name, node, _ in top_level_nodes:
s, e = get_source_range(node, lines)
if name in moves:
# It is a helper (var or func), remove from top level
edits.append((s, e, []))
elif name in children_of:
# It is a target (must be a function to have children), update it
new_source = get_node_source(name)
edits.append((s, e, new_source))
else:
pass
# Sort edits by start line desc
edits.sort(key=lambda x: x[0], reverse=True)
for s, e, replacement_lines in edits:
lines[s:e] = replacement_lines
with open(file_path, "w", encoding="utf-8") as f:
f.writelines(lines)
return True
if __name__ == "__main__":
files = get_project_files()
# Read all files content for global check
all_content: dict[str, str] = {}
for f in files:
with open(f, "r", encoding="utf-8") as fo:
all_content[f] = fo.read()
print(f"Scanning {len(files)} files...")
count = 0
for f in files:
if process_file(f, all_content):
print(f"Refactored {f}")
count += 1
print(f"Done. Modified {count} files.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment