Created
January 15, 2026 19:57
-
-
Save vjeranc/ef99c2c5faeaf9d492ec68f633746789 to your computer and use it in GitHub Desktop.
if helper called only inside 1 fn, inline def
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ast | |
| import os | |
| import re | |
| from collections import defaultdict | |
| from typing import Any | |
| def get_project_files(root: str = ".") -> list[str]: | |
| files: list[str] = [] | |
| for dirpath, dirnames, filenames in os.walk(root): | |
| # Skip hidden/venv dirs | |
| dirnames[:] = [ | |
| d | |
| for d in dirnames | |
| if not d.startswith(".") | |
| and d != "__pycache__" | |
| and d != "venv" | |
| and d != "env" | |
| ] | |
| for f in filenames: | |
| if f.endswith(".py"): | |
| files.append(os.path.join(dirpath, f)) | |
| return files | |
| class UsageVisitor(ast.NodeVisitor): | |
| def __init__(self) -> None: | |
| self.usages: dict[str, set[str]] = defaultdict(set) | |
| self.header_usages: dict[str, set[str]] = defaultdict(set) | |
| self.current_function: str | None = None | |
| self.top_level_function: str | None = None | |
| self.in_header: bool = False | |
| def visit_FunctionDef(self, node: ast.FunctionDef) -> None: | |
| prev_current = self.current_function | |
| prev_top = self.top_level_function | |
| self.current_function = node.name | |
| if self.top_level_function is None: | |
| self.top_level_function = node.name | |
| # Visit header parts with in_header=True | |
| old_in_header = self.in_header | |
| self.in_header = True | |
| for decorator in node.decorator_list: | |
| self.visit(decorator) | |
| self.visit(node.args) | |
| if node.returns: | |
| self.visit(node.returns) | |
| # Visit body with in_header=False | |
| self.in_header = False | |
| for b in node.body: | |
| self.visit(b) | |
| self.in_header = old_in_header | |
| self.current_function = prev_current | |
| self.top_level_function = prev_top | |
| def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: | |
| prev_current = self.current_function | |
| prev_top = self.top_level_function | |
| self.current_function = node.name | |
| if self.top_level_function is None: | |
| self.top_level_function = node.name | |
| old_in_header = self.in_header | |
| self.in_header = True | |
| for decorator in node.decorator_list: | |
| self.visit(decorator) | |
| self.visit(node.args) | |
| if node.returns: | |
| self.visit(node.returns) | |
| self.in_header = False | |
| for b in node.body: | |
| self.visit(b) | |
| self.in_header = old_in_header | |
| self.current_function = prev_current | |
| self.top_level_function = prev_top | |
| def visit_Name(self, node: ast.Name) -> None: | |
| if isinstance(node.ctx, ast.Load) and self.top_level_function is not None: | |
| self.usages[node.id].add(self.top_level_function) | |
| if self.in_header: | |
| self.header_usages[self.top_level_function].add(node.id) | |
| def process_file(file_path: str, all_files_content: dict[str, str]) -> bool: | |
| def get_source_range(node: Any, lines: list[str]) -> tuple[int, int]: | |
| # Start line | |
| start = node.lineno - 1 | |
| # Check for decorators if function | |
| if hasattr(node, "decorator_list") and node.decorator_list: | |
| start = node.decorator_list[0].lineno - 1 | |
| # End line | |
| end = node.end_lineno # inclusive | |
| return start, end | |
| def indent_lines(text_lines: list[str], indent_str: str) -> list[str]: | |
| return [indent_str + line for line in text_lines] | |
| def rename_in_lines(lines: list[str], old_name: str, new_name: str) -> list[str]: | |
| # Very simple renaming using regex on lines. | |
| # Safe because we only rename inside the function scope we just constructed. | |
| # But it might hit strings or comments. | |
| # Better: use AST to find offsets? Too hard with text lines. | |
| # Let's use word-boundary regex. | |
| pattern = re.compile(r"\b" + re.escape(old_name) + r"\b") | |
| new_lines: list[str] = [] | |
| for line in lines: | |
| new_lines.append(pattern.sub(new_name, line)) | |
| return new_lines | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| original_source = f.read() | |
| lines: list[str] = original_source.splitlines(keepends=True) | |
| if not lines: | |
| return False | |
| try: | |
| tree = ast.parse(original_source) | |
| except SyntaxError: | |
| return False | |
| # Collect top-level nodes of interest (Functions and Assignments) | |
| # We maintain order to allow correct dependency ordering in moves. | |
| top_level_nodes: list[tuple[str, Any, str]] = [] # (name, node, type) | |
| top_level_names: set[str] = set() | |
| for node in tree.body: | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| top_level_nodes.append((node.name, node, "func")) | |
| top_level_names.add(node.name) | |
| elif isinstance(node, ast.Assign): | |
| # Check for simple assignments: name = ... | |
| if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): | |
| name = node.targets[0].id | |
| top_level_nodes.append((name, node, "var")) | |
| top_level_names.add(name) | |
| elif isinstance(node, ast.AnnAssign): | |
| # Check for simple annotated assignments: name: type = ... | |
| if isinstance(node.target, ast.Name): | |
| name = node.target.id | |
| top_level_nodes.append((name, node, "var")) | |
| top_level_names.add(name) | |
| if not top_level_nodes: | |
| return False | |
| visitor = UsageVisitor() | |
| visitor.visit(tree) | |
| moves: dict[str, str] = {} # helper_name -> target_func_name | |
| # Determine candidates for moving | |
| for name, node, _ in top_level_nodes: | |
| if name.startswith("__"): | |
| continue | |
| callers = visitor.usages.get(name, set()) | |
| # Filter self-recursion (funcs) or self-ref (vars?) | |
| callers.discard(name) | |
| if len(callers) == 1: | |
| target_name = list(callers)[0] | |
| if target_name in top_level_names: | |
| # Check if used in target's header (e.g. default args, decorators) | |
| # If so, cannot move inside body. | |
| if name in visitor.header_usages.get(target_name, set()): | |
| continue | |
| # Check module-level usage (global scope usage) | |
| module_level_usage = False | |
| for b_node in tree.body: | |
| if not isinstance(b_node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| # Use a temporary visitor or manual walk to check for Name loads | |
| for n in ast.walk(b_node): | |
| # Ensure it's not the definition itself | |
| if n is node: | |
| continue | |
| # For Assign nodes, definition is in targets. | |
| # For AnnAssign, definition is in target. | |
| # We just check if usage is Load. | |
| if ( | |
| isinstance(n, ast.Name) | |
| and n.id == name | |
| and isinstance(n.ctx, ast.Load) | |
| ): | |
| module_level_usage = True | |
| break | |
| if module_level_usage: | |
| break | |
| if module_level_usage: | |
| continue | |
| # Check usage in other files | |
| is_used_elsewhere = False | |
| for other_path, content in all_files_content.items(): | |
| if other_path == file_path: | |
| continue | |
| if re.search(r"\b" + re.escape(name) + r"\b", content): | |
| is_used_elsewhere = True | |
| break | |
| if not is_used_elsewhere: | |
| moves[name] = target_name | |
| if not moves: | |
| return False | |
| # Resolve dependencies (children) | |
| children_of: dict[str, list[str]] = defaultdict(list[str]) | |
| # We iterate top_level_nodes to preserve definition order. | |
| # If A is helper for B, and C is helper for B. And A comes before C in file. | |
| # children_of[B] will have [A, C]. | |
| for name, node, _ in top_level_nodes: | |
| if name in moves: | |
| target = moves[name] | |
| children_of[target].append(name) | |
| # Reconstruct file | |
| processed_sources: dict[str, list[str]] = {} # name -> list of strings (lines) | |
| # Map name back to node for easy lookup | |
| name_to_node = {name: node for name, node, _ in top_level_nodes} | |
| def get_node_source(name: str) -> list[str]: | |
| if name in processed_sources: | |
| return processed_sources[name] | |
| node = name_to_node[name] | |
| s, e = get_source_range(node, lines) | |
| node_lines = lines[s:e] | |
| # Clean name if it starts with underscore (internal helper noise) | |
| clean_name = name | |
| if name.startswith("_") and not name.startswith("__"): | |
| clean_name = name.lstrip("_") | |
| node_lines = rename_in_lines(node_lines, name, clean_name) | |
| # Check if this node (must be a function) has children to inline | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| my_children = children_of.get(name, []) | |
| if my_children: | |
| # Insert children at the start of the function body | |
| insert_idx = 0 | |
| # Determine body start indent | |
| if node.body: | |
| body_start_lineno = node.body[0].lineno - 1 | |
| first_body_line = lines[body_start_lineno] | |
| current_indent = len(first_body_line) - len( | |
| first_body_line.lstrip() | |
| ) | |
| indent_str = " " * current_indent | |
| insert_idx = body_start_lineno - s | |
| else: | |
| indent_str = " " | |
| for i in range(s, e): | |
| if lines[i].strip().endswith(":"): | |
| insert_idx = i + 1 - s | |
| break | |
| else: | |
| insert_idx = len(node_lines) | |
| children_text: list[str] = [] | |
| for child in my_children: | |
| c_lines = get_node_source(child) | |
| c_lines_indented = indent_lines(c_lines, indent_str) | |
| children_text.extend(c_lines_indented) | |
| if c_lines_indented and not c_lines_indented[-1].endswith("\n"): | |
| children_text.append("\n") | |
| node_lines[insert_idx:insert_idx] = children_text | |
| # If we renamed anything, we must also rename call sites in THIS function | |
| if clean_name != name: | |
| # rename_in_lines handles the definition. | |
| # But what about callers of OTHER children that were renamed? | |
| pass | |
| # Wait, the logic is: | |
| # When processing target T, it pulls in children C1, C2. | |
| # If C1 was renamed to c1, then usages of C1 in T must be renamed to c1. | |
| # get_node_source(C1) returns lines with C1 renamed to c1 in its definition. | |
| # But T's lines still use C1. | |
| # So after pulling in all children, T's lines must be updated. | |
| for child in my_children: | |
| if child.startswith("_") and not child.startswith("__"): | |
| clean_child = child.lstrip("_") | |
| node_lines = rename_in_lines(node_lines, child, clean_child) | |
| processed_sources[name] = node_lines | |
| return node_lines | |
| edits: list[tuple[int, int, list[str]]] = [] # (start, end, content) | |
| for name, node, _ in top_level_nodes: | |
| s, e = get_source_range(node, lines) | |
| if name in moves: | |
| # It is a helper (var or func), remove from top level | |
| edits.append((s, e, [])) | |
| elif name in children_of: | |
| # It is a target (must be a function to have children), update it | |
| new_source = get_node_source(name) | |
| edits.append((s, e, new_source)) | |
| else: | |
| pass | |
| # Sort edits by start line desc | |
| edits.sort(key=lambda x: x[0], reverse=True) | |
| for s, e, replacement_lines in edits: | |
| lines[s:e] = replacement_lines | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| f.writelines(lines) | |
| return True | |
| if __name__ == "__main__": | |
| files = get_project_files() | |
| # Read all files content for global check | |
| all_content: dict[str, str] = {} | |
| for f in files: | |
| with open(f, "r", encoding="utf-8") as fo: | |
| all_content[f] = fo.read() | |
| print(f"Scanning {len(files)} files...") | |
| count = 0 | |
| for f in files: | |
| if process_file(f, all_content): | |
| print(f"Refactored {f}") | |
| count += 1 | |
| print(f"Done. Modified {count} files.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment