Created
October 9, 2025 09:01
-
-
Save antocuni/f197fdc9f13a35aabe6157c5bdf79ab7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Add trailing commas to multiline lists, tuples, dicts, sets, and function calls. | |
| """ | |
| import libcst as cst | |
| from pathlib import Path | |
| import sys | |
| class TrailingCommaTransformer(cst.CSTTransformer): | |
| """Add trailing commas to multiline collections and calls.""" | |
| def _is_multiline(self, elements, rbracket) -> bool: | |
| """ | |
| Check if a sequence is multiline by looking for newlines in whitespace. | |
| A collection is multiline if there's a newline before the closing bracket. | |
| """ | |
| if not elements: | |
| return False | |
| # Check if there's a newline in the whitespace before the closing bracket | |
| if hasattr(rbracket, 'whitespace_before'): | |
| ws = rbracket.whitespace_before | |
| if hasattr(ws, 'value') and '\n' in ws.value: | |
| return True | |
| # Check SimpleWhitespace | |
| if isinstance(ws, cst.SimpleWhitespace) and '\n' in ws.value: | |
| return True | |
| # Check ParenthesizedWhitespace | |
| if isinstance(ws, cst.ParenthesizedWhitespace): | |
| for line in ws.empty_lines: | |
| if hasattr(line, 'whitespace') or hasattr(line, 'comment'): | |
| return True | |
| if hasattr(ws, 'first_line') and ws.first_line: | |
| return True | |
| # Also check if any element has a newline after it | |
| for element in elements: | |
| if hasattr(element, 'comma'): | |
| comma = element.comma | |
| # Skip MaybeSentinel values | |
| if isinstance(comma, cst.MaybeSentinel): | |
| continue | |
| if comma and hasattr(comma, 'whitespace_after'): | |
| wa = comma.whitespace_after | |
| if hasattr(wa, 'value') and '\n' in wa.value: | |
| return True | |
| if isinstance(wa, cst.ParenthesizedWhitespace): | |
| return True | |
| return False | |
| def _add_trailing_comma_to_sequence(self, elements): | |
| """Add trailing comma to a sequence of elements if not present.""" | |
| if not elements: | |
| return elements | |
| last_element = elements[-1] | |
| # Check if already has a trailing comma | |
| if hasattr(last_element, 'comma'): | |
| comma = last_element.comma | |
| if comma and not isinstance(comma, cst.MaybeSentinel): | |
| return elements | |
| # Add trailing comma to the last element with no extra whitespace | |
| # The whitespace before the closing bracket will be preserved | |
| new_last = last_element.with_changes( | |
| comma=cst.Comma() | |
| ) | |
| return elements[:-1] + (new_last,) | |
| def leave_List(self, original_node: cst.List, updated_node: cst.List) -> cst.List: | |
| """Transform list literals.""" | |
| if self._is_multiline(updated_node.elements, updated_node.rbracket): | |
| new_elements = self._add_trailing_comma_to_sequence(updated_node.elements) | |
| return updated_node.with_changes(elements=new_elements) | |
| return updated_node | |
| def leave_Tuple(self, original_node: cst.Tuple, updated_node: cst.Tuple) -> cst.Tuple: | |
| """Transform tuple literals.""" | |
| # Only process tuples with parentheses | |
| if not (updated_node.lpar and updated_node.rpar): | |
| return updated_node | |
| rbracket = updated_node.rpar[-1] if updated_node.rpar else None | |
| if rbracket and self._is_multiline(updated_node.elements, rbracket): | |
| new_elements = self._add_trailing_comma_to_sequence(updated_node.elements) | |
| return updated_node.with_changes(elements=new_elements) | |
| return updated_node | |
| def leave_Set(self, original_node: cst.Set, updated_node: cst.Set) -> cst.Set: | |
| """Transform set literals.""" | |
| if self._is_multiline(updated_node.elements, updated_node.rbrace): | |
| new_elements = self._add_trailing_comma_to_sequence(updated_node.elements) | |
| return updated_node.with_changes(elements=new_elements) | |
| return updated_node | |
| def leave_Dict(self, original_node: cst.Dict, updated_node: cst.Dict) -> cst.Dict: | |
| """Transform dict literals.""" | |
| if self._is_multiline(updated_node.elements, updated_node.rbrace): | |
| new_elements = self._add_trailing_comma_to_sequence(updated_node.elements) | |
| return updated_node.with_changes(elements=new_elements) | |
| return updated_node | |
| def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: | |
| """Transform function calls.""" | |
| if self._is_multiline(updated_node.args, updated_node.rpar): | |
| new_args = self._add_trailing_comma_to_sequence(updated_node.args) | |
| return updated_node.with_changes(args=new_args) | |
| return updated_node | |
| def leave_Subscript(self, original_node: cst.Subscript, updated_node: cst.Subscript) -> cst.Subscript: | |
| """Transform subscript operations (e.g., array[1, 2, 3]).""" | |
| if not isinstance(updated_node.slice, (list, tuple)): | |
| slices = [updated_node.slice] | |
| else: | |
| slices = updated_node.slice | |
| if len(slices) == 1 and isinstance(slices[0], cst.Index): | |
| index_node = slices[0].value | |
| if isinstance(index_node, cst.Tuple): | |
| # Check if the tuple inside subscript is multiline | |
| # Use the subscript's rbracket as the closing bracket | |
| if self._is_multiline(index_node.elements, updated_node.rbracket): | |
| new_elements = self._add_trailing_comma_to_sequence(index_node.elements) | |
| new_index = index_node.with_changes(elements=new_elements) | |
| new_slice = slices[0].with_changes(value=new_index) | |
| return updated_node.with_changes(slice=[new_slice]) | |
| return updated_node | |
| def process_file(file_path: Path, dry_run: bool = False) -> bool: | |
| """ | |
| Process a single Python file. | |
| Returns True if the file was modified, False otherwise. | |
| """ | |
| try: | |
| source_code = file_path.read_text() | |
| # Parse the source code | |
| source_tree = cst.parse_module(source_code) | |
| # Transform the tree | |
| transformer = TrailingCommaTransformer() | |
| modified_tree = source_tree.visit(transformer) | |
| # Check if anything changed | |
| new_code = modified_tree.code | |
| if new_code != source_code: | |
| if not dry_run: | |
| file_path.write_text(new_code) | |
| print(f"{'[DRY RUN] Would modify' if dry_run else 'Modified'}: {file_path}") | |
| return True | |
| return False | |
| except Exception as e: | |
| print(f"Error processing {file_path}: {e}", file=sys.stderr) | |
| return False | |
| def main(): | |
| """Main entry point.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Add trailing commas to multiline collections in Python files" | |
| ) | |
| parser.add_argument( | |
| '--dry-run', | |
| action='store_true', | |
| help='Show what would be changed without modifying files' | |
| ) | |
| parser.add_argument( | |
| 'path', | |
| nargs='?', | |
| default='spy', | |
| help='Path to process (default: spy/)' | |
| ) | |
| args = parser.parse_args() | |
| # Find all Python files | |
| path = Path(args.path) | |
| if path.is_file(): | |
| python_files = [path] | |
| else: | |
| python_files = sorted(path.rglob('*.py')) | |
| print(f"Found {len(python_files)} Python files") | |
| modified_count = 0 | |
| for py_file in python_files: | |
| if process_file(py_file, dry_run=args.dry_run): | |
| modified_count += 1 | |
| print(f"\n{'[DRY RUN] Would modify' if args.dry_run else 'Modified'} {modified_count} file(s)") | |
| if __name__ == '__main__': | |
| main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
created with this
claudeprompt: