Skip to content

Instantly share code, notes, and snippets.

@dokato
Created November 18, 2025 22:30
Show Gist options
  • Select an option

  • Save dokato/dc61a07ef5156c12a56d5bd5e0fec652 to your computer and use it in GitHub Desktop.

Select an option

Save dokato/dc61a07ef5156c12a56d5bd5e0fec652 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Serialize a Python AST and save it as a pickle.
Usage examples
- From a file: $ python serialize_python_ast.py --input path/to/code.py --out ast.pkl
- From stdin: cat code.py | python serialize_python_ast.py --out ast.pkl
Output format (pickle)
- Pickled Python `ast.AST` object (the root node of the parsed tree)
- Open with: `pickle.load(open('ast.pkl','rb'))`
"""
from __future__ import annotations
import argparse
import ast
import io
import pickle
import json
import sys
from pathlib import Path
from typing import Iterable, List, Tuple, Union
Primitive = Union[str, int, float, bool, None]
def is_primitive(value: object) -> bool:
return isinstance(value, (str, int, float, bool)) or value is None
def attr_value_repr(value: object) -> str:
"""Render attribute values safely.
- Primitive values use JSON encoding for consistent escaping
- AST context/helper nodes render as their class name string
- Other objects fallback to JSON-encoded string of repr(value)
"""
if is_primitive(value):
return json.dumps(value, ensure_ascii=False)
if isinstance(value, ast.AST):
return json.dumps(value.__class__.__name__)
return json.dumps(repr(value))
def iter_children_with_edges(node: ast.AST) -> Iterable[Tuple[str, int, ast.AST]]:
"""Yield (field_name, index, child_node) for all AST children in field order.
For scalar AST fields, index = -1. For list fields, index is the item position.
"""
for field_name, value in ast.iter_fields(node):
if isinstance(value, ast.AST):
yield (field_name, -1, value)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, ast.AST):
yield (field_name, i, item)
def collect_node_attributes(node: ast.AST) -> List[Tuple[str, Primitive]]:
"""Collect simple attributes for a node as (key, value) pairs.
- Emits primitive field values (str, int, float, bool, None)
- Emits small AST helper nodes (e.g., Load/Store) as their class name string
- Adds source position tags when present
- Excludes actual AST children and lists of ASTs
"""
attrs: List[Tuple[str, Primitive]] = []
# Positional metadata first if present
for k in ("lineno", "col_offset", "end_lineno", "end_col_offset"):
if hasattr(node, k):
v = getattr(node, k)
if v is not None:
attrs.append((k, v))
# Other fields
for field_name, value in ast.iter_fields(node):
if isinstance(value, ast.AST):
# child, not an attribute; but context nodes like Load/Store sometimes
# appear as attributes on Name.id's ctx — they are ASTs. We include these
# as attributes by rendering their class names when they are of known
# lightweight context types.
if isinstance(value, (ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param)):
attrs.append((field_name, value.__class__.__name__))
continue
if isinstance(value, list):
# Lists can contain mixed primitives and ASTs. Only include a list if
# it has no AST children and contains only primitives.
if all(not isinstance(it, ast.AST) and is_primitive(it) for it in value):
# Keep as JSON array of primitives
# Represent as a JSON string to remain a single tag value
attrs.append((field_name, json.loads(json.dumps(value))))
continue
# Primitive attribute or None
if is_primitive(value):
attrs.append((field_name, value))
return attrs
def preorder_serialize(node: ast.AST, out: io.TextIOBase, *, indent: int = 0, parent_field: str | None = None, parent_index: int | None = None) -> None:
"""Write a single-line representation of `node` and recursively its children.
Line format example:
" FunctionDef field=body[0] name="foo" lineno=1 col_offset=0"
"""
parts: List[str] = []
parts.append(" " * indent + node.__class__.__name__)
if parent_field is not None:
if parent_index is None or parent_index < 0:
parts.append(f"field={parent_field}")
else:
parts.append(f"field={parent_field}[{parent_index}]")
# Attributes
for k, v in collect_node_attributes(node):
parts.append(f"{k}={attr_value_repr(v)}")
out.write(" ".join(parts) + "\n")
# Children in field order (pre-order traversal)
for field_name, idx, child in iter_children_with_edges(node):
preorder_serialize(child, out, indent=indent + 2, parent_field=field_name, parent_index=(None if idx == -1 else idx))
def main(argv: List[str] | None = None) -> int:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--input", type=Path, default=None, help="Path to input Python source. If omitted, reads from stdin.")
ap.add_argument("--out", type=Path, default=Path("ast.pkl"), help="Output file path for the pickled AST object")
ap.add_argument("--encoding", type=str, default="utf-8", help="Text encoding for input/output files (default utf-8)")
args = ap.parse_args(argv)
# Read source
if args.input is None:
source = sys.stdin.read()
else:
source = args.input.read_text(encoding=args.encoding)
# Parse AST
try:
tree = ast.parse(source)
except SyntaxError as e:
sys.stderr.write(f"SyntaxError: {e}\n")
return 2
# Serialize (pickle the AST root node)
with args.out.open("wb") as f:
pickle.dump(tree, f, protocol=pickle.HIGHEST_PROTOCOL)
return 0
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment