Created
January 13, 2026 05:07
-
-
Save yiliu30/172f4b039e6fd90bdd088c31491e6d6a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| from functools import wraps | |
| # from vllm import envs | |
| from loguru import logger | |
| def with_thread_limits(): | |
| """ | |
| Decorator to temporarily set OMP_NUM_THREADS and PyTorch threads, | |
| and restore them after the function call. | |
| Args: | |
| div_omp: divide CPU cores by this for OMP_NUM_THREADS | |
| div_torch: divide CPU cores by this for torch.set_num_threads | |
| """ | |
| def decorator(func): | |
| @wraps(func) | |
| def wrapper(*args, **kwargs): | |
| world_size = 1 | |
| if torch.distributed.is_initialized(): | |
| world_size = torch.distributed.get_world_size() | |
| world_size = min(world_size, 8) | |
| div_omp = world_size | |
| div_torch = world_size | |
| # Save original settings | |
| old_omp = os.environ.get("OMP_NUM_THREADS", None) | |
| old_torch = torch.get_num_threads() | |
| import psutil | |
| num_cores = len(psutil.Process().cpu_affinity() or [0]) | |
| # Set new limits | |
| os.environ["OMP_NUM_THREADS"] = str(max(1, num_cores // div_omp)) | |
| torch.set_num_threads(max(1, num_cores // div_torch)) | |
| logger.warning( | |
| # "Setting OMP_NUM_THREADS to %s and torch.set_num_threads to %s " | |
| # "for %s available CPU cores and world size %s", | |
| # os.environ["OMP_NUM_THREADS"], torch.get_num_threads(), | |
| # num_cores, world_size | |
| f"Setting OMP_NUM_THREADS to {os.environ['OMP_NUM_THREADS']} and torch.set_num_threads to {torch.get_num_threads()} " | |
| ) | |
| try: | |
| # Call the actual function | |
| return func(*args, **kwargs) | |
| finally: | |
| # Restore original settings | |
| if old_omp is None: | |
| os.environ.pop("OMP_NUM_THREADS", None) | |
| else: | |
| os.environ["OMP_NUM_THREADS"] = old_omp | |
| torch.set_num_threads(old_torch) | |
| return wrapper | |
| return decorator | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment