Last active
June 22, 2025 15:16
-
-
Save waveletdeboshir/244b62ab800bc32f21a6bbb366cb81ba to your computer and use it in GitHub Desktop.
Jupyter for creation Whisper model without numbers. Existing models: https://huggingface.co/collections/waveletdeboshir/whisper-without-numbers-67004c5d7bf9e1a99e373d54
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/waveletdeboshir/244b62ab800bc32f21a6bbb366cb81ba/removenumberswhisper.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "yjhLPsMyLfKA" | |
| }, | |
| "source": [ | |
| "# Remove number tokens from Whisper model and tokenizer\n", | |
| "\n", | |
| "my lib versions:\n", | |
| "* transformers 4.46.3\n", | |
| "* torch 2.4.0" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "08cOSlseLfKC" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "os.environ[\"HF_HUB_CACHE\"] = \"./models/\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "5ToUGeRtLfKD" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "from transformers import WhisperProcessor, WhisperTokenizer, WhisperForConditionalGeneration" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "ZmWeSMCeLfKD" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Whisper size: tiny, base, small, medium, large-v2, large-v3, large-v3-turbo\n", | |
| "size = \"large-v3\"\n", | |
| "new_name = \"no-numbers\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "r75Xw4BWLfKD", | |
| "outputId": "30e65a13-75a4-4db6-a86f-8ca1d781f86d" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Клонирование в «whisper-large-v3»...\n", | |
| "remote: Enumerating objects: 78, done.\u001b[K\n", | |
| "remote: Counting objects: 100% (38/38), done.\u001b[K\n", | |
| "remote: Compressing objects: 100% (38/38), done.\u001b[K\n", | |
| "remote: Total 78 (delta 21), reused 0 (delta 0), pack-reused 40 (from 1)\u001b[K\n", | |
| "Распаковка объектов: 100% (78/78), 1.21 МиБ | 3.32 МиБ/с, готово.\n", | |
| "Фильтруется содержимое: 100% (7/7), 11.00 ГиБ | 5.00 МиБ/с, готово.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!mkdir models\n", | |
| "!cd models && git clone https://huggingface.co/openai/whisper-{size}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "mURtVpcALfKE" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Load initial model and tokenizer\n", | |
| "tokenizer = WhisperTokenizer.from_pretrained(f\"./models/whisper-{size}\")\n", | |
| "model = WhisperForConditionalGeneration.from_pretrained(f\"./models/whisper-{size}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "RhbCH5tsLfKE" | |
| }, | |
| "source": [ | |
| "# Find tokens with numbers" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "5g-dJg8nLfKE" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Find token indicies for numbers\n", | |
| "number_tokens = [\n", | |
| " i\n", | |
| " for i in range(tokenizer.vocab_size)\n", | |
| " if any(c in \"0123456789\" for c in tokenizer.decode([i], add_special_tokens=False))\n", | |
| "]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "ITBRITeXLfKE", | |
| "outputId": "4779f787-b444-4ccd-f57b-afa7e7441303" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "426" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(number_tokens)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "AKlWluOuLfKE" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# If you want to remove roman numerals too (except I, V, X)\n", | |
| "for roman in [\"II\", \"III\", \"IV\", \"VI\", \"VII\", \"VIII\", \"IX\", \"XI\", \"XII\", \"XIII\", \"XIV\", \"XV\", \"XVI\", \"XVII\", \"XVIII\", \"XIX\", \"XX\"]:\n", | |
| " t = tokenizer.encode(roman, add_special_tokens=False)\n", | |
| " if len(t) == 1:\n", | |
| " number_tokens.append(t[0])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "6b5gqg7jLfKE", | |
| "outputId": "6c097469-1751-4b14-b4db-6a69dde96ed3" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "431" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(number_tokens)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "D4JnhX8dLfKF", | |
| "outputId": "7095ea0c-882a-4e88-d3b6-ec1c840d74c1" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "51866" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model.proj_out.out_features" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "aF7RbEnfLfKF" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Token indicies to keep\n", | |
| "kept_ids = []\n", | |
| "for n in range(model.proj_out.out_features):\n", | |
| " if n not in number_tokens:\n", | |
| " kept_ids.append(n)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "ONbm310bLfKF", | |
| "outputId": "783960a0-719b-4daa-b560-75714a08bd0a" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "51435" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(kept_ids)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "SsOB6OsYLfKF" | |
| }, | |
| "source": [ | |
| "# Update model weights" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "-Df70724LfKF" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import copy\n", | |
| "new_model = copy.deepcopy(model)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "htCG5SjsLfKF" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "new_size = len(kept_ids)\n", | |
| "\n", | |
| "# New embedding layer\n", | |
| "\n", | |
| "endoftext_idx = tokenizer.convert_tokens_to_ids(\"<|endoftext|>\")\n", | |
| "new_emb = torch.nn.Embedding(\n", | |
| " new_size,\n", | |
| " model.model.decoder.embed_tokens.embedding_dim,\n", | |
| " padding_idx=kept_ids.index(endoftext_idx) # new idx of <|endoftext|> token\n", | |
| ")\n", | |
| "\n", | |
| "# New proj_out layer\n", | |
| "new_head = torch.nn.Linear(\n", | |
| " in_features=model.proj_out.in_features,\n", | |
| " out_features=new_size,\n", | |
| " bias=False\n", | |
| ")\n", | |
| "\n", | |
| "# Copying weights\n", | |
| "for new_id, old_id in enumerate(kept_ids):\n", | |
| " new_emb.weight.data[new_id] = model.model.decoder.embed_tokens.weight.data[old_id]\n", | |
| " new_head.weight.data[new_id] = model.proj_out.weight.data[old_id]\n", | |
| "\n", | |
| "# Change layers in model\n", | |
| "new_model.model.decoder.embed_tokens = new_emb\n", | |
| "new_model.proj_out = new_head" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "evfARbfBLfKG" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Change model config\n", | |
| "\n", | |
| "new_model.config.__dict__['vocab_size'] = new_size\n", | |
| "new_model.config.__dict__['_name_or_path'] = f'waveletdeboshir/whisper-{size}-{new_name}'\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "new_model.config.__dict__[\"bos_token_id\"] = kept_ids.index(model.config.__dict__[\"bos_token_id\"])\n", | |
| "new_model.config.__dict__[\"decoder_start_token_id\"] = kept_ids.index(model.config.__dict__[\"decoder_start_token_id\"])\n", | |
| "new_model.config.__dict__[\"eos_token_id\"] = kept_ids.index(model.config.__dict__[\"eos_token_id\"])\n", | |
| "new_model.config.__dict__[\"pad_token_id\"] = kept_ids.index(model.config.__dict__[\"pad_token_id\"])\n", | |
| "new_model.config.__dict__[\"suppress_tokens\"] = []\n", | |
| "new_model.config.__dict__[\"forced_decoder_ids\"] = [\n", | |
| " [\n", | |
| " 1,\n", | |
| " kept_ids.index(tokenizer.convert_tokens_to_ids(\"<|en|>\")) # language\n", | |
| " ],\n", | |
| " [\n", | |
| " 2,\n", | |
| " kept_ids.index(tokenizer.convert_tokens_to_ids(\"<|transcribe|>\")) # <|transcribe|>\n", | |
| " ],\n", | |
| " [\n", | |
| " 3,\n", | |
| " kept_ids.index(tokenizer.convert_tokens_to_ids(\"<|notimestamps|>\")) # <|notimestamps|>\n", | |
| " ]\n", | |
| "]\n", | |
| "\n", | |
| "beg_sup = []\n", | |
| "for t in model.config.__dict__['begin_suppress_tokens']:\n", | |
| " if t in kept_ids:\n", | |
| " beg_sup.append(kept_ids.index(t))\n", | |
| "new_model.config.__dict__['begin_suppress_tokens'] = beg_sup" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "5OdE5pxzLfKG" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Change generation config\n", | |
| "\n", | |
| "beg_sup = []\n", | |
| "for t in model.generation_config.__dict__['begin_suppress_tokens']:\n", | |
| " if t in kept_ids:\n", | |
| " beg_sup.append(kept_ids.index(t))\n", | |
| "new_model.generation_config.__dict__['begin_suppress_tokens'] = beg_sup\n", | |
| "\n", | |
| "new_model.generation_config.__dict__[\"bos_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"bos_token_id\"])\n", | |
| "new_model.generation_config.__dict__[\"decoder_start_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"decoder_start_token_id\"])\n", | |
| "new_model.generation_config.__dict__[\"eos_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"eos_token_id\"])\n", | |
| "new_model.generation_config.__dict__[\"forced_decoder_ids\"] = [\n", | |
| " [\n", | |
| " 1,\n", | |
| " None\n", | |
| " ],\n", | |
| " [\n", | |
| " 2,\n", | |
| " kept_ids.index(tokenizer.convert_tokens_to_ids(\"<|transcribe|>\"))\n", | |
| " ]\n", | |
| " ]\n", | |
| "\n", | |
| "new_lang_to_id = {}\n", | |
| "for key, value in model.generation_config.__dict__[\"lang_to_id\"].items():\n", | |
| " if value in kept_ids:\n", | |
| " new_lang_to_id[key] = kept_ids.index(value)\n", | |
| "new_model.generation_config.__dict__[\"lang_to_id\"] = new_lang_to_id\n", | |
| "\n", | |
| "new_model.generation_config.__dict__[\"no_timestamps_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"no_timestamps_token_id\"])\n", | |
| "new_model.generation_config.__dict__[\"pad_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"pad_token_id\"])\n", | |
| "new_model.generation_config.__dict__[\"prev_sot_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"prev_sot_token_id\"])\n", | |
| "new_model.generation_config.__dict__[\"suppress_tokens\"] = []\n", | |
| "new_model.generation_config.__dict__[\"task_to_id\"] = {\n", | |
| " key: kept_ids.index(value) for key, value in model.generation_config.__dict__[\"task_to_id\"].items()\n", | |
| " }" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "P5JvnUMKLfKG", | |
| "outputId": "28d8254b-9909-4858-e2e7-90b074550d6a" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/Users/daryavozdaeva/Work/anaconda3/envs/torch-env/lib/python3.10/site-packages/transformers/modeling_utils.py:2817: UserWarning: Moving the following attributes in the config to the generation config: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [210, 49826]}. You are seeing this warning because you've set generation parameters in the model config, as opposed to in the generation config.\n", | |
| " warnings.warn(\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "new_model.save_pretrained(f\"models/whisper-{size}-{new_name}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "5ul1_b4lLfKG" | |
| }, | |
| "source": [ | |
| "# Change tokenizer\n", | |
| "\n", | |
| "At first it's better to copy all tokenizer files to separate folder `models/tokenizer`.\n", | |
| "\n", | |
| "Next we create new folder to save changed tokenizer there." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "BLu9uAUlLfKG" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import json" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "uit2dCmCLfKG" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "target_folder = \"tokenizer-nonumbers\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "ZGnu62MJLfKG", | |
| "outputId": "3179787b-2b37-4fa8-f492-14e3937aee91" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mkdir: ./models/tokenizer-nonumbers: File exists\n", | |
| "mkdir: ./models/tokenizer: File exists\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!mkdir ./models/{target_folder}\n", | |
| "\n", | |
| "!mkdir ./models/tokenizer\n", | |
| "!cp ./models/whisper-{size}/added_tokens.json ./models/tokenizer/\n", | |
| "!cp ./models/whisper-{size}/merges.txt ./models/tokenizer/\n", | |
| "!cp ./models/whisper-{size}/special_tokens_map.json ./models/tokenizer/\n", | |
| "!cp ./models/whisper-{size}/tokenizer.json ./models/tokenizer/\n", | |
| "!cp ./models/whisper-{size}/tokenizer_config.json ./models/tokenizer/\n", | |
| "!cp ./models/whisper-{size}/vocab.json ./models/tokenizer/" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "1-qpKVe0LfKH" | |
| }, | |
| "source": [ | |
| "Now we will change ids of tokens in every file" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "IYkS5hhALfKH" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Added tokens\n", | |
| "with open(\"./models/tokenizer/added_tokens.json\", \"r\") as f:\n", | |
| " added_tokens = json.load(f)\n", | |
| "\n", | |
| "ch_added_tokens = {}\n", | |
| "for key, value in added_tokens.items():\n", | |
| " if value in kept_ids:\n", | |
| " ch_added_tokens[key] = kept_ids.index(value)\n", | |
| "\n", | |
| "with open(f\"./models/{target_folder}/added_tokens.json\", \"w\") as f:\n", | |
| " json.dump(ch_added_tokens, f, indent=4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "8VySManzLfKH" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Special tokens map\n", | |
| "with open(\"./models/tokenizer/special_tokens_map.json\", \"r\") as f:\n", | |
| " special_tokens_map = json.load(f)\n", | |
| "\n", | |
| "special_tokens_map[\"additional_special_tokens\"] = [\"<|endoftext|>\"] + list(ch_added_tokens.keys())\n", | |
| "with open(f\"./models/{target_folder}/special_tokens_map.json\", \"w\") as f:\n", | |
| " json.dump(special_tokens_map, f, indent=4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "83HypuxfLfKH" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Tokenizer config\n", | |
| "with open(\"./models/tokenizer/tokenizer_config.json\", \"r\") as f:\n", | |
| " tok_config = json.load(f)\n", | |
| "\n", | |
| "\n", | |
| "ch_added_tokens_decoder = {}\n", | |
| "for key, value in tok_config[\"added_tokens_decoder\"].items():\n", | |
| " if int(key) in kept_ids:\n", | |
| " ch_added_tokens_decoder[str(kept_ids.index(int(key)))] = value\n", | |
| "\n", | |
| "tok_config[\"added_tokens_decoder\"] = ch_added_tokens_decoder\n", | |
| "tok_config[\"additional_special_tokens\"] = [\"<|endoftext|>\"] + list(ch_added_tokens.keys())\n", | |
| "\n", | |
| "with open(f\"./models/{target_folder}/tokenizer_config.json\", \"w\") as f:\n", | |
| " json.dump(tok_config, f, indent=4)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "SUuHp_gHLfKH" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Tokenizer\n", | |
| "with open(\"./models/tokenizer/tokenizer.json\", \"r\") as f:\n", | |
| " tok = json.load(f)\n", | |
| "\n", | |
| "# change added tokens\n", | |
| "ch_added_tokens = []\n", | |
| "for t in tok[\"added_tokens\"]:\n", | |
| " if t[\"id\"] in kept_ids:\n", | |
| " t[\"id\"] = kept_ids.index(t[\"id\"])\n", | |
| " ch_added_tokens.append(t)\n", | |
| "\n", | |
| "tok[\"added_tokens\"] = ch_added_tokens\n", | |
| "\n", | |
| "# change vocab\n", | |
| "ch_vocab = {}\n", | |
| "for key, value in tok[\"model\"][\"vocab\"].items():\n", | |
| " if value in kept_ids:\n", | |
| " ch_vocab[key] = kept_ids.index(value)\n", | |
| "\n", | |
| "tok[\"model\"][\"vocab\"] = ch_vocab\n", | |
| "\n", | |
| "# change post processor\n", | |
| "ch_post = {}\n", | |
| "for key, value in tok[\"post_processor\"][\"special_tokens\"].items():\n", | |
| " if value[\"ids\"][0] in kept_ids:\n", | |
| " value[\"ids\"][0] = kept_ids.index(value[\"ids\"][0])\n", | |
| " ch_post[key] = value\n", | |
| "\n", | |
| "with open(f\"./models/{target_folder}/tokenizer.json\", \"w\") as f:\n", | |
| " json.dump(tok, f, indent=4, ensure_ascii=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "QlWeJHvzLfKH" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Vocab\n", | |
| "with open(f\"./models/{target_folder}/vocab.json\", \"w\") as f:\n", | |
| " json.dump(ch_vocab, f, indent=4, ensure_ascii=True)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "GGukXC4NLfKH" | |
| }, | |
| "source": [ | |
| "Merges file" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "y2Gr3szzLfKI" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "with open(\"./models/tokenizer/merges.txt\", \"r\") as f:\n", | |
| " merges = f.read().split(\"\\n\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "fvDAPljJLfKI" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "not_found = []\n", | |
| "not_found_merged_tokens = []\n", | |
| "found = []\n", | |
| "\n", | |
| "for merge in merges[1:-1]:\n", | |
| " m = merge.split()\n", | |
| " if (m[0] not in ch_vocab.keys() or m[1] not in ch_vocab.keys() or m[0] in not_found_merged_tokens or m[1] in not_found_merged_tokens) and (m[0] + m[1] not in ch_vocab.keys()):\n", | |
| " not_found.append(merge)\n", | |
| " not_found_merged_tokens.append(m[0] + m[1])\n", | |
| " else:\n", | |
| " found.append(merge)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "1H5HdL2CLfKI", | |
| "outputId": "e23e4756-d05c-47af-dd70-9f73c9afb3d2" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "49583" | |
| ] | |
| }, | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(found)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "LX6cCI-WLfKI" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "with open(f\"./models/{target_folder}/merges.txt\", \"w\") as f:\n", | |
| " f.write(\"\\n\".join(found))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "dpM8x4rtLfKI" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Load changed tokenizer from folder\n", | |
| "changed_tok = WhisperTokenizer.from_pretrained(f\"./models/{target_folder}/\", local_files_only=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "0vyoXcWRLfKJ", | |
| "outputId": "e55af0e0-0a48-4d5f-de4b-b4f93e6b3c97" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[49827, 49933, 8458, 34531, 210, None, None, None, None, None, 49826]" | |
| ] | |
| }, | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "changed_tok.encode(\"Текст 12345\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "7cBTn84zLfKJ" | |
| }, | |
| "source": [ | |
| "# Try new model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "_lSobjYKLfKJ" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# We need to copy new tokenizer files\n", | |
| "# normalizer file and preprocessor config from original model\n", | |
| "!cp ./models/{target_folder}/* ./models/whisper-{size}-{new_name}/\n", | |
| "!cp ./models/whisper-{size}/normalizer.json ./models/whisper-{size}-{new_name}/\n", | |
| "!cp ./models/whisper-{size}/preprocessor_config.json ./models/whisper-{size}-{new_name}/" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "referenced_widgets": [ | |
| "57f391a5689f452bad263e42a8622adf" | |
| ] | |
| }, | |
| "id": "zt1X-QKWLfKJ", | |
| "outputId": "3d78b56b-e24f-4d5d-805f-0dc1d68c37db" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "57f391a5689f452bad263e42a8622adf", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Load new model, processor and tokenizer from folder\n", | |
| "\n", | |
| "tokenizer = WhisperTokenizer.from_pretrained(f\"./models/whisper-{size}-{new_name}/\", local_files_only=True)\n", | |
| "model = WhisperForConditionalGeneration.from_pretrained(f\"./models/whisper-{size}-{new_name}\", local_files_only=True)\n", | |
| "preprocessor = WhisperProcessor.from_pretrained(f\"./models/whisper-{size}-{new_name}\", local_files_only=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "jpz3nInpLfKK" | |
| }, | |
| "source": [ | |
| "Check if all works on some test file" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "KUIexZMiLfKK" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torchaudio" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "UyQIo_2PLfKK" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "wav, sr = torchaudio.load(\"test.wav\")\n", | |
| "\n", | |
| "if sr != 16000:\n", | |
| " wav = torchaudio.functional.resample(wav, sr, 16000)\n", | |
| "\n", | |
| "processed = preprocessor(wav[0], sampling_rate=16000, return_tensors=\"pt\")\n", | |
| "\n", | |
| "predicted_ids = model.generate(processed.input_features, language=\"ru\", task=\"transcribe\")\n", | |
| "\n", | |
| "transcriptions = preprocessor.batch_decode(predicted_ids, skip_special_tokens=False)\n", | |
| "\n", | |
| "print(transcriptions)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "4YTber9mLfKK" | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "torch-env", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.12" | |
| }, | |
| "colab": { | |
| "provenance": [], | |
| "include_colab_link": true | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment