-
-
Save aquadzn/32ac53aa6e485e7c3e09b1a0914f7422 to your computer and use it in GitHub Desktop.
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "aquagraphist@gmail.com", | |
| "provenance": [], | |
| "collapsed_sections": [] | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "vngKopoMm6YF" | |
| }, | |
| "source": [ | |
| "# Run on Google Colab using a GPU" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "ZRpyWAFbnWlA" | |
| }, | |
| "source": [ | |
| "## Clone the repo" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "hcACfYD7nXYy" | |
| }, | |
| "source": [ | |
| "!rm -rf sample_data/\n", | |
| "!mkdir out" | |
| ], | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "ntvf2-VJo0wR" | |
| }, | |
| "source": [ | |
| "!mkdir input\n", | |
| "!mkdir output" | |
| ], | |
| "execution_count": 175, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "l2VdQYme11i_" | |
| }, | |
| "source": [ | |
| "!git clone https://github.com/facebookresearch/dino.git" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "orPqkK37nkZ5" | |
| }, | |
| "source": [ | |
| "Download a model, here I used deit small 8 pretrained" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "kfgejeGwniAZ" | |
| }, | |
| "source": [ | |
| "!wget https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth -O dino/dino_deitsmall8_pretrain.pth" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "uu92SGGC1ezm" | |
| }, | |
| "source": [ | |
| "## Look for a video to use and download it\n", | |
| "\n", | |
| "I'm using this one for example\n", | |
| "https://www.pexels.com/fr-fr/video/chien-course-exterieur-journee-ensoleillee-4166347/\n", | |
| "\n", | |
| "\n", | |
| "Then you need to extract frames from the video, you can use ffmpeg.\n", | |
| "\n", | |
| "Video is 60 fps and ~6 sec so you'll get ~360 jpg images\n", | |
| "\n", | |
| "%03d is from 001 to 999" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "LWk0Ymx5n8wN" | |
| }, | |
| "source": [ | |
| "!ffmpeg -i video.mp4 input/img-%03d.jpg" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "J49nJpl5pP05" | |
| }, | |
| "source": [ | |
| "%cd dino/" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "bgU_sdRJpRHQ" | |
| }, | |
| "source": [ | |
| "## Code" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "3nwQ4Ct4pWCr" | |
| }, | |
| "source": [ | |
| "Requirements:\n", | |
| "\n", | |
| "\n", | |
| "* Opencv\n", | |
| "* scikit-image\n", | |
| "* maptlotlib\n", | |
| "* pytorch\n", | |
| "* numpy\n", | |
| "* Pillow\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "3-YQ-kbN30k1" | |
| }, | |
| "source": [ | |
| "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n", | |
| "import os\n", | |
| "import gc\n", | |
| "import sys\n", | |
| "import argparse\n", | |
| "import cv2\n", | |
| "import random\n", | |
| "import colorsys\n", | |
| "import requests\n", | |
| "from io import BytesIO\n", | |
| "\n", | |
| "import skimage.io\n", | |
| "from skimage.measure import find_contours\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "from matplotlib.patches import Polygon\n", | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torchvision\n", | |
| "from torchvision import transforms as pth_transforms\n", | |
| "import numpy as np\n", | |
| "from PIL import Image\n", | |
| "\n", | |
| "import utils\n", | |
| "import vision_transformer as vits" | |
| ], | |
| "execution_count": 57, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "_Iru4HEMrsmk" | |
| }, | |
| "source": [ | |
| "You may need to resize each tensor otherwise you'll get an OOM error\n", | |
| "\n", | |
| "Line 9: `pth_transforms.Resize(512),`\n", | |
| "\n", | |
| "\n", | |
| "Also, the color of video from blogpost is obtained by using cmap=\"inferno\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "eaJJqEOB-LH0" | |
| }, | |
| "source": [ | |
| "def predict_video(args):\n", | |
| " for frame in sorted(os.listdir(args.image_path)):\n", | |
| " with open(os.path.join(args.image_path, frame), 'rb') as f:\n", | |
| " img = Image.open(f)\n", | |
| " img = img.convert('RGB')\n", | |
| "\n", | |
| " transform = pth_transforms.Compose([\n", | |
| " pth_transforms.ToTensor(),\n", | |
| " pth_transforms.Resize(512),\n", | |
| " pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", | |
| " ])\n", | |
| " img = transform(img)\n", | |
| "\n", | |
| " # make the image divisible by the patch size\n", | |
| " w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size\n", | |
| " img = img[:, :w, :h].unsqueeze(0)\n", | |
| "\n", | |
| " w_featmap = img.shape[-2] // args.patch_size\n", | |
| " h_featmap = img.shape[-1] // args.patch_size\n", | |
| "\n", | |
| " attentions = model.forward_selfattention(img.cuda())\n", | |
| "\n", | |
| " nh = attentions.shape[1] # number of head\n", | |
| "\n", | |
| " # we keep only the output patch attention\n", | |
| " attentions = attentions[0, :, 0, 1:].reshape(nh, -1)\n", | |
| "\n", | |
| " # we keep only a certain percentage of the mass\n", | |
| " val, idx = torch.sort(attentions)\n", | |
| " val /= torch.sum(val, dim=1, keepdim=True)\n", | |
| " cumval = torch.cumsum(val, dim=1)\n", | |
| " th_attn = cumval > (1 - args.threshold)\n", | |
| " idx2 = torch.argsort(idx)\n", | |
| " for head in range(nh):\n", | |
| " th_attn[head] = th_attn[head][idx2[head]]\n", | |
| " th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()\n", | |
| " # interpolate\n", | |
| " th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode=\"nearest\")[0].cpu().numpy()\n", | |
| "\n", | |
| " attentions = attentions.reshape(nh, w_featmap, h_featmap)\n", | |
| " attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode=\"nearest\")[0].cpu().numpy()\n", | |
| "\n", | |
| " # save attentions heatmaps\n", | |
| " os.makedirs(args.output_dir, exist_ok=True)\n", | |
| "\n", | |
| " # Saving only last attention layer\n", | |
| " fname = os.path.join(args.output_dir, \"attn-\" + frame)\n", | |
| " plt.imsave(\n", | |
| " fname=fname,\n", | |
| " arr=sum(attentions[i] * 1/attentions.shape[0] for i in range(attentions.shape[0])),\n", | |
| " cmap=\"inferno\",\n", | |
| " format=\"jpg\"\n", | |
| " )\n", | |
| " print(f\"{fname} saved.\")" | |
| ], | |
| "execution_count": 176, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "3T1S_2H5O_XX" | |
| }, | |
| "source": [ | |
| "#@title Args\n", | |
| "\n", | |
| "pretrained_weights_path = \"dino_deitsmall8_pretrain.pth\" #@param {type:\"string\"}\n", | |
| "arch = 'deit_small' #@param [\"deit_small\", \"deit_tiny\", \"vit_base\"]\n", | |
| "input_path = \"../input/\" #@param {type:\"string\"}\n", | |
| "output_path = \"../output/\" #@param {type:\"string\"}\n", | |
| "threshold = 0.6 #@param {type:\"number\"}\n", | |
| "\n", | |
| "\n", | |
| "parser = argparse.ArgumentParser('Visualize Self-Attention maps')\n", | |
| "parser.add_argument('--arch', default='deit_small', type=str,\n", | |
| " choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')\n", | |
| "parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')\n", | |
| "parser.add_argument('--pretrained_weights', default='', type=str,\n", | |
| " help=\"Path to pretrained weights to load.\")\n", | |
| "parser.add_argument(\"--checkpoint_key\", default=\"teacher\", type=str,\n", | |
| " help='Key to use in the checkpoint (example: \"teacher\")')\n", | |
| "parser.add_argument(\"--image_path\", default=None, type=str, help=\"Path of the image to load.\")\n", | |
| "parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.')\n", | |
| "parser.add_argument(\"--threshold\", type=float, default=0.6, help=\"\"\"We visualize masks\n", | |
| " obtained by thresholding the self-attention maps to keep xx% of the mass.\"\"\")\n", | |
| "\n", | |
| "args = parser.parse_args(args=[])\n", | |
| "\n", | |
| "args.arch = arch\n", | |
| "args.pretrained_weights = pretrained_weights_path\n", | |
| "args.image_path = \"../video/\"\n", | |
| "args.output_dir = \"../out/\"\n", | |
| "args.threshold = threshold" | |
| ], | |
| "execution_count": 164, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "j2eEQ8YQ69aG", | |
| "outputId": "1e94923c-b8d7-4a5b-fa75-ce3cdf729823" | |
| }, | |
| "source": [ | |
| "model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)\n", | |
| "for p in model.parameters():\n", | |
| " p.requires_grad = False\n", | |
| "model.eval()\n", | |
| "model.cuda()\n", | |
| "if os.path.isfile(args.pretrained_weights):\n", | |
| " state_dict = torch.load(args.pretrained_weights, map_location=\"cpu\")\n", | |
| " if args.checkpoint_key is not None and args.checkpoint_key in state_dict:\n", | |
| " print(f\"Take key {args.checkpoint_key} in provided checkpoint dict\")\n", | |
| " state_dict = state_dict[args.checkpoint_key]\n", | |
| " state_dict = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n", | |
| " msg = model.load_state_dict(state_dict, strict=False)\n", | |
| " print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))\n", | |
| "else:\n", | |
| " print(\"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.\")\n", | |
| " url = None\n", | |
| " if args.arch == \"deit_small\" and args.patch_size == 16:\n", | |
| " url = \"dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth\"\n", | |
| " elif args.arch == \"deit_small\" and args.patch_size == 8:\n", | |
| " url = \"dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth\" # model used for visualizations in our paper\n", | |
| " elif args.arch == \"vit_base\" and args.patch_size == 16:\n", | |
| " url = \"dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth\"\n", | |
| " elif args.arch == \"vit_base\" and args.patch_size == 8:\n", | |
| " url = \"dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth\"\n", | |
| " if url is not None:\n", | |
| " print(\"Since no pretrained weights have been provided, we load the reference pretrained DINO weights.\")\n", | |
| " state_dict = torch.hub.load_state_dict_from_url(url=\"https://dl.fbaipublicfiles.com/dino/\" + url)\n", | |
| " model.load_state_dict(state_dict, strict=True)\n", | |
| " else:\n", | |
| " print(\"There is no reference weights available for this model => We use random weights.\")\n" | |
| ], | |
| "execution_count": 165, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Pretrained weights found at dino_deitsmall8_pretrain.pth and loaded with msg: <All keys matched successfully>\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "aRk_XpBVBlIn", | |
| "outputId": "ed864ea9-788f-4b6b-b3f8-768291b5c3ba" | |
| }, | |
| "source": [ | |
| "torch.cuda.empty_cache()\n", | |
| "gc.collect()" | |
| ], | |
| "execution_count": 166, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "26740" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 166 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "z-z9LXVWsak_" | |
| }, | |
| "source": [ | |
| "## Run inference\n", | |
| "\n", | |
| "\n", | |
| "Resize if OOM" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "V6SgQq0s84ZX" | |
| }, | |
| "source": [ | |
| "predict_video(args)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "SXOZzCiTsmGl" | |
| }, | |
| "source": [ | |
| "## Output images to video\n", | |
| "\n", | |
| "Input video is 60 fps, change if yours is different" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "vIhTNIEnkEfR" | |
| }, | |
| "source": [ | |
| "!ffmpeg -framerate 60 -i ../output/attn-image-%03d.jpg ../output.mp4" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "4-ZTjxJBtCn_" | |
| }, | |
| "source": [ | |
| "If you want both input and output videos side by side:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "U-kSzhZZtFow" | |
| }, | |
| "source": [ | |
| "!ffmpeg -i ../video.mp4 -i ../output.mp4 -filter_complex '[0:v]pad=iw*2:ih[int];[int][1:v]overlay=W/2:0[vid]' -map '[vid]' -c:v libx264 -crf 23 -preset veryfast final.mp4" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
[FIX] Error: 'VisionTransformer' object has no attribute 'forward_selfattention'
Change the following code in line 21 of def predict_video(args):
attentions = model.forward_selfattention(img.cuda()) -> attentions = model.get_last_selfattention(img.cuda())
Good evening I would like to know if you can help me with this error, I don't know how to call the "vision_transformer" module. I really appreciate it
Did you resolve this? If not, be sure to run all of the cells in order and change into the DINO directory beforehand with "%cd dino/" and it might be helpful to import it as "vits" (import vision_transformer as vits) as that is how it is called in the notebook.
Good evening I would like to know if you can help me with this error, I don't know how to call the "vision_transformer" module. I really appreciate it
Did you resolve this? If not, be sure to run all of the cells in order and change into the DINO directory beforehand with "%cd dino/" and it might be helpful to import it as "vits" (import vision_transformer as vits) as that is how it is called in the notebook.
Yes, thank you very much Felipe, as you said I had missed running the% cd dino / code. I was finally able to get to the end of the code but again I have been stuck since I got a new warning that does not let me continue. Please if you can help me with this.

[FIX] File not found '../output/attn-image-%03d.jpg'
Actually there is a typo in there, it should be:
!ffmpeg -framerate 60 -i ../out/attn-img-%03d.jpg ../output.mp4
Good evening I would like to know if you can help me with this error, I don't know how to call the "vision_transformer" module. I really appreciate it
Did you resolve this? If not, be sure to run all of the cells in order and change into the DINO directory beforehand with "%cd dino/" and it might be helpful to import it as "vits" (import vision_transformer as vits) as that is how it is called in the notebook.
Yes, thank you very much Felipe, as you said I had missed running the% cd dino / code. I was finally able to get to the end of the code but again I have been stuck since I got a new warning that does not let me continue. Please if you can help me with this.
Has anyone tried using the other models? I get several size mismatch errors
Has anyone tried using the other models? I get several size mismatch errors
I have not trained it with another network, but I will try and tell you.
Can anyone help me out here? I've tried changing args.input_path to "../input/" and args.output_dir to "../output/" as suggested above but keep running into the following error:
predict_video(args)
IsADirectoryError Traceback (most recent call last)
in ()
----> 1 predict_video(args)
in predict_video(args)
1 def predict_video(args):
2 for frame in sorted(os.listdir(args.image_path)):
----> 3 with open(os.path.join(args.image_path, frame), 'rb') as f:
4 img = Image.open(f)
5 img = img.convert('RGB')
IsADirectoryError: [Errno 21] Is a directory: '../input/.ipynb_checkpoints'

Been getting a few errors while trying to run this notebook.


I think the
args.input_pathandargs.output_dirshould be changed to these values else there would be an error thrown during inference.I've also been getting another error while running inference at
predict_video(args)seems to throw the following error for me. Would be great if I could have some insight as to correct this.