2026-02-21 20:48:10 +01:00

230 lines
8.8 KiB
C++

#include "Conversation.hpp"
#include <boost/beast/websocket/ssl.hpp>
#include <boost/beast/websocket.hpp>
#include <boost/beast/ssl.hpp>
#include <boost/beast/core/detail/base64.hpp>
#include <boost/asio/connect.hpp>
#include <boost/algorithm/string.hpp>
#include <iostream>
#include <sstream>
#include <openssl/ssl.h>
using tcp = boost::asio::ip::tcp;
namespace ssl = boost::asio::ssl;
namespace websocket = boost::beast::websocket;
namespace beast = boost::beast;
static std::string base64Encode(const std::vector<char>& data) {
auto encodedSize = beast::detail::base64::encoded_size(data.size());
std::string out(encodedSize, '\0');
beast::detail::base64::encode(&out[0], data.data(), data.size());
return out;
}
static std::vector<char> base64Decode(const std::string& str) {
auto decodedSize = beast::detail::base64::decoded_size(str.size());
std::vector<char> out(decodedSize);
auto result = beast::detail::base64::decode(out.data(), str.data(), str.size());
out.resize(result.first);
return out;
}
static std::string toString(const nlohmann::json& j){
if(j.is_string()) return j.get<std::string>();
if(j.is_number_integer()) return std::to_string(j.get<int64_t>());
return j.dump();
}
Conversation::Conversation(const std::string& agentId, bool requiresAuth,
std::shared_ptr<AudioInterface> audioInterface,
CallbackAgentResponse callbackAgentResponse,
CallbackAgentResponseCorrection callbackAgentResponseCorrection,
CallbackUserTranscript callbackUserTranscript,
CallbackLatencyMeasurement callbackLatencyMeasurement)
: agentId_(agentId),
requiresAuth_(requiresAuth),
audioInterface_(std::move(audioInterface)),
callbackAgentResponse_(std::move(callbackAgentResponse)),
callbackAgentResponseCorrection_(std::move(callbackAgentResponseCorrection)),
callbackUserTranscript_(std::move(callbackUserTranscript)),
callbackLatencyMeasurement_(std::move(callbackLatencyMeasurement)) {
sslCtx_.set_default_verify_paths();
}
Conversation::~Conversation() {
endSession();
}
void Conversation::startSession() {
shouldStop_.store(false);
workerThread_ = std::thread(&Conversation::run, this);
}
void Conversation::endSession() {
shouldStop_.store(true);
if (ws_) {
beast::error_code ec;
ws_->close(websocket::close_code::normal, ec);
}
if (audioInterface_) {
audioInterface_->stop();
}
if (workerThread_.joinable()) {
workerThread_.join();
}
}
std::string Conversation::waitForSessionEnd() {
if (workerThread_.joinable()) {
workerThread_.join();
}
return conversationId_;
}
void Conversation::sendUserMessage(const std::string& text) {
if (!ws_) {
throw std::runtime_error("Session not started");
}
nlohmann::json j = {
{"type", "user_message"},
{"text", text}
};
ws_->write(boost::asio::buffer(j.dump()));
}
void Conversation::registerUserActivity() {
if (!ws_) throw std::runtime_error("Session not started");
nlohmann::json j = {{"type", "user_activity"}};
ws_->write(boost::asio::buffer(j.dump()));
}
void Conversation::sendContextualUpdate(const std::string& content) {
if (!ws_) throw std::runtime_error("Session not started");
nlohmann::json j = {{"type", "contextual_update"}, {"content", content}};
ws_->write(boost::asio::buffer(j.dump()));
}
std::string Conversation::getWssUrl() const {
// Hard-coded base env for demo; in production you'd call ElevenLabs env endpoint.
std::ostringstream oss;
oss << "wss://api.elevenlabs.io/v1/convai/conversation?agent_id=" << agentId_;
return oss.str();
}
void Conversation::run() {
try {
auto url = getWssUrl();
std::string protocol, host, target;
unsigned short port = 443;
// Very naive parse: wss://host[:port]/path?query
if (boost::starts_with(url, "wss://")) {
protocol = "wss";
host = url.substr(6);
} else {
throw std::runtime_error("Only wss:// URLs supported in this demo");
}
auto slashPos = host.find('/');
if (slashPos == std::string::npos) {
target = "/";
} else {
target = host.substr(slashPos);
host = host.substr(0, slashPos);
}
auto colonPos = host.find(':');
if (colonPos != std::string::npos) {
port = static_cast<unsigned short>(std::stoi(host.substr(colonPos + 1)));
host = host.substr(0, colonPos);
}
tcp::resolver resolver(ioc_);
auto const results = resolver.resolve(host, std::to_string(port));
beast::ssl_stream<tcp::socket> stream(ioc_, sslCtx_);
boost::asio::connect(beast::get_lowest_layer(stream), results);
if (!SSL_set_tlsext_host_name(stream.native_handle(), host.c_str())) {
throw std::runtime_error("Failed to set SNI hostname on SSL stream");
}
stream.handshake(ssl::stream_base::client);
ws_ = std::make_unique<websocket_t>(std::move(stream));
ws_->set_option(websocket::stream_base::timeout::suggested(beast::role_type::client));
ws_->handshake(host, target);
// send initiation data
nlohmann::json init = {
{"type", "conversation_initiation_client_data"},
{"custom_llm_extra_body", nlohmann::json::object()},
{"conversation_config_override", nlohmann::json::object()},
{"dynamic_variables", nlohmann::json::object()}
};
ws_->write(boost::asio::buffer(init.dump()));
// Prepare audio callback
auto inputCb = [this](const std::vector<char>& audio) {
nlohmann::json msg = {
{"user_audio_chunk", base64Encode(audio)}
};
ws_->write(boost::asio::buffer(msg.dump()));
};
audioInterface_->start(inputCb);
beast::flat_buffer buffer;
while (!shouldStop_.load()) {
beast::error_code ec;
ws_->read(buffer, ec);
if (ec) {
std::cerr << "Websocket read error: " << ec.message() << std::endl;
break;
}
auto text = beast::buffers_to_string(buffer.data());
buffer.consume(buffer.size());
try {
auto message = nlohmann::json::parse(text);
handleMessage(message);
} catch (const std::exception& ex) {
std::cerr << "JSON parse error: " << ex.what() << std::endl;
}
}
} catch (const std::exception& ex) {
std::cerr << "Conversation error: " << ex.what() << std::endl;
}
}
void Conversation::handleMessage(const nlohmann::json& message) {
std::string type = message.value("type", "");
if (type == "conversation_initiation_metadata") {
conversationId_ = message["conversation_initiation_metadata_event"]["conversation_id"].get<std::string>();
} else if (type == "audio") {
auto event = message["audio_event"];
int eventId = std::stoi(toString(event["event_id"]));
if (eventId <= lastInterruptId_.load()) return;
auto audioBytes = base64Decode(event["audio_base_64"].get<std::string>());
audioInterface_->output(audioBytes);
} else if (type == "agent_response" && callbackAgentResponse_) {
auto event = message["agent_response_event"];
callbackAgentResponse_(event["agent_response"].get<std::string>());
} else if (type == "agent_response_correction" && callbackAgentResponseCorrection_) {
auto event = message["agent_response_correction_event"];
callbackAgentResponseCorrection_(event["original_agent_response"].get<std::string>(),
event["corrected_agent_response"].get<std::string>());
} else if (type == "user_transcript" && callbackUserTranscript_) {
auto event = message["user_transcription_event"];
callbackUserTranscript_(event["user_transcript"].get<std::string>());
} else if (type == "interruption") {
auto event = message["interruption_event"];
lastInterruptId_.store(std::stoi(toString(event["event_id"])));
audioInterface_->interrupt();
} else if (type == "ping") {
auto event = message["ping_event"];
nlohmann::json pong = {{"type", "pong"}, {"event_id", event["event_id"]}};
ws_->write(boost::asio::buffer(pong.dump()));
if (callbackLatencyMeasurement_ && event.contains("ping_ms")) {
int latency = event["ping_ms"].is_number() ? event["ping_ms"].get<int>() : std::stoi(event["ping_ms"].get<std::string>());
callbackLatencyMeasurement_(latency);
}
}
// Note: client tool call handling omitted for brevity.
}