Created
January 5, 2024 19:06
-
-
Save keturn/2007d471c0a2b04f3d7eeca054da4d18 to your computer and use it in GitHub Desktop.
exploring CLIP token embeddings as used for Stable Diffusion inputs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "27582118-8452-4e16-bbdf-d6bdb7e665d6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "\n", | |
| "from pathlib import Path\n", | |
| "\n", | |
| "import diffusers, torch, transformers, tokenizers" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "3ea9f738-051b-41d1-b5a4-641554ba4ace", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "MODEL_PATH = Path(os.environ['INVOKEAI_ROOT']) / 'models' / 'sdxl' / 'main' / 'stable-diffusion-xl-base-1-0'\n", | |
| "assert MODEL_PATH.exists() and MODEL_PATH.is_dir()\n", | |
| "\n", | |
| "device = torch.device(\"cuda:0\")\n", | |
| "torch.set_default_device(device)\n", | |
| "cx = torch.inference_mode(True).__enter__()\n", | |
| "\n", | |
| "clip = transformers.CLIPTextModelWithProjection.from_pretrained(MODEL_PATH / 'text_encoder_2', use_safetensors=True)\n", | |
| "embeddings = clip.get_input_embeddings()\n", | |
| "positional_embeddings = clip.text_model.embeddings.position_embedding\n", | |
| "\n", | |
| "tokenizer = transformers.CLIPTokenizerFast.from_pretrained(MODEL_PATH / 'tokenizer_2', use_safetensors=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "9fea5e58-052a-4c75-bd17-b29200bbcb70", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([49406, 320, 2713, 2870, 3086, 633, 518, 7991, 6913, 281,\n", | |
| " 1170, 970, 593, 8922, 11729, 269, 49407], device='cuda:0')" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "prompt = \"a classic oil painting from the dutch masters: still life with banana sushi.\"\n", | |
| "token_ids = torch.tensor(tokenizer(prompt)['input_ids'])\n", | |
| "display(token_ids)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "a3309e51-8688-4421-92d8-698a1f7f8382", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([17, 1280])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-0.0030, 0.0038, 0.0004, ..., -0.0003, 0.0038, -0.0007],\n", | |
| " [-0.0020, -0.0234, 0.0159, ..., -0.0003, -0.0149, 0.0073],\n", | |
| " [-0.0136, -0.0032, -0.0109, ..., 0.0050, -0.0047, -0.0118],\n", | |
| " ...,\n", | |
| " [ 0.0173, 0.0183, -0.0243, ..., -0.0076, -0.0182, -0.0043],\n", | |
| " [ 0.0042, 0.0105, -0.0081, ..., -0.0052, 0.0157, -0.0053],\n", | |
| " [ 0.0024, 0.0029, 0.0049, ..., 0.0023, -0.0013, 0.0016]],\n", | |
| " device='cuda:0', grad_fn=<EmbeddingBackward0>)" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "prompt_token_embeddings = embeddings(token_ids)\n", | |
| "display(prompt_token_embeddings.shape, prompt_token_embeddings)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "6e67c6e6-7af4-45b0-b467-11dbe96a3734", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n", | |
| " device='cuda:0')" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1, 17, 1280])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[-2.9507e-03, 3.7556e-03, 4.4179e-04, ..., -2.7585e-04,\n", | |
| " 3.7937e-03, -7.4005e-04],\n", | |
| " [-4.5052e-03, 4.9324e-03, -2.1343e-03, ..., -3.2425e-05,\n", | |
| " 2.7370e-03, -2.0905e-03],\n", | |
| " [-1.7939e-03, -5.5611e-05, -2.3365e-03, ..., 1.3590e-04,\n", | |
| " 2.1152e-03, -7.2002e-04],\n", | |
| " ...,\n", | |
| " [-6.4888e-03, 1.2619e-02, 1.1116e-02, ..., 1.4324e-03,\n", | |
| " 1.5526e-03, -4.5509e-03],\n", | |
| " [-2.0618e-03, 8.5602e-03, 9.1782e-03, ..., 4.4751e-04,\n", | |
| " -2.1400e-03, -5.3635e-03],\n", | |
| " [-5.2986e-03, 5.6686e-03, 6.6223e-03, ..., 1.4315e-03,\n", | |
| " 1.6890e-03, -4.6959e-03]]], device='cuda:0',\n", | |
| " grad_fn=<EmbeddingBackward0>)" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "position_ids = clip.text_model.embeddings.position_ids[:, :token_ids.shape[-1]]\n", | |
| "display(position_ids)\n", | |
| "prompt_position_embeddings = positional_embeddings(position_ids)\n", | |
| "display(prompt_position_embeddings.shape, prompt_position_embeddings)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "1553e9e6-f706-444f-b887-6c8f25eb0492", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1, 17, 1280])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[-0.0059, 0.0075, 0.0009, ..., -0.0006, 0.0076, -0.0015],\n", | |
| " [-0.0065, -0.0185, 0.0137, ..., -0.0003, -0.0121, 0.0053],\n", | |
| " [-0.0154, -0.0032, -0.0133, ..., 0.0051, -0.0026, -0.0125],\n", | |
| " ...,\n", | |
| " [ 0.0109, 0.0309, -0.0131, ..., -0.0062, -0.0166, -0.0089],\n", | |
| " [ 0.0021, 0.0191, 0.0011, ..., -0.0048, 0.0136, -0.0107],\n", | |
| " [-0.0029, 0.0086, 0.0116, ..., 0.0037, 0.0004, -0.0031]]],\n", | |
| " device='cuda:0', grad_fn=<AddBackward0>)" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Use the CLIPTextEmbedding's complete forward method, which both gets the tokene embeddings and applies the positional embeddings to them.\n", | |
| "prompt_combined_embeddings = clip.text_model.embeddings(token_ids)\n", | |
| "display(prompt_combined_embeddings.shape, prompt_combined_embeddings)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "4acdee70-e769-459d-96fe-cdb43399310a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# reversing that addition in reduced precision to see how much of the positional information is intact\n", | |
| "bf16_positions = prompt_combined_embeddings.to(dtype=torch.bfloat16) - prompt_token_embeddings.to(dtype=torch.bfloat16)\n", | |
| "fp16_positions = prompt_combined_embeddings.to(dtype=torch.float16) - prompt_token_embeddings.to(dtype=torch.float16)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "ae908245-1dd8-43e8-b297-e7632c0cdd87", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-0.0291, 2.8052, 3.0012, 2.8337, 2.8358, 2.9264, 2.6316, 2.7296,\n", | |
| " 2.6183, 2.4568, 2.6012, 2.5254, 2.3077, 2.2098, 1.9971, 1.9277,\n", | |
| " 0.5930]], device='cuda:0', grad_fn=<SubBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# trying to get an idea of how many bits the position info is shifted from the token info\n", | |
| "prompt_token_embeddings.abs().mean(-1).log2() - bf16_positions.abs().mean(-1).log2()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "6de74ec4-539d-4432-95e6-8f306ca83641", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[0.0041, 0.0696, nan, 0.0773, 0.0743, 0.0745, 0.0678, 0.0708, 0.0742,\n", | |
| " 0.0604, 0.0655, 0.0639, 0.0599, 0.0488, 0.0423, 0.0431, 0.0181]],\n", | |
| " device='cuda:0', grad_fn=<MeanBackward1>)" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "position_mantissas = prompt_position_embeddings.frexp().mantissa\n", | |
| "bf16_position_mantissas = bf16_positions.to(dtype=torch.float32).frexp().mantissa # bf16 does not implement frexp itself\n", | |
| "mantissa_error = (bf16_position_mantissas - position_mantissas) / position_mantissas\n", | |
| "display(mantissa_error.abs().mean(-1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "8680e9dd-5bc0-4c2a-8c55-9b24bf624db1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[8.7450e-06, 1.1561e-02, nan, 1.1818e-02, 1.1582e-02, 1.1504e-02,\n", | |
| " 9.6999e-03, 1.2325e-02, 1.4067e-02, 1.2303e-02, 9.1884e-03, 1.1376e-02,\n", | |
| " 9.2605e-03, 8.0006e-03, 9.0526e-03, 6.5079e-03, 2.7143e-03]],\n", | |
| " device='cuda:0', grad_fn=<MeanBackward1>)" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "fp16_position_mantissas = fp16_positions.frexp().mantissa\n", | |
| "fp16_mantissa_error = (fp16_position_mantissas - position_mantissas) / position_mantissas\n", | |
| "display(fp16_mantissa_error.abs().mean(-1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "c29a0cce-ce2d-4041-9db7-2abd88184519", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "v311_invoke", | |
| "language": "python", | |
| "name": "v311_invoke" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment