Created
August 22, 2024 21:22
-
-
Save woshiyyya/b0b5fd0666a54e7890e727a647be3cbe to your computer and use it in GitHub Desktop.
ADAG hide the actual method error stack trace, but printing a timeout error
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ray | |
| from ray.air.util.torch_dist import _init_torch_distributed | |
| from ray.air._internal.util import find_free_port | |
| from ray.dag.input_node import InputNode | |
| from ray.dag.output_node import MultiOutputNode | |
| from ray.experimental.channel.torch_tensor_type import TorchTensorType | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from collections import defaultdict | |
| from dataclasses import dataclass, asdict | |
| from typing import List | |
| import numpy as np | |
| from torch.distributed.device_mesh import init_device_mesh | |
| from torch.distributed.tensor.parallel import ( | |
| ColwiseParallel, | |
| RowwiseParallel, | |
| parallelize_module, | |
| ) | |
| from open_clip import get_model_config | |
| from open_clip.model import _build_vision_tower, _build_text_tower | |
| from open_clip.transform import PreprocessCfg, image_transform_v2 | |
| from open_clip.transformer import text_global_pool | |
| from open_clip.loss import ClipLoss | |
| @dataclass | |
| class TorchDistributedConfig: | |
| rank: int | |
| local_rank: int | |
| world_size: int | |
| local_world_size: int | |
| master_addr: str | |
| master_port: str | |
| gpu_ids: List[int] | |
| def initialize_dist_group(workers): | |
| """Initialize PyTorch Distributed Process Group for a set of workers.""" | |
| worker_metadata = ray.get([worker.get_metadata.remote() for worker in workers]) | |
| for worker_id, metadata in enumerate(worker_metadata): | |
| metadata["worker_id"] = worker_id | |
| aggregated_metadata = defaultdict(list) | |
| for metadata in worker_metadata: | |
| aggregated_metadata[metadata["address"]].append(metadata) | |
| for metadata_list_per_ip in aggregated_metadata.values(): | |
| metadata_list_per_ip.sort(key=lambda x: x["gpu_ids"]) | |
| rank = 0 | |
| world_size = len(workers) | |
| dist_configs = dict() | |
| for metadata_list_per_ip in aggregated_metadata.values(): | |
| local_rank = 0 | |
| local_world_size = len(metadata_list_per_ip) | |
| visible_device_ids = [] | |
| for metadata in metadata_list_per_ip: | |
| visible_device_ids += metadata["gpu_ids"] | |
| for metadata in metadata_list_per_ip: | |
| if rank == 0: | |
| master_addr = metadata["address"] | |
| master_port = metadata["port"] | |
| worker_id = metadata["worker_id"] | |
| worker_config = TorchDistributedConfig( | |
| rank=rank, | |
| local_rank=local_rank, | |
| world_size=world_size, | |
| local_world_size=local_world_size, | |
| master_addr=master_addr, | |
| master_port=master_port, | |
| gpu_ids=visible_device_ids, | |
| ) | |
| rank += 1 | |
| local_rank += 1 | |
| dist_configs[worker_id] = worker_config | |
| ray.get( | |
| [ | |
| worker.init_dist_group.remote(dist_configs[worker_id]) | |
| for worker_id, worker in enumerate(workers) | |
| ] | |
| ) | |
| print("Finished initializing distributed process group.") | |
| class BaseWorker: | |
| def __init__(self) -> None: | |
| pass | |
| def get_metadata(self): | |
| return { | |
| "gpu_ids": ray.get_gpu_ids(), | |
| "address": ray.util.get_node_ip_address(), | |
| "port": find_free_port(), | |
| } | |
| def init_dist_group(self, dist_config): | |
| self.dist_config = dist_config | |
| _init_torch_distributed( | |
| init_method="env", backend="nccl", **asdict(dist_config) | |
| ) | |
| # print(f"Rank {self.dist_config.rank}: Initialized") | |
| # if self.dist_config.rank == 0: | |
| # print(asdict(self.dist_config)) | |
| @ray.remote(num_gpus=1) | |
| class VisionEncoder(BaseWorker): | |
| def __init__(self, model_name = 'ViT-L-14') -> None: | |
| super().__init__() | |
| self.model_name = model_name | |
| model_config = get_model_config(model_name) | |
| assert model_config is not None, f"incorrect {model_name}!" | |
| self.model = _build_vision_tower(model_config['embed_dim'], model_config['vision_cfg']) | |
| num_visual_blocks = model_config['vision_cfg']['layers'] | |
| self.tp_plan = { | |
| "transformer": { | |
| # **{f"resblocks.{i}.attn.out_proj": RowwiseParallel() for i in range(num_visual_blocks)}, | |
| **{f"resblocks.{i}.mlp.c_fc": ColwiseParallel() for i in range(num_visual_blocks)}, | |
| **{f"resblocks.{i}.mlp.c_proj": RowwiseParallel() for i in range(num_visual_blocks)}, | |
| } | |
| } | |
| self.clip_loss_fn = ClipLoss(cache_labels=True) | |
| self.init_logit_scale = np.log(1 / 0.07) | |
| self.device_set = False | |
| def init_parallel_strategy(self): | |
| # Apply parallel strategy for model (TP/DP/...) | |
| self.rank = int(os.environ["LOCAL_RANK"]) | |
| self.world_size = int(os.environ["LOCAL_WORLD_SIZE"]) | |
| torch.cuda.set_device(self.rank) | |
| tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(self.world_size,)) | |
| parallelize_module(self.model, tp_mesh, self.tp_plan) | |
| self.model.to("cuda") | |
| # print(f"Vision: rank: {self.rank} | device: {torch.cuda.current_device()}") | |
| def encode_image(self, images, normalize: bool = True): | |
| features = self.model(images) | |
| return F.normalize(features, dim=-1) if normalize else features | |
| def forward(self, inputs): | |
| images, text = inputs | |
| print(f"**** FWD VISION-{self.rank}") | |
| if not self.device_set: | |
| torch.cuda.set_device(self.rank) | |
| self.device_set = True | |
| self.vision_features = self.encode_image(images.cuda()) | |
| return self.vision_features.detach() | |
| def backward(self, text_features, dummy_features=None): | |
| # Calculate some MM loss | |
| print(f"**** BWD VISION-{self.rank}") | |
| # text_features, logit_scale = features[0] | |
| # text_features = torch.ones((4)) | |
| self.loss = self.clip_loss_fn(self.vision_features, text_features, self.init_logit_scale) | |
| self.loss.backward() | |
| return None | |
| # def read_input(self, input): | |
| # image_input, text_input = input | |
| # return image_input, text_input | |
| @ray.remote(num_gpus=1) | |
| class TextEncoder(BaseWorker): | |
| def __init__(self, model_name = 'ViT-L-14') -> None: | |
| super().__init__() | |
| self.model_name = model_name | |
| model_config = get_model_config(model_name) | |
| assert model_config is not None, f"incorrect {model_name}!" | |
| self.model = _build_text_tower(model_config['embed_dim'], model_config['text_cfg']) | |
| num_text_blocks = model_config['text_cfg']['layers'] | |
| self.tp_plan = { | |
| "transformer": { | |
| # **{f"resblocks.{i}.attn.out_proj": RowwiseParallel() for i in range(num_text_blocks)}, | |
| **{f"resblocks.{i}.mlp.c_fc": ColwiseParallel() for i in range(num_text_blocks)}, | |
| **{f"resblocks.{i}.mlp.c_proj": RowwiseParallel() for i in range(num_text_blocks)}, | |
| } | |
| } | |
| self.init_logit_scale = np.log(1 / 0.07) | |
| # self.logit_scale = nn.Parameter(torch.ones([]) * self.init_logit_scale).cuda() | |
| self.clip_loss_fn = ClipLoss(cache_labels=True) | |
| self.device_set = False | |
| def init_parallel_strategy(self): | |
| # Apply parallel strategy for model (TP/DP/...) | |
| self.rank = int(os.environ["LOCAL_RANK"]) | |
| self.world_size = int(os.environ["LOCAL_WORLD_SIZE"]) | |
| torch.cuda.set_device(self.rank) | |
| tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(self.world_size,)) | |
| parallelize_module(self.model, tp_mesh, self.tp_plan) | |
| self.model.to("cuda") | |
| def encode_text(self, text, normalize: bool = True): | |
| cast_dtype = self.model.transformer.get_cast_dtype() | |
| x = self.model.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] | |
| x = x + self.model.positional_embedding.to(cast_dtype) | |
| x = self.model.transformer(x, attn_mask=self.model.attn_mask) | |
| x = self.model.ln_final(x) # [batch_size, n_ctx, transformer.width] | |
| x, _ = text_global_pool(x, text, self.model.pool_type) | |
| if self.model.text_projection is not None: | |
| if isinstance(self.model.text_projection, nn.Linear): | |
| x = self.model.text_projection(x) | |
| else: | |
| x = x @ self.model.text_projection | |
| return F.normalize(x, dim=-1) if normalize else x | |
| return x | |
| def forward(self, inputs): | |
| images, text = inputs | |
| print(f"**** FWD TEXT-{self.rank}") | |
| if not self.device_set: | |
| torch.cuda.set_device(self.rank) | |
| self.device_set = True | |
| # Move text to cuda to avoid timeout error | |
| # self.text_features = self.encode_text(text.cuda()) | |
| # Reproduce timeout | |
| self.text_features = self.encode_text(text) | |
| return self.text_features.detach() # , self.logit_scale.detach() | |
| def backward(self, vision_features, dummy_features=None): | |
| # Calculate some MM loss | |
| print(f"**** BWD TEXT-{self.rank}") | |
| self.loss = self.clip_loss_fn(vision_features, self.text_features, self.init_logit_scale) | |
| self.loss.backward() | |
| return None | |
| def read_input(self, input): | |
| return input | |
| if __name__ == "__main__": | |
| N = 1 | |
| for i in range(N): | |
| print(f"Run ({i})...") | |
| text_encoders = [TextEncoder.remote() for _ in range(2)] | |
| initialize_dist_group(text_encoders) | |
| vision_encoders = [VisionEncoder.remote() for _ in range(2)] | |
| initialize_dist_group(vision_encoders) | |
| # define a tensor parallel plan for the vision encoder | |
| ray.get([worker.init_parallel_strategy.remote() for worker in text_encoders]) | |
| ray.get([worker.init_parallel_strategy.remote() for worker in vision_encoders]) | |
| outputs = [] | |
| with InputNode() as input_node: | |
| # dummy read similar to https://github.com/ray-project/ray/issues/47041 | |
| inputs = text_encoders[0].read_input.bind(input_node) | |
| text_activations = [worker.forward.bind(inputs) for worker in text_encoders] | |
| vision_activations = [worker.forward.bind(inputs) for worker in vision_encoders] | |
| text_activations[0].with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) | |
| vision_activations[0].with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) | |
| text_bwd = [ | |
| text_encoders[0].backward.bind(vision_activations[0]), | |
| text_encoders[1].backward.bind(vision_activations[0], text_activations[1]), | |
| ] | |
| vision_bwd = [ | |
| vision_encoders[0].backward.bind(text_activations[0]), | |
| vision_encoders[1].backward.bind(text_activations[0], vision_activations[1]), | |
| ] | |
| outputs = vision_bwd + text_bwd | |
| dag = MultiOutputNode(outputs) | |
| dag = dag.experimental_compile() | |
| rand_images = torch.randn(4, 3, 224, 224) | |
| rand_text = torch.randint(0, 49408, (4, 77)) | |
| print(ray.get(dag.execute((rand_images, rand_text)))) | |
| print("Done...") | |
| dag.teardown() | |
| for actor in vision_encoders + text_encoders: | |
| ray.kill(actor) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment