Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Created August 22, 2024 21:22
Show Gist options
  • Select an option

  • Save woshiyyya/b0b5fd0666a54e7890e727a647be3cbe to your computer and use it in GitHub Desktop.

Select an option

Save woshiyyya/b0b5fd0666a54e7890e727a647be3cbe to your computer and use it in GitHub Desktop.
ADAG hide the actual method error stack trace, but printing a timeout error
import ray
from ray.air.util.torch_dist import _init_torch_distributed
from ray.air._internal.util import find_free_port
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import defaultdict
from dataclasses import dataclass, asdict
from typing import List
import numpy as np
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
from open_clip import get_model_config
from open_clip.model import _build_vision_tower, _build_text_tower
from open_clip.transform import PreprocessCfg, image_transform_v2
from open_clip.transformer import text_global_pool
from open_clip.loss import ClipLoss
@dataclass
class TorchDistributedConfig:
rank: int
local_rank: int
world_size: int
local_world_size: int
master_addr: str
master_port: str
gpu_ids: List[int]
def initialize_dist_group(workers):
"""Initialize PyTorch Distributed Process Group for a set of workers."""
worker_metadata = ray.get([worker.get_metadata.remote() for worker in workers])
for worker_id, metadata in enumerate(worker_metadata):
metadata["worker_id"] = worker_id
aggregated_metadata = defaultdict(list)
for metadata in worker_metadata:
aggregated_metadata[metadata["address"]].append(metadata)
for metadata_list_per_ip in aggregated_metadata.values():
metadata_list_per_ip.sort(key=lambda x: x["gpu_ids"])
rank = 0
world_size = len(workers)
dist_configs = dict()
for metadata_list_per_ip in aggregated_metadata.values():
local_rank = 0
local_world_size = len(metadata_list_per_ip)
visible_device_ids = []
for metadata in metadata_list_per_ip:
visible_device_ids += metadata["gpu_ids"]
for metadata in metadata_list_per_ip:
if rank == 0:
master_addr = metadata["address"]
master_port = metadata["port"]
worker_id = metadata["worker_id"]
worker_config = TorchDistributedConfig(
rank=rank,
local_rank=local_rank,
world_size=world_size,
local_world_size=local_world_size,
master_addr=master_addr,
master_port=master_port,
gpu_ids=visible_device_ids,
)
rank += 1
local_rank += 1
dist_configs[worker_id] = worker_config
ray.get(
[
worker.init_dist_group.remote(dist_configs[worker_id])
for worker_id, worker in enumerate(workers)
]
)
print("Finished initializing distributed process group.")
class BaseWorker:
def __init__(self) -> None:
pass
def get_metadata(self):
return {
"gpu_ids": ray.get_gpu_ids(),
"address": ray.util.get_node_ip_address(),
"port": find_free_port(),
}
def init_dist_group(self, dist_config):
self.dist_config = dist_config
_init_torch_distributed(
init_method="env", backend="nccl", **asdict(dist_config)
)
# print(f"Rank {self.dist_config.rank}: Initialized")
# if self.dist_config.rank == 0:
# print(asdict(self.dist_config))
@ray.remote(num_gpus=1)
class VisionEncoder(BaseWorker):
def __init__(self, model_name = 'ViT-L-14') -> None:
super().__init__()
self.model_name = model_name
model_config = get_model_config(model_name)
assert model_config is not None, f"incorrect {model_name}!"
self.model = _build_vision_tower(model_config['embed_dim'], model_config['vision_cfg'])
num_visual_blocks = model_config['vision_cfg']['layers']
self.tp_plan = {
"transformer": {
# **{f"resblocks.{i}.attn.out_proj": RowwiseParallel() for i in range(num_visual_blocks)},
**{f"resblocks.{i}.mlp.c_fc": ColwiseParallel() for i in range(num_visual_blocks)},
**{f"resblocks.{i}.mlp.c_proj": RowwiseParallel() for i in range(num_visual_blocks)},
}
}
self.clip_loss_fn = ClipLoss(cache_labels=True)
self.init_logit_scale = np.log(1 / 0.07)
self.device_set = False
def init_parallel_strategy(self):
# Apply parallel strategy for model (TP/DP/...)
self.rank = int(os.environ["LOCAL_RANK"])
self.world_size = int(os.environ["LOCAL_WORLD_SIZE"])
torch.cuda.set_device(self.rank)
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(self.world_size,))
parallelize_module(self.model, tp_mesh, self.tp_plan)
self.model.to("cuda")
# print(f"Vision: rank: {self.rank} | device: {torch.cuda.current_device()}")
def encode_image(self, images, normalize: bool = True):
features = self.model(images)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, inputs):
images, text = inputs
print(f"**** FWD VISION-{self.rank}")
if not self.device_set:
torch.cuda.set_device(self.rank)
self.device_set = True
self.vision_features = self.encode_image(images.cuda())
return self.vision_features.detach()
def backward(self, text_features, dummy_features=None):
# Calculate some MM loss
print(f"**** BWD VISION-{self.rank}")
# text_features, logit_scale = features[0]
# text_features = torch.ones((4))
self.loss = self.clip_loss_fn(self.vision_features, text_features, self.init_logit_scale)
self.loss.backward()
return None
# def read_input(self, input):
# image_input, text_input = input
# return image_input, text_input
@ray.remote(num_gpus=1)
class TextEncoder(BaseWorker):
def __init__(self, model_name = 'ViT-L-14') -> None:
super().__init__()
self.model_name = model_name
model_config = get_model_config(model_name)
assert model_config is not None, f"incorrect {model_name}!"
self.model = _build_text_tower(model_config['embed_dim'], model_config['text_cfg'])
num_text_blocks = model_config['text_cfg']['layers']
self.tp_plan = {
"transformer": {
# **{f"resblocks.{i}.attn.out_proj": RowwiseParallel() for i in range(num_text_blocks)},
**{f"resblocks.{i}.mlp.c_fc": ColwiseParallel() for i in range(num_text_blocks)},
**{f"resblocks.{i}.mlp.c_proj": RowwiseParallel() for i in range(num_text_blocks)},
}
}
self.init_logit_scale = np.log(1 / 0.07)
# self.logit_scale = nn.Parameter(torch.ones([]) * self.init_logit_scale).cuda()
self.clip_loss_fn = ClipLoss(cache_labels=True)
self.device_set = False
def init_parallel_strategy(self):
# Apply parallel strategy for model (TP/DP/...)
self.rank = int(os.environ["LOCAL_RANK"])
self.world_size = int(os.environ["LOCAL_WORLD_SIZE"])
torch.cuda.set_device(self.rank)
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(self.world_size,))
parallelize_module(self.model, tp_mesh, self.tp_plan)
self.model.to("cuda")
def encode_text(self, text, normalize: bool = True):
cast_dtype = self.model.transformer.get_cast_dtype()
x = self.model.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding.to(cast_dtype)
x = self.model.transformer(x, attn_mask=self.model.attn_mask)
x = self.model.ln_final(x) # [batch_size, n_ctx, transformer.width]
x, _ = text_global_pool(x, text, self.model.pool_type)
if self.model.text_projection is not None:
if isinstance(self.model.text_projection, nn.Linear):
x = self.model.text_projection(x)
else:
x = x @ self.model.text_projection
return F.normalize(x, dim=-1) if normalize else x
return x
def forward(self, inputs):
images, text = inputs
print(f"**** FWD TEXT-{self.rank}")
if not self.device_set:
torch.cuda.set_device(self.rank)
self.device_set = True
# Move text to cuda to avoid timeout error
# self.text_features = self.encode_text(text.cuda())
# Reproduce timeout
self.text_features = self.encode_text(text)
return self.text_features.detach() # , self.logit_scale.detach()
def backward(self, vision_features, dummy_features=None):
# Calculate some MM loss
print(f"**** BWD TEXT-{self.rank}")
self.loss = self.clip_loss_fn(vision_features, self.text_features, self.init_logit_scale)
self.loss.backward()
return None
def read_input(self, input):
return input
if __name__ == "__main__":
N = 1
for i in range(N):
print(f"Run ({i})...")
text_encoders = [TextEncoder.remote() for _ in range(2)]
initialize_dist_group(text_encoders)
vision_encoders = [VisionEncoder.remote() for _ in range(2)]
initialize_dist_group(vision_encoders)
# define a tensor parallel plan for the vision encoder
ray.get([worker.init_parallel_strategy.remote() for worker in text_encoders])
ray.get([worker.init_parallel_strategy.remote() for worker in vision_encoders])
outputs = []
with InputNode() as input_node:
# dummy read similar to https://github.com/ray-project/ray/issues/47041
inputs = text_encoders[0].read_input.bind(input_node)
text_activations = [worker.forward.bind(inputs) for worker in text_encoders]
vision_activations = [worker.forward.bind(inputs) for worker in vision_encoders]
text_activations[0].with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL))
vision_activations[0].with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL))
text_bwd = [
text_encoders[0].backward.bind(vision_activations[0]),
text_encoders[1].backward.bind(vision_activations[0], text_activations[1]),
]
vision_bwd = [
vision_encoders[0].backward.bind(text_activations[0]),
vision_encoders[1].backward.bind(text_activations[0], vision_activations[1]),
]
outputs = vision_bwd + text_bwd
dag = MultiOutputNode(outputs)
dag = dag.experimental_compile()
rand_images = torch.randn(4, 3, 224, 224)
rand_text = torch.randint(0, 49408, (4, 77))
print(ray.get(dag.execute((rand_images, rand_text))))
print("Done...")
dag.teardown()
for actor in vision_encoders + text_encoders:
ray.kill(actor)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment