Created
January 1, 2026 20:14
-
-
Save sergigp/c457948bf8307240449d586d79aadc1d to your computer and use it in GitHub Desktop.
Patch for react-native-sherpa-onnx-offline-tts to add generate() method
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/ios/SherpaOnnxOfflineTts.mm b/ios/SherpaOnnxOfflineTts.mm | |
| index 1234567..abcdefg 100644 | |
| --- a/ios/SherpaOnnxOfflineTts.mm | |
| +++ b/ios/SherpaOnnxOfflineTts.mm | |
| @@ -5,6 +5,13 @@ | |
| // Initialize method exposed to React Native | |
| RCT_EXTERN_METHOD(initializeTTS:(double)sampleRate channels:(NSInteger)channels modelId:(NSString *)modelId) | |
| +// Generate method exposed to React Native (saves to WAV file) | |
| +RCT_EXTERN_METHOD(generate:(NSString *)text | |
| + sid:(NSInteger)sid | |
| + speed:(double)speed | |
| + resolver:(RCTPromiseResolveBlock)resolver | |
| + rejecter:(RCTPromiseRejectBlock)rejecter) | |
| + | |
| // Generate and Play method exposed to React Native | |
| RCT_EXTERN_METHOD(generateAndPlay:(NSString *)text | |
| sid:(NSInteger)sid | |
| diff --git a/ios/SherpaOnnxOfflineTts.swift b/ios/SherpaOnnxOfflineTts.swift | |
| index 1234567..abcdefg 100644 | |
| --- a/ios/SherpaOnnxOfflineTts.swift | |
| +++ b/ios/SherpaOnnxOfflineTts.swift | |
| @@ -113,6 +113,76 @@ class TTSManager: RCTEventEmitter, AudioPlayerDelegate { | |
| realTimeAudioPlayer?.playAudioData(from: audio) | |
| } | |
| + // Generate audio and save to WAV file | |
| + @objc(generate:sid:speed:resolver:rejecter:) | |
| + func generate(_ text: String, sid: Int, speed: Double, resolver: @escaping RCTPromiseResolveBlock, rejecter: @escaping RCTPromiseRejectBlock) { | |
| + let trimmedText = text.trimmingCharacters(in: .whitespacesAndNewlines) | |
| + | |
| + guard !trimmedText.isEmpty else { | |
| + rejecter("EMPTY_TEXT", "Input text is empty", nil) | |
| + return | |
| + } | |
| + | |
| + guard let tts = self.tts else { | |
| + rejecter("NOT_INITIALIZED", "TTS was never initialized", nil) | |
| + return | |
| + } | |
| + | |
| + print("Generating audio for text length: \(trimmedText.count)") | |
| + let startTime = Date() | |
| + | |
| + let audio = tts.generate(text: trimmedText, sid: sid, speed: Float(speed)) | |
| + | |
| + let endTime = Date() | |
| + let generationTime = endTime.timeIntervalSince(startTime) | |
| + print("Time taken for TTS generation: \(generationTime) seconds") | |
| + | |
| + // Save to WAV file | |
| + do { | |
| + let wavPath = try saveToWavFile(audio: audio) | |
| + print("WAV file saved to: \(wavPath)") | |
| + resolver(wavPath) | |
| + } catch { | |
| + rejecter("WAV_SAVE_ERROR", "Failed to save WAV file: \(error.localizedDescription)", error) | |
| + } | |
| + } | |
| + | |
| + // Helper function to save audio to WAV file | |
| + private func saveToWavFile(audio: SherpaOnnxGeneratedAudioWrapper) throws -> String { | |
| + let samples = audio.samples | |
| + let sampleRate = audio.sampleRate | |
| + | |
| + // Create temporary file path | |
| + let tempDir = NSTemporaryDirectory() | |
| + let uuid = UUID().uuidString | |
| + let wavPath = "\(tempDir)sherpa-\(uuid).wav" | |
| + | |
| + // Convert Float samples to Int16 PCM | |
| + var pcmSamples: [Int16] = [] | |
| + for sample in samples { | |
| + let clampedSample = max(-1.0, min(1.0, sample)) | |
| + let int16Sample = Int16(clampedSample * 32767.0) | |
| + pcmSamples.append(int16Sample) | |
| + } | |
| + | |
| + // Write WAV file | |
| + let audioFormat = AVAudioFormat(commonFormat: .pcmFormatInt16, | |
| + sampleRate: Double(sampleRate), | |
| + channels: 1, | |
| + interleaved: false)! | |
| + | |
| + let audioFile = try AVAudioFile(forWriting: URL(fileURLWithPath: wavPath), | |
| + settings: audioFormat.settings) | |
| + | |
| + let buffer = AVAudioPCMBuffer(pcmFormat: audioFormat, frameCapacity: UInt32(pcmSamples.count))! | |
| + buffer.frameLength = UInt32(pcmSamples.count) | |
| + | |
| + let channelData = buffer.int16ChannelData![0] | |
| + for i in 0..<pcmSamples.count { | |
| + channelData[i] = pcmSamples[i] | |
| + } | |
| + | |
| + try audioFile.write(from: buffer) | |
| + | |
| + return wavPath | |
| + } | |
| + | |
| // Clean up resources | |
| @objc func deinitialize() { | |
| self.realTimeAudioPlayer?.stop() | |
| diff --git a/src/index.tsx b/src/index.tsx | |
| index 1234567..abcdefg 100644 | |
| --- a/src/index.tsx | |
| +++ b/src/index.tsx | |
| @@ -9,6 +9,15 @@ const initialize = (modelId: string) => { | |
| TTSManager.initializeTTS(22050, 1, modelId); | |
| }; | |
| +const generate = async (text: any, sid: any, speed: any) => { | |
| + try { | |
| + const wavPath = await TTSManager.generate(text, sid, speed); | |
| + return wavPath; | |
| + } catch (error) { | |
| + console.error(error); | |
| + throw error; | |
| + } | |
| +}; | |
| + | |
| const generateAndPlay = async (text: any, sid: any, speed: any) => { | |
| try { | |
| const result = await TTSManager.generateAndPlay(text, sid, speed); | |
| @@ -18,6 +27,10 @@ const generateAndPlay = async (text: any, sid: any, speed: any) => { | |
| } | |
| }; | |
| +const stopPlaying = () => { | |
| + TTSManager.stopPlaying(); | |
| +}; | |
| + | |
| const deinitialize = () => { | |
| TTSManager.deinitialize(); | |
| }; | |
| @@ -34,7 +47,9 @@ const addVolumeListener = (callback: any) => { | |
| export default { | |
| initialize, | |
| + generate, | |
| generateAndPlay, | |
| + stopPlaying, | |
| deinitialize, | |
| addVolumeListener, | |
| }; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment