Last active
August 9, 2024 12:17
-
-
Save cn-ml/a9d1e8ec7049f0c77ff748bee10671dc to your computer and use it in GitHub Desktop.
Docker container database backup script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| from argparse import OPTIONAL, ArgumentParser | |
| from datetime import datetime | |
| from functools import wraps | |
| from gzip import compress, decompress | |
| from pathlib import Path | |
| from re import compile | |
| from subprocess import PIPE, run | |
| from typing import ( | |
| Callable, | |
| Iterable, | |
| Literal, | |
| ParamSpec, | |
| TypeVar, | |
| cast, | |
| get_args, | |
| ) | |
| Params = ParamSpec("Params") | |
| T = TypeVar("T") | |
| DbType = Literal["postgres", "mariadb"] | |
| DB_TYPES = set(cast(tuple[DbType, ...], get_args(DbType))) | |
| DEFAULT_DB: DbType = "postgres" | |
| TIMESTAMP_FORMAT = "%Y%m%d%H%M%S" | |
| FILENAME_REGEX = compile(r"^backup_(?P<time>\d{14})\.sql\.gz$") | |
| def listify(f: Callable[Params, Iterable[T]]): | |
| @wraps(f) | |
| def wrapper(*args: Params.args, **kwargs: Params.kwargs): | |
| result = f(*args, **kwargs) | |
| return list(result) | |
| return wrapper | |
| class BackupDirectory: | |
| def __init__(self, directory: str | Path): | |
| if isinstance(directory, str): | |
| directory = Path(directory) | |
| self.directory = directory | |
| print(f"Backup directory at {directory}") | |
| return super().__init__() | |
| def create(self, content: str): | |
| return BackupFile.create(self.directory, content) | |
| def find_backups(self): | |
| try: | |
| files = list(filter(lambda x: x.is_file(), self.directory.iterdir())) | |
| except FileNotFoundError: | |
| return | |
| for file in files: | |
| if (match := FILENAME_REGEX.search(file.name)) is None: | |
| continue | |
| try: | |
| timestamp = datetime.strptime(match.group("time"), TIMESTAMP_FORMAT) | |
| except ValueError as e: | |
| print(f"Warning: invalid timestamp for {file}: {e}") | |
| continue | |
| yield BackupFile(file, timestamp) | |
| def get_recent_backup(self): | |
| # existing backups in order of most recent to least recent | |
| existing_backups = list(self.find_backups()) | |
| print(f"Found {len(existing_backups)} existing backups!") | |
| if len(existing_backups) == 0: | |
| return None | |
| existing_backups.sort(key=lambda x: x.timestamp, reverse=True) | |
| return existing_backups[0] | |
| class BackupFile: | |
| def __init__(self, file: Path, timestamp: datetime): | |
| self.file = file | |
| self.timestamp = timestamp | |
| return super().__init__() | |
| def write(self, content: str): | |
| directory = self.file.parent | |
| directory.mkdir(parents=True, exist_ok=True) | |
| encoded = content.encode("utf-8") | |
| compressed = compress(encoded) | |
| written = self.file.write_bytes(compressed) | |
| print(f"{written} bytes written to backup at {self.file.absolute()}") | |
| def read(self): | |
| compressed = self.file.read_bytes() | |
| encoded = decompress(compressed) | |
| return encoded.decode("utf-8") | |
| @classmethod | |
| def create(cls, directory: Path, content: str | None = None): | |
| timestamp = datetime.now() | |
| file = directory / f"backup_{timestamp.strftime(TIMESTAMP_FORMAT)}.sql.gz" | |
| backup = cls(file, timestamp) | |
| if content is not None: | |
| backup.write(content) | |
| return backup | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "directory", type=BackupDirectory, help="Destination folder for backup files." | |
| ) | |
| parser.add_argument( | |
| "container", type=str, help="Container name of the database container." | |
| ) | |
| parser.add_argument( | |
| "role", | |
| nargs=OPTIONAL, | |
| type=str, | |
| help="Role name of the root user.", | |
| default="root", | |
| ) | |
| parser.add_argument( | |
| "--password", | |
| type=str, | |
| required=False, | |
| help="Password of the user.", | |
| ) | |
| parser.add_argument( | |
| "type", | |
| nargs=OPTIONAL, | |
| help="Type of the database container.", | |
| choices=DB_TYPES, | |
| default=DEFAULT_DB, | |
| ) # constrained to DbType by choices | |
| return parser.parse_args() | |
| def read_output(*command: str): | |
| process = run(command, stdout=PIPE, check=True) | |
| return process.stdout | |
| def read_docker_command(container: str, *cmd: str): | |
| return read_output("docker", "exec", container, *cmd).decode("utf-8") | |
| def create_backup(container: str, role: str, password: str | None, db_type: DbType): | |
| match db_type: | |
| case "postgres": | |
| return read_docker_command( | |
| container, "pg_dumpall", "--username", role, "--clean", "--if-exists" | |
| ) | |
| case "mariadb": | |
| assert password is not None, f"Password required for {db_type}" | |
| return read_docker_command( | |
| container, | |
| "mariadb-dump", | |
| "-u", | |
| role, | |
| f"--password={password}", | |
| "--skip-dump-date", | |
| "--all-databases", | |
| ) | |
| raise NotImplementedError(f"Unhandled db type {db_type!r}!") | |
| def run_backup( | |
| directory: BackupDirectory, | |
| container: str, | |
| role: str, | |
| password: str | None, | |
| db_type: DbType, | |
| ): | |
| backup = create_backup(container, role, password, db_type) | |
| if (recent_backup := directory.get_recent_backup()) is not None: | |
| recent_content = recent_backup.read() | |
| if recent_content == backup: | |
| age = datetime.now() - recent_backup.timestamp | |
| print(f"Backup has not changed since most recent backup ({age} old)") | |
| return recent_backup | |
| return directory.create(backup) | |
| def main(): | |
| args = parse_args() | |
| directory = cast(BackupDirectory, args.directory) | |
| container = cast(str, args.container) | |
| role = cast(str, args.role) | |
| password = cast(str | None, args.password) | |
| db_type = cast(DbType, args.type) | |
| run_backup(directory, container, role, password, db_type) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment