Created
October 22, 2024 20:51
-
-
Save usmanm/5dd8ee59c18b290333254435d5e73927 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| g++ -O3 -Wall -shared -std=c++11 \ | |
| -undefined dynamic_lookup \ | |
| $(python3 -m pybind11 --includes) \ | |
| -Isdk/macos-arm/include \ | |
| krisp_processor.cpp \ | |
| sdk/macos-arm/lib/libkrisp-audio-sdk.a \ | |
| sdk/macos-arm/external/libresample.a \ | |
| -o krisp_processor$(python3-config --extension-suffix) | |
| */ | |
| #include <string> | |
| #include <memory> | |
| #include <vector> | |
| #include <pybind11/pybind11.h> | |
| #include <pybind11/numpy.h> | |
| #include <krisp-audio-sdk.hpp> | |
| #include <krisp-audio-sdk-nc.hpp> | |
| using namespace Krisp::AudioSdk; | |
| static std::pair<SamplingRate, bool> getKrispSamplingRate(uint32_t rate) | |
| { | |
| std::pair<SamplingRate, bool> result; | |
| result.second = true; | |
| switch (rate) | |
| { | |
| case 8000: | |
| result.first = SamplingRate::Sr8000Hz; | |
| break; | |
| case 16000: | |
| result.first = SamplingRate::Sr16000Hz; | |
| break; | |
| case 32000: | |
| result.first = SamplingRate::Sr32000Hz; | |
| break; | |
| case 44100: | |
| result.first = SamplingRate::Sr44100Hz; | |
| break; | |
| case 48000: | |
| result.first = SamplingRate::Sr48000Hz; | |
| break; | |
| case 88200: | |
| result.first = SamplingRate::Sr88200Hz; | |
| break; | |
| case 96000: | |
| result.first = SamplingRate::Sr96000Hz; | |
| break; | |
| } | |
| return result; | |
| } | |
| namespace py = pybind11; | |
| class KrispProcessor | |
| { | |
| public: | |
| KrispProcessor(uint32_t sampleRate, uint32_t channels, const std::wstring &weight) : _weight(weight) | |
| { | |
| auto out = getKrispSamplingRate(sampleRate); | |
| if (!out.second) | |
| { | |
| throw std::runtime_error("unsupported sample rate"); | |
| } | |
| const SamplingRate inRate = out.first; | |
| const SamplingRate outRate = inRate; | |
| constexpr FrameDuration frameDurationMillis = FrameDuration::Fd10ms; | |
| _frame_size = (sampleRate * static_cast<uint32_t>(frameDurationMillis)) / 1000; | |
| _frame_length = _frame_size * channels; | |
| globalInit(L""); | |
| ModelInfo ncModelInfo; | |
| ncModelInfo.path = _weight; | |
| NcSessionConfig ncCfg; | |
| ncCfg.inputSampleRate = inRate; | |
| ncCfg.inputFrameDuration = frameDurationMillis; | |
| ncCfg.outputSampleRate = outRate; | |
| ncCfg.modelInfo = &ncModelInfo; | |
| ncCfg.enableSessionStats = false; | |
| _session = Nc<int16_t>::create(ncCfg); | |
| _frame_buffer.resize(_frame_length); | |
| } | |
| ~KrispProcessor() | |
| { | |
| _session.reset(); | |
| globalDestroy(); | |
| } | |
| void store_audio_chunk(const py::array_t<short> &audio_chunk) | |
| { | |
| py::buffer_info info = audio_chunk.request(); | |
| const short *chunk_ptr = static_cast<short *>(info.ptr); | |
| size_t chunk_size = static_cast<size_t>(info.size); | |
| _audio_data.resize(chunk_size + _remaining_sample_count); | |
| std::memcpy(_audio_data.data() + _remaining_sample_count * sizeof(short), | |
| static_cast<const void *>(chunk_ptr), | |
| chunk_size * sizeof(short)); | |
| _remaining_sample_count = 0; | |
| } | |
| size_t get_samples_count() | |
| { | |
| return _audio_data.size(); | |
| } | |
| unsigned get_processed_frames(py::array_t<short> &python_output_frames) | |
| { | |
| py::buffer_info buf_info = python_output_frames.request(); | |
| short *output_ptr = static_cast<short *>(buf_info.ptr); | |
| size_t buffer_frame_count = static_cast<size_t>(buf_info.size) / _frame_length; | |
| size_t audio_frame_count = _audio_data.size() / _frame_length; | |
| if (buffer_frame_count < audio_frame_count) | |
| { | |
| throw std::runtime_error("buffer is too small"); | |
| } | |
| _remaining_sample_count = _audio_data.size() % _frame_length; | |
| unsigned processed_frames = 0; | |
| auto frame_start_it = _audio_data.begin(); | |
| auto frame_end_it = _audio_data.begin(); | |
| for (unsigned i = 0; i < audio_frame_count; ++i) | |
| { | |
| std::advance(frame_end_it, _frame_length); | |
| _session->process( | |
| &(*frame_start_it), | |
| _frame_length, | |
| _frame_buffer.data(), | |
| _frame_length, | |
| 100, | |
| nullptr); | |
| std::copy(_frame_buffer.begin(), _frame_buffer.end(), output_ptr + i * _frame_length); | |
| frame_start_it = frame_end_it; | |
| ++processed_frames; | |
| } | |
| if (_remaining_sample_count) | |
| { | |
| std::copy(frame_end_it, frame_end_it + static_cast<long>(_remaining_sample_count), _audio_data.begin()); | |
| } | |
| return processed_frames; | |
| } | |
| private: | |
| std::wstring _weight; | |
| uint32_t _frame_size; | |
| uint32_t _frame_length; | |
| uint64_t _remaining_sample_count = 0; | |
| std::shared_ptr<Nc<int16_t>> _session; | |
| std::vector<short> _audio_data; | |
| std::vector<short> _frame_buffer; | |
| }; | |
| PYBIND11_MODULE(krisp_processor, m) | |
| { | |
| py::class_<KrispProcessor>(m, "KrispProcessor") | |
| .def(py::init<unsigned, unsigned, std::wstring>()) | |
| .def("store_audio_chunk", &KrispProcessor::store_audio_chunk) | |
| .def("get_processed_frames", &KrispProcessor::get_processed_frames) | |
| .def("get_samples_count", &KrispProcessor::get_samples_count); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment