Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active September 6, 2024 18:13
Show Gist options
  • Select an option

  • Save woshiyyya/a74d52e9c1365a03dc5eac18b5d04353 to your computer and use it in GitHub Desktop.

Select an option

Save woshiyyya/a74d52e9c1365a03dc5eac18b5d04353 to your computer and use it in GitHub Desktop.
zbh1 debug
Traceback (most recent call last):
File "/home/ray/default/skeleton_zb_h1.py", line 106, in <module>
ray.get(dag.execute(1))
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2648, in get
return object_refs.get(timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/compiled_dag_ref.py", line 90, in get
return_vals = self._dag._execute_until(self._execution_index, timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/dag/compiled_dag_node.py", line 1785, in _execute_until
return self._dag_output_fetcher.read(timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/channel/common.py", line 282, in read
outputs = self._read_list(timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/channel/common.py", line 307, in _read_list
results.append(c.read(timeout))
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/channel/shared_memory_channel.py", line 600, in read
return self._channel_dict[actor_id].read(timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/channel/shared_memory_channel.py", line 454, in read
ret = self._worker.get_objects(
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 848, in get_objects
data_metadata_pairs = self.core_worker.get_objects(
File "python/ray/_raylet.pyx", line 3509, in ray._raylet.CoreWorker.get_objects
File "python/ray/includes/common.pxi", line 100, in ray._raylet.check_status
ray.exceptions.RayChannelError: System error: Channel closed.
(Worker pid=34143) Destructing NCCL group on actor: Actor(Worker, 566c8e599d55ba925e665ae107000000)
(Worker pid=34145) ERROR:root:Compiled DAG task exited with exception
(Worker pid=34145) Traceback (most recent call last):
(Worker pid=34145) File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/dag/compiled_dag_node.py", line 116, in do_exec_tasks
(Worker pid=34145) done = tasks[operation.exec_task_idx].exec_operation(
(Worker pid=34145) File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/dag/compiled_dag_node.py", line 490, in exec_operation
(Worker pid=34145) return self._read()
(Worker pid=34145) File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/dag/compiled_dag_node.py", line 407, in _read
(Worker pid=34145) input_data = self.input_reader.read()
(Worker pid=34145) File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/channel/common.py", line 282, in read
(Worker pid=34145) outputs = self._read_list(timeout)
(Worker pid=34145) File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/channel/common.py", line 307, in _read_list
(Worker pid=34145) results.append(c.read(timeout))
(Worker pid=34145) File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/channel/shared_memory_channel.py", line 600, in read
(Worker pid=34145) return self._channel_dict[actor_id].read(timeout)
(Worker pid=34145) File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/channel/intra_process_channel.py", line 63, in read
(Worker pid=34145) return ctx.get_data(self._channel_id)
(Worker pid=34145) File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/channel/serialization_context.py", line 42, in get_data
(Worker pid=34145) channel_id in self.intra_process_channel_buffers
(Worker pid=34145) AssertionError: Channel 3de60c8c-0851-48ec-9573-1672445e8174 does not exist in the buffer.
(Worker pid=34144) Destructing NCCL group on actor: Actor(Worker, 5e591526c813185d7de35f0607000000)
import ray
import ray.cluster_utils
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.dag import InputNode, MultiOutputNode
from typing import Optional
from ray.dag.compiled_dag_node import CompiledDAG
@ray.remote(num_cpus=0, num_gpus=1)
class Worker:
def __init__(self, rank: Optional[int] = None):
self.rank = rank
self.trace = []
def fwd(self, value):
self.trace.append(("FWD", self.rank))
return value
@ray.method(num_returns=2)
def bwd(self, value):
self.trace.append(("BWD", self.rank))
return value, None
def w(self, value):
self.trace.append(("W", self.rank))
return None
def pop_trace(self):
trace = self.trace
self.trace = []
return trace
def read_input(self, input):
return input
def no_op(self, value):
return value
def no_op_two(self, value1, value2):
return value1, value2
def generate_zbh1_dag(num_workers: int, num_microbatches: int, num_lead_microbatches: int):
workers = [Worker.remote(rank) for rank in range(num_workers)]
with InputNode() as inp:
fwd_queues = [[] for _ in range(num_workers)]
bwd_queues = [[] for _ in range(num_workers)]
# Once a worker's counter reaches 0, it cannot execute another fwd until it
# executes a bwd first.
fwd_counter = [num_lead_microbatches - i for i in range(num_workers)]
bwd_counter = [0 for i in range(num_workers)]
# All of the done batches.
done = []
# FWD on worker 0.
input_data = workers[0].read_input.bind(inp)
for i in range(num_microbatches):
fwd_queues[0].append(input_data)
while len(done) < num_microbatches:
for i, worker in enumerate(workers):
if fwd_counter[i] > 0 and fwd_queues[i]:
b = fwd_queues[i].pop(0)
b = worker.fwd.bind(b)
if i < num_workers - 1:
fwd_queues[i + 1].append(b)
# Use NCCL channel for communication between workers.
b.with_type_hint(
TorchTensorType(transport=TorchTensorType.NCCL)
)
else:
bwd_queues[i].append(b)
fwd_counter[i] -= 1
elif bwd_queues[i]:
b = bwd_queues[i].pop(0)
# Is there any problem if c doesn't have a downstream node?
b, c = worker.bwd.bind(b)
if i > 0:
bwd_queues[i - 1].append(b)
# Use NCCL channel for communication between workers.
b.with_type_hint(
TorchTensorType(transport=TorchTensorType.NCCL)
)
else:
done.append(b)
fwd_counter[i] += 1
# Code change for Zero Bubble PP
# ++++++++++++++++++++++++++++++++++++++++++++++++
bwd_counter[i] += 1
if bwd_counter[i] > i:
c = worker.w.bind(c)
if bwd_counter[i] == num_microbatches:
for _ in range(i):
c = worker.w.bind(c)
# ++++++++++++++++++++++++++++++++++++++++++++++++
dag = MultiOutputNode(done)
compiled_dag = dag.experimental_compile()
return compiled_dag, workers
if __name__ == "__main__":
dag, workers = generate_zbh1_dag(num_workers=4, num_lead_microbatches=4, num_microbatches=8)
ray.get(dag.execute(1))
for i, worker in enumerate(workers):
trace = ray.get(worker.pop_trace.remote())
print(i, trace)
@woshiyyya
Copy link
Author

woshiyyya commented Sep 6, 2024

PP@2x (5)

The DAG I built is shown above. It is adapted from the 1F1B dag (white arrows). I added the red ones to support zbh1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment