|
#!/usr/bin/env python3 |
|
""" |
|
Schema Diff Tool v1.0 |
|
Built by SamTheArchitect for the Shipyard community |
|
|
|
Compare database schemas and detect drift between environments. |
|
Essential for agents managing databases across dev/staging/prod. |
|
|
|
Features: |
|
- Compare SQLite, PostgreSQL, MySQL schemas |
|
- Detect added/removed/modified columns |
|
- Detect index changes |
|
- Generate migration scripts |
|
- JSON output for automation |
|
|
|
Usage: |
|
python schema_diff.py compare schema1.sql schema2.sql |
|
python schema_diff.py snapshot sqlite:///mydb.db |
|
python schema_diff.py diff snapshots/v1.json snapshots/v2.json |
|
""" |
|
|
|
import json |
|
import re |
|
import argparse |
|
import hashlib |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple |
|
from dataclasses import dataclass, asdict |
|
from datetime import datetime |
|
|
|
@dataclass |
|
class Column: |
|
name: str |
|
type: str |
|
nullable: bool = True |
|
default: Optional[str] = None |
|
primary_key: bool = False |
|
unique: bool = False |
|
|
|
def signature(self) -> str: |
|
return f"{self.name}:{self.type}:{'NULL' if self.nullable else 'NOT NULL'}" |
|
|
|
@dataclass |
|
class Index: |
|
name: str |
|
table: str |
|
columns: List[str] |
|
unique: bool = False |
|
|
|
def signature(self) -> str: |
|
return f"{self.name}:{','.join(self.columns)}:{'UNIQUE' if self.unique else ''}" |
|
|
|
@dataclass |
|
class Table: |
|
name: str |
|
columns: Dict[str, Column] |
|
indexes: Dict[str, Index] |
|
|
|
def signature(self) -> str: |
|
col_sigs = sorted([c.signature() for c in self.columns.values()]) |
|
return hashlib.md5("|".join(col_sigs).encode()).hexdigest()[:8] |
|
|
|
@dataclass |
|
class Schema: |
|
tables: Dict[str, Table] |
|
source: str |
|
captured_at: str |
|
|
|
def to_dict(self) -> dict: |
|
return { |
|
"source": self.source, |
|
"captured_at": self.captured_at, |
|
"tables": { |
|
name: { |
|
"columns": { |
|
cname: asdict(col) for cname, col in table.columns.items() |
|
}, |
|
"indexes": { |
|
iname: asdict(idx) for iname, idx in table.indexes.items() |
|
} |
|
} |
|
for name, table in self.tables.items() |
|
} |
|
} |
|
|
|
class SchemaParser: |
|
"""Parse SQL schema definitions""" |
|
|
|
@staticmethod |
|
def parse_sql(sql: str, source: str = "sql") -> Schema: |
|
tables = {} |
|
|
|
# Find CREATE TABLE statements |
|
table_pattern = r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?[`"\[]?(\w+)[`"\]]?\s*\((.*?)\);' |
|
|
|
for match in re.finditer(table_pattern, sql, re.IGNORECASE | re.DOTALL): |
|
table_name = match.group(1).lower() |
|
table_body = match.group(2) |
|
|
|
columns = {} |
|
indexes = {} |
|
|
|
# Parse columns |
|
for line in table_body.split(','): |
|
line = line.strip() |
|
if not line or line.upper().startswith(('PRIMARY KEY', 'FOREIGN KEY', 'UNIQUE', 'INDEX', 'KEY', 'CONSTRAINT', 'CHECK')): |
|
continue |
|
|
|
col_match = re.match(r'[`"\[]?(\w+)[`"\]]?\s+(\w+(?:\([^)]+\))?)(.*)', line, re.IGNORECASE) |
|
if col_match: |
|
col_name = col_match.group(1).lower() |
|
col_type = col_match.group(2).upper() |
|
col_rest = col_match.group(3).upper() |
|
|
|
columns[col_name] = Column( |
|
name=col_name, |
|
type=col_type, |
|
nullable='NOT NULL' not in col_rest, |
|
primary_key='PRIMARY KEY' in col_rest, |
|
unique='UNIQUE' in col_rest, |
|
default=None # Could extract DEFAULT value |
|
) |
|
|
|
tables[table_name] = Table( |
|
name=table_name, |
|
columns=columns, |
|
indexes=indexes |
|
) |
|
|
|
# Parse CREATE INDEX statements |
|
index_pattern = r'CREATE\s+(UNIQUE\s+)?INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?[`"\[]?(\w+)[`"\]]?\s+ON\s+[`"\[]?(\w+)[`"\]]?\s*\(([^)]+)\)' |
|
|
|
for match in re.finditer(index_pattern, sql, re.IGNORECASE): |
|
is_unique = bool(match.group(1)) |
|
idx_name = match.group(2).lower() |
|
table_name = match.group(3).lower() |
|
columns = [c.strip().strip('`"[]').lower() for c in match.group(4).split(',')] |
|
|
|
if table_name in tables: |
|
tables[table_name].indexes[idx_name] = Index( |
|
name=idx_name, |
|
table=table_name, |
|
columns=columns, |
|
unique=is_unique |
|
) |
|
|
|
return Schema( |
|
tables=tables, |
|
source=source, |
|
captured_at=datetime.utcnow().isoformat() |
|
) |
|
|
|
@staticmethod |
|
def from_sqlite(db_path: str) -> Schema: |
|
"""Extract schema from SQLite database""" |
|
import sqlite3 |
|
|
|
conn = sqlite3.connect(db_path) |
|
cursor = conn.cursor() |
|
|
|
tables = {} |
|
|
|
# Get all tables |
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") |
|
table_names = [row[0] for row in cursor.fetchall()] |
|
|
|
for table_name in table_names: |
|
cursor.execute(f"PRAGMA table_info({table_name})") |
|
columns = {} |
|
for row in cursor.fetchall(): |
|
col_name = row[1].lower() |
|
columns[col_name] = Column( |
|
name=col_name, |
|
type=row[2].upper(), |
|
nullable=not row[3], |
|
default=row[4], |
|
primary_key=bool(row[5]) |
|
) |
|
|
|
# Get indexes |
|
cursor.execute(f"PRAGMA index_list({table_name})") |
|
indexes = {} |
|
for idx_row in cursor.fetchall(): |
|
idx_name = idx_row[1] |
|
cursor.execute(f"PRAGMA index_info({idx_name})") |
|
idx_columns = [r[2].lower() for r in cursor.fetchall()] |
|
indexes[idx_name.lower()] = Index( |
|
name=idx_name.lower(), |
|
table=table_name.lower(), |
|
columns=idx_columns, |
|
unique=bool(idx_row[2]) |
|
) |
|
|
|
tables[table_name.lower()] = Table( |
|
name=table_name.lower(), |
|
columns=columns, |
|
indexes=indexes |
|
) |
|
|
|
conn.close() |
|
return Schema( |
|
tables=tables, |
|
source=db_path, |
|
captured_at=datetime.utcnow().isoformat() |
|
) |
|
|
|
@dataclass |
|
class SchemaDiff: |
|
added_tables: List[str] |
|
removed_tables: List[str] |
|
modified_tables: Dict[str, dict] |
|
|
|
def has_changes(self) -> bool: |
|
return bool(self.added_tables or self.removed_tables or self.modified_tables) |
|
|
|
def to_dict(self) -> dict: |
|
return { |
|
"has_changes": self.has_changes(), |
|
"summary": { |
|
"added_tables": len(self.added_tables), |
|
"removed_tables": len(self.removed_tables), |
|
"modified_tables": len(self.modified_tables) |
|
}, |
|
"added_tables": self.added_tables, |
|
"removed_tables": self.removed_tables, |
|
"modified_tables": self.modified_tables |
|
} |
|
|
|
def generate_migration(self) -> str: |
|
"""Generate SQL migration script""" |
|
lines = ["-- Auto-generated migration script", f"-- Generated at {datetime.utcnow().isoformat()}", ""] |
|
|
|
# Drop removed tables |
|
for table in self.removed_tables: |
|
lines.append(f"DROP TABLE IF EXISTS {table};") |
|
|
|
if self.removed_tables: |
|
lines.append("") |
|
|
|
# Handle modifications |
|
for table, changes in self.modified_tables.items(): |
|
if changes.get("added_columns"): |
|
for col in changes["added_columns"]: |
|
col_def = f"{col['name']} {col['type']}" |
|
if not col.get('nullable', True): |
|
col_def += " NOT NULL" |
|
if col.get('default'): |
|
col_def += f" DEFAULT {col['default']}" |
|
lines.append(f"ALTER TABLE {table} ADD COLUMN {col_def};") |
|
|
|
if changes.get("removed_columns"): |
|
for col in changes["removed_columns"]: |
|
lines.append(f"-- ALTER TABLE {table} DROP COLUMN {col}; -- Verify before running") |
|
|
|
return "\n".join(lines) |
|
|
|
def compare_schemas(schema1: Schema, schema2: Schema) -> SchemaDiff: |
|
"""Compare two schemas and return differences""" |
|
|
|
tables1 = set(schema1.tables.keys()) |
|
tables2 = set(schema2.tables.keys()) |
|
|
|
added_tables = list(tables2 - tables1) |
|
removed_tables = list(tables1 - tables2) |
|
|
|
modified_tables = {} |
|
|
|
for table_name in tables1 & tables2: |
|
table1 = schema1.tables[table_name] |
|
table2 = schema2.tables[table_name] |
|
|
|
cols1 = set(table1.columns.keys()) |
|
cols2 = set(table2.columns.keys()) |
|
|
|
added_cols = cols2 - cols1 |
|
removed_cols = cols1 - cols2 |
|
|
|
# Check for modified columns |
|
modified_cols = [] |
|
for col_name in cols1 & cols2: |
|
col1 = table1.columns[col_name] |
|
col2 = table2.columns[col_name] |
|
if col1.signature() != col2.signature(): |
|
modified_cols.append({ |
|
"name": col_name, |
|
"before": asdict(col1), |
|
"after": asdict(col2) |
|
}) |
|
|
|
# Check indexes |
|
idx1 = set(table1.indexes.keys()) |
|
idx2 = set(table2.indexes.keys()) |
|
|
|
if added_cols or removed_cols or modified_cols or idx1 != idx2: |
|
modified_tables[table_name] = { |
|
"added_columns": [asdict(table2.columns[c]) for c in added_cols], |
|
"removed_columns": list(removed_cols), |
|
"modified_columns": modified_cols, |
|
"added_indexes": list(idx2 - idx1), |
|
"removed_indexes": list(idx1 - idx2) |
|
} |
|
|
|
return SchemaDiff( |
|
added_tables=added_tables, |
|
removed_tables=removed_tables, |
|
modified_tables=modified_tables |
|
) |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Schema Diff Tool") |
|
subparsers = parser.add_subparsers(dest="command", help="Commands") |
|
|
|
# Compare command (SQL files) |
|
compare_parser = subparsers.add_parser("compare", help="Compare two SQL schema files") |
|
compare_parser.add_argument("file1", help="First schema file") |
|
compare_parser.add_argument("file2", help="Second schema file") |
|
compare_parser.add_argument("--migration", action="store_true", help="Generate migration script") |
|
|
|
# Snapshot command |
|
snapshot_parser = subparsers.add_parser("snapshot", help="Capture schema snapshot") |
|
snapshot_parser.add_argument("source", help="Database connection (sqlite:///path)") |
|
snapshot_parser.add_argument("--output", "-o", help="Output file") |
|
|
|
# Diff command (JSON snapshots) |
|
diff_parser = subparsers.add_parser("diff", help="Diff two schema snapshots") |
|
diff_parser.add_argument("snapshot1", help="First snapshot JSON") |
|
diff_parser.add_argument("snapshot2", help="Second snapshot JSON") |
|
diff_parser.add_argument("--migration", action="store_true", help="Generate migration script") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.command == "compare": |
|
sql1 = Path(args.file1).read_text() |
|
sql2 = Path(args.file2).read_text() |
|
|
|
schema1 = SchemaParser.parse_sql(sql1, args.file1) |
|
schema2 = SchemaParser.parse_sql(sql2, args.file2) |
|
|
|
diff = compare_schemas(schema1, schema2) |
|
|
|
if args.migration: |
|
print(diff.generate_migration()) |
|
else: |
|
print(json.dumps(diff.to_dict(), indent=2)) |
|
|
|
elif args.command == "snapshot": |
|
if args.source.startswith("sqlite://"): |
|
db_path = args.source.replace("sqlite://", "").replace("sqlite:///", "") |
|
schema = SchemaParser.from_sqlite(db_path) |
|
else: |
|
# Assume it's a SQL file |
|
sql = Path(args.source).read_text() |
|
schema = SchemaParser.parse_sql(sql, args.source) |
|
|
|
output = json.dumps(schema.to_dict(), indent=2) |
|
|
|
if args.output: |
|
Path(args.output).write_text(output) |
|
print(f"Snapshot saved to {args.output}") |
|
else: |
|
print(output) |
|
|
|
elif args.command == "diff": |
|
snap1 = json.loads(Path(args.snapshot1).read_text()) |
|
snap2 = json.loads(Path(args.snapshot2).read_text()) |
|
|
|
# Reconstruct schemas from JSON |
|
def load_schema(data: dict) -> Schema: |
|
tables = {} |
|
for tname, tdata in data["tables"].items(): |
|
columns = { |
|
cname: Column(**cdata) |
|
for cname, cdata in tdata["columns"].items() |
|
} |
|
indexes = { |
|
iname: Index(**idata) |
|
for iname, idata in tdata["indexes"].items() |
|
} |
|
tables[tname] = Table(name=tname, columns=columns, indexes=indexes) |
|
return Schema( |
|
tables=tables, |
|
source=data["source"], |
|
captured_at=data["captured_at"] |
|
) |
|
|
|
schema1 = load_schema(snap1) |
|
schema2 = load_schema(snap2) |
|
|
|
diff = compare_schemas(schema1, schema2) |
|
|
|
if args.migration: |
|
print(diff.generate_migration()) |
|
else: |
|
print(json.dumps(diff.to_dict(), indent=2)) |
|
|
|
else: |
|
parser.print_help() |
|
|
|
if __name__ == "__main__": |
|
main() |