Created
November 16, 2025 07:40
-
-
Save kevinmingtarja/909f6017285cc3f16055adcd9c5c03c1 to your computer and use it in GitHub Desktop.
Generated code from RayCodegen (skypilot 0.10.5)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import functools | |
| import getpass | |
| import hashlib | |
| import io | |
| import os | |
| import pathlib | |
| import selectors | |
| import shlex | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import textwrap | |
| import time | |
| from typing import Dict, List, Optional, Tuple, Union | |
| # Set the environment variables to avoid deduplicating logs and | |
| # scheduler events. This should be set in driver code, since we are | |
| # not using `ray job submit` anymore, and the environment variables | |
| # from the ray cluster is not inherited. | |
| os.environ['RAY_DEDUP_LOGS'] = '0' | |
| os.environ['RAY_SCHEDULER_EVENTS'] = '0' | |
| import ray | |
| import ray.util as ray_util | |
| from sky.skylet import autostop_lib | |
| from sky.skylet import constants | |
| from sky.skylet import job_lib | |
| from sky.utils import log_utils | |
| from sky.utils import subprocess_utils | |
| SKY_REMOTE_WORKDIR = '~/sky_workdir' | |
| CANCELLED_RETURN_CODE = 137 | |
| kwargs = dict() | |
| # Only set the `_temp_dir` to SkyPilot's ray cluster directory when | |
| # the directory exists for backward compatibility for the VM | |
| # launched before #1790. | |
| if os.path.exists('/tmp/ray_skypilot'): | |
| kwargs['_temp_dir'] = '/tmp/ray_skypilot' | |
| ray.init( | |
| address='auto', | |
| namespace='__sky__1__', | |
| log_to_driver=True, | |
| **kwargs | |
| ) | |
| def get_or_fail(futures, pg) -> List[int]: | |
| """Wait for tasks, if any fails, cancel all unready.""" | |
| if not futures: | |
| return [], [] | |
| returncodes = [1] * len(futures) | |
| pids = [None] * len(futures) | |
| failed = False | |
| # Wait for 1 task to be ready. | |
| ready = [] | |
| # Keep invoking ray.wait if ready is empty. This is because | |
| # ray.wait with timeout=None will only wait for 10**6 seconds, | |
| # which will cause tasks running for more than 12 days to return | |
| # before becoming ready. | |
| # (Such tasks are common in serving jobs.) | |
| # Reference: https://github.com/ray-project/ray/blob/ray-2.9.3/python/ray/_private/worker.py#L2845-L2846 | |
| def handle_ready_tasks(tasks: List[ray.ObjectRef]) -> None: | |
| nonlocal returncodes, pids, failed | |
| for task in tasks: | |
| idx = futures.index(task) | |
| res = ray.get(task) | |
| returncodes[idx] = res['return_code'] | |
| pids[idx] = res['pid'] | |
| if res['return_code'] != 0: | |
| failed = True | |
| while not ready: | |
| ready, unready = ray.wait(futures) | |
| handle_ready_tasks(ready) | |
| while unready: | |
| if failed: | |
| for task in unready: | |
| # ray.cancel without force fails to kill tasks. | |
| # We use force=True to kill unready tasks. | |
| ray.cancel(task, force=True) | |
| # Use SIGKILL=128+9 to indicate the task is forcely | |
| # killed. | |
| idx = futures.index(task) | |
| returncodes[idx] = CANCELLED_RETURN_CODE | |
| break | |
| ready, unready = ray.wait(unready) | |
| handle_ready_tasks(ready) | |
| # Remove the placement group after all tasks are done, so that | |
| # the next job can be scheduled on the released resources | |
| # immediately. | |
| ray_util.remove_placement_group(pg) | |
| sys.stdout.flush() | |
| return returncodes, pids | |
| run_fn = None | |
| futures = [] | |
| class _ProcessingArgs: | |
| """Arguments for processing logs.""" | |
| def __init__(self, | |
| log_path: str, | |
| stream_logs: bool, | |
| start_streaming_at: str = '', | |
| end_streaming_at: Optional[str] = None, | |
| skip_lines: Optional[List[str]] = None, | |
| replace_crlf: bool = False, | |
| line_processor: Optional[log_utils.LineProcessor] = None, | |
| streaming_prefix: Optional[str] = None) -> None: | |
| self.log_path = log_path | |
| self.stream_logs = stream_logs | |
| self.start_streaming_at = start_streaming_at | |
| self.end_streaming_at = end_streaming_at | |
| self.skip_lines = skip_lines | |
| self.replace_crlf = replace_crlf | |
| self.line_processor = line_processor | |
| self.streaming_prefix = streaming_prefix | |
| def _get_context(): | |
| # TODO(aylei): remove this after we drop the backward-compatibility for | |
| # 0.9.x in 0.12.0 | |
| # Keep backward-compatibility for the old version of SkyPilot runtimes. | |
| if 'context' in globals(): | |
| return context.get() | |
| else: | |
| return None | |
| def _handle_io_stream(io_stream, out_stream, args: _ProcessingArgs): | |
| """Process the stream of a process.""" | |
| out_io = io.TextIOWrapper(io_stream, | |
| encoding='utf-8', | |
| newline='', | |
| errors='replace', | |
| write_through=True) | |
| start_streaming_flag = False | |
| end_streaming_flag = False | |
| streaming_prefix = args.streaming_prefix if args.streaming_prefix else '' | |
| line_processor = (log_utils.LineProcessor() | |
| if args.line_processor is None else args.line_processor) | |
| out = [] | |
| with open(args.log_path, 'a', encoding='utf-8') as fout: | |
| with line_processor: | |
| while True: | |
| ctx = _get_context() | |
| if ctx is not None and ctx.is_canceled(): | |
| return | |
| line = out_io.readline() | |
| if not line: | |
| break | |
| # start_streaming_at logic in processor.process_line(line) | |
| if args.replace_crlf and line.endswith('\r\n'): | |
| # Replace CRLF with LF to avoid ray logging to the same | |
| # line due to separating lines with '\n'. | |
| line = line[:-2] + '\n' | |
| if (args.skip_lines is not None and | |
| any(skip in line for skip in args.skip_lines)): | |
| continue | |
| if args.start_streaming_at in line: | |
| start_streaming_flag = True | |
| if (args.end_streaming_at is not None and | |
| args.end_streaming_at in line): | |
| # Keep executing the loop, only stop streaming. | |
| # E.g., this is used for `sky bench` to hide the | |
| # redundant messages of `sky launch` while | |
| # saving them in log files. | |
| end_streaming_flag = True | |
| if (args.stream_logs and start_streaming_flag and | |
| not end_streaming_flag): | |
| print(streaming_prefix + line, | |
| end='', | |
| file=out_stream, | |
| flush=True) | |
| if args.log_path != '/dev/null': | |
| fout.write(line) | |
| fout.flush() | |
| line_processor.process_line(line) | |
| out.append(line) | |
| return ''.join(out) | |
| def process_subprocess_stream(proc, stdout_stream_handler, | |
| stderr_stream_handler) -> Tuple[str, str]: | |
| """Process the stream of a process in threads, blocking.""" | |
| if proc.stderr is not None: | |
| # Asyncio does not work as the output processing can be executed in a | |
| # different thread. | |
| # selectors is possible to handle the multiplexing of stdout/stderr, | |
| # but it introduces buffering making the output not streaming. | |
| with multiprocessing.pool.ThreadPool(processes=1) as pool: | |
| stderr_fut = pool.apply_async(stderr_stream_handler, | |
| args=(proc.stderr, sys.stderr)) | |
| # Do not launch a thread for stdout as the rich.status does not | |
| # work in a thread, which is used in | |
| # log_utils.RayUpLineProcessor. | |
| stdout = stdout_stream_handler(proc.stdout, sys.stdout) | |
| stderr = stderr_fut.get() | |
| else: | |
| stdout = stdout_stream_handler(proc.stdout, sys.stdout) | |
| stderr = '' | |
| return stdout, stderr | |
| def run_with_log( | |
| cmd: Union[List[str], str], | |
| log_path: str, | |
| *, | |
| require_outputs: bool = False, | |
| stream_logs: bool = False, | |
| start_streaming_at: str = '', | |
| end_streaming_at: Optional[str] = None, | |
| skip_lines: Optional[List[str]] = None, | |
| shell: bool = False, | |
| with_ray: bool = False, | |
| process_stream: bool = True, | |
| line_processor: Optional[log_utils.LineProcessor] = None, | |
| streaming_prefix: Optional[str] = None, | |
| log_cmd: bool = False, | |
| **kwargs, | |
| ) -> Union[int, Tuple[int, str, str]]: | |
| """Runs a command and logs its output to a file. | |
| Args: | |
| cmd: The command to run. | |
| log_path: The path to the log file. | |
| stream_logs: Whether to stream the logs to stdout/stderr. | |
| require_outputs: Whether to return the stdout/stderr of the command. | |
| process_stream: Whether to post-process the stdout/stderr of the | |
| command, such as replacing or skipping lines on the fly. If | |
| enabled, lines are printed only when '\r' or '\n' is found. | |
| Returns the returncode or returncode, stdout and stderr of the command. | |
| Note that the stdout and stderr is already decoded. | |
| """ | |
| assert process_stream or not require_outputs, ( | |
| process_stream, require_outputs, | |
| 'require_outputs should be False when process_stream is False') | |
| log_path = os.path.expanduser(log_path) | |
| dirname = os.path.dirname(log_path) | |
| os.makedirs(dirname, exist_ok=True) | |
| # Redirect stderr to stdout when using ray, to preserve the order of | |
| # stdout and stderr. | |
| stdout_arg = stderr_arg = None | |
| ctx = _get_context() | |
| if process_stream or ctx is not None: | |
| # Capture stdout/stderr of the subprocess if: | |
| # 1. Post-processing is needed (process_stream=True) | |
| # 2. Potential contextual handling is needed (ctx is not None) | |
| # TODO(aylei): can we always capture the stdout/stderr? | |
| stdout_arg = subprocess.PIPE | |
| stderr_arg = subprocess.PIPE if not with_ray else subprocess.STDOUT | |
| # Use stdin=subprocess.DEVNULL by default, as allowing inputs will mess up | |
| # the terminal output when typing in the terminal that starts the API | |
| # server. | |
| stdin = kwargs.pop('stdin', subprocess.DEVNULL) | |
| if log_cmd: | |
| with open(log_path, 'a', encoding='utf-8') as f: | |
| print(f'Running command: {cmd}', file=f) | |
| with subprocess.Popen(cmd, | |
| stdout=stdout_arg, | |
| stderr=stderr_arg, | |
| start_new_session=True, | |
| shell=shell, | |
| stdin=stdin, | |
| **kwargs) as proc: | |
| try: | |
| if ctx is not None: | |
| # When runs in coroutine, use kill_pg if available to avoid | |
| # the overhead of refreshing the process tree in the daemon. | |
| subprocess_utils.kill_process_daemon(proc.pid, use_kill_pg=True) | |
| else: | |
| # For backward compatibility, do not specify use_kill_pg by | |
| # default. | |
| subprocess_utils.kill_process_daemon(proc.pid) | |
| stdout = '' | |
| stderr = '' | |
| stdout_stream_handler = None | |
| stderr_stream_handler = None | |
| if process_stream: | |
| if skip_lines is None: | |
| skip_lines = [] | |
| # Skip these lines caused by `-i` option of bash. Failed to | |
| # find other way to turn off these two warning. | |
| # https://stackoverflow.com/questions/13300764/how-to-tell-bash-not-to-issue-warnings-cannot-set-terminal-process-group-and # pylint: disable=line-too-long | |
| # `ssh -T -i -tt` still cause the problem. | |
| skip_lines += [ | |
| 'bash: cannot set terminal process group', | |
| 'bash: no job control in this shell', | |
| ] | |
| # We need this even if the log_path is '/dev/null' to ensure the | |
| # progress bar is shown. | |
| # NOTE: Lines are printed only when '\r' or '\n' is found. | |
| args = _ProcessingArgs( | |
| log_path=log_path, | |
| stream_logs=stream_logs, | |
| start_streaming_at=start_streaming_at, | |
| end_streaming_at=end_streaming_at, | |
| skip_lines=skip_lines, | |
| line_processor=line_processor, | |
| # Replace CRLF when the output is logged to driver by ray. | |
| replace_crlf=with_ray, | |
| streaming_prefix=streaming_prefix, | |
| ) | |
| stdout_stream_handler = functools.partial( | |
| _handle_io_stream, | |
| args=args, | |
| ) | |
| if proc.stderr is not None: | |
| err_args = copy.copy(args) | |
| err_args.line_processor = None | |
| stderr_stream_handler = functools.partial( | |
| _handle_io_stream, | |
| args=err_args, | |
| ) | |
| if ctx is not None: | |
| # When runs in a coroutine, always process the subprocess | |
| # stream to: | |
| # 1. handle context cancellation | |
| # 2. redirect subprocess stdout/stderr to the contextual | |
| # stdout/stderr of current coroutine. | |
| stdout, stderr = context_utils.pipe_and_wait_process( | |
| ctx, | |
| proc, | |
| stdout_stream_handler=stdout_stream_handler, | |
| stderr_stream_handler=stderr_stream_handler) | |
| elif process_stream: | |
| # When runs in a process, only process subprocess stream if | |
| # necessary to avoid unnecessary stream handling overhead. | |
| stdout, stderr = process_subprocess_stream( | |
| proc, stdout_stream_handler, stderr_stream_handler) | |
| # Ensure returncode is set. | |
| proc.wait() | |
| if require_outputs: | |
| return proc.returncode, stdout, stderr | |
| return proc.returncode | |
| except KeyboardInterrupt: | |
| # Kill the subprocess directly, otherwise, the underlying | |
| # process will only be killed after the python program exits, | |
| # causing the stream handling stuck at `readline`. | |
| subprocess_utils.kill_children_processes() | |
| raise | |
| def make_task_bash_script(codegen: str, | |
| env_vars: Optional[Dict[str, str]] = None) -> str: | |
| # set -a is used for exporting all variables functions to the environment | |
| # so that bash `user_script` can access `conda activate`. Detail: #436. | |
| # Reference: https://www.gnu.org/software/bash/manual/html_node/The-Set-Builtin.html # pylint: disable=line-too-long | |
| # DEACTIVATE_SKY_REMOTE_PYTHON_ENV: Deactivate the SkyPilot runtime env, as | |
| # the ray cluster is started within the runtime env, which may cause the | |
| # user program to run in that env as well. | |
| # PYTHONUNBUFFERED is used to disable python output buffering. | |
| script = [ | |
| textwrap.dedent(f"""\ | |
| #!/bin/bash | |
| source ~/.bashrc | |
| set -a | |
| . $(conda info --base 2> /dev/null)/etc/profile.d/conda.sh > /dev/null 2>&1 || true | |
| set +a | |
| {constants.DEACTIVATE_SKY_REMOTE_PYTHON_ENV} | |
| export PYTHONUNBUFFERED=1 | |
| cd {constants.SKY_REMOTE_WORKDIR}"""), | |
| ] | |
| if env_vars is not None: | |
| for k, v in env_vars.items(): | |
| script.append(f'export {k}={shlex.quote(str(v))}') | |
| script += [ | |
| codegen, | |
| '', # New line at EOF. | |
| ] | |
| script = '\n'.join(script) | |
| return script | |
| def add_ray_env_vars( | |
| env_vars: Optional[Dict[str, str]] = None) -> Dict[str, str]: | |
| # Adds Ray-related environment variables. | |
| if env_vars is None: | |
| env_vars = {} | |
| ray_env_vars = [ | |
| 'CUDA_VISIBLE_DEVICES', 'RAY_CLIENT_MODE', 'RAY_JOB_ID', | |
| 'RAY_RAYLET_PID', 'OMP_NUM_THREADS' | |
| ] | |
| env_dict = dict(os.environ) | |
| for env_var in ray_env_vars: | |
| if env_var in env_dict: | |
| env_vars[env_var] = env_dict[env_var] | |
| return env_vars | |
| def run_bash_command_with_log(bash_command: str, | |
| log_path: str, | |
| env_vars: Optional[Dict[str, str]] = None, | |
| stream_logs: bool = False, | |
| with_ray: bool = False): | |
| with tempfile.NamedTemporaryFile('w', prefix='sky_app_', | |
| delete=False) as fp: | |
| bash_command = make_task_bash_script(bash_command, env_vars=env_vars) | |
| fp.write(bash_command) | |
| fp.flush() | |
| script_path = fp.name | |
| # Need this `-i` option to make sure `source ~/.bashrc` work. | |
| inner_command = f'/bin/bash -i {script_path}' | |
| return run_with_log(inner_command, | |
| log_path, | |
| stream_logs=stream_logs, | |
| with_ray=with_ray, | |
| shell=True) | |
| def run_bash_command_with_log_and_return_pid( | |
| bash_command: str, | |
| log_path: str, | |
| env_vars: Optional[Dict[str, str]] = None, | |
| stream_logs: bool = False, | |
| with_ray: bool = False): | |
| return_code = run_bash_command_with_log(bash_command, log_path, env_vars, | |
| stream_logs, with_ray) | |
| return {'return_code': return_code, 'pid': os.getpid()} | |
| run_bash_command_with_log = run_bash_command_with_log | |
| run_bash_command_with_log_and_return_pid = ray.remote(run_bash_command_with_log_and_return_pid) | |
| if hasattr(autostop_lib, 'set_last_active_time_to_now'): | |
| autostop_lib.set_last_active_time_to_now() | |
| job_lib.set_status(1, job_lib.JobStatus.PENDING) | |
| pg = ray_util.placement_group([{"CPU": 0.5}], 'STRICT_SPREAD') | |
| plural = 's' if 1 > 1 else '' | |
| node_str = f'1 node{plural}' | |
| message = ('├── ' | |
| 'Waiting for task resources on ' | |
| f'{node_str}.') | |
| print(message, flush=True) | |
| # FIXME: This will print the error message from autoscaler if | |
| # it is waiting for other task to finish. We should hide the | |
| # error message. | |
| ray.get(pg.ready()) | |
| print('\x1b[2m└── \x1b[0mJob started. Streaming logs... \x1b[2m(Ctrl-C to exit log streaming; job will not be killed)\x1b[0m', flush=True) | |
| job_lib.set_job_started(1) | |
| job_lib.scheduler.schedule_step() | |
| @ray.remote | |
| def check_ip(): | |
| return ray.util.get_node_ip_address() | |
| gang_scheduling_id_to_ip = ray.get([ | |
| check_ip.options( | |
| num_cpus=0.5, | |
| scheduling_strategy=ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy( | |
| placement_group=pg, | |
| placement_group_bundle_index=i | |
| )).remote() | |
| for i in range(pg.bundle_count) | |
| ]) | |
| cluster_ips_to_node_id = {ip: i for i, ip in enumerate(['10.12.2.52'])} | |
| job_ip_rank_list = sorted(gang_scheduling_id_to_ip, key=cluster_ips_to_node_id.get) | |
| job_ip_rank_map = {ip: i for i, ip in enumerate(job_ip_rank_list)} | |
| job_ip_list_str = '\n'.join(job_ip_rank_list) | |
| sky_env_vars_dict = {} | |
| sky_env_vars_dict['SKYPILOT_NODE_IPS'] = job_ip_list_str | |
| sky_env_vars_dict['SKYPILOT_NUM_NODES'] = len(job_ip_rank_list) | |
| sky_env_vars_dict['SKYPILOT_TASK_ID'] = 'sky-2025-11-14-21-37-15-955021_scale-test-active_1' | |
| sky_env_vars_dict['SKYPILOT_CLUSTER_INFO'] = '{"cluster_name": "scale-test-active", "cloud": "Kubernetes", "region": "gke_sky-dev-465_us-west1-a_repro", "zone": null}' | |
| script = "echo 'test'" | |
| rclone_flush_script = '\n# Only waits if cached mount is enabled (RCLONE_MOUNT_CACHED_LOG_DIR is not empty)\n# findmnt alone is not enough, as some clouds (e.g. AWS on ARM64) uses\n# rclone for normal mounts as well.\nif [ $(findmnt -t fuse.rclone --noheading | wc -l) -gt 0 ] && [ -d ~/.sky/rclone_log ] && [ "$(ls -A ~/.sky/rclone_log)" ]; then\n flushed=0\n # extra second on top of --vfs-cache-poll-interval to\n # avoid race condition between rclone log line creation and this check.\n sleep 1\n while [ $flushed -eq 0 ]; do\n # sleep for the same interval as --vfs-cache-poll-interval\n sleep 10\n flushed=1\n for file in ~/.sky/rclone_log/*; do\n exitcode=0\n tac $file | grep "vfs cache: cleaned:" -m 1 | grep "in use 0, to upload 0, uploading 0" -q || exitcode=$?\n if [ $exitcode -ne 0 ]; then\n echo "skypilot: cached mount is still uploading to remote"\n flushed=0\n break\n fi\n done\n done\n echo "skypilot: cached mount uploaded complete"\nfi' | |
| if run_fn is not None: | |
| script = run_fn(0, gang_scheduling_id_to_ip) | |
| if script is not None: | |
| script=f'unset RAY_RAYLET_PID; {script}' | |
| script += rclone_flush_script | |
| sky_env_vars_dict['SKYPILOT_NUM_GPUS_PER_NODE'] = 0 | |
| ip = gang_scheduling_id_to_ip[0] | |
| rank = job_ip_rank_map[ip] | |
| if len(cluster_ips_to_node_id) == 1: # Single-node task on single-node cluter | |
| name_str = 'sky-cmd,' if 'sky-cmd' != None else 'task,' | |
| log_path = os.path.expanduser(os.path.join('~/sky_logs/1-sky-cmd/tasks', 'run.log')) | |
| else: # Single-node or multi-node task on multi-node cluster | |
| idx_in_cluster = cluster_ips_to_node_id[ip] | |
| if cluster_ips_to_node_id[ip] == 0: | |
| node_name = 'head' | |
| else: | |
| node_name = f'worker{idx_in_cluster}' | |
| name_str = f'{node_name}, rank={rank},' | |
| log_path = os.path.expanduser(os.path.join('~/sky_logs/1-sky-cmd/tasks', f'{rank}-{node_name}.log')) | |
| sky_env_vars_dict['SKYPILOT_NODE_RANK'] = rank | |
| sky_env_vars_dict['SKYPILOT_INTERNAL_JOB_ID'] = 1 | |
| futures.append(run_bash_command_with_log_and_return_pid \ | |
| .options(name=name_str, num_cpus=0.5, scheduling_strategy=ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0)) \ | |
| .remote( | |
| script, | |
| log_path, | |
| env_vars=sky_env_vars_dict, | |
| stream_logs=True, | |
| with_ray=True, | |
| )) | |
| returncodes, _ = get_or_fail(futures, pg) | |
| if sum(returncodes) != 0: | |
| job_lib.set_status(1, job_lib.JobStatus.FAILED) | |
| # Schedule the next pending job immediately to make the job | |
| # scheduling more efficient. | |
| job_lib.scheduler.schedule_step() | |
| # This waits for all streaming logs to finish. | |
| time.sleep(0.5) | |
| reason = '' | |
| # 139 is the return code of SIGSEGV, i.e. Segmentation Fault. | |
| if any(r == 139 for r in returncodes): | |
| reason = '(likely due to Segmentation Fault)' | |
| if any(r == 137 for r in returncodes): | |
| # Find the first non-137 return code | |
| non_137 = next(r for r in returncodes if r != 137) | |
| reason = f'(A Worker failed with return code {non_137}, SkyPilot cleaned up the processes on other nodes with return code 137)' | |
| print('ERROR: Job 1 failed with ' | |
| 'return code list:', | |
| returncodes, | |
| reason, | |
| flush=True) | |
| # Need this to set the job status in ray job to be FAILED. | |
| sys.exit(1) | |
| else: | |
| job_lib.set_status(1, job_lib.JobStatus.SUCCEEDED) | |
| # Schedule the next pending job immediately to make the job | |
| # scheduling more efficient. | |
| job_lib.scheduler.schedule_step() | |
| # This waits for all streaming logs to finish. | |
| time.sleep(0.5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment