Skip to content

Instantly share code, notes, and snippets.

@rahulunair
Last active July 13, 2024 03:44
Show Gist options
  • Select an option

  • Save rahulunair/e9add016f78fc3a5f4c9d3bd5155054d to your computer and use it in GitHub Desktop.

Select an option

Save rahulunair/e9add016f78fc3a5f4c9d3bd5155054d to your computer and use it in GitHub Desktop.
import logging
import warnings
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.models import resnet50
import intel_extension_for_pytorch as ipex
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline
from diffusers import DiffusionPipeline
from sentence_transformers import SentenceTransformer
# Setup logging and warnings
logging.basicConfig(level=logging.ERROR)
st_logger = logging.getLogger("sentence_transformers")
st_logger.setLevel(logging.ERROR)
warnings.filterwarnings("ignore")
def parse_arguments():
parser = argparse.ArgumentParser(
description="Benchmark deep learning workloads on Intel XPUs."
)
parser.add_argument(
"--iterations",
type=int,
default=5,
help="Number of benchmark iterations to run",
)
parser.add_argument(
"--epochs",
type=int,
default=1,
help="Number of training epochs for fine-tuning",
)
parser.add_argument(
"--use_accelerate",
action="store_true",
help="Enable multi-XPU support with Accelerate.",
)
return parser.parse_args()
def device_selection(args):
accelerator = Accelerator()
device = accelerator.device if args.use_accelerate else "xpu:0"
print(f"Using device: {device}")
return device
def run_benchmark(task_func, *args):
times = []
for _ in range(args[-1]):
start_time = time.time()
task_func(*args[:-1])
times.append(time.time() - start_time)
return np.mean(times), np.std(times)
def compute_transformer_embeddings(model, tokenizer, sentences, device):
encoded_input = tokenizer(
sentences, padding=True, truncation=True, return_tensors="pt"
).to(device)
with torch.inference_mode(), torch.xpu.amp.autocast(
enabled=True, dtype=torch.bfloat16
):
output = model(**encoded_input)[0][:, 0]
embeddings = torch.nn.functional.normalize(output, p=2, dim=1)
return embeddings
def compute_st_embeddings(st_model, tokenizer, sentences, device):
with torch.inference_mode():
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
embeddings = st_model.encode(
sentences,
convert_to_tensor=True,
device=device,
normalize_embeddings=True,
)
return embeddings
def infer_resnet(model, dataloader, device):
model.eval()
with torch.inference_mode(), torch.xpu.amp.autocast(
enabled=True, dtype=torch.bfloat16
):
for images, _ in dataloader:
output = model(images.to(device))
return output
def train_resnet(model, dataloader, criterion, optimizer, device, epochs):
t1 = time.time()
for epoch in range(epochs):
model.train()
for images, targets in dataloader:
images, targets = images.to(device), targets.to(device)
optimizer.zero_grad()
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
outputs = model(images)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
return time.time() - t1
def setup_resnet(device, batch_size=256, pretrained=True):
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
model = resnet50(pretrained=pretrained)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model, optimizer = ipex.optimize(
model.to(device), optimizer=optimizer, dtype=torch.bfloat16
)
return model, dataloader, criterion, optimizer
def setup_text_generation(device):
torch.set_default_device(str(device))
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2", torch_dtype=torch.bfloat16, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
model = ipex.llm.optimize(model, dtype=torch.bfloat16)
return model.to(device), tokenizer
def generate_text(model, tokenizer, device, prompt, max_length):
inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
start_time = time.time()
outputs = model.generate(**inputs, max_length=max_length)
generation_time = time.time() - start_time
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return text, generation_time
def setup_diffusion_model(device):
model_id = "CompVis/stable-diffusion-v1-4"
pipe = DiffusionPipeline.from_pretrained(
model_id, variant="fp16", torch_dtype=torch.bfloat16
).to(device)
return pipe
def generate_images(pipe, prompt, num_images):
start_time = time.time()
for _ in range(num_images):
pipe(prompt)
return time.time() - start_time
def main():
args = parse_arguments()
device = device_selection(args)
iterations = args.iterations
epochs = args.epochs
# Setup for embeddings
embedding_model = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(embedding_model)
transformer_model = AutoModel.from_pretrained(embedding_model).to(device)
st_model = SentenceTransformer(embedding_model).to(device)
sentences = ["Example sentence for embedding."] * 10
# Setup ResNet
resnet_model, resnet_dataloader, criterion, optimizer = setup_resnet(device)
# genai
phi2_model, phi2_tokenizer = setup_text_generation(device)
#diffusion_pipe = setup_diffusion_model(device)
# Benchmarking
prompt_text = '''def print_prime(n):
"""
Print all primes between 1 and n
"""'''
generated_text, text_gen_time = generate_text(
phi2_model, phi2_tokenizer, device, prompt_text, max_length=64
)
print(f"Text Generation Time for 64 tokens: {text_gen_time:.2f} seconds")
#image_gen_time = generate_images(diffusion_pipe, "A futuristic cityscape", 10)
#print(f"Image Generation Time for 10 images: {image_gen_time:.2f} seconds")
transformer_time, _ = run_benchmark(
compute_transformer_embeddings,
transformer_model,
tokenizer,
sentences,
device,
iterations,
)
print(f"Transformer Embeddings Time: {transformer_time:.2f} seconds")
st_time, _ = run_benchmark(
compute_st_embeddings, st_model, tokenizer, sentences, device, iterations
)
print(f"Sentence Transformer Embeddings Time: {st_time:.2f} seconds")
resnet_inference_time, _ = run_benchmark(
infer_resnet, resnet_model, resnet_dataloader, device, iterations
)
print(f"ResNet Inference Time: {resnet_inference_time:.2f} seconds")
training_time_resnet = train_resnet(
resnet_model, resnet_dataloader, criterion, optimizer, device, epochs
) # Training for specified epochs
print(f"Completed ResNet Training for {epochs} epochs in {training_time_resnet:.2f} seconds")
if __name__ == "__main__":
main()
@rahulunair
Copy link
Author

pvc 1100 gpu vs pvc 1100 bm

PVC 1100 VM

  • Text Generation Time for 64 tokens: 2.24 seconds
  • Transformer Embeddings Time: 0.07 seconds
  • Sentence Transformer Embeddings Time: 0.04 seconds
  • ResNet50 Inference Time (cifar): 41.58 seconds
  • ResNet50 finetuning (cifar) for 1 epochs in 76.19 seconds

pvc 1100 baremetal

  • Text Generation Time for 64 tokens: 2.14 seconds
  • Transformer Embeddings Time: 0.04 seconds
  • Sentence Transformer Embeddings Time: 0.02 seconds
  • ResNet50 Inference Time (cifar): 43.14 seconds
  • ResNet50 finetuning (cifar) for 1 epochs in 72.11 seconds

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment