Created
November 13, 2024 14:40
-
-
Save tanukon/312cb492f8fe1eb84922056884b0e98f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import json\n", | |
| "import os\n", | |
| "import pandas as pd" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "main_dir = 'blog_data'\n", | |
| "audio_list = [os.path.join(main_dir, f) for f in os.listdir(main_dir)]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "MoviePy - Writing audio in blog_data/conan.mp3\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| " \r" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "MoviePy - Done.\n", | |
| "MoviePy - Writing audio in blog_data/snl_kylo_ren.mp3\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| " " | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "MoviePy - Done.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "\r" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import moviepy.editor as mp\n", | |
| "\n", | |
| "# convert mp4 file into mp3\n", | |
| "for audio_path in audio_list:\n", | |
| " audio_result_path = os.path.join(main_dir, audio_path.split('/')[-1].split('.')[0]+'.mp3')\n", | |
| "\n", | |
| " clip = mp.VideoFileClip(audio_path)\n", | |
| " clip.audio.write_audiofile(audio_result_path)\n", | |
| " " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "result_dir = './results'\n", | |
| "os.makedirs(result_dir, exist_ok=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "audio_path = 'blog_data/conan.mp3'" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Vanilla Whisper" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import whisper\n", | |
| "\n", | |
| "model = whisper.load_model(\"turbo\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "55.2 s ± 3.99 s per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "\n", | |
| "result = model.transcribe(audio_path)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "result = model.transcribe(audio_path)\n", | |
| "\n", | |
| "with open(os.path.join(result_dir, 'whisper_turbo.json'), 'w') as f:\n", | |
| " json.dump(result, f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "result_dict = {'start_time': [], 'end_time': [], 'transcription': []}\n", | |
| "segments = result[\"segments\"]\n", | |
| "\n", | |
| "for segment in segments:\n", | |
| " result_dict['start_time'].append(segment[\"start\"])\n", | |
| " result_dict['end_time'].append(segment[\"end\"])\n", | |
| " result_dict['transcription'].append(segment[\"text\"])\n", | |
| " \n", | |
| "result_df = pd.DataFrame.from_dict(result_dict)\n", | |
| "result_df.to_csv(os.path.join(result_dir, 'vanilla-whisper.csv'), index=None)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "For your information, Whipser large-v3 computation time is following." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████████████████████████████████| 2.88G/2.88G [00:21<00:00, 141MiB/s]\n", | |
| "/opt/conda/envs/audioenv/lib/python3.10/site-packages/whisper/__init__.py:150: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", | |
| " checkpoint = torch.load(fp, map_location=device)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import whisper\n", | |
| "\n", | |
| "model = whisper.load_model(\"large-v3\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4min 30s ± 49 s per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "\n", | |
| "result = model.transcribe(audio_path)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Faster Whisper" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/opt/conda/envs/audioenv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
| " from .autonotebook import tqdm as notebook_tqdm\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "from faster_whisper import WhisperModel\n", | |
| "\n", | |
| "model_size = \"deepdml/faster-whisper-large-v3-turbo-ct2\"\n", | |
| "\n", | |
| "# Run on GPU with FP16\n", | |
| "model = WhisperModel(model_size_or_path=model_size, device=\"cuda\", compute_type=\"float16\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "3.27 s ± 28.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "\n", | |
| "segments, info = model.transcribe(audio_path, beam_size=5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "segments, info = model.transcribe(audio_path, beam_size=5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>start_time</th>\n", | |
| " <th>end_time</th>\n", | |
| " <th>transcription</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>0.0</td>\n", | |
| " <td>1.58</td>\n", | |
| " <td>Is this your holding here, sir?</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1.7</td>\n", | |
| " <td>4.06</td>\n", | |
| " <td>It's a marine toad, Conan.</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " start_time end_time transcription\n", | |
| "0 0.0 1.58 Is this your holding here, sir?\n", | |
| "1 1.7 4.06 It's a marine toad, Conan." | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_dict = {'start_time': [], 'end_time': [], 'transcription': []}\n", | |
| "\n", | |
| "for segment in segments:\n", | |
| " result_dict['start_time'].append(segment.start)\n", | |
| " result_dict['end_time'].append(segment.end)\n", | |
| " result_dict['transcription'].append(segment.text)\n", | |
| " \n", | |
| "result_df = pd.DataFrame.from_dict(result_dict)\n", | |
| "result_df.to_csv(os.path.join(result_dir, 'faster-whisper.csv'), index=None)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Whisper-X" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [disable_jit_profiling, allow_tf32]\n", | |
| "INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []\n", | |
| "WARNING:py.warnings:/opt/conda/envs/audioenv/lib/python3.10/site-packages/pyannote/audio/pipelines/speaker_verification.py:45: UserWarning: Module 'speechbrain.pretrained' was deprecated, redirecting to 'speechbrain.inference'. Please update your script. This is a change from SpeechBrain 1.0. See: https://github.com/speechbrain/speechbrain/releases/tag/v1.0.0\n", | |
| " from speechbrain.pretrained import (\n", | |
| "\n", | |
| "WARNING:py.warnings:/opt/conda/envs/audioenv/lib/python3.10/site-packages/pyannote/audio/pipelines/speaker_verification.py:53: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.\n", | |
| " torchaudio.set_audio_backend(backend)\n", | |
| "\n", | |
| "WARNING:py.warnings:/opt/conda/envs/audioenv/lib/python3.10/site-packages/pyannote/audio/tasks/segmentation/mixins.py:37: UserWarning: `torchaudio.backend.common.AudioMetaData` has been moved to `torchaudio.AudioMetaData`. Please update the import path.\n", | |
| " from torchaudio.backend.common import AudioMetaData\n", | |
| "\n", | |
| "Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../.cache/torch/whisperx-vad-segmentation.bin`\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "No language specified, language will be first be detected for each audio file (increases inference time).\n", | |
| "Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.\n", | |
| "Model was trained with torch 1.10.0+cu102, yours is 2.5.1. Bad things might happen unless you revert torch to 1.x.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import whisperx\n", | |
| "\n", | |
| "model_size = \"deepdml/faster-whisper-large-v3-turbo-ct2\"\n", | |
| "\n", | |
| "# Transcribe with original whisper (batched)\n", | |
| "model = whisperx.load_model(model_size, 'cuda', compute_type=\"float16\")\n", | |
| "model_a, metadata = whisperx.load_align_model(language_code='en', device='cuda')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "WARNING:py.warnings:/opt/conda/envs/audioenv/lib/python3.10/site-packages/pyannote/audio/utils/reproducibility.py:74: ReproducibilityWarning: TensorFloat-32 (TF32) has been disabled as it might lead to reproducibility issues and lower accuracy.\n", | |
| "It can be re-enabled by calling\n", | |
| " >>> import torch\n", | |
| " >>> torch.backends.cuda.matmul.allow_tf32 = True\n", | |
| " >>> torch.backends.cudnn.allow_tf32 = True\n", | |
| "See https://github.com/pyannote/pyannote-audio/issues/1370 for more details.\n", | |
| "\n", | |
| " warnings.warn(\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Detected language: en (1.00) in first 30s of audio...\n", | |
| "Detected language: en (1.00) in first 30s of audio...\n", | |
| "Detected language: en (1.00) in first 30s of audio...\n", | |
| "Detected language: en (1.00) in first 30s of audio...\n", | |
| "Detected language: en (1.00) in first 30s of audio...\n", | |
| "Detected language: en (1.00) in first 30s of audio...\n", | |
| "Detected language: en (1.00) in first 30s of audio...\n", | |
| "Detected language: en (1.00) in first 30s of audio...\n", | |
| "29.9 s ± 69.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "\n", | |
| "audio = whisperx.load_audio(audio_path)\n", | |
| "whisper_result = model.transcribe(audio, batch_size=16)\n", | |
| "result = whisperx.align(whisper_result[\"segments\"], model_a, metadata, audio, 'cuda', return_char_alignments=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "WARNING:py.warnings:/opt/conda/envs/audioenv/lib/python3.10/site-packages/pyannote/audio/utils/reproducibility.py:74: ReproducibilityWarning: TensorFloat-32 (TF32) has been disabled as it might lead to reproducibility issues and lower accuracy.\n", | |
| "It can be re-enabled by calling\n", | |
| " >>> import torch\n", | |
| " >>> torch.backends.cuda.matmul.allow_tf32 = True\n", | |
| " >>> torch.backends.cudnn.allow_tf32 = True\n", | |
| "See https://github.com/pyannote/pyannote-audio/issues/1370 for more details.\n", | |
| "\n", | |
| " warnings.warn(\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Detected language: en (1.00) in first 30s of audio...\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "audio = whisperx.load_audio(audio_path)\n", | |
| "whisper_result = model.transcribe(audio, batch_size=16)\n", | |
| "result = whisperx.align(whisper_result[\"segments\"], model_a, metadata, audio, 'cuda', return_char_alignments=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "result_dict = {'start_time': [], 'end_time': [], 'transcription': []}\n", | |
| "segments = result[\"segments\"]\n", | |
| "\n", | |
| "for segment in segments:\n", | |
| " result_dict['start_time'].append(segment[\"start\"])\n", | |
| " result_dict['end_time'].append(segment[\"end\"])\n", | |
| " result_dict['transcription'].append(segment[\"text\"])\n", | |
| " \n", | |
| "result_df = pd.DataFrame.from_dict(result_dict)\n", | |
| "result_df.to_csv(os.path.join(result_dir, 'WhisperX.csv'), index=None)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Distil-Whisper" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline\n", | |
| "\n", | |
| "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", | |
| "torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n", | |
| "\n", | |
| "model_id = \"distil-whisper/distil-large-v3\"\n", | |
| "\n", | |
| "model = AutoModelForSpeechSeq2Seq.from_pretrained(\n", | |
| " model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True\n", | |
| ")\n", | |
| "model.to(device)\n", | |
| "\n", | |
| "processor = AutoProcessor.from_pretrained(model_id)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "pipe = pipeline(\n", | |
| " \"automatic-speech-recognition\",\n", | |
| " model=model,\n", | |
| " tokenizer=processor.tokenizer,\n", | |
| " feature_extractor=processor.feature_extractor,\n", | |
| " max_new_tokens=128,\n", | |
| " torch_dtype=torch_dtype,\n", | |
| " device=device,\n", | |
| " return_timestamps=True\n", | |
| ")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, None], [2, 50360]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "22.7 s ± 27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "\n", | |
| "result = pipe(audio_path)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "result = pipe(audio_path)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "result_dict = {'start_time': [], 'end_time': [], 'transcription': []}\n", | |
| "segments = result['chunks']\n", | |
| "\n", | |
| "for segment in segments:\n", | |
| " result_dict['start_time'].append(segment[\"timestamp\"][0])\n", | |
| " result_dict['end_time'].append(segment[\"timestamp\"][1])\n", | |
| " result_dict['transcription'].append(segment[\"text\"])\n", | |
| " \n", | |
| "result_df = pd.DataFrame.from_dict(result_dict)\n", | |
| "result_df.to_csv(os.path.join(result_dir, 'Distil-Whisper.csv'), index=None)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Whisper-Medusa" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "※ You need to change your conda environment." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/opt/conda/envs/medusa/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
| " from .autonotebook import tqdm as notebook_tqdm\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import torch\n", | |
| "import torchaudio\n", | |
| "\n", | |
| "from whisper_medusa import WhisperMedusaModel\n", | |
| "from transformers import WhisperProcessor" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/opt/conda/envs/medusa/lib/python3.11/site-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", | |
| " warnings.warn(\n", | |
| "Loading checkpoint shards: 100%|██████████| 2/2 [00:23<00:00, 11.62s/it]\n", | |
| "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model_name = \"aiola/whisper-medusa-linear-libri\"\n", | |
| "model = WhisperMedusaModel.from_pretrained(model_name)\n", | |
| "processor = WhisperProcessor.from_pretrained(model_name)\n", | |
| "\n", | |
| "model = model.to('cuda')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "SAMPLING_RATE = 16000\n", | |
| "language = \"en\"\n", | |
| "regulation_factor=1.01\n", | |
| "regulation_start=140\n", | |
| "device = 'cuda'" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([2, 36625491])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "input_speech, sr = torchaudio.load(audio_path)\n", | |
| "print(input_speech.shape)\n", | |
| "if input_speech.shape[0] > 1: # If stereo, average the channels\n", | |
| " input_speech = input_speech.mean(dim=0, keepdim=True)\n", | |
| "\n", | |
| "if sr != SAMPLING_RATE:\n", | |
| " input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)\n", | |
| "\n", | |
| "exponential_decay_length_penalty = (regulation_start, regulation_factor)\n", | |
| "\n", | |
| "input_features = processor(input_speech.squeeze(), return_tensors=\"pt\", sampling_rate=SAMPLING_RATE).input_features\n", | |
| "input_features = input_features.to(device)\n", | |
| "\n", | |
| "model_output = model.generate(\n", | |
| " input_features,\n", | |
| " language=language,\n", | |
| " exponential_decay_length_penalty=exponential_decay_length_penalty,\n", | |
| ")\n", | |
| "predict_ids = model_output[0]\n", | |
| "pred = processor.decode(predict_ids, skip_special_tokens=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'Is this your holding,, sir\"? It\\'s a marine towed, conan. Really, you should probably be holding onto it, and get used to this. What, w, o,, o, o, o, o, o, o, o, o. o, o, o, o, o, o, o, o. o, o, o, o, o, o, o, o. o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, o,'" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pred" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "audioenv", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.10" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment