Created
January 21, 2026 10:16
-
-
Save Paillat-dev/eb618eb0f5a3a90929e666969ec5a647 to your computer and use it in GitHub Desktop.
python `require_extra` util
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import importlib | |
| import importlib.metadata as md | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from collections.abc import Mapping | |
| class ExtraRequirementError(RuntimeError): | |
| """Raised when an extra/backend requirement is not satisfied.""" | |
| def _norm_dist_name(name: str) -> str: | |
| """Normalize distribution name per PEP 503. | |
| Converts to lowercase and replaces underscores/dots with hyphens. | |
| Args: | |
| name: Distribution name to normalize. | |
| Returns: | |
| Normalized distribution name. | |
| Example: | |
| >>> _norm_dist_name("Py_Cord") | |
| 'py-cord' | |
| """ | |
| return name.strip().casefold().replace("_", "-").replace(".", "-") | |
| def _parse_version(v: str) -> tuple[int, ...]: | |
| """Parse version string into numeric tuple. | |
| Extracts numeric components and stops at first non-numeric part. | |
| Args: | |
| v: Version string to parse (e.g., "2.7.0rc2"). | |
| Returns: | |
| Tuple of version numbers (e.g., (2, 7, 0)). | |
| Example: | |
| >>> _parse_version("2.7.0rc2") | |
| (2, 7, 0) | |
| >>> _parse_version("3.11") | |
| (3, 11) | |
| """ | |
| parts: list[int] = [] | |
| buf = "" | |
| for ch in v: | |
| if ch.isdigit(): | |
| buf += ch | |
| elif buf: | |
| parts.append(int(buf)) | |
| buf = "" | |
| if buf: | |
| parts.append(int(buf)) | |
| return tuple(parts) if parts else (0,) | |
| def _packages_providing(top_level: str) -> set[str]: | |
| """Get normalized distribution names providing a top-level module. | |
| Args: | |
| top_level: Top-level module name (e.g., "discord"). | |
| Returns: | |
| Set of normalized distribution names that provide the module. | |
| Example: | |
| >>> _packages_providing("discord") | |
| {'py-cord'} | |
| """ | |
| mapping: Mapping[str, list[str]] = md.packages_distributions() | |
| providers = mapping.get(top_level, []) | |
| return {_norm_dist_name(p) for p in providers} | |
| def require_extra( | |
| *, | |
| package: str, | |
| import_name: str, | |
| version: tuple[int, ...] | None = None, | |
| ) -> None: | |
| """Ensure import is provided exclusively by package with correct version. | |
| Validates that: | |
| 1. The import name can be imported successfully. | |
| 2. Exactly one distribution provides it, and it's the specified package. | |
| 3. The package version is >= the required version (same major version). | |
| Args: | |
| package: Required package distribution name (e.g., "py-cord"). | |
| import_name: Top-level import name to check (e.g., "discord"). | |
| version: Minimum required version tuple (e.g., (2, 7, 0)). | |
| Major version must match exactly. Defaults to None (no check). | |
| Raises: | |
| ModuleNotFoundError: If import fails or package not found in metadata. | |
| ExtraRequirementError: If wrong provider or version mismatch detected. | |
| Example: | |
| >>> require_extra(package="py-cord", import_name="discord", version=(2, 7, 0)) | |
| # Raises if discord.py is installed instead of py-cord | |
| # Raises if py-cord version is 2.6.x or 3.x | |
| """ | |
| pkg_norm = _norm_dist_name(package) | |
| # 1) Ensure the import works | |
| try: | |
| importlib.import_module(import_name) | |
| except ModuleNotFoundError as e: | |
| version_str = f" (need >= {'.'.join(map(str, version))})" if version else "" | |
| msg = ( | |
| f"`{import_name}` is required but could not be imported.\n" | |
| + f"Install with: pip install '{package}'{version_str}" | |
| ) | |
| raise ModuleNotFoundError(msg) from e | |
| # 2) Ensure ONLY the requested package provides it | |
| providers = _packages_providing(import_name) | |
| if not providers: | |
| msg = ( | |
| f"Imported `{import_name}`, but no distribution claims it in metadata.\n" | |
| + f"Try: pip install -U --force-reinstall {package}" | |
| ) | |
| raise ExtraRequirementError(msg) | |
| if providers != {pkg_norm}: | |
| other = sorted(p for p in providers if p != pkg_norm) | |
| uninstall = f" pip uninstall -y {' '.join(other)}\n" if other else "" | |
| msg = ( | |
| f"`{import_name}` must be provided exclusively by `{package}`, " | |
| + f"but found: {', '.join(sorted(providers))}.\n" | |
| + f"Fix by uninstalling conflicts:\n{uninstall}" | |
| + f" pip install -U '{package}'" | |
| ) | |
| raise ExtraRequirementError(msg) | |
| # 3) Version check | |
| if version is None: | |
| return | |
| try: | |
| installed_str = md.version(package) | |
| except md.PackageNotFoundError as e: | |
| msg = f"Distribution `{package}` not found in metadata.\nInstall with: pip install '{package}'" | |
| raise ModuleNotFoundError(msg) from e | |
| installed = _parse_version(installed_str) | |
| required = tuple(version) | |
| max_len = max(len(installed), len(required)) | |
| inst_padded = installed + (0,) * (max_len - len(installed)) | |
| req_padded = required + (0,) * (max_len - len(required)) | |
| if installed[0] != required[0]: | |
| msg = ( | |
| f"`{package}` major version must be {required[0]}.x, " | |
| + f"but found {installed_str}.\n" | |
| + f"Install: pip install '{package}>={required[0]}.{required[1] if len(required) > 1 else 0}'" | |
| ) | |
| raise ExtraRequirementError(msg) | |
| if inst_padded < req_padded: | |
| msg = ( | |
| f"`{package}` >= {'.'.join(map(str, required))} required, " | |
| + f"found {installed_str}.\n" | |
| + f"Upgrade: pip install -U '{package}'" | |
| ) | |
| raise ExtraRequirementError(msg) | |
| __all__ = ("require_extra",) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment