Skip to content

Instantly share code, notes, and snippets.

@giladbarnea
Created November 9, 2025 21:55
Show Gist options
  • Select an option

  • Save giladbarnea/05121a93ba1dda30fee54b58dc409626 to your computer and use it in GitHub Desktop.

Select an option

Save giladbarnea/05121a93ba1dda30fee54b58dc409626 to your computer and use it in GitHub Desktop.
print_python_logical_branches.py
#!/usr/bin/env uvx --with=rich python3
"""
Control flow fork analyzer - Final polished version
Shows branching points in code with gradient coloring and clear formatting
Usage:
python analyze_final.py <file_path> <symbol>
Examples:
python analyze_final.py utils/asr/transcription_functions.py whisper_transcribe
python analyze_final.py preprocessor/transcribe/initiate_transcription.py InitiateTranscriptionPreProcessor._process
"""
import ast
import sys
from pathlib import Path
from rich.console import Console
from rich.panel import Panel
from rich.style import Style
from rich.syntax import Syntax
from rich.text import Text
console = Console()
class ForkAnalyzer(ast.NodeVisitor):
"""
Analyzes AST to find control flow forks.
A "fork" is any point where execution flow branches:
- if, try, except, finally, for, while, else (control flow keywords)
- return, raise (flow termination)
- Statements immediately following control blocks ("regardless" siblings)
"""
def __init__(self, source_lines, target_class=None, target_method=None):
self.source_lines = source_lines
self.fork_lines = set()
self.target_class = target_class # None for module-level functions
self.target_method = target_method
self.in_target_class = False
self.in_target_method = False
def add_fork(self, lineno):
if 1 <= lineno <= len(self.source_lines):
self.fork_lines.add(lineno)
def visit_ClassDef(self, node):
# If we're looking for a class method
if self.target_class and node.name == self.target_class:
self.in_target_class = True
self.generic_visit(node)
self.in_target_class = False
else:
self.generic_visit(node)
def visit_FunctionDef(self, node):
# Check if this is our target function/method
if self.target_class is None:
# Module-level function
if node.name == self.target_method:
self.in_target_method = True
self.process_body(node.body)
self.in_target_method = False
else:
self.generic_visit(node)
elif self.in_target_class and node.name == self.target_method:
# Class method
self.in_target_method = True
self.process_body(node.body)
self.in_target_method = False
else:
self.generic_visit(node)
def visit_AsyncFunctionDef(self, node):
# Check if this is our target async function/method
if self.target_class is None:
# Module-level async function
if node.name == self.target_method:
self.in_target_method = True
self.process_body(node.body)
self.in_target_method = False
else:
self.generic_visit(node)
elif self.in_target_class and node.name == self.target_method:
# Class async method
self.in_target_method = True
self.process_body(node.body)
self.in_target_method = False
else:
self.generic_visit(node)
def find_else_line(self, first_else_stmt_line):
"""Find the 'else:' keyword line by scanning backwards from first statement"""
# Scan backwards from the first statement line to find 'else:'
for line_no in range(first_else_stmt_line - 1, max(0, first_else_stmt_line - 10), -1):
if line_no <= 0 or line_no > len(self.source_lines):
continue
line = self.source_lines[line_no - 1].strip()
if line.startswith('else:') or line == 'else:':
return line_no
# Fallback to the first statement line if we can't find it
return first_else_stmt_line
def find_finally_line(self, first_finally_stmt_line):
"""Find the 'finally:' keyword line by scanning backwards from first statement"""
# Scan backwards from the first statement line to find 'finally:'
for line_no in range(first_finally_stmt_line - 1, max(0, first_finally_stmt_line - 10), -1):
if line_no <= 0 or line_no > len(self.source_lines):
continue
line = self.source_lines[line_no - 1].strip()
if line.startswith('finally:') or line == 'finally:':
return line_no
# Fallback to the first statement line if we can't find it
return first_finally_stmt_line
def process_body(self, body):
"""Recursively process statement bodies"""
for i, stmt in enumerate(body):
if isinstance(stmt, ast.If):
self.add_fork(stmt.lineno)
self.process_body(stmt.body)
# Handle else/elif
if stmt.orelse:
if isinstance(stmt.orelse[0], ast.If):
# elif - will be handled by recursion
self.process_body(stmt.orelse)
else:
# else block - find the actual 'else:' line
else_line = self.find_else_line(stmt.orelse[0].lineno)
self.add_fork(else_line)
self.process_body(stmt.orelse)
# Mark sibling after if block (the "regardless" concept)
if i + 1 < len(body):
self.add_fork(body[i + 1].lineno)
elif isinstance(stmt, ast.Try):
self.add_fork(stmt.lineno)
self.process_body(stmt.body)
# Handle except handlers
for handler in stmt.handlers:
self.add_fork(handler.lineno)
self.process_body(handler.body)
# Handle else clause in try
if stmt.orelse:
else_line = self.find_else_line(stmt.orelse[0].lineno)
self.add_fork(else_line)
self.process_body(stmt.orelse)
# Handle finally
if stmt.finalbody:
# Find the 'finally:' line
finally_line = self.find_finally_line(stmt.finalbody[0].lineno)
self.add_fork(finally_line)
self.process_body(stmt.finalbody)
# Sibling after try block
if i + 1 < len(body):
self.add_fork(body[i + 1].lineno)
elif isinstance(stmt, ast.For):
self.add_fork(stmt.lineno)
self.process_body(stmt.body)
if stmt.orelse:
else_line = self.find_else_line(stmt.orelse[0].lineno)
self.add_fork(else_line)
self.process_body(stmt.orelse)
if i + 1 < len(body):
self.add_fork(body[i + 1].lineno)
elif isinstance(stmt, ast.While):
self.add_fork(stmt.lineno)
self.process_body(stmt.body)
if stmt.orelse:
else_line = self.find_else_line(stmt.orelse[0].lineno)
self.add_fork(else_line)
self.process_body(stmt.orelse)
if i + 1 < len(body):
self.add_fork(body[i + 1].lineno)
elif isinstance(stmt, ast.Raise):
self.add_fork(stmt.lineno)
elif isinstance(stmt, ast.Return):
self.add_fork(stmt.lineno)
def highlight_line(line_text, base_style=None):
"""Syntax highlight a single line of Python code using Rich's built-in Syntax"""
# Create a Syntax object for this line
syntax_obj = Syntax(
line_text,
"python",
theme="monokai",
line_numbers=False,
word_wrap=False,
background_color="default", # Don't use theme background
)
# Use the highlight method to get a Text object with syntax highlighting
highlighted = syntax_obj.highlight(line_text)
# Strip trailing newlines and remove background colors from spans
plain_stripped = highlighted.plain.rstrip("\n\r")
result = Text(plain_stripped)
# Re-apply all the original spans but without background color
for span in highlighted.spans:
if span.end <= len(plain_stripped):
# Remove background color by creating new Style with same color but no bgcolor
if isinstance(span.style, Style):
style_without_bg = Style(color=span.style.color, bold=span.style.bold,
italic=span.style.italic, underline=span.style.underline)
else:
style_without_bg = span.style
result.stylize(style_without_bg, span.start, span.end)
elif span.start < len(plain_stripped):
# Partial span, truncate it
if isinstance(span.style, Style):
style_without_bg = Style(color=span.style.color, bold=span.style.bold,
italic=span.style.italic, underline=span.style.underline)
else:
style_without_bg = span.style
result.stylize(style_without_bg, span.start, len(plain_stripped))
# If we need to apply a base style (dim/dark), apply it on top
if base_style:
# Apply the base style to the entire text
result.stylize(base_style, 0, len(result))
return result
def display_fork(source_lines, line_no, all_fork_lines):
"""Display a fork with gradient-colored context lines"""
idx = line_no - 1
if idx >= len(source_lines):
return
# Fork line (normal syntax highlighting)
fork_line = source_lines[idx]
syntax_highlighted = highlight_line(fork_line)
# Build a single Text object combining line number and code
full_line = Text()
full_line.append(f"{line_no:4d}", style="bold white")
full_line.append(" │ ")
full_line.append_text(syntax_highlighted)
console.print(full_line)
# n+1 (dim) - but stop if it's another fork
if idx + 1 < len(source_lines):
if line_no + 1 in all_fork_lines:
return # Stop here, that fork will be displayed separately
n_plus_1 = source_lines[idx + 1]
syntax_highlighted = highlight_line(n_plus_1, base_style="dim")
full_line = Text()
full_line.append(f"{line_no+1:4d}", style="dim white")
full_line.append(" │ ")
full_line.append_text(syntax_highlighted)
console.print(full_line)
# n+2 (darker grey) - but stop if it's another fork
if idx + 2 < len(source_lines):
if line_no + 2 in all_fork_lines:
return # Stop here, that fork will be displayed separately
n_plus_2 = source_lines[idx + 2]
syntax_highlighted = highlight_line(n_plus_2, base_style="rgb(100,100,100)")
full_line = Text()
full_line.append(f"{line_no+2:4d}", style="dim")
full_line.append(" │ ")
full_line.append_text(syntax_highlighted)
console.print(full_line)
# Ellipsis if more lines exist (and it's not a fork)
if idx + 3 < len(source_lines):
if line_no + 3 not in all_fork_lines:
# Get indentation of n+3 line
n_plus_3 = source_lines[idx + 3]
indent = len(n_plus_3) - len(n_plus_3.lstrip())
indent_spaces = " " * indent
console.print(f"[dim] [/dim] │ [rgb(100,100,100)]{indent_spaces}...[/rgb(100,100,100)]")
def parse_symbol(symbol):
"""
Parse a symbol into class and method components.
Returns:
tuple: (class_name, method_name) or (None, function_name)
Examples:
"whisper_transcribe" -> (None, "whisper_transcribe")
"InitiateTranscriptionPreProcessor._process" -> ("InitiateTranscriptionPreProcessor", "_process")
"""
if '.' in symbol:
# Class method: ClassName.method_name
parts = symbol.rsplit('.', 1)
return parts[0], parts[1]
else:
# Module-level function
return None, symbol
def main():
# Parse command-line arguments
if len(sys.argv) != 3:
console.print("[red]Error: Missing arguments[/red]")
console.print("\n[bold]Usage:[/bold]")
console.print(" python analyze_final.py <file_path> <symbol>")
console.print("\n[bold]Examples:[/bold]")
console.print(" python analyze_final.py utils/asr/transcription_functions.py whisper_transcribe")
console.print(" python analyze_final.py preprocessor/transcribe/initiate_transcription.py InitiateTranscriptionPreProcessor._process")
sys.exit(1)
file_path = Path(sys.argv[1])
symbol = sys.argv[2]
# Check if file exists
if not file_path.exists():
console.print(f"[red]Error: File not found: {file_path}[/red]")
sys.exit(1)
# Parse the symbol
target_class, target_method = parse_symbol(symbol)
# Read and parse source
source = file_path.read_text()
source_lines = source.splitlines()
tree = ast.parse(source)
# Analyze control flow
analyzer = ForkAnalyzer(source_lines, target_class=target_class, target_method=target_method)
analyzer.visit(tree)
sorted_lines = sorted(analyzer.fork_lines)
# Check if we found the target
if not sorted_lines:
if target_class:
console.print(f"[red]Error: Could not find method '{target_method}' in class '{target_class}'[/red]")
else:
console.print(f"[red]Error: Could not find function '{target_method}'[/red]")
sys.exit(1)
# Build display name
if target_class:
display_name = f"{target_class}.{target_method}()"
else:
display_name = f"{target_method}()"
# Display header
console.print()
console.print(Panel(
f"[bold cyan]Control Flow Analysis[/bold cyan]\n"
f"File: [blue]{file_path}[/blue]\n"
f"Function: [yellow]{display_name}[/yellow]\n"
f"Branching points found: [green]{len(sorted_lines)}[/green]",
border_style="cyan"
))
console.print()
# Display each fork
fork_line_set = set(sorted_lines) # Convert to set for O(1) lookup
for line_no in sorted_lines:
display_fork(source_lines, line_no, fork_line_set)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment