bump to 0.8.0

This commit is contained in:
Terrence
2024-11-14 23:15:43 +08:00
parent ddb375173e
commit ec918748f1
34 changed files with 1039 additions and 497 deletions

View File

@ -2,6 +2,7 @@
#include "system_info.h"
#include "ml307_ssl_transport.h"
#include "audio_codec.h"
#include "protocols/mqtt_protocol.h"
#include <cstring>
#include <esp_log.h>
@ -27,8 +28,8 @@ Application::Application() {
}
Application::~Application() {
if (ws_client_ != nullptr) {
delete ws_client_;
if (protocol_ != nullptr) {
delete protocol_;
}
if (opus_decoder_ != nullptr) {
opus_decoder_destroy(opus_decoder_);
@ -48,10 +49,6 @@ void Application::CheckNewVersion() {
ota_.SetPostData(Board::GetInstance().GetJson());
ota_.CheckVersion();
if (ota_.HasNewVersion()) {
// Wait for the chat state to be idle
while (chat_state_ != kChatStateIdle) {
vTaskDelay(100);
}
SetChatState(kChatStateUpgrading);
ota_.StartUpgrade([](int progress, size_t speed) {
char buffer[64];
@ -84,49 +81,36 @@ void Application::Alert(const std::string&& title, const std::string&& message)
void Application::PlayLocalFile(const char* data, size_t size) {
ESP_LOGI(TAG, "PlayLocalFile: %zu bytes", size);
SetDecodeSampleRate(16000);
auto codec = Board::GetInstance().GetAudioCodec();
codec->EnableOutput(true);
for (const char* p = data; p < data + size; ) {
auto p3 = (BinaryProtocol3*)p;
p += sizeof(BinaryProtocol3);
auto payload_size = ntohs(p3->payload_size);
std::string opus;
opus.resize(payload_size);
memcpy(opus.data(), p3->payload, payload_size);
p += payload_size;
{
std::lock_guard<std::mutex> lock(mutex_);
auto packet = new AudioPacket();
packet->type = kAudioPacketTypeStart;
audio_decode_queue_.push_back(packet);
}
ParseBinaryProtocol3(data, size);
{
std::lock_guard<std::mutex> lock(mutex_);
auto packet = new AudioPacket();
packet->type = kAudioPacketTypeStop;
audio_decode_queue_.push_back(packet);
cv_.notify_all();
audio_decode_queue_.emplace_back(std::move(opus));
}
cv_.notify_all();
}
void Application::ToggleChatState() {
Schedule([this]() {
if (chat_state_ == kChatStateIdle) {
SetChatState(kChatStateConnecting);
StartWebSocketClient();
if (ws_client_ && ws_client_->IsConnected()) {
if (protocol_->OpenAudioChannel()) {
opus_encoder_.ResetState();
#ifdef CONFIG_USE_AFE_SR
audio_processor_.Start();
#endif
SetChatState(kChatStateListening);
ESP_LOGI(TAG, "Communication started");
} else {
SetChatState(kChatStateIdle);
}
} else if (chat_state_ == kChatStateSpeaking) {
AbortSpeaking();
} else if (chat_state_ == kChatStateListening) {
if (ws_client_ && ws_client_->IsConnected()) {
ws_client_->Close();
}
protocol_->CloseAudioChannel();
}
});
}
@ -139,9 +123,11 @@ void Application::Start() {
builtin_led->SetBlue();
builtin_led->StartContinuousBlink(100);
/* Setup the display */
auto display = board.GetDisplay();
display->SetupUI();
/* Setup the audio codec */
auto codec = board.GetAudioCodec();
opus_decode_sample_rate_ = codec->output_sample_rate();
opus_decoder_ = opus_decoder_create(opus_decode_sample_rate_, 1, NULL);
@ -150,10 +136,6 @@ void Application::Start() {
input_resampler_.Configure(codec->input_sample_rate(), 16000);
reference_resampler_.Configure(codec->input_sample_rate(), 16000);
}
codec->EnableInput(true);
codec->EnableOutput(true);
codec->EnableOutput(false);
codec->OnInputData([this, codec](std::vector<int16_t>&& data) {
if (codec->input_sample_rate() != 16000) {
if (codec->input_channels() == 2) {
@ -196,8 +178,7 @@ void Application::Start() {
#endif
});
// OPUS encoder / decoder use a lot of stack memory
const size_t opus_stack_size = 4096 * 8;
const size_t opus_stack_size = 4096 * 8; // OPUS encoder / decoder use a lot of stack memory
audio_encode_task_stack_ = (StackType_t*)heap_caps_malloc(opus_stack_size, MALLOC_CAP_SPIRAM);
audio_encode_task_ = xTaskCreateStatic([](void* arg) {
Application* app = (Application*)arg;
@ -205,18 +186,12 @@ void Application::Start() {
vTaskDelete(NULL);
}, "opus_encode", opus_stack_size, this, 1, audio_encode_task_stack_, &audio_encode_task_buffer_);
xTaskCreate([](void* arg) {
Application* app = (Application*)arg;
app->AudioPlayTask();
vTaskDelete(NULL);
}, "play_audio", 4096 * 4, this, 4, NULL);
codec->Start();
/* Wait for the network to be ready */
board.StartNetwork();
// Blink the LED to indicate the device is running
builtin_led->SetGreen();
builtin_led->BlinkOnce();
const size_t main_loop_stack_size = 4096 * 2;
const size_t main_loop_stack_size = 4096 * 8;
main_loop_task_stack_ = (StackType_t*)heap_caps_malloc(main_loop_stack_size, MALLOC_CAP_SPIRAM);
xTaskCreateStatic([](void* arg) {
Application* app = (Application*)arg;
@ -224,23 +199,23 @@ void Application::Start() {
vTaskDelete(NULL);
}, "main_loop", main_loop_stack_size, this, 1, main_loop_task_stack_, &main_loop_task_buffer_);
// Launch a task to check for new firmware version
xTaskCreate([](void* arg) {
Application* app = (Application*)arg;
app->CheckNewVersion();
vTaskDelete(NULL);
}, "check_new_version", 4096 * 2, this, 1, NULL);
// Check for new firmware version or get the MQTT broker address
while (true) {
CheckNewVersion();
if (ota_.HasMqttConfig()) {
break;
}
Alert("Error", "Missing MQTT config");
vTaskDelay(pdMS_TO_TICKS(10000));
}
#ifdef CONFIG_USE_AFE_SR
audio_processor_.Initialize(codec->input_channels(), codec->input_reference());
audio_processor_.OnOutput([this](std::vector<int16_t>&& data) {
Schedule([this, data = std::move(data)]() {
if (chat_state_ == kChatStateListening) {
std::lock_guard<std::mutex> lock(mutex_);
audio_encode_queue_.emplace_back(std::move(data));
cv_.notify_all();
}
});
std::lock_guard<std::mutex> lock(mutex_);
audio_encode_queue_.emplace_back(std::move(data));
cv_.notify_all();
});
wake_word_detect_.Initialize(codec->input_channels(), codec->input_reference());
@ -261,24 +236,18 @@ void Application::Start() {
wake_word_detect_.OnWakeWordDetected([this]() {
Schedule([this]() {
if (chat_state_ == kChatStateIdle) {
// Encode the wake word data and start websocket client at the same time
// They both consume a lot of time (700ms), so we can do them in parallel
SetChatState(kChatStateConnecting);
wake_word_detect_.EncodeWakeWordData();
SetChatState(kChatStateConnecting);
if (ws_client_ == nullptr) {
StartWebSocketClient();
}
if (ws_client_ && ws_client_->IsConnected()) {
auto encoded = wake_word_detect_.GetWakeWordStream();
// Send the wake word data to the server
ws_client_->Send(encoded.data(), encoded.size(), true);
if (protocol_->OpenAudioChannel()) {
std::string opus;
// Encode and send the wake word data to the server
while (wake_word_detect_.GetWakeWordOpus(opus)) {
protocol_->SendAudio(opus);
}
opus_encoder_.ResetState();
// Send a ready message to indicate the server that the wake word data is sent
SetChatState(kChatStateWakeWordDetected);
// If connected, the hello message is already sent, so we can start communication
audio_processor_.Start();
ESP_LOGI(TAG, "Audio processor started");
} else {
SetChatState(kChatStateIdle);
}
@ -293,7 +262,68 @@ void Application::Start() {
wake_word_detect_.StartDetection();
#endif
chat_state_ = kChatStateIdle;
// Initialize the protocol
display->SetText("Starting\nProtocol...");
protocol_ = new MqttProtocol(ota_.GetMqttConfig());
protocol_->OnIncomingAudio([this](const std::string& data) {
std::lock_guard<std::mutex> lock(mutex_);
audio_decode_queue_.emplace_back(std::move(data));
cv_.notify_all();
});
protocol_->OnAudioChannelClosed([this]() {
Schedule([this]() {
SetChatState(kChatStateIdle);
});
});
protocol_->OnIncomingJson([this](const cJSON* root) {
// Parse JSON data
auto type = cJSON_GetObjectItem(root, "type");
if (strcmp(type->valuestring, "tts") == 0) {
auto state = cJSON_GetObjectItem(root, "state");
if (strcmp(state->valuestring, "start") == 0) {
Schedule([this]() {
skip_to_end_ = false;
SetChatState(kChatStateSpeaking);
});
} else if (strcmp(state->valuestring, "stop") == 0) {
Schedule([this]() {
auto codec = Board::GetInstance().GetAudioCodec();
codec->WaitForOutputDone();
SetChatState(kChatStateListening);
});
} else if (strcmp(state->valuestring, "sentence_start") == 0) {
auto text = cJSON_GetObjectItem(root, "text");
if (text != NULL) {
ESP_LOGI(TAG, ">> %s", text->valuestring);
}
}
} else if (strcmp(type->valuestring, "stt") == 0) {
auto text = cJSON_GetObjectItem(root, "text");
if (text != NULL) {
ESP_LOGI(TAG, ">> %s", text->valuestring);
}
} else if (strcmp(type->valuestring, "llm") == 0) {
auto emotion = cJSON_GetObjectItem(root, "emotion");
if (emotion != NULL) {
ESP_LOGD(TAG, "EMOTION: %s", emotion->valuestring);
}
} else if (strcmp(type->valuestring, "hello") == 0) {
// Get sample rate from hello message
auto audio_params = cJSON_GetObjectItem(root, "audio_params");
if (audio_params != NULL) {
auto sample_rate = cJSON_GetObjectItem(audio_params, "sample_rate");
if (sample_rate != NULL) {
SetDecodeSampleRate(sample_rate->valueint);
}
}
}
});
// Blink the LED to indicate the device is running
builtin_led->SetGreen();
builtin_led->BlinkOnce();
SetChatState(kChatStateIdle);
display->UpdateDisplay();
}
@ -321,16 +351,12 @@ void Application::MainLoop() {
void Application::AbortSpeaking() {
ESP_LOGI(TAG, "Abort speaking");
skip_to_end_ = true;
std::string json = "{\"type\":\"abort\"}";
protocol_->SendText(json);
if (ws_client_ && ws_client_->IsConnected()) {
cJSON* root = cJSON_CreateObject();
cJSON_AddStringToObject(root, "type", "abort");
char* json = cJSON_PrintUnformatted(root);
ws_client_->Send(json);
cJSON_Delete(root);
free(json);
}
skip_to_end_ = true;
auto codec = Board::GetInstance().GetAudioCodec();
codec->ClearOutputQueue();
}
void Application::SetChatState(ChatState state) {
@ -359,6 +385,9 @@ void Application::SetChatState(ChatState state) {
case kChatStateIdle:
builtin_led->TurnOff();
display->SetText("I'm\nIdle.");
#ifdef CONFIG_USE_AFE_SR
audio_processor_.Stop();
#endif
break;
case kChatStateConnecting:
builtin_led->SetBlue();
@ -369,11 +398,17 @@ void Application::SetChatState(ChatState state) {
builtin_led->SetRed();
builtin_led->TurnOn();
display->SetText("I'm\nListening...");
#ifdef CONFIG_USE_AFE_SR
audio_processor_.Start();
#endif
break;
case kChatStateSpeaking:
builtin_led->SetGreen();
builtin_led->TurnOn();
display->SetText("I'm\nSpeaking...");
#ifdef CONFIG_USE_AFE_SR
audio_processor_.Stop();
#endif
break;
case kChatStateWakeWordDetected:
builtin_led->SetBlue();
@ -385,37 +420,20 @@ void Application::SetChatState(ChatState state) {
break;
}
if (ws_client_ && ws_client_->IsConnected()) {
cJSON* root = cJSON_CreateObject();
cJSON_AddStringToObject(root, "type", "state");
cJSON_AddStringToObject(root, "state", state_str[chat_state_]);
char* json = cJSON_PrintUnformatted(root);
ws_client_->Send(json);
cJSON_Delete(root);
free(json);
}
}
BinaryProtocol3* Application::AllocateBinaryProtocol3(const uint8_t* payload, size_t payload_size) {
auto protocol = (BinaryProtocol3*)heap_caps_malloc(sizeof(BinaryProtocol3) + payload_size, MALLOC_CAP_SPIRAM);
assert(protocol != nullptr);
protocol->type = 0;
protocol->reserved = 0;
protocol->payload_size = htons(payload_size);
assert(sizeof(BinaryProtocol3) == 4UL);
memcpy(protocol->payload, payload, payload_size);
return protocol;
std::string json = "{\"type\":\"state\",\"state\":\"";
json += state_str[chat_state_];
json += "\"}";
protocol_->SendText(json);
}
void Application::AudioEncodeTask() {
ESP_LOGI(TAG, "Audio encode task started");
const int max_audio_play_queue_size_ = 2; // avoid decoding too fast
auto codec = Board::GetInstance().GetAudioCodec();
while (true) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this]() {
return !audio_encode_queue_.empty() || (!audio_decode_queue_.empty() && audio_play_queue_.size() < max_audio_play_queue_size_);
return !audio_encode_queue_.empty() || !audio_decode_queue_.empty();
});
if (!audio_encode_queue_.empty()) {
@ -423,108 +441,38 @@ void Application::AudioEncodeTask() {
audio_encode_queue_.pop_front();
lock.unlock();
// Encode audio data
opus_encoder_.Encode(pcm, [this](const uint8_t* opus, size_t opus_size) {
auto protocol = AllocateBinaryProtocol3(opus, opus_size);
Schedule([this, protocol, opus_size]() {
if (ws_client_ && ws_client_->IsConnected()) {
if (!ws_client_->Send(protocol, sizeof(BinaryProtocol3) + opus_size, true)) {
ESP_LOGE(TAG, "Failed to send audio data");
}
}
heap_caps_free(protocol);
Schedule([this, data = std::string(reinterpret_cast<const char*>(opus), opus_size)]() {
protocol_->SendAudio(data);
});
});
} else if (!audio_decode_queue_.empty()) {
auto packet = std::move(audio_decode_queue_.front());
auto opus = std::move(audio_decode_queue_.front());
audio_decode_queue_.pop_front();
lock.unlock();
if (packet->type == kAudioPacketTypeData && !skip_to_end_) {
int frame_size = opus_decode_sample_rate_ * opus_duration_ms_ / 1000;
packet->pcm.resize(frame_size);
int ret = opus_decode(opus_decoder_, packet->opus.data(), packet->opus.size(), packet->pcm.data(), frame_size, 0);
if (ret < 0) {
ESP_LOGE(TAG, "Failed to decode audio, error code: %d", ret);
delete packet;
continue;
}
if (opus_decode_sample_rate_ != codec->output_sample_rate()) {
int target_size = output_resampler_.GetOutputSamples(frame_size);
std::vector<int16_t> resampled(target_size);
output_resampler_.Process(packet->pcm.data(), frame_size, resampled.data());
packet->pcm = std::move(resampled);
}
if (skip_to_end_) {
continue;
}
std::lock_guard<std::mutex> lock(mutex_);
audio_play_queue_.push_back(packet);
cv_.notify_all();
}
}
}
int frame_size = opus_decode_sample_rate_ * opus_duration_ms_ / 1000;
std::vector<int16_t> pcm(frame_size);
void Application::HandleAudioPacket(AudioPacket* packet) {
switch (packet->type)
{
case kAudioPacketTypeData: {
if (skip_to_end_) {
break;
}
// This will block until the audio device has finished playing the audio
auto codec = Board::GetInstance().GetAudioCodec();
codec->OutputData(packet->pcm);
break;
}
case kAudioPacketTypeStart:
break_speaking_ = false;
skip_to_end_ = false;
Schedule([this]() {
SetChatState(kChatStateSpeaking);
});
break;
case kAudioPacketTypeStop:
Schedule([this]() {
if (ws_client_ && ws_client_->IsConnected()) {
SetChatState(kChatStateListening);
} else {
SetChatState(kChatStateIdle);
int ret = opus_decode(opus_decoder_, (const unsigned char*)opus.data(), opus.size(), pcm.data(), frame_size, 0);
if (ret < 0) {
ESP_LOGE(TAG, "Failed to decode audio, error code: %d", ret);
continue;
}
});
break;
case kAudioPacketTypeSentenceStart:
ESP_LOGI(TAG, "<< %s", packet->text.c_str());
break;
case kAudioPacketTypeSentenceEnd:
if (break_speaking_) {
skip_to_end_ = true;
// Resample if the sample rate is different
if (opus_decode_sample_rate_ != codec->output_sample_rate()) {
int target_size = output_resampler_.GetOutputSamples(frame_size);
std::vector<int16_t> resampled(target_size);
output_resampler_.Process(pcm.data(), frame_size, resampled.data());
pcm = std::move(resampled);
}
codec->OutputData(pcm);
}
break;
default:
ESP_LOGI(TAG, "Unknown packet type: %d", packet->type);
break;
}
delete packet;
}
void Application::AudioPlayTask() {
ESP_LOGI(TAG, "Audio play task started");
while (true) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this]() {
return !audio_play_queue_.empty();
});
auto packet = std::move(audio_play_queue_.front());
audio_play_queue_.pop_front();
cv_.notify_all();
lock.unlock();
HandleAudioPacket(packet);
}
}
@ -543,127 +491,3 @@ void Application::SetDecodeSampleRate(int sample_rate) {
output_resampler_.Configure(opus_decode_sample_rate_, codec->output_sample_rate());
}
}
void Application::ParseBinaryProtocol3(const char* data, size_t size) {
for (const char* p = data; p < data + size; ) {
auto protocol = (BinaryProtocol3*)p;
p += sizeof(BinaryProtocol3);
auto packet = new AudioPacket();
packet->type = kAudioPacketTypeData;
auto payload_size = ntohs(protocol->payload_size);
packet->opus.resize(payload_size);
memcpy(packet->opus.data(), protocol->payload, payload_size);
p += payload_size;
std::lock_guard<std::mutex> lock(mutex_);
audio_decode_queue_.push_back(packet);
}
}
void Application::StartWebSocketClient() {
if (ws_client_ != nullptr) {
ESP_LOGW(TAG, "WebSocket client already exists");
delete ws_client_;
}
std::string url = CONFIG_WEBSOCKET_URL;
std::string token = "Bearer " + std::string(CONFIG_WEBSOCKET_ACCESS_TOKEN);
ws_client_ = Board::GetInstance().CreateWebSocket();
ws_client_->SetHeader("Authorization", token.c_str());
ws_client_->SetHeader("Protocol-Version", std::to_string(PROTOCOL_VERSION).c_str());
ws_client_->SetHeader("Device-Id", SystemInfo::GetMacAddress().c_str());
ws_client_->OnConnected([this]() {
ESP_LOGI(TAG, "Websocket connected");
// Send hello message to describe the client
// keys: message type, version, wakeup_model, audio_params (format, sample_rate, channels)
std::string message = "{";
message += "\"type\":\"hello\",";
message += "\"audio_params\":{";
message += "\"format\":\"opus\", \"sample_rate\":16000, \"channels\":1";
message += "}}";
ws_client_->Send(message);
});
ws_client_->OnData([this](const char* data, size_t len, bool binary) {
if (binary) {
ParseBinaryProtocol3(data, len);
cv_.notify_all();
} else {
// Parse JSON data
auto root = cJSON_Parse(data);
auto type = cJSON_GetObjectItem(root, "type");
if (type != NULL) {
if (strcmp(type->valuestring, "tts") == 0) {
auto packet = new AudioPacket();
auto state = cJSON_GetObjectItem(root, "state");
if (strcmp(state->valuestring, "start") == 0) {
packet->type = kAudioPacketTypeStart;
auto sample_rate = cJSON_GetObjectItem(root, "sample_rate");
if (sample_rate != NULL) {
SetDecodeSampleRate(sample_rate->valueint);
}
// If the device is speaking, we need to skip the last session
skip_to_end_ = true;
} else if (strcmp(state->valuestring, "stop") == 0) {
packet->type = kAudioPacketTypeStop;
} else if (strcmp(state->valuestring, "sentence_end") == 0) {
packet->type = kAudioPacketTypeSentenceEnd;
} else if (strcmp(state->valuestring, "sentence_start") == 0) {
packet->type = kAudioPacketTypeSentenceStart;
packet->text = cJSON_GetObjectItem(root, "text")->valuestring;
}
std::lock_guard<std::mutex> lock(mutex_);
audio_decode_queue_.push_back(packet);
cv_.notify_all();
} else if (strcmp(type->valuestring, "stt") == 0) {
auto text = cJSON_GetObjectItem(root, "text");
if (text != NULL) {
ESP_LOGI(TAG, ">> %s", text->valuestring);
}
} else if (strcmp(type->valuestring, "llm") == 0) {
auto emotion = cJSON_GetObjectItem(root, "emotion");
if (emotion != NULL) {
ESP_LOGD(TAG, "EMOTION: %s", emotion->valuestring);
}
} else {
ESP_LOGW(TAG, "Unknown message type: %s", type->valuestring);
}
} else {
ESP_LOGE(TAG, "Missing message type, data: %s", data);
}
cJSON_Delete(root);
}
});
ws_client_->OnError([this](int error) {
ESP_LOGE(TAG, "Websocket error: %d", error);
});
ws_client_->OnDisconnected([this]() {
ESP_LOGI(TAG, "Websocket disconnected");
Schedule([this]() {
auto codec = Board::GetInstance().GetAudioCodec();
codec->EnableOutput(false);
#ifdef CONFIG_USE_AFE_SR
audio_processor_.Stop();
#endif
delete ws_client_;
ws_client_ = nullptr;
SetChatState(kChatStateIdle);
});
});
if (!ws_client_->Connect(url.c_str())) {
ESP_LOGE(TAG, "Failed to connect to websocket server");
return;
}
// 建立语音通道后打开音频输出,避免待机时喇叭底噪
auto codec = Board::GetInstance().GetAudioCodec();
codec->EnableOutput(true);
}