Skip to content

Instantly share code, notes, and snippets.

@shreyshahi
Last active March 5, 2024 00:28
Show Gist options
  • Select an option

  • Save shreyshahi/3b6227338a916660e2d422cbc975a2d6 to your computer and use it in GitHub Desktop.

Select an option

Save shreyshahi/3b6227338a916660e2d422cbc975a2d6 to your computer and use it in GitHub Desktop.
Simple code to make stable playground v2.5 dream about cats
# Code inspired from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
# slerp function is entirely lifted from the above gist.
import torch
from diffusers import DiffusionPipeline
import numpy as np
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
""" helper function to spherically interpolate two arrays v1 v2 """
inputs_are_torch = False
input_device = None
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2
def main():
pipeline = DiffusionPipeline.from_pretrained(
"playgroundai/playground-v2.5-1024px-aesthetic",
torch_dtype = torch.float16,
use_safetensors=True,
).to("cuda")
prompt = "photograph of a cat high quality"
folder = "cats_playground"
max_frame_number = 5000
frame_number = 0
num_interpolated_frames = 500
quality = 90
latent_shape = (1, 4, 128, 128)
v1 = torch.randn(latent_shape)
while frame_number < max_frame_number:
v2 = torch.randn(latent_shape)
for i in range(num_interpolated_frames):
t = i * 1.0 / (num_interpolated_frames - 1.0)
v = slerp(t, v1, v2)
print(f"Creating and saving frame number {frame_number:06d}")
image = pipeline(prompt, latents = v.half()).images[0]
output_path = f"{folder}/{frame_number:06d}.jpg"
image.save(output_path, quality=quality)
frame_number += 1
v1 = v2
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment