Skip to content

Instantly share code, notes, and snippets.

@usmanm
Created October 22, 2024 20:51
Show Gist options
  • Select an option

  • Save usmanm/5dd8ee59c18b290333254435d5e73927 to your computer and use it in GitHub Desktop.

Select an option

Save usmanm/5dd8ee59c18b290333254435d5e73927 to your computer and use it in GitHub Desktop.
/*
g++ -O3 -Wall -shared -std=c++11 \
-undefined dynamic_lookup \
$(python3 -m pybind11 --includes) \
-Isdk/macos-arm/include \
krisp_processor.cpp \
sdk/macos-arm/lib/libkrisp-audio-sdk.a \
sdk/macos-arm/external/libresample.a \
-o krisp_processor$(python3-config --extension-suffix)
*/
#include <string>
#include <memory>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <krisp-audio-sdk.hpp>
#include <krisp-audio-sdk-nc.hpp>
using namespace Krisp::AudioSdk;
static std::pair<SamplingRate, bool> getKrispSamplingRate(uint32_t rate)
{
std::pair<SamplingRate, bool> result;
result.second = true;
switch (rate)
{
case 8000:
result.first = SamplingRate::Sr8000Hz;
break;
case 16000:
result.first = SamplingRate::Sr16000Hz;
break;
case 32000:
result.first = SamplingRate::Sr32000Hz;
break;
case 44100:
result.first = SamplingRate::Sr44100Hz;
break;
case 48000:
result.first = SamplingRate::Sr48000Hz;
break;
case 88200:
result.first = SamplingRate::Sr88200Hz;
break;
case 96000:
result.first = SamplingRate::Sr96000Hz;
break;
}
return result;
}
namespace py = pybind11;
class KrispProcessor
{
public:
KrispProcessor(uint32_t sampleRate, uint32_t channels, const std::wstring &weight) : _weight(weight)
{
auto out = getKrispSamplingRate(sampleRate);
if (!out.second)
{
throw std::runtime_error("unsupported sample rate");
}
const SamplingRate inRate = out.first;
const SamplingRate outRate = inRate;
constexpr FrameDuration frameDurationMillis = FrameDuration::Fd10ms;
_frame_size = (sampleRate * static_cast<uint32_t>(frameDurationMillis)) / 1000;
_frame_length = _frame_size * channels;
globalInit(L"");
ModelInfo ncModelInfo;
ncModelInfo.path = _weight;
NcSessionConfig ncCfg;
ncCfg.inputSampleRate = inRate;
ncCfg.inputFrameDuration = frameDurationMillis;
ncCfg.outputSampleRate = outRate;
ncCfg.modelInfo = &ncModelInfo;
ncCfg.enableSessionStats = false;
_session = Nc<int16_t>::create(ncCfg);
_frame_buffer.resize(_frame_length);
}
~KrispProcessor()
{
_session.reset();
globalDestroy();
}
void store_audio_chunk(const py::array_t<short> &audio_chunk)
{
py::buffer_info info = audio_chunk.request();
const short *chunk_ptr = static_cast<short *>(info.ptr);
size_t chunk_size = static_cast<size_t>(info.size);
_audio_data.resize(chunk_size + _remaining_sample_count);
std::memcpy(_audio_data.data() + _remaining_sample_count * sizeof(short),
static_cast<const void *>(chunk_ptr),
chunk_size * sizeof(short));
_remaining_sample_count = 0;
}
size_t get_samples_count()
{
return _audio_data.size();
}
unsigned get_processed_frames(py::array_t<short> &python_output_frames)
{
py::buffer_info buf_info = python_output_frames.request();
short *output_ptr = static_cast<short *>(buf_info.ptr);
size_t buffer_frame_count = static_cast<size_t>(buf_info.size) / _frame_length;
size_t audio_frame_count = _audio_data.size() / _frame_length;
if (buffer_frame_count < audio_frame_count)
{
throw std::runtime_error("buffer is too small");
}
_remaining_sample_count = _audio_data.size() % _frame_length;
unsigned processed_frames = 0;
auto frame_start_it = _audio_data.begin();
auto frame_end_it = _audio_data.begin();
for (unsigned i = 0; i < audio_frame_count; ++i)
{
std::advance(frame_end_it, _frame_length);
_session->process(
&(*frame_start_it),
_frame_length,
_frame_buffer.data(),
_frame_length,
100,
nullptr);
std::copy(_frame_buffer.begin(), _frame_buffer.end(), output_ptr + i * _frame_length);
frame_start_it = frame_end_it;
++processed_frames;
}
if (_remaining_sample_count)
{
std::copy(frame_end_it, frame_end_it + static_cast<long>(_remaining_sample_count), _audio_data.begin());
}
return processed_frames;
}
private:
std::wstring _weight;
uint32_t _frame_size;
uint32_t _frame_length;
uint64_t _remaining_sample_count = 0;
std::shared_ptr<Nc<int16_t>> _session;
std::vector<short> _audio_data;
std::vector<short> _frame_buffer;
};
PYBIND11_MODULE(krisp_processor, m)
{
py::class_<KrispProcessor>(m, "KrispProcessor")
.def(py::init<unsigned, unsigned, std::wstring>())
.def("store_audio_chunk", &KrispProcessor::store_audio_chunk)
.def("get_processed_frames", &KrispProcessor::get_processed_frames)
.def("get_samples_count", &KrispProcessor::get_samples_count);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment