Skip to content

Instantly share code, notes, and snippets.

@laohyx
Last active January 17, 2026 05:53
Show Gist options
  • Select an option

  • Save laohyx/c73609f55e495efaeec3304b07eb3623 to your computer and use it in GitHub Desktop.

Select an option

Save laohyx/c73609f55e495efaeec3304b07eb3623 to your computer and use it in GitHub Desktop.
Read .safetensors files and print layers and shape. Best for HuggingFace LLM models.
#!/usr/bin/env python3
"""
Script to read LLM LoRA files or model directories containing safetensors files.
Displays total parameters and layer-wise shapes with aggregation.
"""
import argparse
import os
from pathlib import Path
from collections import defaultdict
import re
from typing import Dict, List, Tuple
from safetensors import safe_open
from rich.console import Console
from rich.table import Table
from rich.tree import Tree
from rich import box
def read_safetensors_file(file_path: str) -> Dict[str, Tuple]:
"""Read a single safetensors file and return tensor names and shapes."""
tensors = {}
with safe_open(file_path, framework="numpy", device="cpu") as f:
for key in f.keys():
# Get tensor metadata without loading the full tensor data
tensor_slice = f.get_slice(key)
tensors[key] = tuple(tensor_slice.get_shape())
return tensors
def read_safetensors_directory(dir_path: str) -> Dict[str, Tuple]:
"""Read all safetensors files in a directory."""
all_tensors = {}
safetensors_files = sorted(Path(dir_path).glob("*.safetensors"))
if not safetensors_files:
raise ValueError(f"No .safetensors files found in {dir_path}")
for file_path in safetensors_files:
tensors = read_safetensors_file(str(file_path))
all_tensors.update(tensors)
return all_tensors
def extract_layer_number(key: str) -> Tuple[str, int]:
"""
Extract layer number from tensor key.
Returns (base_pattern, layer_number) or (key, -1) if no layer number found.
"""
# Match patterns like: model.layers.0.xxx, blocks.12.xxx, layer.5.xxx, etc.
patterns = [
r'(.*?\.(?:layers|blocks|layer|h)\.)\d+(\..+)',
r'(.*?\.)\d+(\..+)',
]
for pattern in patterns:
match = re.search(pattern, key)
if match:
# Extract the number
number_match = re.search(r'\.(\d+)\.', key[len(match.group(1))-1:])
if number_match:
layer_num = int(number_match.group(1))
prefix = match.group(1)
suffix = match.group(2)
base_pattern = f"{prefix}[NUM]{suffix}"
return base_pattern, layer_num
return key, -1
def aggregate_layers(tensors: Dict[str, Tuple]) -> Dict[str, Dict]:
"""
Aggregate tensors by layer patterns.
Returns dict with aggregated info including shape and layer ranges.
"""
# Group by base pattern
pattern_groups = defaultdict(lambda: {"shape": None, "layers": set(), "example_key": None})
ungrouped = {}
for key, shape in tensors.items():
base_pattern, layer_num = extract_layer_number(key)
if layer_num >= 0:
pattern_groups[base_pattern]["shape"] = shape
pattern_groups[base_pattern]["layers"].add(layer_num)
if pattern_groups[base_pattern]["example_key"] is None:
pattern_groups[base_pattern]["example_key"] = key
else:
ungrouped[key] = shape
# Format aggregated results
aggregated = {}
# Add grouped layers with layer ranges
for pattern, info in sorted(pattern_groups.items()):
layers = sorted(info["layers"])
if len(layers) > 1:
layer_ranges = []
start = layers[0]
end = layers[0]
for i in range(1, len(layers)):
if layers[i] == end + 1:
end = layers[i]
else:
if start == end:
layer_ranges.append(str(start))
else:
layer_ranges.append(f"{start}-{end}")
start = layers[i]
end = layers[i]
if start == end:
layer_ranges.append(str(start))
else:
layer_ranges.append(f"{start}-{end}")
formatted_key = pattern.replace("[NUM]", f"[{','.join(layer_ranges)}]")
aggregated[formatted_key] = {
"shape": info["shape"],
"count": len(layers)
}
else:
# Single layer, use original key
aggregated[info["example_key"]] = {
"shape": info["shape"],
"count": 1
}
# Add ungrouped tensors
for key, shape in sorted(ungrouped.items()):
aggregated[key] = {
"shape": shape,
"count": 1
}
return aggregated
def calculate_params(shape: Tuple) -> int:
"""Calculate number of parameters from shape."""
if not shape:
return 0
result = 1
for dim in shape:
result *= dim
return result
def build_hierarchy_tree(aggregated: Dict[str, Dict]) -> Dict:
"""Build a hierarchical tree structure from tensor keys."""
tree = {}
for key, info in aggregated.items():
parts = key.split('.')
current = tree
for i, part in enumerate(parts):
if i == len(parts) - 1:
# Leaf node
current[part] = {
"_leaf": True,
"shape": info["shape"],
"count": info["count"]
}
else:
# Branch node
if part not in current:
current[part] = {}
elif isinstance(current[part], dict) and current[part].get("_leaf"):
# Handle conflict: already a leaf
continue
current = current[part]
return tree
def count_non_leaf_children(tree: Dict) -> int:
"""Count non-leaf children in a tree node."""
if not isinstance(tree, dict):
return 0
count = 0
for key, value in tree.items():
if not (isinstance(value, dict) and value.get("_leaf")):
count += 1
return count
def merge_single_child_path(tree: Dict, current_path: str = "") -> Tuple[str, Dict]:
"""
Merge single-child nodes into a single path.
Returns (merged_path, final_node).
"""
if not isinstance(tree, dict) or tree.get("_leaf"):
return current_path, tree
# Count non-leaf children
non_leaf_count = count_non_leaf_children(tree)
# If there's only one non-leaf child, merge it
if non_leaf_count == 1:
for key, value in tree.items():
if not (isinstance(value, dict) and value.get("_leaf")):
new_path = f"{current_path}.{key}" if current_path else key
return merge_single_child_path(value, new_path)
return current_path, tree
def print_tree_recursive(tree: Dict, parent_key: str, console: Console, indent: int = 0):
"""Recursively print the tree structure with rich formatting, merging single-child paths."""
for key in sorted(tree.keys()):
value = tree[key]
full_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict) and value.get("_leaf"):
# Leaf node
shape_str = "×".join(map(str, value["shape"]))
param_count = calculate_params(value["shape"]) * value["count"]
count_str = f" (×{value['count']} layers)" if value["count"] > 1 else ""
console.print(
" " * indent + f"[cyan]{key}[/cyan]: "
f"[yellow]{shape_str}[/yellow]{count_str} "
f"[dim]({param_count:,} params)[/dim]"
)
else:
# Branch node - check if we can merge single-child paths
merged_path, final_node = merge_single_child_path(value, key)
# Print the merged path
console.print(" " * indent + f"[bold green]{merged_path}[/bold green]")
# Print children of the final node
for child_key in sorted(final_node.keys()):
child_value = final_node[child_key]
if isinstance(child_value, dict) and child_value.get("_leaf"):
# Leaf node
shape_str = "×".join(map(str, child_value["shape"]))
param_count = calculate_params(child_value["shape"]) * child_value["count"]
count_str = f" (×{child_value['count']} layers)" if child_value["count"] > 1 else ""
console.print(
" " * (indent + 1) + f"[cyan]{child_key}[/cyan]: "
f"[yellow]{shape_str}[/yellow]{count_str} "
f"[dim]({param_count:,} params)[/dim]"
)
else:
# Recursive call for deeper branches
print_tree_recursive({child_key: child_value}, merged_path, console, indent + 1)
def main():
parser = argparse.ArgumentParser(
description="Read and analyze safetensors files (LoRA or model files)"
)
parser.add_argument(
"path",
type=str,
help="Path to a safetensors file or directory containing safetensors files"
)
parser.add_argument(
"--no-aggregate",
action="store_true",
help="Don't aggregate layers, show all tensors individually"
)
args = parser.parse_args()
console = Console()
# Check if path exists
path = Path(args.path)
if not path.exists():
console.print(f"[red]Error: Path {args.path} does not exist[/red]")
return
# Read tensors
console.print(f"\n[bold]Reading safetensors from:[/bold] {args.path}\n")
try:
if path.is_file():
tensors = read_safetensors_file(str(path))
else:
tensors = read_safetensors_directory(str(path))
except Exception as e:
console.print(f"[red]Error reading safetensors: {e}[/red]")
return
# Calculate total parameters
total_params = sum(calculate_params(shape) for shape in tensors.values())
console.print(f"[bold green]Total tensors:[/bold green] {len(tensors):,}")
console.print(f"[bold green]Total parameters:[/bold green] {total_params:,}")
console.print(f"[bold green]Total parameters (M):[/bold green] {total_params / 1_000_000:.2f}M")
console.print(f"[bold green]Total parameters (B):[/bold green] {total_params / 1_000_000_000:.4f}B\n")
# Aggregate or show raw
if args.no_aggregate:
console.print("[bold]All Tensors (No Aggregation):[/bold]\n")
for key, shape in sorted(tensors.items()):
shape_str = "×".join(map(str, shape))
param_count = calculate_params(shape)
console.print(f"[cyan]{key}[/cyan]: [yellow]{shape_str}[/yellow] [dim]({param_count:,} params)[/dim]")
else:
console.print("[bold]Layer Structure (Aggregated):[/bold]\n")
aggregated = aggregate_layers(tensors)
# Build and print hierarchy
hierarchy = build_hierarchy_tree(aggregated)
print_tree_recursive(hierarchy, "", console)
console.print()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment