Last active
January 17, 2026 05:53
-
-
Save laohyx/c73609f55e495efaeec3304b07eb3623 to your computer and use it in GitHub Desktop.
Read .safetensors files and print layers and shape. Best for HuggingFace LLM models.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Script to read LLM LoRA files or model directories containing safetensors files. | |
| Displays total parameters and layer-wise shapes with aggregation. | |
| """ | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import re | |
| from typing import Dict, List, Tuple | |
| from safetensors import safe_open | |
| from rich.console import Console | |
| from rich.table import Table | |
| from rich.tree import Tree | |
| from rich import box | |
| def read_safetensors_file(file_path: str) -> Dict[str, Tuple]: | |
| """Read a single safetensors file and return tensor names and shapes.""" | |
| tensors = {} | |
| with safe_open(file_path, framework="numpy", device="cpu") as f: | |
| for key in f.keys(): | |
| # Get tensor metadata without loading the full tensor data | |
| tensor_slice = f.get_slice(key) | |
| tensors[key] = tuple(tensor_slice.get_shape()) | |
| return tensors | |
| def read_safetensors_directory(dir_path: str) -> Dict[str, Tuple]: | |
| """Read all safetensors files in a directory.""" | |
| all_tensors = {} | |
| safetensors_files = sorted(Path(dir_path).glob("*.safetensors")) | |
| if not safetensors_files: | |
| raise ValueError(f"No .safetensors files found in {dir_path}") | |
| for file_path in safetensors_files: | |
| tensors = read_safetensors_file(str(file_path)) | |
| all_tensors.update(tensors) | |
| return all_tensors | |
| def extract_layer_number(key: str) -> Tuple[str, int]: | |
| """ | |
| Extract layer number from tensor key. | |
| Returns (base_pattern, layer_number) or (key, -1) if no layer number found. | |
| """ | |
| # Match patterns like: model.layers.0.xxx, blocks.12.xxx, layer.5.xxx, etc. | |
| patterns = [ | |
| r'(.*?\.(?:layers|blocks|layer|h)\.)\d+(\..+)', | |
| r'(.*?\.)\d+(\..+)', | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, key) | |
| if match: | |
| # Extract the number | |
| number_match = re.search(r'\.(\d+)\.', key[len(match.group(1))-1:]) | |
| if number_match: | |
| layer_num = int(number_match.group(1)) | |
| prefix = match.group(1) | |
| suffix = match.group(2) | |
| base_pattern = f"{prefix}[NUM]{suffix}" | |
| return base_pattern, layer_num | |
| return key, -1 | |
| def aggregate_layers(tensors: Dict[str, Tuple]) -> Dict[str, Dict]: | |
| """ | |
| Aggregate tensors by layer patterns. | |
| Returns dict with aggregated info including shape and layer ranges. | |
| """ | |
| # Group by base pattern | |
| pattern_groups = defaultdict(lambda: {"shape": None, "layers": set(), "example_key": None}) | |
| ungrouped = {} | |
| for key, shape in tensors.items(): | |
| base_pattern, layer_num = extract_layer_number(key) | |
| if layer_num >= 0: | |
| pattern_groups[base_pattern]["shape"] = shape | |
| pattern_groups[base_pattern]["layers"].add(layer_num) | |
| if pattern_groups[base_pattern]["example_key"] is None: | |
| pattern_groups[base_pattern]["example_key"] = key | |
| else: | |
| ungrouped[key] = shape | |
| # Format aggregated results | |
| aggregated = {} | |
| # Add grouped layers with layer ranges | |
| for pattern, info in sorted(pattern_groups.items()): | |
| layers = sorted(info["layers"]) | |
| if len(layers) > 1: | |
| layer_ranges = [] | |
| start = layers[0] | |
| end = layers[0] | |
| for i in range(1, len(layers)): | |
| if layers[i] == end + 1: | |
| end = layers[i] | |
| else: | |
| if start == end: | |
| layer_ranges.append(str(start)) | |
| else: | |
| layer_ranges.append(f"{start}-{end}") | |
| start = layers[i] | |
| end = layers[i] | |
| if start == end: | |
| layer_ranges.append(str(start)) | |
| else: | |
| layer_ranges.append(f"{start}-{end}") | |
| formatted_key = pattern.replace("[NUM]", f"[{','.join(layer_ranges)}]") | |
| aggregated[formatted_key] = { | |
| "shape": info["shape"], | |
| "count": len(layers) | |
| } | |
| else: | |
| # Single layer, use original key | |
| aggregated[info["example_key"]] = { | |
| "shape": info["shape"], | |
| "count": 1 | |
| } | |
| # Add ungrouped tensors | |
| for key, shape in sorted(ungrouped.items()): | |
| aggregated[key] = { | |
| "shape": shape, | |
| "count": 1 | |
| } | |
| return aggregated | |
| def calculate_params(shape: Tuple) -> int: | |
| """Calculate number of parameters from shape.""" | |
| if not shape: | |
| return 0 | |
| result = 1 | |
| for dim in shape: | |
| result *= dim | |
| return result | |
| def build_hierarchy_tree(aggregated: Dict[str, Dict]) -> Dict: | |
| """Build a hierarchical tree structure from tensor keys.""" | |
| tree = {} | |
| for key, info in aggregated.items(): | |
| parts = key.split('.') | |
| current = tree | |
| for i, part in enumerate(parts): | |
| if i == len(parts) - 1: | |
| # Leaf node | |
| current[part] = { | |
| "_leaf": True, | |
| "shape": info["shape"], | |
| "count": info["count"] | |
| } | |
| else: | |
| # Branch node | |
| if part not in current: | |
| current[part] = {} | |
| elif isinstance(current[part], dict) and current[part].get("_leaf"): | |
| # Handle conflict: already a leaf | |
| continue | |
| current = current[part] | |
| return tree | |
| def count_non_leaf_children(tree: Dict) -> int: | |
| """Count non-leaf children in a tree node.""" | |
| if not isinstance(tree, dict): | |
| return 0 | |
| count = 0 | |
| for key, value in tree.items(): | |
| if not (isinstance(value, dict) and value.get("_leaf")): | |
| count += 1 | |
| return count | |
| def merge_single_child_path(tree: Dict, current_path: str = "") -> Tuple[str, Dict]: | |
| """ | |
| Merge single-child nodes into a single path. | |
| Returns (merged_path, final_node). | |
| """ | |
| if not isinstance(tree, dict) or tree.get("_leaf"): | |
| return current_path, tree | |
| # Count non-leaf children | |
| non_leaf_count = count_non_leaf_children(tree) | |
| # If there's only one non-leaf child, merge it | |
| if non_leaf_count == 1: | |
| for key, value in tree.items(): | |
| if not (isinstance(value, dict) and value.get("_leaf")): | |
| new_path = f"{current_path}.{key}" if current_path else key | |
| return merge_single_child_path(value, new_path) | |
| return current_path, tree | |
| def print_tree_recursive(tree: Dict, parent_key: str, console: Console, indent: int = 0): | |
| """Recursively print the tree structure with rich formatting, merging single-child paths.""" | |
| for key in sorted(tree.keys()): | |
| value = tree[key] | |
| full_key = f"{parent_key}.{key}" if parent_key else key | |
| if isinstance(value, dict) and value.get("_leaf"): | |
| # Leaf node | |
| shape_str = "×".join(map(str, value["shape"])) | |
| param_count = calculate_params(value["shape"]) * value["count"] | |
| count_str = f" (×{value['count']} layers)" if value["count"] > 1 else "" | |
| console.print( | |
| " " * indent + f"[cyan]{key}[/cyan]: " | |
| f"[yellow]{shape_str}[/yellow]{count_str} " | |
| f"[dim]({param_count:,} params)[/dim]" | |
| ) | |
| else: | |
| # Branch node - check if we can merge single-child paths | |
| merged_path, final_node = merge_single_child_path(value, key) | |
| # Print the merged path | |
| console.print(" " * indent + f"[bold green]{merged_path}[/bold green]") | |
| # Print children of the final node | |
| for child_key in sorted(final_node.keys()): | |
| child_value = final_node[child_key] | |
| if isinstance(child_value, dict) and child_value.get("_leaf"): | |
| # Leaf node | |
| shape_str = "×".join(map(str, child_value["shape"])) | |
| param_count = calculate_params(child_value["shape"]) * child_value["count"] | |
| count_str = f" (×{child_value['count']} layers)" if child_value["count"] > 1 else "" | |
| console.print( | |
| " " * (indent + 1) + f"[cyan]{child_key}[/cyan]: " | |
| f"[yellow]{shape_str}[/yellow]{count_str} " | |
| f"[dim]({param_count:,} params)[/dim]" | |
| ) | |
| else: | |
| # Recursive call for deeper branches | |
| print_tree_recursive({child_key: child_value}, merged_path, console, indent + 1) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Read and analyze safetensors files (LoRA or model files)" | |
| ) | |
| parser.add_argument( | |
| "path", | |
| type=str, | |
| help="Path to a safetensors file or directory containing safetensors files" | |
| ) | |
| parser.add_argument( | |
| "--no-aggregate", | |
| action="store_true", | |
| help="Don't aggregate layers, show all tensors individually" | |
| ) | |
| args = parser.parse_args() | |
| console = Console() | |
| # Check if path exists | |
| path = Path(args.path) | |
| if not path.exists(): | |
| console.print(f"[red]Error: Path {args.path} does not exist[/red]") | |
| return | |
| # Read tensors | |
| console.print(f"\n[bold]Reading safetensors from:[/bold] {args.path}\n") | |
| try: | |
| if path.is_file(): | |
| tensors = read_safetensors_file(str(path)) | |
| else: | |
| tensors = read_safetensors_directory(str(path)) | |
| except Exception as e: | |
| console.print(f"[red]Error reading safetensors: {e}[/red]") | |
| return | |
| # Calculate total parameters | |
| total_params = sum(calculate_params(shape) for shape in tensors.values()) | |
| console.print(f"[bold green]Total tensors:[/bold green] {len(tensors):,}") | |
| console.print(f"[bold green]Total parameters:[/bold green] {total_params:,}") | |
| console.print(f"[bold green]Total parameters (M):[/bold green] {total_params / 1_000_000:.2f}M") | |
| console.print(f"[bold green]Total parameters (B):[/bold green] {total_params / 1_000_000_000:.4f}B\n") | |
| # Aggregate or show raw | |
| if args.no_aggregate: | |
| console.print("[bold]All Tensors (No Aggregation):[/bold]\n") | |
| for key, shape in sorted(tensors.items()): | |
| shape_str = "×".join(map(str, shape)) | |
| param_count = calculate_params(shape) | |
| console.print(f"[cyan]{key}[/cyan]: [yellow]{shape_str}[/yellow] [dim]({param_count:,} params)[/dim]") | |
| else: | |
| console.print("[bold]Layer Structure (Aggregated):[/bold]\n") | |
| aggregated = aggregate_layers(tensors) | |
| # Build and print hierarchy | |
| hierarchy = build_hierarchy_tree(aggregated) | |
| print_tree_recursive(hierarchy, "", console) | |
| console.print() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment