Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Created August 19, 2024 23:54
Show Gist options
  • Select an option

  • Save woshiyyya/db31952b94e0216bc93ee3ef30ef32e6 to your computer and use it in GitHub Desktop.

Select an option

Save woshiyyya/db31952b94e0216bc93ee3ef30ef32e6 to your computer and use it in GitHub Desktop.
DistMM DAG Timeout Failure
import ray
from ray.air.util.torch_dist import _init_torch_distributed
from ray.air._internal.util import find_free_port
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import defaultdict
from dataclasses import dataclass, asdict
from typing import List
import numpy as np
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
@dataclass
class TorchDistributedConfig:
rank: int
local_rank: int
world_size: int
local_world_size: int
master_addr: str
master_port: str
gpu_ids: List[int]
def initialize_dist_group(workers):
"""Initialize PyTorch Distributed Process Group for a set of workers."""
worker_metadata = ray.get([worker.get_metadata.remote() for worker in workers])
for worker_id, metadata in enumerate(worker_metadata):
metadata["worker_id"] = worker_id
aggregated_metadata = defaultdict(list)
for metadata in worker_metadata:
aggregated_metadata[metadata["address"]].append(metadata)
for metadata_list_per_ip in aggregated_metadata.values():
metadata_list_per_ip.sort(key=lambda x: x["gpu_ids"])
rank = 0
world_size = len(workers)
dist_configs = dict()
for metadata_list_per_ip in aggregated_metadata.values():
local_rank = 0
local_world_size = len(metadata_list_per_ip)
visible_device_ids = []
for metadata in metadata_list_per_ip:
visible_device_ids += metadata["gpu_ids"]
for metadata in metadata_list_per_ip:
if rank == 0:
master_addr = metadata["address"]
master_port = metadata["port"]
worker_id = metadata["worker_id"]
worker_config = TorchDistributedConfig(
rank=rank,
local_rank=local_rank,
world_size=world_size,
local_world_size=local_world_size,
master_addr=master_addr,
master_port=master_port,
gpu_ids=visible_device_ids,
)
rank += 1
local_rank += 1
dist_configs[worker_id] = worker_config
ray.get(
[
worker.init_dist_group.remote(dist_configs[worker_id])
for worker_id, worker in enumerate(workers)
]
)
print("Finished initializing distributed process group.")
class BaseWorker:
def __init__(self) -> None:
pass
def get_metadata(self):
return {
"gpu_ids": ray.get_gpu_ids(),
"address": ray.util.get_node_ip_address(),
"port": find_free_port(),
}
def init_dist_group(self, dist_config):
self.dist_config = dist_config
_init_torch_distributed(
init_method="env", backend="nccl", **asdict(dist_config)
)
# print(f"Rank {self.dist_config.rank}: Initialized")
# if self.dist_config.rank == 0:
# print(asdict(self.dist_config))
@ray.remote(num_gpus=1)
class EncoderWorker(BaseWorker):
def __init__(self, name) -> None:
super().__init__()
# Define the encoder model, at least
self.name = name
self.model = nn.Sequential(nn.Linear(10, 100), nn.Linear(100, 10))
self.tp_plan = {
"0": ColwiseParallel(),
"1": RowwiseParallel(),
}
def init_parallel_strategy(self):
# Apply parallel strategy for model (TP/DP/...)
self.rank = int(os.environ["LOCAL_RANK"])
self.world_size = int(os.environ["LOCAL_WORLD_SIZE"])
torch.cuda.set_device(self.rank)
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(self.world_size,))
parallelize_module(self.model, tp_mesh, self.tp_plan)
def calculate_loss(self, features_1, features_2):
return (features_1 * features_2).sum(1).mean()
def forward(self, input):
print(f"**** FWD {self.name}-{self.rank}")
self.features = self.model(input)
return self.features.detach()
def backward(self, features_1, features_2):
# Calculate some MM loss
print(f"**** BWD {self.name}-{self.rank}")
self.loss = self.calculate_loss(self.features, features_1)
self.loss.backward()
return None
def read_input(self, input):
return input
if __name__ == "__main__":
text_encoders = [EncoderWorker.remote("text") for _ in range(2)]
initialize_dist_group(text_encoders)
vision_encoders = [EncoderWorker.remote("vision") for _ in range(2)]
initialize_dist_group(vision_encoders)
# define a tensor parallel plan for the vision encoder
ray.get([worker.init_parallel_strategy.remote() for worker in text_encoders])
ray.get([worker.init_parallel_strategy.remote() for worker in vision_encoders])
outputs = []
with InputNode() as input_node:
inputs = text_encoders[0].read_input.bind(input_node)
text_activations = [worker.forward.bind(inputs) for worker in text_encoders]
vision_activations = [worker.forward.bind(inputs) for worker in vision_encoders]
for dag_node in text_activations + vision_activations:
dag_node.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL))
text_bwd = [
worker.backward.bind(vision_activations[0], vision_activations[1])
for worker in text_encoders
]
vision_bwd = [
worker.backward.bind(text_activations[0], text_activations[1])
for worker in vision_encoders
]
outputs = vision_bwd + text_bwd
dag = MultiOutputNode(outputs)
dag = dag.experimental_compile()
print(ray.get(dag.execute(torch.randn(5, 10))))
import ray
from ray.air.util.torch_dist import _init_torch_distributed
from ray.air._internal.util import find_free_port
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import defaultdict
from dataclasses import dataclass, asdict
from typing import List
import numpy as np
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
@dataclass
class TorchDistributedConfig:
rank: int
local_rank: int
world_size: int
local_world_size: int
master_addr: str
master_port: str
gpu_ids: List[int]
def initialize_dist_group(workers):
"""Initialize PyTorch Distributed Process Group for a set of workers."""
worker_metadata = ray.get([worker.get_metadata.remote() for worker in workers])
for worker_id, metadata in enumerate(worker_metadata):
metadata["worker_id"] = worker_id
aggregated_metadata = defaultdict(list)
for metadata in worker_metadata:
aggregated_metadata[metadata["address"]].append(metadata)
for metadata_list_per_ip in aggregated_metadata.values():
metadata_list_per_ip.sort(key=lambda x: x["gpu_ids"])
rank = 0
world_size = len(workers)
dist_configs = dict()
for metadata_list_per_ip in aggregated_metadata.values():
local_rank = 0
local_world_size = len(metadata_list_per_ip)
visible_device_ids = []
for metadata in metadata_list_per_ip:
visible_device_ids += metadata["gpu_ids"]
for metadata in metadata_list_per_ip:
if rank == 0:
master_addr = metadata["address"]
master_port = metadata["port"]
worker_id = metadata["worker_id"]
worker_config = TorchDistributedConfig(
rank=rank,
local_rank=local_rank,
world_size=world_size,
local_world_size=local_world_size,
master_addr=master_addr,
master_port=master_port,
gpu_ids=visible_device_ids,
)
rank += 1
local_rank += 1
dist_configs[worker_id] = worker_config
ray.get(
[
worker.init_dist_group.remote(dist_configs[worker_id])
for worker_id, worker in enumerate(workers)
]
)
print("Finished initializing distributed process group.")
class BaseWorker:
def __init__(self) -> None:
pass
def get_metadata(self):
return {
"gpu_ids": ray.get_gpu_ids(),
"address": ray.util.get_node_ip_address(),
"port": find_free_port(),
}
def init_dist_group(self, dist_config):
self.dist_config = dist_config
_init_torch_distributed(
init_method="env", backend="nccl", **asdict(dist_config)
)
# print(f"Rank {self.dist_config.rank}: Initialized")
# if self.dist_config.rank == 0:
# print(asdict(self.dist_config))
@ray.remote(num_gpus=1)
class EncoderWorker(BaseWorker):
def __init__(self, name) -> None:
super().__init__()
# Define the encoder model, at least
self.name = name
self.model = nn.Sequential(nn.Linear(10, 100), nn.Linear(100, 10))
self.tp_plan = {
"0": ColwiseParallel(),
"1": RowwiseParallel(),
}
def init_parallel_strategy(self):
# Apply parallel strategy for model (TP/DP/...)
self.rank = int(os.environ["LOCAL_RANK"])
self.world_size = int(os.environ["LOCAL_WORLD_SIZE"])
torch.cuda.set_device(self.rank)
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(self.world_size,))
parallelize_module(self.model, tp_mesh, self.tp_plan)
def calculate_loss(self, features_1, features_2):
return (features_1 * features_2).sum(1).mean()
def forward(self, input):
print(f"**** FWD {self.name}-{self.rank}")
self.features = self.model(input)
return self.features.detach()
def backward(self, features_1):
# Calculate some MM loss
print(f"**** BWD {self.name}-{self.rank}")
self.loss = self.calculate_loss(self.features, features_1)
self.loss.backward()
return None
def read_input(self, input):
return input
def aggregate_activations(self, *args):
return args[0]
if __name__ == "__main__":
text_encoders = [EncoderWorker.remote("text") for _ in range(2)]
initialize_dist_group(text_encoders)
vision_encoders = [EncoderWorker.remote("vision") for _ in range(2)]
initialize_dist_group(vision_encoders)
# define a tensor parallel plan for the vision encoder
ray.get([worker.init_parallel_strategy.remote() for worker in text_encoders])
ray.get([worker.init_parallel_strategy.remote() for worker in vision_encoders])
outputs = []
with InputNode() as input_node:
inputs = text_encoders[0].read_input.bind(input_node)
text_activations = [worker.forward.bind(inputs) for worker in text_encoders]
vision_activations = [worker.forward.bind(inputs) for worker in vision_encoders]
for i, dag_node in enumerate(text_activations + vision_activations):
if i == 1 or i == 3:
dag_node.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL))
agg_text_activation = text_encoders[0].aggregate_activations.bind(*text_activations)
agg_vision_activation = vision_encoders[0].aggregate_activations.bind(*vision_activations)
agg_text_activation.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL))
agg_vision_activation.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL))
text_bwd = [
worker.backward.bind(agg_vision_activation)
for worker in text_encoders
]
vision_bwd = [
worker.backward.bind(agg_text_activation)
for worker in vision_encoders
]
outputs = vision_bwd + text_bwd
dag = MultiOutputNode(outputs)
dag = dag.experimental_compile()
print(ray.get(dag.execute(torch.randn(5, 10))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment