230 lines
8.8 KiB
C++
230 lines
8.8 KiB
C++
#include "Conversation.hpp"
|
|
|
|
#include <boost/beast/websocket/ssl.hpp>
|
|
#include <boost/beast/websocket.hpp>
|
|
#include <boost/beast/ssl.hpp>
|
|
#include <boost/beast/core/detail/base64.hpp>
|
|
#include <boost/asio/connect.hpp>
|
|
#include <boost/algorithm/string.hpp>
|
|
#include <iostream>
|
|
#include <sstream>
|
|
#include <openssl/ssl.h>
|
|
|
|
using tcp = boost::asio::ip::tcp;
|
|
namespace ssl = boost::asio::ssl;
|
|
namespace websocket = boost::beast::websocket;
|
|
namespace beast = boost::beast;
|
|
|
|
static std::string base64Encode(const std::vector<char>& data) {
|
|
auto encodedSize = beast::detail::base64::encoded_size(data.size());
|
|
std::string out(encodedSize, '\0');
|
|
beast::detail::base64::encode(&out[0], data.data(), data.size());
|
|
return out;
|
|
}
|
|
|
|
static std::vector<char> base64Decode(const std::string& str) {
|
|
auto decodedSize = beast::detail::base64::decoded_size(str.size());
|
|
std::vector<char> out(decodedSize);
|
|
auto result = beast::detail::base64::decode(out.data(), str.data(), str.size());
|
|
out.resize(result.first);
|
|
return out;
|
|
}
|
|
|
|
static std::string toString(const nlohmann::json& j){
|
|
if(j.is_string()) return j.get<std::string>();
|
|
if(j.is_number_integer()) return std::to_string(j.get<int64_t>());
|
|
return j.dump();
|
|
}
|
|
|
|
Conversation::Conversation(const std::string& agentId, bool requiresAuth,
|
|
std::shared_ptr<AudioInterface> audioInterface,
|
|
CallbackAgentResponse callbackAgentResponse,
|
|
CallbackAgentResponseCorrection callbackAgentResponseCorrection,
|
|
CallbackUserTranscript callbackUserTranscript,
|
|
CallbackLatencyMeasurement callbackLatencyMeasurement)
|
|
: agentId_(agentId),
|
|
requiresAuth_(requiresAuth),
|
|
audioInterface_(std::move(audioInterface)),
|
|
callbackAgentResponse_(std::move(callbackAgentResponse)),
|
|
callbackAgentResponseCorrection_(std::move(callbackAgentResponseCorrection)),
|
|
callbackUserTranscript_(std::move(callbackUserTranscript)),
|
|
callbackLatencyMeasurement_(std::move(callbackLatencyMeasurement)) {
|
|
|
|
sslCtx_.set_default_verify_paths();
|
|
}
|
|
|
|
Conversation::~Conversation() {
|
|
endSession();
|
|
}
|
|
|
|
void Conversation::startSession() {
|
|
shouldStop_.store(false);
|
|
workerThread_ = std::thread(&Conversation::run, this);
|
|
}
|
|
|
|
void Conversation::endSession() {
|
|
shouldStop_.store(true);
|
|
if (ws_) {
|
|
beast::error_code ec;
|
|
ws_->close(websocket::close_code::normal, ec);
|
|
}
|
|
if (audioInterface_) {
|
|
audioInterface_->stop();
|
|
}
|
|
if (workerThread_.joinable()) {
|
|
workerThread_.join();
|
|
}
|
|
}
|
|
|
|
std::string Conversation::waitForSessionEnd() {
|
|
if (workerThread_.joinable()) {
|
|
workerThread_.join();
|
|
}
|
|
return conversationId_;
|
|
}
|
|
|
|
void Conversation::sendUserMessage(const std::string& text) {
|
|
if (!ws_) {
|
|
throw std::runtime_error("Session not started");
|
|
}
|
|
nlohmann::json j = {
|
|
{"type", "user_message"},
|
|
{"text", text}
|
|
};
|
|
ws_->write(boost::asio::buffer(j.dump()));
|
|
}
|
|
|
|
void Conversation::registerUserActivity() {
|
|
if (!ws_) throw std::runtime_error("Session not started");
|
|
nlohmann::json j = {{"type", "user_activity"}};
|
|
ws_->write(boost::asio::buffer(j.dump()));
|
|
}
|
|
|
|
void Conversation::sendContextualUpdate(const std::string& content) {
|
|
if (!ws_) throw std::runtime_error("Session not started");
|
|
nlohmann::json j = {{"type", "contextual_update"}, {"content", content}};
|
|
ws_->write(boost::asio::buffer(j.dump()));
|
|
}
|
|
|
|
std::string Conversation::getWssUrl() const {
|
|
// Hard-coded base env for demo; in production you'd call ElevenLabs env endpoint.
|
|
std::ostringstream oss;
|
|
oss << "wss://api.elevenlabs.io/v1/convai/conversation?agent_id=" << agentId_;
|
|
return oss.str();
|
|
}
|
|
|
|
void Conversation::run() {
|
|
try {
|
|
auto url = getWssUrl();
|
|
std::string protocol, host, target;
|
|
unsigned short port = 443;
|
|
|
|
// Very naive parse: wss://host[:port]/path?query
|
|
if (boost::starts_with(url, "wss://")) {
|
|
protocol = "wss";
|
|
host = url.substr(6);
|
|
} else {
|
|
throw std::runtime_error("Only wss:// URLs supported in this demo");
|
|
}
|
|
auto slashPos = host.find('/');
|
|
if (slashPos == std::string::npos) {
|
|
target = "/";
|
|
} else {
|
|
target = host.substr(slashPos);
|
|
host = host.substr(0, slashPos);
|
|
}
|
|
auto colonPos = host.find(':');
|
|
if (colonPos != std::string::npos) {
|
|
port = static_cast<unsigned short>(std::stoi(host.substr(colonPos + 1)));
|
|
host = host.substr(0, colonPos);
|
|
}
|
|
|
|
tcp::resolver resolver(ioc_);
|
|
auto const results = resolver.resolve(host, std::to_string(port));
|
|
|
|
beast::ssl_stream<tcp::socket> stream(ioc_, sslCtx_);
|
|
boost::asio::connect(beast::get_lowest_layer(stream), results);
|
|
if (!SSL_set_tlsext_host_name(stream.native_handle(), host.c_str())) {
|
|
throw std::runtime_error("Failed to set SNI hostname on SSL stream");
|
|
}
|
|
stream.handshake(ssl::stream_base::client);
|
|
|
|
ws_ = std::make_unique<websocket_t>(std::move(stream));
|
|
ws_->set_option(websocket::stream_base::timeout::suggested(beast::role_type::client));
|
|
ws_->handshake(host, target);
|
|
|
|
// send initiation data
|
|
nlohmann::json init = {
|
|
{"type", "conversation_initiation_client_data"},
|
|
{"custom_llm_extra_body", nlohmann::json::object()},
|
|
{"conversation_config_override", nlohmann::json::object()},
|
|
{"dynamic_variables", nlohmann::json::object()}
|
|
};
|
|
ws_->write(boost::asio::buffer(init.dump()));
|
|
|
|
// Prepare audio callback
|
|
auto inputCb = [this](const std::vector<char>& audio) {
|
|
nlohmann::json msg = {
|
|
{"user_audio_chunk", base64Encode(audio)}
|
|
};
|
|
ws_->write(boost::asio::buffer(msg.dump()));
|
|
};
|
|
audioInterface_->start(inputCb);
|
|
|
|
beast::flat_buffer buffer;
|
|
while (!shouldStop_.load()) {
|
|
beast::error_code ec;
|
|
ws_->read(buffer, ec);
|
|
if (ec) {
|
|
std::cerr << "Websocket read error: " << ec.message() << std::endl;
|
|
break;
|
|
}
|
|
auto text = beast::buffers_to_string(buffer.data());
|
|
buffer.consume(buffer.size());
|
|
try {
|
|
auto message = nlohmann::json::parse(text);
|
|
handleMessage(message);
|
|
} catch (const std::exception& ex) {
|
|
std::cerr << "JSON parse error: " << ex.what() << std::endl;
|
|
}
|
|
}
|
|
} catch (const std::exception& ex) {
|
|
std::cerr << "Conversation error: " << ex.what() << std::endl;
|
|
}
|
|
}
|
|
|
|
void Conversation::handleMessage(const nlohmann::json& message) {
|
|
std::string type = message.value("type", "");
|
|
if (type == "conversation_initiation_metadata") {
|
|
conversationId_ = message["conversation_initiation_metadata_event"]["conversation_id"].get<std::string>();
|
|
} else if (type == "audio") {
|
|
auto event = message["audio_event"];
|
|
int eventId = std::stoi(toString(event["event_id"]));
|
|
if (eventId <= lastInterruptId_.load()) return;
|
|
auto audioBytes = base64Decode(event["audio_base_64"].get<std::string>());
|
|
audioInterface_->output(audioBytes);
|
|
} else if (type == "agent_response" && callbackAgentResponse_) {
|
|
auto event = message["agent_response_event"];
|
|
callbackAgentResponse_(event["agent_response"].get<std::string>());
|
|
} else if (type == "agent_response_correction" && callbackAgentResponseCorrection_) {
|
|
auto event = message["agent_response_correction_event"];
|
|
callbackAgentResponseCorrection_(event["original_agent_response"].get<std::string>(),
|
|
event["corrected_agent_response"].get<std::string>());
|
|
} else if (type == "user_transcript" && callbackUserTranscript_) {
|
|
auto event = message["user_transcription_event"];
|
|
callbackUserTranscript_(event["user_transcript"].get<std::string>());
|
|
} else if (type == "interruption") {
|
|
auto event = message["interruption_event"];
|
|
lastInterruptId_.store(std::stoi(toString(event["event_id"])));
|
|
audioInterface_->interrupt();
|
|
} else if (type == "ping") {
|
|
auto event = message["ping_event"];
|
|
nlohmann::json pong = {{"type", "pong"}, {"event_id", event["event_id"]}};
|
|
ws_->write(boost::asio::buffer(pong.dump()));
|
|
if (callbackLatencyMeasurement_ && event.contains("ping_ms")) {
|
|
int latency = event["ping_ms"].is_number() ? event["ping_ms"].get<int>() : std::stoi(event["ping_ms"].get<std::string>());
|
|
callbackLatencyMeasurement_(latency);
|
|
}
|
|
}
|
|
// Note: client tool call handling omitted for brevity.
|
|
}
|