Created
June 8, 2025 22:12
-
-
Save fepegar/15ae4b68c0da68ae97aba9a2ddc0b412 to your computer and use it in GitHub Desktop.
run-capi.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyP72aZCjnR5fkDHhkgn26Gc", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/fepegar/15ae4b68c0da68ae97aba9a2ddc0b412/run-capi.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "%pip install --quiet einops jaxtyping numpy omegaconf rich scikit-learn torch torchvision" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "G3vj01I9SjRj", | |
| "outputId": "81fbc920-6a13-4b6a-91ed-e4e2e3617e5a" | |
| }, | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.4/55.4 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m97.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m73.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m46.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m59.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25h" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import torch\n", | |
| "\n", | |
| "capi_vitl14_p205 = torch.hub.load('facebookresearch/capi:main', 'capi_vitl14_p205')" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "6Q2r8CuKSxd_", | |
| "outputId": "aca470ea-8785-4471-ac0c-69adae07f99c" | |
| }, | |
| "execution_count": 3, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "Downloading: \"https://github.com/facebookresearch/capi/zipball/main\" to /root/.cache/torch/hub/main.zip\n", | |
| "WARNING:utils:Can't import sklearnex. If installed, that speeds up scikit-learn 10-100x\n", | |
| "Downloading: \"https://dl.fbaipublicfiles.com/capi/capi_vitl14_p205.pth\" to /root/.cache/torch/hub/checkpoints/capi_vitl14_p205.pth\n", | |
| "100%|██████████| 900M/900M [00:47<00:00, 20.0MB/s]\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "img = torch.zeros(1, 3, 224, 224)\n", | |
| "global_repr, registers, feature_map = capi_vitl14_p205(img)\n", | |
| "global_repr.shape, registers.shape, feature_map.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "OETucVAkTovr", | |
| "outputId": "b201ea10-1ec9-4fb2-b766-9823a6dae3ec" | |
| }, | |
| "execution_count": 4, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([1, 1024]),\n", | |
| " torch.Size([1, 16, 1024]),\n", | |
| " torch.Size([1, 16, 16, 1024]))" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 4 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "%load_ext rich" | |
| ], | |
| "metadata": { | |
| "id": "sTZh4MX4mam9" | |
| }, | |
| "execution_count": 6, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "capi_vitl14_p205" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 981 | |
| }, | |
| "id": "amKIooVSmbtH", | |
| "outputId": "2272eb5a-433a-4974-dae1-bedb51a7b401" | |
| }, | |
| "execution_count": 7, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "\n", | |
| "\u001b[1;35mEncoderDecoder\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mencoder\u001b[1m)\u001b[0m: \u001b[1;35mTransformer\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mblocks\u001b[1m)\u001b[0m: \u001b[1;35mModuleList\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0m\u001b[1;36m0\u001b[0m-\u001b[1;36m23\u001b[0m\u001b[1m)\u001b[0m: \u001b[1;36m24\u001b[0m x \u001b[1;35mBlock\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mresidual1\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mAttention\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mq_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mk_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mv_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mproj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mrope\u001b[1m)\u001b[0m: \u001b[1;35mRope\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mresidual2\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mMlp\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mfc1\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mact\u001b[1m)\u001b[0m: \u001b[1;35mGELU\u001b[0m\u001b[1m(\u001b[0m\u001b[33mapproximate\u001b[0m=\u001b[32m'none'\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mfc2\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mdecoder\u001b[1m)\u001b[0m: \u001b[1;35mTransformer\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mblocks\u001b[1m)\u001b[0m: \u001b[1;35mModuleList\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0m\u001b[1;36m0\u001b[0m-\u001b[1;36m11\u001b[0m\u001b[1m)\u001b[0m: \u001b[1;36m12\u001b[0m x \u001b[1;35mBlock\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mresidual1\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mAttention\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mq_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mk_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mv_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mproj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mrope\u001b[1m)\u001b[0m: \u001b[1;35mRope\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mresidual2\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mMlp\u001b[0m\u001b[1m(\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mfc1\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mact\u001b[1m)\u001b[0m: \u001b[1;35mGELU\u001b[0m\u001b[1m(\u001b[0m\u001b[33mapproximate\u001b[0m=\u001b[32m'none'\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mfc2\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mpatch_embed\u001b[1m)\u001b[0m: \u001b[1;35mConv2d\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m1024\u001b[0m, \u001b[33mkernel_size\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m14\u001b[0m, \u001b[1;36m14\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m14\u001b[0m, \u001b[1;36m14\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0menc_norm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| " \u001b[1m(\u001b[0mdec_norm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
| "\u001b[1m)\u001b[0m" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 7 | |
| } | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment