Skip to content

Instantly share code, notes, and snippets.

@aquadzn
Last active November 25, 2023 20:47
Show Gist options
  • Select an option

  • Save aquadzn/32ac53aa6e485e7c3e09b1a0914f7422 to your computer and use it in GitHub Desktop.

Select an option

Save aquadzn/32ac53aa6e485e7c3e09b1a0914f7422 to your computer and use it in GitHub Desktop.
DINO Video inference
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "aquagraphist@gmail.com",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "vngKopoMm6YF"
},
"source": [
"# Run on Google Colab using a GPU"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZRpyWAFbnWlA"
},
"source": [
"## Clone the repo"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hcACfYD7nXYy"
},
"source": [
"!rm -rf sample_data/\n",
"!mkdir out"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ntvf2-VJo0wR"
},
"source": [
"!mkdir input\n",
"!mkdir output"
],
"execution_count": 175,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "l2VdQYme11i_"
},
"source": [
"!git clone https://github.com/facebookresearch/dino.git"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "orPqkK37nkZ5"
},
"source": [
"Download a model, here I used deit small 8 pretrained"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kfgejeGwniAZ"
},
"source": [
"!wget https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth -O dino/dino_deitsmall8_pretrain.pth"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "uu92SGGC1ezm"
},
"source": [
"## Look for a video to use and download it\n",
"\n",
"I'm using this one for example\n",
"https://www.pexels.com/fr-fr/video/chien-course-exterieur-journee-ensoleillee-4166347/\n",
"\n",
"\n",
"Then you need to extract frames from the video, you can use ffmpeg.\n",
"\n",
"Video is 60 fps and ~6 sec so you'll get ~360 jpg images\n",
"\n",
"%03d is from 001 to 999"
]
},
{
"cell_type": "code",
"metadata": {
"id": "LWk0Ymx5n8wN"
},
"source": [
"!ffmpeg -i video.mp4 input/img-%03d.jpg"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "J49nJpl5pP05"
},
"source": [
"%cd dino/"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "bgU_sdRJpRHQ"
},
"source": [
"## Code"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3nwQ4Ct4pWCr"
},
"source": [
"Requirements:\n",
"\n",
"\n",
"* Opencv\n",
"* scikit-image\n",
"* maptlotlib\n",
"* pytorch\n",
"* numpy\n",
"* Pillow\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "3-YQ-kbN30k1"
},
"source": [
"# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n",
"import os\n",
"import gc\n",
"import sys\n",
"import argparse\n",
"import cv2\n",
"import random\n",
"import colorsys\n",
"import requests\n",
"from io import BytesIO\n",
"\n",
"import skimage.io\n",
"from skimage.measure import find_contours\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Polygon\n",
"import torch\n",
"import torch.nn as nn\n",
"import torchvision\n",
"from torchvision import transforms as pth_transforms\n",
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"import utils\n",
"import vision_transformer as vits"
],
"execution_count": 57,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Iru4HEMrsmk"
},
"source": [
"You may need to resize each tensor otherwise you'll get an OOM error\n",
"\n",
"Line 9: `pth_transforms.Resize(512),`\n",
"\n",
"\n",
"Also, the color of video from blogpost is obtained by using cmap=\"inferno\""
]
},
{
"cell_type": "code",
"metadata": {
"id": "eaJJqEOB-LH0"
},
"source": [
"def predict_video(args):\n",
" for frame in sorted(os.listdir(args.image_path)):\n",
" with open(os.path.join(args.image_path, frame), 'rb') as f:\n",
" img = Image.open(f)\n",
" img = img.convert('RGB')\n",
"\n",
" transform = pth_transforms.Compose([\n",
" pth_transforms.ToTensor(),\n",
" pth_transforms.Resize(512),\n",
" pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
" ])\n",
" img = transform(img)\n",
"\n",
" # make the image divisible by the patch size\n",
" w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size\n",
" img = img[:, :w, :h].unsqueeze(0)\n",
"\n",
" w_featmap = img.shape[-2] // args.patch_size\n",
" h_featmap = img.shape[-1] // args.patch_size\n",
"\n",
" attentions = model.forward_selfattention(img.cuda())\n",
"\n",
" nh = attentions.shape[1] # number of head\n",
"\n",
" # we keep only the output patch attention\n",
" attentions = attentions[0, :, 0, 1:].reshape(nh, -1)\n",
"\n",
" # we keep only a certain percentage of the mass\n",
" val, idx = torch.sort(attentions)\n",
" val /= torch.sum(val, dim=1, keepdim=True)\n",
" cumval = torch.cumsum(val, dim=1)\n",
" th_attn = cumval > (1 - args.threshold)\n",
" idx2 = torch.argsort(idx)\n",
" for head in range(nh):\n",
" th_attn[head] = th_attn[head][idx2[head]]\n",
" th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()\n",
" # interpolate\n",
" th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode=\"nearest\")[0].cpu().numpy()\n",
"\n",
" attentions = attentions.reshape(nh, w_featmap, h_featmap)\n",
" attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode=\"nearest\")[0].cpu().numpy()\n",
"\n",
" # save attentions heatmaps\n",
" os.makedirs(args.output_dir, exist_ok=True)\n",
"\n",
" # Saving only last attention layer\n",
" fname = os.path.join(args.output_dir, \"attn-\" + frame)\n",
" plt.imsave(\n",
" fname=fname,\n",
" arr=sum(attentions[i] * 1/attentions.shape[0] for i in range(attentions.shape[0])),\n",
" cmap=\"inferno\",\n",
" format=\"jpg\"\n",
" )\n",
" print(f\"{fname} saved.\")"
],
"execution_count": 176,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3T1S_2H5O_XX"
},
"source": [
"#@title Args\n",
"\n",
"pretrained_weights_path = \"dino_deitsmall8_pretrain.pth\" #@param {type:\"string\"}\n",
"arch = 'deit_small' #@param [\"deit_small\", \"deit_tiny\", \"vit_base\"]\n",
"input_path = \"../input/\" #@param {type:\"string\"}\n",
"output_path = \"../output/\" #@param {type:\"string\"}\n",
"threshold = 0.6 #@param {type:\"number\"}\n",
"\n",
"\n",
"parser = argparse.ArgumentParser('Visualize Self-Attention maps')\n",
"parser.add_argument('--arch', default='deit_small', type=str,\n",
" choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')\n",
"parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')\n",
"parser.add_argument('--pretrained_weights', default='', type=str,\n",
" help=\"Path to pretrained weights to load.\")\n",
"parser.add_argument(\"--checkpoint_key\", default=\"teacher\", type=str,\n",
" help='Key to use in the checkpoint (example: \"teacher\")')\n",
"parser.add_argument(\"--image_path\", default=None, type=str, help=\"Path of the image to load.\")\n",
"parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.')\n",
"parser.add_argument(\"--threshold\", type=float, default=0.6, help=\"\"\"We visualize masks\n",
" obtained by thresholding the self-attention maps to keep xx% of the mass.\"\"\")\n",
"\n",
"args = parser.parse_args(args=[])\n",
"\n",
"args.arch = arch\n",
"args.pretrained_weights = pretrained_weights_path\n",
"args.image_path = \"../video/\"\n",
"args.output_dir = \"../out/\"\n",
"args.threshold = threshold"
],
"execution_count": 164,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "j2eEQ8YQ69aG",
"outputId": "1e94923c-b8d7-4a5b-fa75-ce3cdf729823"
},
"source": [
"model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)\n",
"for p in model.parameters():\n",
" p.requires_grad = False\n",
"model.eval()\n",
"model.cuda()\n",
"if os.path.isfile(args.pretrained_weights):\n",
" state_dict = torch.load(args.pretrained_weights, map_location=\"cpu\")\n",
" if args.checkpoint_key is not None and args.checkpoint_key in state_dict:\n",
" print(f\"Take key {args.checkpoint_key} in provided checkpoint dict\")\n",
" state_dict = state_dict[args.checkpoint_key]\n",
" state_dict = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n",
" msg = model.load_state_dict(state_dict, strict=False)\n",
" print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))\n",
"else:\n",
" print(\"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.\")\n",
" url = None\n",
" if args.arch == \"deit_small\" and args.patch_size == 16:\n",
" url = \"dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth\"\n",
" elif args.arch == \"deit_small\" and args.patch_size == 8:\n",
" url = \"dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth\" # model used for visualizations in our paper\n",
" elif args.arch == \"vit_base\" and args.patch_size == 16:\n",
" url = \"dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth\"\n",
" elif args.arch == \"vit_base\" and args.patch_size == 8:\n",
" url = \"dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth\"\n",
" if url is not None:\n",
" print(\"Since no pretrained weights have been provided, we load the reference pretrained DINO weights.\")\n",
" state_dict = torch.hub.load_state_dict_from_url(url=\"https://dl.fbaipublicfiles.com/dino/\" + url)\n",
" model.load_state_dict(state_dict, strict=True)\n",
" else:\n",
" print(\"There is no reference weights available for this model => We use random weights.\")\n"
],
"execution_count": 165,
"outputs": [
{
"output_type": "stream",
"text": [
"Pretrained weights found at dino_deitsmall8_pretrain.pth and loaded with msg: <All keys matched successfully>\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aRk_XpBVBlIn",
"outputId": "ed864ea9-788f-4b6b-b3f8-768291b5c3ba"
},
"source": [
"torch.cuda.empty_cache()\n",
"gc.collect()"
],
"execution_count": 166,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"26740"
]
},
"metadata": {
"tags": []
},
"execution_count": 166
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z-z9LXVWsak_"
},
"source": [
"## Run inference\n",
"\n",
"\n",
"Resize if OOM"
]
},
{
"cell_type": "code",
"metadata": {
"id": "V6SgQq0s84ZX"
},
"source": [
"predict_video(args)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "SXOZzCiTsmGl"
},
"source": [
"## Output images to video\n",
"\n",
"Input video is 60 fps, change if yours is different"
]
},
{
"cell_type": "code",
"metadata": {
"id": "vIhTNIEnkEfR"
},
"source": [
"!ffmpeg -framerate 60 -i ../output/attn-image-%03d.jpg ../output.mp4"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "4-ZTjxJBtCn_"
},
"source": [
"If you want both input and output videos side by side:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "U-kSzhZZtFow"
},
"source": [
"!ffmpeg -i ../video.mp4 -i ../output.mp4 -filter_complex '[0:v]pad=iw*2:ih[int];[int][1:v]overlay=W/2:0[vid]' -map '[vid]' -c:v libx264 -crf 23 -preset veryfast final.mp4"
],
"execution_count": null,
"outputs": []
}
]
}
@rahulkrprajapati
Copy link

Been getting a few errors while trying to run this notebook.
Revision 1
I think the args.input_path and args.output_dir should be changed to these values else there would be an error thrown during inference.
I've also been getting another error while running inference at predict_video(args) seems to throw the following error for me. Would be great if I could have some insight as to correct this.
Error 1

@sachdevkartik
Copy link

sachdevkartik commented May 14, 2021

[FIX] Error: 'VisionTransformer' object has no attribute 'forward_selfattention'

Change the following code in line 21 of def predict_video(args):
attentions = model.forward_selfattention(img.cuda()) -> attentions = model.get_last_selfattention(img.cuda())

@rahulkrprajapati
Copy link

rahulkrprajapati commented May 15, 2021

Thanks 👍 !! Finally got it working.
Any idea how I could get the attention heat-maps with a transparent background instead of the all black opaque background such that it feels like an overlay on the input video? I essentially want to get an output like so.
output

@Bennetash
Copy link

Good evening I would like to know if you can help me with this error, I don't know how to call the "vision_transformer" module. I really appreciate it
11

@felipe-parodi
Copy link

Good evening I would like to know if you can help me with this error, I don't know how to call the "vision_transformer" module. I really appreciate it

Did you resolve this? If not, be sure to run all of the cells in order and change into the DINO directory beforehand with "%cd dino/" and it might be helpful to import it as "vits" (import vision_transformer as vits) as that is how it is called in the notebook.

@Bennetash
Copy link

Good evening I would like to know if you can help me with this error, I don't know how to call the "vision_transformer" module. I really appreciate it

Did you resolve this? If not, be sure to run all of the cells in order and change into the DINO directory beforehand with "%cd dino/" and it might be helpful to import it as "vits" (import vision_transformer as vits) as that is how it is called in the notebook.

Yes, thank you very much Felipe, as you said I had missed running the% cd dino / code. I was finally able to get to the end of the code but again I have been stuck since I got a new warning that does not let me continue. Please if you can help me with this.
Captura de Pantalla 2021-05-17 a la(s) 16 27 52

@sachdevkartik
Copy link

[FIX] File not found '../output/attn-image-%03d.jpg'

Actually there is a typo in there, it should be:
!ffmpeg -framerate 60 -i ../out/attn-img-%03d.jpg ../output.mp4

Good evening I would like to know if you can help me with this error, I don't know how to call the "vision_transformer" module. I really appreciate it

Did you resolve this? If not, be sure to run all of the cells in order and change into the DINO directory beforehand with "%cd dino/" and it might be helpful to import it as "vits" (import vision_transformer as vits) as that is how it is called in the notebook.

Yes, thank you very much Felipe, as you said I had missed running the% cd dino / code. I was finally able to get to the end of the code but again I have been stuck since I got a new warning that does not let me continue. Please if you can help me with this.

@felipe-parodi
Copy link

Has anyone tried using the other models? I get several size mismatch errors

@Bennetash
Copy link

Has anyone tried using the other models? I get several size mismatch errors

I have not trained it with another network, but I will try and tell you.

@patriciajelee
Copy link

Can anyone help me out here? I've tried changing args.input_path to "../input/" and args.output_dir to "../output/" as suggested above but keep running into the following error:

predict_video(args)

IsADirectoryError Traceback (most recent call last)
in ()
----> 1 predict_video(args)

in predict_video(args)
1 def predict_video(args):
2 for frame in sorted(os.listdir(args.image_path)):
----> 3 with open(os.path.join(args.image_path, frame), 'rb') as f:
4 img = Image.open(f)
5 img = img.convert('RGB')

IsADirectoryError: [Errno 21] Is a directory: '../input/.ipynb_checkpoints'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment