Last active
September 28, 2024 15:25
-
-
Save tanukon/a4c894ec72d952e93fde9bd2cd56da37 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/opt/conda/envs/InstantStyle/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
| " from .autonotebook import tqdm as notebook_tqdm\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import cv2\n", | |
| "import os\n", | |
| "import torch\n", | |
| "\n", | |
| "from diffusers import StableDiffusionXLPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline\n", | |
| "from PIL import Image\n", | |
| "from ip_adapter import IPAdapterXL" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Basic InstantStyle Usage: Style Change " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# the model path definition\n", | |
| "base_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", | |
| "image_encoder_path = \"sdxl_models/image_encoder\"\n", | |
| "ip_ckpt = \"sdxl_models/ip-adapter_sdxl.bin\"\n", | |
| "device = \"cuda\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Loading pipeline components...: 100%|██████████| 7/7 [00:52<00:00, 7.53s/it]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# load SDXL pipeline\n", | |
| "pipe = StableDiffusionXLPipeline.from_pretrained(\n", | |
| " base_model_path,\n", | |
| " torch_dtype=torch.float16,\n", | |
| " add_watermarker=False,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# reduce memory consumption\n", | |
| "pipe.enable_vae_tiling()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/ubuntu/InstantStyle/ip_adapter/ip_adapter.py:149: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", | |
| " state_dict = torch.load(self.ip_ckpt, map_location=\"cpu\")\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# load ip-adapter\n", | |
| "# target_blocks=[\"block\"] for original IP-Adapter\n", | |
| "# target_blocks=[\"up_blocks.0.attentions.1\"] for style blocks only\n", | |
| "# target_blocks = [\"up_blocks.0.attentions.1\", \"down_blocks.2.attentions.1\"] # for style+layout blocks\n", | |
| "target_blocks = [\"up_blocks.0.attentions.1\"]\n", | |
| "ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=target_blocks)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "image_dir = './styles'\n", | |
| "image_path_list = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 30/30 [00:23<00:00, 1.28it/s]\n", | |
| "100%|██████████| 30/30 [00:23<00:00, 1.27it/s]\n", | |
| "100%|██████████| 30/30 [00:23<00:00, 1.27it/s]\n", | |
| "100%|██████████| 30/30 [00:23<00:00, 1.28it/s]\n", | |
| "100%|██████████| 30/30 [00:23<00:00, 1.27it/s]\n", | |
| "100%|██████████| 30/30 [00:23<00:00, 1.27it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# generate image variations with only image prompt\n", | |
| "\n", | |
| "result_dir = './results'\n", | |
| "os.makedirs(result_dir, exist_ok=True)\n", | |
| "\n", | |
| "for image_path in image_path_list:\n", | |
| " image_name = image_path.split('/')[-1]\n", | |
| " image = Image.open(image_path)\n", | |
| " image = image.resize((512, 512))\n", | |
| " with torch.inference_mode():\n", | |
| " images = ip_model.generate(pil_image=image,\n", | |
| " prompt=\"a cat, masterpiece, best quality, high quality\",\n", | |
| " negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n", | |
| " scale=1.0,\n", | |
| " guidance_scale=5,\n", | |
| " num_samples=1,\n", | |
| " num_inference_steps=30, \n", | |
| " seed=42,\n", | |
| " #neg_content_prompt=\"a rabbit\",\n", | |
| " #neg_content_scale=0.5,\n", | |
| " )\n", | |
| " images[0].save(os.path.join(result_dir, image_name))\n", | |
| " \n", | |
| " del images\n", | |
| " torch.cuda.empty_cache()\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Better disentanglement : subtract content feature from the reference image" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def disentanglement(style_image_path: str, neg_prompt: str):\n", | |
| " image_name = style_image_path.split('/')[-1]\n", | |
| " image = Image.open(style_image_path)\n", | |
| " image = image.resize((512, 512))\n", | |
| "\n", | |
| " with torch.inference_mode():\n", | |
| " images = ip_model.generate(pil_image=image,\n", | |
| " prompt=\"a cat, masterpiece, best quality, high quality\",\n", | |
| " negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n", | |
| " scale=1.0,\n", | |
| " guidance_scale=5,\n", | |
| " num_samples=1,\n", | |
| " num_inference_steps=30, \n", | |
| " seed=42,\n", | |
| " neg_content_prompt=neg_prompt,\n", | |
| " neg_content_scale=0.5,\n", | |
| " )\n", | |
| " images[0].save(os.path.join(result_dir, 'de_'+image_name))\n", | |
| " \n", | |
| " del images\n", | |
| " torch.cuda.empty_cache()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "style_image_path = './styles/style3.jpg'\n", | |
| "result_dir = './results'\n", | |
| "os.makedirs(result_dir, exist_ok=True)\n", | |
| "\n", | |
| "disentanglement(style_image_path=style_image_path, neg_prompt=\"mushroom\")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 30/30 [00:08<00:00, 3.49it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "style_image_path = './styles/style6.jpg'\n", | |
| "result_dir = './results'\n", | |
| "os.makedirs(result_dir, exist_ok=True)\n", | |
| "\n", | |
| "disentanglement(style_image_path=style_image_path, neg_prompt=\"building\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Image-based style change" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# the model path definition\n", | |
| "base_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", | |
| "image_encoder_path = \"sdxl_models/image_encoder\"\n", | |
| "ip_ckpt = \"sdxl_models/ip-adapter_sdxl.bin\"\n", | |
| "device = \"cuda\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "controlnet_path = \"diffusers/controlnet-canny-sdxl-1.0\"\n", | |
| "controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=False, torch_dtype=torch.float16).to(device)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Loading pipeline components...: 100%|██████████| 7/7 [00:02<00:00, 3.24it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# load SDXL pipeline\n", | |
| "pipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n", | |
| " base_model_path,\n", | |
| " controlnet=controlnet,\n", | |
| " torch_dtype=torch.float16,\n", | |
| " add_watermarker=False,\n", | |
| ")\n", | |
| "pipe.enable_vae_tiling()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/ubuntu/InstantStyle/ip_adapter/ip_adapter.py:149: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", | |
| " state_dict = torch.load(self.ip_ckpt, map_location=\"cpu\")\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=[\"up_blocks.0.attentions.1\"])\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "image_dir = './styles'\n", | |
| "image_path_list = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# control image\n", | |
| "input_image = cv2.imread(\"./assets/nero.jpg\")\n", | |
| "detected_map = cv2.Canny(input_image, 50, 200)\n", | |
| "canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n", | |
| "100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n", | |
| "100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n", | |
| "100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n", | |
| "100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n", | |
| "100%|██████████| 30/30 [00:42<00:00, 1.41s/it]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# generate image variations with only image prompt\n", | |
| "\n", | |
| "result_dir = './results'\n", | |
| "os.makedirs(result_dir, exist_ok=True)\n", | |
| "\n", | |
| "for image_path in image_path_list:\n", | |
| " image_name = image_path.split('/')[-1]\n", | |
| " image = Image.open(image_path)\n", | |
| " image = image.resize((512, 512))\n", | |
| " \n", | |
| " with torch.inference_mode():\n", | |
| " images = ip_model.generate(pil_image=image,\n", | |
| " prompt=\"a cat, masterpiece, best quality, high quality\",\n", | |
| " negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n", | |
| " scale=1.0,\n", | |
| " guidance_scale=5,\n", | |
| " num_samples=1,\n", | |
| " num_inference_steps=30, \n", | |
| " seed=42,\n", | |
| " image=canny_map,\n", | |
| " controlnet_conditioning_scale=0.6,\n", | |
| " )\n", | |
| " images[0].save(os.path.join(result_dir, 'nero_'+image_name))\n", | |
| " \n", | |
| " del images\n", | |
| " torch.cuda.empty_cache()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "InstantStyle", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.15" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment