Skip to content

Instantly share code, notes, and snippets.

@tanukon
Last active September 28, 2024 15:25
Show Gist options
  • Select an option

  • Save tanukon/a4c894ec72d952e93fde9bd2cd56da37 to your computer and use it in GitHub Desktop.

Select an option

Save tanukon/a4c894ec72d952e93fde9bd2cd56da37 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/envs/InstantStyle/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import cv2\n",
"import os\n",
"import torch\n",
"\n",
"from diffusers import StableDiffusionXLPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline\n",
"from PIL import Image\n",
"from ip_adapter import IPAdapterXL"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Basic InstantStyle Usage: Style Change "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# the model path definition\n",
"base_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
"image_encoder_path = \"sdxl_models/image_encoder\"\n",
"ip_ckpt = \"sdxl_models/ip-adapter_sdxl.bin\"\n",
"device = \"cuda\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading pipeline components...: 100%|██████████| 7/7 [00:52<00:00, 7.53s/it]\n"
]
}
],
"source": [
"# load SDXL pipeline\n",
"pipe = StableDiffusionXLPipeline.from_pretrained(\n",
" base_model_path,\n",
" torch_dtype=torch.float16,\n",
" add_watermarker=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# reduce memory consumption\n",
"pipe.enable_vae_tiling()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/InstantStyle/ip_adapter/ip_adapter.py:149: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load(self.ip_ckpt, map_location=\"cpu\")\n"
]
}
],
"source": [
"# load ip-adapter\n",
"# target_blocks=[\"block\"] for original IP-Adapter\n",
"# target_blocks=[\"up_blocks.0.attentions.1\"] for style blocks only\n",
"# target_blocks = [\"up_blocks.0.attentions.1\", \"down_blocks.2.attentions.1\"] # for style+layout blocks\n",
"target_blocks = [\"up_blocks.0.attentions.1\"]\n",
"ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=target_blocks)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"image_dir = './styles'\n",
"image_path_list = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 30/30 [00:23<00:00, 1.28it/s]\n",
"100%|██████████| 30/30 [00:23<00:00, 1.27it/s]\n",
"100%|██████████| 30/30 [00:23<00:00, 1.27it/s]\n",
"100%|██████████| 30/30 [00:23<00:00, 1.28it/s]\n",
"100%|██████████| 30/30 [00:23<00:00, 1.27it/s]\n",
"100%|██████████| 30/30 [00:23<00:00, 1.27it/s]\n"
]
}
],
"source": [
"# generate image variations with only image prompt\n",
"\n",
"result_dir = './results'\n",
"os.makedirs(result_dir, exist_ok=True)\n",
"\n",
"for image_path in image_path_list:\n",
" image_name = image_path.split('/')[-1]\n",
" image = Image.open(image_path)\n",
" image = image.resize((512, 512))\n",
" with torch.inference_mode():\n",
" images = ip_model.generate(pil_image=image,\n",
" prompt=\"a cat, masterpiece, best quality, high quality\",\n",
" negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n",
" scale=1.0,\n",
" guidance_scale=5,\n",
" num_samples=1,\n",
" num_inference_steps=30, \n",
" seed=42,\n",
" #neg_content_prompt=\"a rabbit\",\n",
" #neg_content_scale=0.5,\n",
" )\n",
" images[0].save(os.path.join(result_dir, image_name))\n",
" \n",
" del images\n",
" torch.cuda.empty_cache()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Better disentanglement : subtract content feature from the reference image"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def disentanglement(style_image_path: str, neg_prompt: str):\n",
" image_name = style_image_path.split('/')[-1]\n",
" image = Image.open(style_image_path)\n",
" image = image.resize((512, 512))\n",
"\n",
" with torch.inference_mode():\n",
" images = ip_model.generate(pil_image=image,\n",
" prompt=\"a cat, masterpiece, best quality, high quality\",\n",
" negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n",
" scale=1.0,\n",
" guidance_scale=5,\n",
" num_samples=1,\n",
" num_inference_steps=30, \n",
" seed=42,\n",
" neg_content_prompt=neg_prompt,\n",
" neg_content_scale=0.5,\n",
" )\n",
" images[0].save(os.path.join(result_dir, 'de_'+image_name))\n",
" \n",
" del images\n",
" torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"style_image_path = './styles/style3.jpg'\n",
"result_dir = './results'\n",
"os.makedirs(result_dir, exist_ok=True)\n",
"\n",
"disentanglement(style_image_path=style_image_path, neg_prompt=\"mushroom\")\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 30/30 [00:08<00:00, 3.49it/s]\n"
]
}
],
"source": [
"style_image_path = './styles/style6.jpg'\n",
"result_dir = './results'\n",
"os.makedirs(result_dir, exist_ok=True)\n",
"\n",
"disentanglement(style_image_path=style_image_path, neg_prompt=\"building\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Image-based style change"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# the model path definition\n",
"base_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
"image_encoder_path = \"sdxl_models/image_encoder\"\n",
"ip_ckpt = \"sdxl_models/ip-adapter_sdxl.bin\"\n",
"device = \"cuda\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"controlnet_path = \"diffusers/controlnet-canny-sdxl-1.0\"\n",
"controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=False, torch_dtype=torch.float16).to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading pipeline components...: 100%|██████████| 7/7 [00:02<00:00, 3.24it/s]\n"
]
}
],
"source": [
"# load SDXL pipeline\n",
"pipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n",
" base_model_path,\n",
" controlnet=controlnet,\n",
" torch_dtype=torch.float16,\n",
" add_watermarker=False,\n",
")\n",
"pipe.enable_vae_tiling()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/InstantStyle/ip_adapter/ip_adapter.py:149: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load(self.ip_ckpt, map_location=\"cpu\")\n"
]
}
],
"source": [
"ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=[\"up_blocks.0.attentions.1\"])\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"image_dir = './styles'\n",
"image_path_list = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# control image\n",
"input_image = cv2.imread(\"./assets/nero.jpg\")\n",
"detected_map = cv2.Canny(input_image, 50, 200)\n",
"canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n",
"100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n",
"100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n",
"100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n",
"100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n",
"100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n"
]
}
],
"source": [
"# generate image variations with only image prompt\n",
"\n",
"result_dir = './results'\n",
"os.makedirs(result_dir, exist_ok=True)\n",
"\n",
"for image_path in image_path_list:\n",
" image_name = image_path.split('/')[-1]\n",
" image = Image.open(image_path)\n",
" image = image.resize((512, 512))\n",
" \n",
" with torch.inference_mode():\n",
" images = ip_model.generate(pil_image=image,\n",
" prompt=\"a cat, masterpiece, best quality, high quality\",\n",
" negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n",
" scale=1.0,\n",
" guidance_scale=5,\n",
" num_samples=1,\n",
" num_inference_steps=30, \n",
" seed=42,\n",
" image=canny_map,\n",
" controlnet_conditioning_scale=0.6,\n",
" )\n",
" images[0].save(os.path.join(result_dir, 'nero_'+image_name))\n",
" \n",
" del images\n",
" torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "InstantStyle",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment