Created
August 11, 2025 09:19
-
-
Save vvolhejn/262be38655cbd8e5d21a30b5f19f0e63 to your computer and use it in GitHub Desktop.
Kyutai TTS connection via React hook
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // For context, see: https://kyutai.org/next/tts | |
| // and https://github.com/kyutai-labs/delayed-streams-modeling/issues/111 | |
| import { useEffect, useState, useRef, useCallback } from "react"; | |
| import useWebSocket, { ReadyState } from "react-use-websocket"; | |
| import { encode, decode, decodeAsync } from "@msgpack/msgpack"; | |
| import ttsVoices from "@/assets/tts_voices.json"; | |
| const SAMPLE_RATE = 24000; | |
| export function encodeWav( | |
| float32Array: Float32Array, | |
| sampleRate: number | |
| ): Blob { | |
| // Convert Float32 to 16-bit PCM | |
| const buffer = new ArrayBuffer(44 + float32Array.length * 2); | |
| const view = new DataView(buffer); | |
| // RIFF identifier 'RIFF' | |
| view.setUint32(0, 0x52494646, false); | |
| // file length minus RIFF identifier length and file description length | |
| view.setUint32(4, 36 + float32Array.length * 2, true); | |
| // RIFF type 'WAVE' | |
| view.setUint32(8, 0x57415645, false); | |
| // format chunk identifier 'fmt ' | |
| view.setUint32(12, 0x666d7420, false); | |
| // format chunk length | |
| view.setUint32(16, 16, true); | |
| // sample format (raw) | |
| view.setUint16(20, 1, true); | |
| // channel count | |
| view.setUint16(22, 1, true); | |
| // sample rate | |
| view.setUint32(24, sampleRate, true); | |
| // byte rate (sample rate * block align) | |
| view.setUint32(28, sampleRate * 2, true); | |
| // block align (channel count * bytes per sample) | |
| view.setUint16(32, 2, true); | |
| // bits per sample | |
| view.setUint16(34, 16, true); | |
| // data chunk identifier 'data' | |
| view.setUint32(36, 0x64617461, false); | |
| // data chunk length | |
| view.setUint32(40, float32Array.length * 2, true); | |
| // Write PCM samples | |
| for (let i = 0; i < float32Array.length; i++) { | |
| const s = Math.max(-1, Math.min(1, float32Array[i])); | |
| view.setInt16(44 + i * 2, s < 0 ? s * 0x8000 : s * 0x7fff, true); | |
| } | |
| return new Blob([buffer], { type: "audio/wav" }); | |
| } | |
| export type TTSTextMessage = { | |
| type: "Text"; | |
| text: string; | |
| start_s: number; | |
| stop_s: number; | |
| }; | |
| export type TTSAudioMessage = { | |
| type: "Audio"; | |
| pcm: number[]; | |
| }; | |
| export type TTSErrorMessage = { | |
| type: "Error"; | |
| message: string; | |
| }; | |
| export type TTSReadyMessage = { | |
| type: "Ready"; | |
| }; | |
| export type TTSMessage = | |
| | TTSTextMessage | |
| | TTSAudioMessage | |
| | TTSErrorMessage | |
| | TTSReadyMessage; | |
| export interface Voice { | |
| path_on_server: string; | |
| name?: string; | |
| default_text?: string; | |
| priority?: number; // Optional priority for sorting, higher first | |
| } | |
| export interface TtsStreamingQuery { | |
| seed?: number; | |
| temperature?: number; | |
| top_k?: number; | |
| format?: string; // default: "PcmMessagePack" | |
| voice?: string; | |
| voices?: string[]; | |
| max_seq_len?: number; | |
| cfg_alpha?: number; | |
| auth_id?: string; | |
| // This is ignored by the server, but we use it to force a new connection | |
| request_id?: number; | |
| } | |
| export function ttsStreamingQueryToUrlParams(query: TtsStreamingQuery): string { | |
| const urlEscape = encodeURIComponent; | |
| const params = Object.entries(query) | |
| .filter((entry) => entry[1] !== undefined && entry[1] !== null) | |
| .map(([key, value]) => { | |
| if (Array.isArray(value)) { | |
| // Join array values as comma-separated | |
| return `${key}=${urlEscape(value.join(","))}`; | |
| } | |
| return `${key}=${urlEscape(value as string | number)}`; | |
| }) | |
| .join("&"); | |
| return params ? `?${params}` : ""; | |
| } | |
| async function decodeFromBlob(blob: Blob): Promise<unknown> { | |
| if (blob.stream) { | |
| return await decodeAsync(blob.stream()); | |
| } else { | |
| // if stream() is not available | |
| return decode(await blob.arrayBuffer()); | |
| } | |
| } | |
| export function useTts() { | |
| let voices = ttsVoices as Voice[]; | |
| // Set default priority for voice based on the dataset (first component of path_on_server) | |
| const priorityMap: Record<string, number> = { | |
| expresso: 10, | |
| "cml-tts": -1, | |
| vctk: -5, | |
| "unmute-prod-website": -7, | |
| ears: -10, | |
| }; | |
| voices = voices.map((voice) => { | |
| const dataset = voice.path_on_server.split("/")[0]; | |
| return { | |
| ...voice, | |
| priority: (voice.priority || 0) + (priorityMap[dataset] || 0), // Default to 0 if not found | |
| }; | |
| }); | |
| // sort voices so that the ones with the name come first | |
| voices.sort((a, b) => { | |
| // Sort by priority descending (higher priority first) | |
| const priorityA = a.priority ?? 0; | |
| const priorityB = b.priority ?? 0; | |
| if (priorityA !== priorityB) { | |
| return priorityB - priorityA; | |
| } | |
| if (a.name && !b.name) return -1; | |
| if (!a.name && b.name) return 1; | |
| // If both have names, sort alphabetically | |
| if (a.name && b.name) return a.name.localeCompare(b.name); | |
| return 0; | |
| }); | |
| const [voice, setVoice] = useState<Voice>(voices[0]); | |
| const [requestId, setRequestId] = useState(0); | |
| const [currentTime, setCurrentTime] = useState(0); | |
| const [transcriptWithTimestamps, setTranscriptWithTimestamps] = useState< | |
| TTSTextMessage[] | |
| >([]); | |
| const query: TtsStreamingQuery = { | |
| voice: voice.path_on_server, | |
| cfg_alpha: 1.5, | |
| format: "PcmMessagePack", | |
| auth_id: "public_token", | |
| request_id: requestId, | |
| }; | |
| // replace with your server URL | |
| const baseUrl = "ws://localhost:8089/api/tts_streaming"; | |
| const [text, setText] = useState<string | null>(null); | |
| const [shouldConnect, setShouldConnect] = useState(false); | |
| const { sendMessage, readyState, lastMessage } = useWebSocket( | |
| baseUrl + ttsStreamingQueryToUrlParams(query), | |
| {}, | |
| shouldConnect | |
| ); | |
| const pcmChunksRef = useRef<Float32Array[]>([]); | |
| // We need to keep track of the duration in a state variable so that | |
| // the timer that disables isMidPlayback is up to date | |
| const [totalDuration, setTotalDuration] = useState(0); | |
| const audioContextRef = useRef<AudioContext | null>(null); | |
| // Are we paused using the play/pause button? | |
| const [isPaused, setIsPaused] = useState(false); | |
| // Are we in the middle of playback? Also true if we are paused, but false when playback is done | |
| const [isMidPlayback, setIsMidPlayback] = useState(false); | |
| // Send text to generate when connection opens | |
| useEffect(() => { | |
| if (readyState === ReadyState.OPEN && text) { | |
| sendMessage( | |
| encode({ | |
| type: "Text", | |
| text, | |
| }) | |
| ); | |
| sendMessage(encode({ type: "Eos" })); | |
| } | |
| }, [readyState, sendMessage, text]); | |
| // Sync AudioContext state with isPaused | |
| useEffect(() => { | |
| const audioCtx = audioContextRef.current; | |
| if (!audioCtx) return; | |
| if (!isPaused && audioCtx.state === "suspended") { | |
| audioCtx.resume(); | |
| } else if (isPaused && audioCtx.state === "running") { | |
| audioCtx.suspend(); | |
| } | |
| }, [isPaused]); | |
| const ttsPlay = (textToGenerate: string, voiceToGenerate: Voice) => { | |
| if (textToGenerate === text && voiceToGenerate === voice) { | |
| if (isPaused) { | |
| setIsPaused(false); | |
| } else if (!isMidPlayback) { | |
| for (let i = 0; i < pcmChunksRef.current.length; i++) { | |
| queueUpChunk(i); | |
| } | |
| } else { | |
| console.error("Got into a weird state"); | |
| } | |
| } else { | |
| setShouldConnect(false); | |
| setVoice(voiceToGenerate); | |
| pcmChunksRef.current = []; // Reset PCM buffer | |
| setTranscriptWithTimestamps([]); // Reset text messages | |
| setTotalDuration(0); | |
| audioContextRef.current?.close(); | |
| audioContextRef.current = null; | |
| setTimeout(() => { | |
| setRequestId((id) => id + 1); // force new connection | |
| setText(textToGenerate); | |
| setShouldConnect(true); | |
| setIsPaused(false); | |
| }, 0); | |
| } | |
| }; | |
| const ttsPause = () => { | |
| setIsPaused(true); | |
| }; | |
| // Clear all state related to the current connection | |
| const clearState = useCallback(() => { | |
| pcmChunksRef.current = []; | |
| setTranscriptWithTimestamps([]); | |
| setTotalDuration(0); | |
| audioContextRef.current?.close(); | |
| audioContextRef.current = null; | |
| setCurrentTime(0); | |
| setIsMidPlayback(false); | |
| setIsPaused(true); | |
| setShouldConnect(false); | |
| }, []); | |
| // When a new chunk is queued, update totalDuration and playbackStartTime | |
| const queueUpChunk = useCallback( | |
| (i: number) => { | |
| if (i === 0) { | |
| // Only create the AudioContext when we start playing so that the start time is 0 | |
| audioContextRef.current = new window.AudioContext(); | |
| setIsMidPlayback(true); | |
| } | |
| if (!pcmChunksRef.current) return; | |
| const audioCtx = audioContextRef.current; | |
| if (!audioCtx) return; | |
| const buffer = audioCtx.createBuffer( | |
| 1, | |
| pcmChunksRef.current[i].length, | |
| SAMPLE_RATE | |
| ); | |
| buffer.getChannelData(0).set(pcmChunksRef.current[i]); | |
| const source = audioCtx.createBufferSource(); | |
| source.buffer = buffer; | |
| source.connect(audioCtx.destination); | |
| const chunkDuration = buffer.duration; | |
| const startAt = i * chunkDuration; | |
| source.start(startAt); | |
| }, | |
| [audioContextRef] | |
| ); | |
| // Handle incoming messages | |
| useEffect(() => { | |
| if (lastMessage === null) return; | |
| if (!shouldConnect) return; // Don't process messages from closed connections | |
| // We need async for decodeFromBlob | |
| const handleMessage = async () => { | |
| if (!(lastMessage.data instanceof Blob)) { | |
| console.error("Expected Blob data, but received:", lastMessage.data); | |
| return; | |
| } | |
| const data = (await decodeFromBlob(lastMessage.data)) as TTSMessage; | |
| switch (data.type) { | |
| case "Text": | |
| setTranscriptWithTimestamps((prev) => [ | |
| ...prev, | |
| data as TTSTextMessage, | |
| ]); | |
| break; | |
| case "Audio": | |
| // Collect PCM data | |
| if (data.pcm && data.pcm.length > 0) { | |
| // Convert to Float32Array if not already | |
| const chunk = new Float32Array(data.pcm); | |
| pcmChunksRef.current.push(chunk); | |
| setTotalDuration((prev) => prev + chunk.length / SAMPLE_RATE); | |
| queueUpChunk(pcmChunksRef.current.length - 1); | |
| } | |
| break; | |
| case "Error": | |
| console.error("Received error message:", data.message); | |
| break; | |
| case "Ready": | |
| break; | |
| default: | |
| console.warn("Unknown message type:", data); | |
| } | |
| }; | |
| handleMessage(); | |
| }, [lastMessage, queueUpChunk, shouldConnect]); | |
| // Update isMidPlayback | |
| useEffect(() => { | |
| // Not currently playing -> no timer needed | |
| if (!isMidPlayback || isPaused || !audioContextRef.current) { | |
| return; | |
| } | |
| if (totalDuration < 1) { | |
| // Hack: prevent race condition where the audio gets immediately paused | |
| return; | |
| } | |
| const audioCtx = audioContextRef.current; | |
| const now = audioCtx.currentTime; | |
| const endTime = pcmChunksRef.current.reduce( | |
| (acc, chunk) => acc + chunk.length / SAMPLE_RATE, | |
| 0 | |
| ); | |
| const remaining = Math.max(0, endTime - now); | |
| if (remaining === 0) { | |
| setIsMidPlayback(false); | |
| return; | |
| } | |
| const timeout = setTimeout(() => { | |
| setIsMidPlayback(false); | |
| audioContextRef.current?.close(); | |
| audioContextRef.current = null; | |
| }, remaining * 1000); | |
| return () => clearTimeout(timeout); | |
| }, [isMidPlayback, totalDuration, isPaused]); | |
| // Track currentTime while playing | |
| useEffect(() => { | |
| let rafId: number; | |
| function updateTime() { | |
| const audioCtx = audioContextRef.current; | |
| if (audioCtx && isMidPlayback && !isPaused) { | |
| setCurrentTime(audioCtx.currentTime); | |
| rafId = requestAnimationFrame(updateTime); | |
| } | |
| } | |
| if (isMidPlayback && !isPaused) { | |
| rafId = requestAnimationFrame(updateTime); | |
| } | |
| // Do not reset currentTime on pause; only reset when playback ends | |
| if (!isMidPlayback) { | |
| setCurrentTime(0); | |
| } | |
| return () => { | |
| if (rafId) cancelAnimationFrame(rafId); | |
| }; | |
| }, [isMidPlayback, isPaused]); | |
| const getPcmData = useCallback(() => { | |
| if (!pcmChunksRef.current.length) return null; | |
| const totalLength = pcmChunksRef.current.reduce( | |
| (acc, arr) => acc + arr.length, | |
| 0 | |
| ); | |
| const result = new Float32Array(totalLength); | |
| let offset = 0; | |
| for (const chunk of pcmChunksRef.current) { | |
| result.set(chunk, offset); | |
| offset += chunk.length; | |
| } | |
| return result; | |
| }, []); | |
| const getAudioBlob = useCallback(() => { | |
| const pcm = getPcmData(); | |
| if (!pcm) return null; | |
| return encodeWav(pcm, SAMPLE_RATE); | |
| }, [getPcmData]); | |
| return { | |
| readyState, | |
| ttsPlay, | |
| ttsPause, | |
| isPaused, | |
| isMidPlayback, | |
| voices, | |
| currentTime, | |
| transcriptWithTimestamps, | |
| clearState, | |
| getAudioBlob, | |
| shouldConnect, | |
| }; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment