Created
August 19, 2024 23:54
-
-
Save woshiyyya/db31952b94e0216bc93ee3ef30ef32e6 to your computer and use it in GitHub Desktop.
DistMM DAG Timeout Failure
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ray | |
| from ray.air.util.torch_dist import _init_torch_distributed | |
| from ray.air._internal.util import find_free_port | |
| from ray.dag.input_node import InputNode | |
| from ray.dag.output_node import MultiOutputNode | |
| from ray.experimental.channel.torch_tensor_type import TorchTensorType | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from collections import defaultdict | |
| from dataclasses import dataclass, asdict | |
| from typing import List | |
| import numpy as np | |
| from torch.distributed.device_mesh import init_device_mesh | |
| from torch.distributed.tensor.parallel import ( | |
| ColwiseParallel, | |
| RowwiseParallel, | |
| parallelize_module, | |
| ) | |
| @dataclass | |
| class TorchDistributedConfig: | |
| rank: int | |
| local_rank: int | |
| world_size: int | |
| local_world_size: int | |
| master_addr: str | |
| master_port: str | |
| gpu_ids: List[int] | |
| def initialize_dist_group(workers): | |
| """Initialize PyTorch Distributed Process Group for a set of workers.""" | |
| worker_metadata = ray.get([worker.get_metadata.remote() for worker in workers]) | |
| for worker_id, metadata in enumerate(worker_metadata): | |
| metadata["worker_id"] = worker_id | |
| aggregated_metadata = defaultdict(list) | |
| for metadata in worker_metadata: | |
| aggregated_metadata[metadata["address"]].append(metadata) | |
| for metadata_list_per_ip in aggregated_metadata.values(): | |
| metadata_list_per_ip.sort(key=lambda x: x["gpu_ids"]) | |
| rank = 0 | |
| world_size = len(workers) | |
| dist_configs = dict() | |
| for metadata_list_per_ip in aggregated_metadata.values(): | |
| local_rank = 0 | |
| local_world_size = len(metadata_list_per_ip) | |
| visible_device_ids = [] | |
| for metadata in metadata_list_per_ip: | |
| visible_device_ids += metadata["gpu_ids"] | |
| for metadata in metadata_list_per_ip: | |
| if rank == 0: | |
| master_addr = metadata["address"] | |
| master_port = metadata["port"] | |
| worker_id = metadata["worker_id"] | |
| worker_config = TorchDistributedConfig( | |
| rank=rank, | |
| local_rank=local_rank, | |
| world_size=world_size, | |
| local_world_size=local_world_size, | |
| master_addr=master_addr, | |
| master_port=master_port, | |
| gpu_ids=visible_device_ids, | |
| ) | |
| rank += 1 | |
| local_rank += 1 | |
| dist_configs[worker_id] = worker_config | |
| ray.get( | |
| [ | |
| worker.init_dist_group.remote(dist_configs[worker_id]) | |
| for worker_id, worker in enumerate(workers) | |
| ] | |
| ) | |
| print("Finished initializing distributed process group.") | |
| class BaseWorker: | |
| def __init__(self) -> None: | |
| pass | |
| def get_metadata(self): | |
| return { | |
| "gpu_ids": ray.get_gpu_ids(), | |
| "address": ray.util.get_node_ip_address(), | |
| "port": find_free_port(), | |
| } | |
| def init_dist_group(self, dist_config): | |
| self.dist_config = dist_config | |
| _init_torch_distributed( | |
| init_method="env", backend="nccl", **asdict(dist_config) | |
| ) | |
| # print(f"Rank {self.dist_config.rank}: Initialized") | |
| # if self.dist_config.rank == 0: | |
| # print(asdict(self.dist_config)) | |
| @ray.remote(num_gpus=1) | |
| class EncoderWorker(BaseWorker): | |
| def __init__(self, name) -> None: | |
| super().__init__() | |
| # Define the encoder model, at least | |
| self.name = name | |
| self.model = nn.Sequential(nn.Linear(10, 100), nn.Linear(100, 10)) | |
| self.tp_plan = { | |
| "0": ColwiseParallel(), | |
| "1": RowwiseParallel(), | |
| } | |
| def init_parallel_strategy(self): | |
| # Apply parallel strategy for model (TP/DP/...) | |
| self.rank = int(os.environ["LOCAL_RANK"]) | |
| self.world_size = int(os.environ["LOCAL_WORLD_SIZE"]) | |
| torch.cuda.set_device(self.rank) | |
| tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(self.world_size,)) | |
| parallelize_module(self.model, tp_mesh, self.tp_plan) | |
| def calculate_loss(self, features_1, features_2): | |
| return (features_1 * features_2).sum(1).mean() | |
| def forward(self, input): | |
| print(f"**** FWD {self.name}-{self.rank}") | |
| self.features = self.model(input) | |
| return self.features.detach() | |
| def backward(self, features_1, features_2): | |
| # Calculate some MM loss | |
| print(f"**** BWD {self.name}-{self.rank}") | |
| self.loss = self.calculate_loss(self.features, features_1) | |
| self.loss.backward() | |
| return None | |
| def read_input(self, input): | |
| return input | |
| if __name__ == "__main__": | |
| text_encoders = [EncoderWorker.remote("text") for _ in range(2)] | |
| initialize_dist_group(text_encoders) | |
| vision_encoders = [EncoderWorker.remote("vision") for _ in range(2)] | |
| initialize_dist_group(vision_encoders) | |
| # define a tensor parallel plan for the vision encoder | |
| ray.get([worker.init_parallel_strategy.remote() for worker in text_encoders]) | |
| ray.get([worker.init_parallel_strategy.remote() for worker in vision_encoders]) | |
| outputs = [] | |
| with InputNode() as input_node: | |
| inputs = text_encoders[0].read_input.bind(input_node) | |
| text_activations = [worker.forward.bind(inputs) for worker in text_encoders] | |
| vision_activations = [worker.forward.bind(inputs) for worker in vision_encoders] | |
| for dag_node in text_activations + vision_activations: | |
| dag_node.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) | |
| text_bwd = [ | |
| worker.backward.bind(vision_activations[0], vision_activations[1]) | |
| for worker in text_encoders | |
| ] | |
| vision_bwd = [ | |
| worker.backward.bind(text_activations[0], text_activations[1]) | |
| for worker in vision_encoders | |
| ] | |
| outputs = vision_bwd + text_bwd | |
| dag = MultiOutputNode(outputs) | |
| dag = dag.experimental_compile() | |
| print(ray.get(dag.execute(torch.randn(5, 10)))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ray | |
| from ray.air.util.torch_dist import _init_torch_distributed | |
| from ray.air._internal.util import find_free_port | |
| from ray.dag.input_node import InputNode | |
| from ray.dag.output_node import MultiOutputNode | |
| from ray.experimental.channel.torch_tensor_type import TorchTensorType | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from collections import defaultdict | |
| from dataclasses import dataclass, asdict | |
| from typing import List | |
| import numpy as np | |
| from torch.distributed.device_mesh import init_device_mesh | |
| from torch.distributed.tensor.parallel import ( | |
| ColwiseParallel, | |
| RowwiseParallel, | |
| parallelize_module, | |
| ) | |
| @dataclass | |
| class TorchDistributedConfig: | |
| rank: int | |
| local_rank: int | |
| world_size: int | |
| local_world_size: int | |
| master_addr: str | |
| master_port: str | |
| gpu_ids: List[int] | |
| def initialize_dist_group(workers): | |
| """Initialize PyTorch Distributed Process Group for a set of workers.""" | |
| worker_metadata = ray.get([worker.get_metadata.remote() for worker in workers]) | |
| for worker_id, metadata in enumerate(worker_metadata): | |
| metadata["worker_id"] = worker_id | |
| aggregated_metadata = defaultdict(list) | |
| for metadata in worker_metadata: | |
| aggregated_metadata[metadata["address"]].append(metadata) | |
| for metadata_list_per_ip in aggregated_metadata.values(): | |
| metadata_list_per_ip.sort(key=lambda x: x["gpu_ids"]) | |
| rank = 0 | |
| world_size = len(workers) | |
| dist_configs = dict() | |
| for metadata_list_per_ip in aggregated_metadata.values(): | |
| local_rank = 0 | |
| local_world_size = len(metadata_list_per_ip) | |
| visible_device_ids = [] | |
| for metadata in metadata_list_per_ip: | |
| visible_device_ids += metadata["gpu_ids"] | |
| for metadata in metadata_list_per_ip: | |
| if rank == 0: | |
| master_addr = metadata["address"] | |
| master_port = metadata["port"] | |
| worker_id = metadata["worker_id"] | |
| worker_config = TorchDistributedConfig( | |
| rank=rank, | |
| local_rank=local_rank, | |
| world_size=world_size, | |
| local_world_size=local_world_size, | |
| master_addr=master_addr, | |
| master_port=master_port, | |
| gpu_ids=visible_device_ids, | |
| ) | |
| rank += 1 | |
| local_rank += 1 | |
| dist_configs[worker_id] = worker_config | |
| ray.get( | |
| [ | |
| worker.init_dist_group.remote(dist_configs[worker_id]) | |
| for worker_id, worker in enumerate(workers) | |
| ] | |
| ) | |
| print("Finished initializing distributed process group.") | |
| class BaseWorker: | |
| def __init__(self) -> None: | |
| pass | |
| def get_metadata(self): | |
| return { | |
| "gpu_ids": ray.get_gpu_ids(), | |
| "address": ray.util.get_node_ip_address(), | |
| "port": find_free_port(), | |
| } | |
| def init_dist_group(self, dist_config): | |
| self.dist_config = dist_config | |
| _init_torch_distributed( | |
| init_method="env", backend="nccl", **asdict(dist_config) | |
| ) | |
| # print(f"Rank {self.dist_config.rank}: Initialized") | |
| # if self.dist_config.rank == 0: | |
| # print(asdict(self.dist_config)) | |
| @ray.remote(num_gpus=1) | |
| class EncoderWorker(BaseWorker): | |
| def __init__(self, name) -> None: | |
| super().__init__() | |
| # Define the encoder model, at least | |
| self.name = name | |
| self.model = nn.Sequential(nn.Linear(10, 100), nn.Linear(100, 10)) | |
| self.tp_plan = { | |
| "0": ColwiseParallel(), | |
| "1": RowwiseParallel(), | |
| } | |
| def init_parallel_strategy(self): | |
| # Apply parallel strategy for model (TP/DP/...) | |
| self.rank = int(os.environ["LOCAL_RANK"]) | |
| self.world_size = int(os.environ["LOCAL_WORLD_SIZE"]) | |
| torch.cuda.set_device(self.rank) | |
| tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(self.world_size,)) | |
| parallelize_module(self.model, tp_mesh, self.tp_plan) | |
| def calculate_loss(self, features_1, features_2): | |
| return (features_1 * features_2).sum(1).mean() | |
| def forward(self, input): | |
| print(f"**** FWD {self.name}-{self.rank}") | |
| self.features = self.model(input) | |
| return self.features.detach() | |
| def backward(self, features_1): | |
| # Calculate some MM loss | |
| print(f"**** BWD {self.name}-{self.rank}") | |
| self.loss = self.calculate_loss(self.features, features_1) | |
| self.loss.backward() | |
| return None | |
| def read_input(self, input): | |
| return input | |
| def aggregate_activations(self, *args): | |
| return args[0] | |
| if __name__ == "__main__": | |
| text_encoders = [EncoderWorker.remote("text") for _ in range(2)] | |
| initialize_dist_group(text_encoders) | |
| vision_encoders = [EncoderWorker.remote("vision") for _ in range(2)] | |
| initialize_dist_group(vision_encoders) | |
| # define a tensor parallel plan for the vision encoder | |
| ray.get([worker.init_parallel_strategy.remote() for worker in text_encoders]) | |
| ray.get([worker.init_parallel_strategy.remote() for worker in vision_encoders]) | |
| outputs = [] | |
| with InputNode() as input_node: | |
| inputs = text_encoders[0].read_input.bind(input_node) | |
| text_activations = [worker.forward.bind(inputs) for worker in text_encoders] | |
| vision_activations = [worker.forward.bind(inputs) for worker in vision_encoders] | |
| for i, dag_node in enumerate(text_activations + vision_activations): | |
| if i == 1 or i == 3: | |
| dag_node.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) | |
| agg_text_activation = text_encoders[0].aggregate_activations.bind(*text_activations) | |
| agg_vision_activation = vision_encoders[0].aggregate_activations.bind(*vision_activations) | |
| agg_text_activation.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) | |
| agg_vision_activation.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) | |
| text_bwd = [ | |
| worker.backward.bind(agg_vision_activation) | |
| for worker in text_encoders | |
| ] | |
| vision_bwd = [ | |
| worker.backward.bind(agg_text_activation) | |
| for worker in vision_encoders | |
| ] | |
| outputs = vision_bwd + text_bwd | |
| dag = MultiOutputNode(outputs) | |
| dag = dag.experimental_compile() | |
| print(ray.get(dag.execute(torch.randn(5, 10)))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment