Skip to content

Instantly share code, notes, and snippets.

@vvolhejn
Created August 11, 2025 09:19
Show Gist options
  • Select an option

  • Save vvolhejn/262be38655cbd8e5d21a30b5f19f0e63 to your computer and use it in GitHub Desktop.

Select an option

Save vvolhejn/262be38655cbd8e5d21a30b5f19f0e63 to your computer and use it in GitHub Desktop.
Kyutai TTS connection via React hook
// For context, see: https://kyutai.org/next/tts
// and https://github.com/kyutai-labs/delayed-streams-modeling/issues/111
import { useEffect, useState, useRef, useCallback } from "react";
import useWebSocket, { ReadyState } from "react-use-websocket";
import { encode, decode, decodeAsync } from "@msgpack/msgpack";
import ttsVoices from "@/assets/tts_voices.json";
const SAMPLE_RATE = 24000;
export function encodeWav(
float32Array: Float32Array,
sampleRate: number
): Blob {
// Convert Float32 to 16-bit PCM
const buffer = new ArrayBuffer(44 + float32Array.length * 2);
const view = new DataView(buffer);
// RIFF identifier 'RIFF'
view.setUint32(0, 0x52494646, false);
// file length minus RIFF identifier length and file description length
view.setUint32(4, 36 + float32Array.length * 2, true);
// RIFF type 'WAVE'
view.setUint32(8, 0x57415645, false);
// format chunk identifier 'fmt '
view.setUint32(12, 0x666d7420, false);
// format chunk length
view.setUint32(16, 16, true);
// sample format (raw)
view.setUint16(20, 1, true);
// channel count
view.setUint16(22, 1, true);
// sample rate
view.setUint32(24, sampleRate, true);
// byte rate (sample rate * block align)
view.setUint32(28, sampleRate * 2, true);
// block align (channel count * bytes per sample)
view.setUint16(32, 2, true);
// bits per sample
view.setUint16(34, 16, true);
// data chunk identifier 'data'
view.setUint32(36, 0x64617461, false);
// data chunk length
view.setUint32(40, float32Array.length * 2, true);
// Write PCM samples
for (let i = 0; i < float32Array.length; i++) {
const s = Math.max(-1, Math.min(1, float32Array[i]));
view.setInt16(44 + i * 2, s < 0 ? s * 0x8000 : s * 0x7fff, true);
}
return new Blob([buffer], { type: "audio/wav" });
}
export type TTSTextMessage = {
type: "Text";
text: string;
start_s: number;
stop_s: number;
};
export type TTSAudioMessage = {
type: "Audio";
pcm: number[];
};
export type TTSErrorMessage = {
type: "Error";
message: string;
};
export type TTSReadyMessage = {
type: "Ready";
};
export type TTSMessage =
| TTSTextMessage
| TTSAudioMessage
| TTSErrorMessage
| TTSReadyMessage;
export interface Voice {
path_on_server: string;
name?: string;
default_text?: string;
priority?: number; // Optional priority for sorting, higher first
}
export interface TtsStreamingQuery {
seed?: number;
temperature?: number;
top_k?: number;
format?: string; // default: "PcmMessagePack"
voice?: string;
voices?: string[];
max_seq_len?: number;
cfg_alpha?: number;
auth_id?: string;
// This is ignored by the server, but we use it to force a new connection
request_id?: number;
}
export function ttsStreamingQueryToUrlParams(query: TtsStreamingQuery): string {
const urlEscape = encodeURIComponent;
const params = Object.entries(query)
.filter((entry) => entry[1] !== undefined && entry[1] !== null)
.map(([key, value]) => {
if (Array.isArray(value)) {
// Join array values as comma-separated
return `${key}=${urlEscape(value.join(","))}`;
}
return `${key}=${urlEscape(value as string | number)}`;
})
.join("&");
return params ? `?${params}` : "";
}
async function decodeFromBlob(blob: Blob): Promise<unknown> {
if (blob.stream) {
return await decodeAsync(blob.stream());
} else {
// if stream() is not available
return decode(await blob.arrayBuffer());
}
}
export function useTts() {
let voices = ttsVoices as Voice[];
// Set default priority for voice based on the dataset (first component of path_on_server)
const priorityMap: Record<string, number> = {
expresso: 10,
"cml-tts": -1,
vctk: -5,
"unmute-prod-website": -7,
ears: -10,
};
voices = voices.map((voice) => {
const dataset = voice.path_on_server.split("/")[0];
return {
...voice,
priority: (voice.priority || 0) + (priorityMap[dataset] || 0), // Default to 0 if not found
};
});
// sort voices so that the ones with the name come first
voices.sort((a, b) => {
// Sort by priority descending (higher priority first)
const priorityA = a.priority ?? 0;
const priorityB = b.priority ?? 0;
if (priorityA !== priorityB) {
return priorityB - priorityA;
}
if (a.name && !b.name) return -1;
if (!a.name && b.name) return 1;
// If both have names, sort alphabetically
if (a.name && b.name) return a.name.localeCompare(b.name);
return 0;
});
const [voice, setVoice] = useState<Voice>(voices[0]);
const [requestId, setRequestId] = useState(0);
const [currentTime, setCurrentTime] = useState(0);
const [transcriptWithTimestamps, setTranscriptWithTimestamps] = useState<
TTSTextMessage[]
>([]);
const query: TtsStreamingQuery = {
voice: voice.path_on_server,
cfg_alpha: 1.5,
format: "PcmMessagePack",
auth_id: "public_token",
request_id: requestId,
};
// replace with your server URL
const baseUrl = "ws://localhost:8089/api/tts_streaming";
const [text, setText] = useState<string | null>(null);
const [shouldConnect, setShouldConnect] = useState(false);
const { sendMessage, readyState, lastMessage } = useWebSocket(
baseUrl + ttsStreamingQueryToUrlParams(query),
{},
shouldConnect
);
const pcmChunksRef = useRef<Float32Array[]>([]);
// We need to keep track of the duration in a state variable so that
// the timer that disables isMidPlayback is up to date
const [totalDuration, setTotalDuration] = useState(0);
const audioContextRef = useRef<AudioContext | null>(null);
// Are we paused using the play/pause button?
const [isPaused, setIsPaused] = useState(false);
// Are we in the middle of playback? Also true if we are paused, but false when playback is done
const [isMidPlayback, setIsMidPlayback] = useState(false);
// Send text to generate when connection opens
useEffect(() => {
if (readyState === ReadyState.OPEN && text) {
sendMessage(
encode({
type: "Text",
text,
})
);
sendMessage(encode({ type: "Eos" }));
}
}, [readyState, sendMessage, text]);
// Sync AudioContext state with isPaused
useEffect(() => {
const audioCtx = audioContextRef.current;
if (!audioCtx) return;
if (!isPaused && audioCtx.state === "suspended") {
audioCtx.resume();
} else if (isPaused && audioCtx.state === "running") {
audioCtx.suspend();
}
}, [isPaused]);
const ttsPlay = (textToGenerate: string, voiceToGenerate: Voice) => {
if (textToGenerate === text && voiceToGenerate === voice) {
if (isPaused) {
setIsPaused(false);
} else if (!isMidPlayback) {
for (let i = 0; i < pcmChunksRef.current.length; i++) {
queueUpChunk(i);
}
} else {
console.error("Got into a weird state");
}
} else {
setShouldConnect(false);
setVoice(voiceToGenerate);
pcmChunksRef.current = []; // Reset PCM buffer
setTranscriptWithTimestamps([]); // Reset text messages
setTotalDuration(0);
audioContextRef.current?.close();
audioContextRef.current = null;
setTimeout(() => {
setRequestId((id) => id + 1); // force new connection
setText(textToGenerate);
setShouldConnect(true);
setIsPaused(false);
}, 0);
}
};
const ttsPause = () => {
setIsPaused(true);
};
// Clear all state related to the current connection
const clearState = useCallback(() => {
pcmChunksRef.current = [];
setTranscriptWithTimestamps([]);
setTotalDuration(0);
audioContextRef.current?.close();
audioContextRef.current = null;
setCurrentTime(0);
setIsMidPlayback(false);
setIsPaused(true);
setShouldConnect(false);
}, []);
// When a new chunk is queued, update totalDuration and playbackStartTime
const queueUpChunk = useCallback(
(i: number) => {
if (i === 0) {
// Only create the AudioContext when we start playing so that the start time is 0
audioContextRef.current = new window.AudioContext();
setIsMidPlayback(true);
}
if (!pcmChunksRef.current) return;
const audioCtx = audioContextRef.current;
if (!audioCtx) return;
const buffer = audioCtx.createBuffer(
1,
pcmChunksRef.current[i].length,
SAMPLE_RATE
);
buffer.getChannelData(0).set(pcmChunksRef.current[i]);
const source = audioCtx.createBufferSource();
source.buffer = buffer;
source.connect(audioCtx.destination);
const chunkDuration = buffer.duration;
const startAt = i * chunkDuration;
source.start(startAt);
},
[audioContextRef]
);
// Handle incoming messages
useEffect(() => {
if (lastMessage === null) return;
if (!shouldConnect) return; // Don't process messages from closed connections
// We need async for decodeFromBlob
const handleMessage = async () => {
if (!(lastMessage.data instanceof Blob)) {
console.error("Expected Blob data, but received:", lastMessage.data);
return;
}
const data = (await decodeFromBlob(lastMessage.data)) as TTSMessage;
switch (data.type) {
case "Text":
setTranscriptWithTimestamps((prev) => [
...prev,
data as TTSTextMessage,
]);
break;
case "Audio":
// Collect PCM data
if (data.pcm && data.pcm.length > 0) {
// Convert to Float32Array if not already
const chunk = new Float32Array(data.pcm);
pcmChunksRef.current.push(chunk);
setTotalDuration((prev) => prev + chunk.length / SAMPLE_RATE);
queueUpChunk(pcmChunksRef.current.length - 1);
}
break;
case "Error":
console.error("Received error message:", data.message);
break;
case "Ready":
break;
default:
console.warn("Unknown message type:", data);
}
};
handleMessage();
}, [lastMessage, queueUpChunk, shouldConnect]);
// Update isMidPlayback
useEffect(() => {
// Not currently playing -> no timer needed
if (!isMidPlayback || isPaused || !audioContextRef.current) {
return;
}
if (totalDuration < 1) {
// Hack: prevent race condition where the audio gets immediately paused
return;
}
const audioCtx = audioContextRef.current;
const now = audioCtx.currentTime;
const endTime = pcmChunksRef.current.reduce(
(acc, chunk) => acc + chunk.length / SAMPLE_RATE,
0
);
const remaining = Math.max(0, endTime - now);
if (remaining === 0) {
setIsMidPlayback(false);
return;
}
const timeout = setTimeout(() => {
setIsMidPlayback(false);
audioContextRef.current?.close();
audioContextRef.current = null;
}, remaining * 1000);
return () => clearTimeout(timeout);
}, [isMidPlayback, totalDuration, isPaused]);
// Track currentTime while playing
useEffect(() => {
let rafId: number;
function updateTime() {
const audioCtx = audioContextRef.current;
if (audioCtx && isMidPlayback && !isPaused) {
setCurrentTime(audioCtx.currentTime);
rafId = requestAnimationFrame(updateTime);
}
}
if (isMidPlayback && !isPaused) {
rafId = requestAnimationFrame(updateTime);
}
// Do not reset currentTime on pause; only reset when playback ends
if (!isMidPlayback) {
setCurrentTime(0);
}
return () => {
if (rafId) cancelAnimationFrame(rafId);
};
}, [isMidPlayback, isPaused]);
const getPcmData = useCallback(() => {
if (!pcmChunksRef.current.length) return null;
const totalLength = pcmChunksRef.current.reduce(
(acc, arr) => acc + arr.length,
0
);
const result = new Float32Array(totalLength);
let offset = 0;
for (const chunk of pcmChunksRef.current) {
result.set(chunk, offset);
offset += chunk.length;
}
return result;
}, []);
const getAudioBlob = useCallback(() => {
const pcm = getPcmData();
if (!pcm) return null;
return encodeWav(pcm, SAMPLE_RATE);
}, [getPcmData]);
return {
readyState,
ttsPlay,
ttsPause,
isPaused,
isMidPlayback,
voices,
currentTime,
transcriptWithTimestamps,
clearState,
getAudioBlob,
shouldConnect,
};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment