Skip to content

Instantly share code, notes, and snippets.

@Paillat-dev
Created January 21, 2026 10:16
Show Gist options
  • Select an option

  • Save Paillat-dev/eb618eb0f5a3a90929e666969ec5a647 to your computer and use it in GitHub Desktop.

Select an option

Save Paillat-dev/eb618eb0f5a3a90929e666969ec5a647 to your computer and use it in GitHub Desktop.
python `require_extra` util
import importlib
import importlib.metadata as md
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping
class ExtraRequirementError(RuntimeError):
"""Raised when an extra/backend requirement is not satisfied."""
def _norm_dist_name(name: str) -> str:
"""Normalize distribution name per PEP 503.
Converts to lowercase and replaces underscores/dots with hyphens.
Args:
name: Distribution name to normalize.
Returns:
Normalized distribution name.
Example:
>>> _norm_dist_name("Py_Cord")
'py-cord'
"""
return name.strip().casefold().replace("_", "-").replace(".", "-")
def _parse_version(v: str) -> tuple[int, ...]:
"""Parse version string into numeric tuple.
Extracts numeric components and stops at first non-numeric part.
Args:
v: Version string to parse (e.g., "2.7.0rc2").
Returns:
Tuple of version numbers (e.g., (2, 7, 0)).
Example:
>>> _parse_version("2.7.0rc2")
(2, 7, 0)
>>> _parse_version("3.11")
(3, 11)
"""
parts: list[int] = []
buf = ""
for ch in v:
if ch.isdigit():
buf += ch
elif buf:
parts.append(int(buf))
buf = ""
if buf:
parts.append(int(buf))
return tuple(parts) if parts else (0,)
def _packages_providing(top_level: str) -> set[str]:
"""Get normalized distribution names providing a top-level module.
Args:
top_level: Top-level module name (e.g., "discord").
Returns:
Set of normalized distribution names that provide the module.
Example:
>>> _packages_providing("discord")
{'py-cord'}
"""
mapping: Mapping[str, list[str]] = md.packages_distributions()
providers = mapping.get(top_level, [])
return {_norm_dist_name(p) for p in providers}
def require_extra(
*,
package: str,
import_name: str,
version: tuple[int, ...] | None = None,
) -> None:
"""Ensure import is provided exclusively by package with correct version.
Validates that:
1. The import name can be imported successfully.
2. Exactly one distribution provides it, and it's the specified package.
3. The package version is >= the required version (same major version).
Args:
package: Required package distribution name (e.g., "py-cord").
import_name: Top-level import name to check (e.g., "discord").
version: Minimum required version tuple (e.g., (2, 7, 0)).
Major version must match exactly. Defaults to None (no check).
Raises:
ModuleNotFoundError: If import fails or package not found in metadata.
ExtraRequirementError: If wrong provider or version mismatch detected.
Example:
>>> require_extra(package="py-cord", import_name="discord", version=(2, 7, 0))
# Raises if discord.py is installed instead of py-cord
# Raises if py-cord version is 2.6.x or 3.x
"""
pkg_norm = _norm_dist_name(package)
# 1) Ensure the import works
try:
importlib.import_module(import_name)
except ModuleNotFoundError as e:
version_str = f" (need >= {'.'.join(map(str, version))})" if version else ""
msg = (
f"`{import_name}` is required but could not be imported.\n"
+ f"Install with: pip install '{package}'{version_str}"
)
raise ModuleNotFoundError(msg) from e
# 2) Ensure ONLY the requested package provides it
providers = _packages_providing(import_name)
if not providers:
msg = (
f"Imported `{import_name}`, but no distribution claims it in metadata.\n"
+ f"Try: pip install -U --force-reinstall {package}"
)
raise ExtraRequirementError(msg)
if providers != {pkg_norm}:
other = sorted(p for p in providers if p != pkg_norm)
uninstall = f" pip uninstall -y {' '.join(other)}\n" if other else ""
msg = (
f"`{import_name}` must be provided exclusively by `{package}`, "
+ f"but found: {', '.join(sorted(providers))}.\n"
+ f"Fix by uninstalling conflicts:\n{uninstall}"
+ f" pip install -U '{package}'"
)
raise ExtraRequirementError(msg)
# 3) Version check
if version is None:
return
try:
installed_str = md.version(package)
except md.PackageNotFoundError as e:
msg = f"Distribution `{package}` not found in metadata.\nInstall with: pip install '{package}'"
raise ModuleNotFoundError(msg) from e
installed = _parse_version(installed_str)
required = tuple(version)
max_len = max(len(installed), len(required))
inst_padded = installed + (0,) * (max_len - len(installed))
req_padded = required + (0,) * (max_len - len(required))
if installed[0] != required[0]:
msg = (
f"`{package}` major version must be {required[0]}.x, "
+ f"but found {installed_str}.\n"
+ f"Install: pip install '{package}>={required[0]}.{required[1] if len(required) > 1 else 0}'"
)
raise ExtraRequirementError(msg)
if inst_padded < req_padded:
msg = (
f"`{package}` >= {'.'.join(map(str, required))} required, "
+ f"found {installed_str}.\n"
+ f"Upgrade: pip install -U '{package}'"
)
raise ExtraRequirementError(msg)
__all__ = ("require_extra",)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment