build: add benchmark
Change-Id: I42977a327869e1d545c140361a9a155ebf13de9e
diff --git a/tools/benchmark/main.cpp b/tools/benchmark/main.cpp
new file mode 100644
index 0000000..0a9e8d7
--- /dev/null
+++ b/tools/benchmark/main.cpp
@@ -0,0 +1,229 @@
+#include "../common.h"
+
+#include "connectionmanager.h"
+#include "multiplexed_socket.h"
+#include "certstore.h"
+
+#include <opendht/log.h>
+#include <opendht/utils.h>
+#include <opendht/thread_pool.h>
+
+#include <asio/executor_work_guard.hpp>
+#include <asio/io_context.hpp>
+
+namespace dhtnet {
+using namespace std::literals::chrono_literals;
+using clock = std::chrono::high_resolution_clock;
+using time_point = clock::time_point;
+using duration = clock::duration;
+
+struct ConnectionHandler
+{
+ dht::crypto::Identity id;
+ std::shared_ptr<Logger> logger;
+ std::shared_ptr<tls::CertificateStore> certStore;
+ std::shared_ptr<dht::DhtRunner> dht;
+ std::shared_ptr<ConnectionManager> connectionManager;
+ std::shared_ptr<asio::io_context> ioContext;
+ std::shared_ptr<std::thread> ioContextRunner;
+};
+
+std::unique_ptr<ConnectionHandler>
+setupHandler(const std::string& name,
+ std::shared_ptr<asio::io_context> ioContext,
+ std::shared_ptr<std::thread> ioContextRunner,
+ std::unique_ptr<IceTransportFactory>& factory,
+ std::shared_ptr<Logger> logger)
+{
+ auto h = std::make_unique<ConnectionHandler>();
+ auto ca = dht::crypto::generateIdentity("ca");
+ h->id = dht::crypto::generateIdentity(name, ca);
+ h->logger = logger;
+ h->certStore = std::make_shared<tls::CertificateStore>(name, h->logger);
+ h->ioContext = std::make_shared<asio::io_context>();
+ h->ioContext = ioContext;
+ h->ioContextRunner = ioContextRunner;
+
+ dht::DhtRunner::Config dhtConfig;
+ dhtConfig.dht_config.id = h->id;
+ dhtConfig.threaded = true;
+
+ dht::DhtRunner::Context dhtContext;
+ dhtContext.certificateStore = [c = h->certStore](const dht::InfoHash& pk_id) {
+ std::vector<std::shared_ptr<dht::crypto::Certificate>> ret;
+ if (auto cert = c->getCertificate(pk_id.toString()))
+ ret.emplace_back(std::move(cert));
+ return ret;
+ };
+ // dhtContext.logger = h->logger;
+
+ h->dht = std::make_shared<dht::DhtRunner>();
+ h->dht->run(dhtConfig, std::move(dhtContext));
+ //h->dht->bootstrap("127.0.0.1:36432");
+ h->dht->bootstrap("bootstrap.jami.net");
+
+ auto config = std::make_shared<ConnectionManager::Config>();
+ config->dht = h->dht;
+ config->id = h->id;
+ config->ioContext = h->ioContext;
+ config->factory = factory.get();
+ config->logger = logger;
+ config->certStore = h->certStore.get();
+
+ std::filesystem::path currentPath = std::filesystem::current_path();
+ std::filesystem::path tempDirPath = currentPath / "temp";
+
+ config->cachePath = tempDirPath.string();
+
+ h->connectionManager = std::make_shared<ConnectionManager>(config);
+ h->connectionManager->onICERequest([](const DeviceId&) { return true; });
+ return h;
+}
+
+struct BenchResult {
+ duration connection;
+ duration send;
+ bool success;
+};
+
+BenchResult
+runBench(std::shared_ptr<asio::io_context> ioContext,
+ std::shared_ptr<std::thread> ioContextRunner,
+ std::unique_ptr<IceTransportFactory>& factory,
+ std::shared_ptr<Logger> logger)
+{
+ BenchResult ret;
+ std::mutex mtx;
+ std::unique_lock<std::mutex> lock {mtx};
+ std::condition_variable serverConVar;
+
+ //auto boostrap_node = std::make_shared<dht::DhtRunner>();
+ //boostrap_node->run(36432);
+
+ fmt::println("Generating identities…");
+ auto server = setupHandler("server", ioContext, ioContextRunner, factory, logger);
+ auto client = setupHandler("client", ioContext, ioContextRunner, factory, logger);
+
+ client->connectionManager->onDhtConnected(client->id.first->getPublicKey());
+ server->connectionManager->onDhtConnected(server->id.first->getPublicKey());
+
+ server->connectionManager->onChannelRequest(
+ [](const std::shared_ptr<dht::crypto::Certificate>&,
+ const std::string& name) {
+ return name == "channelName";
+ });
+ server->connectionManager->onConnectionReady([&](const DeviceId& device, const std::string& name, std::shared_ptr<ChannelSocket> socket) {
+ if (socket) {
+ fmt::println("Server: Connection succeeded");
+ socket->setOnRecv([s=socket.get()](const uint8_t* data, size_t size) {
+ std::error_code ec;
+ return s->write(data, size, ec);
+ });
+ } else {
+ fmt::println("Server: Connection failed");
+ }
+ });
+
+ std::condition_variable cv;
+ bool completed = false;
+ size_t rx = 0;
+ constexpr size_t TX_SIZE = 64 * 1024;
+ constexpr size_t TX_NUM = 1024;
+ constexpr size_t TX_GOAL = TX_SIZE * TX_NUM;
+ time_point start_connect, start_send;
+
+ std::this_thread::sleep_for(5s);
+ fmt::println("Connecting…");
+ start_connect = clock::now();
+ client->connectionManager->connectDevice(server->id.second, "channelName", [&](std::shared_ptr<ChannelSocket> socket, const DeviceId&) {
+ if (socket) {
+ socket->setOnRecv([&](const uint8_t* data, size_t size) {
+ rx += size;
+ if (rx == TX_GOAL) {
+ auto end = clock::now();
+ ret.send = end - start_send;
+ fmt::println("Streamed {} bytes back and forth in {} ({} kBps)", rx, dht::print_duration(ret.send), (unsigned)(rx / (1000 * std::chrono::duration<double>(ret.send).count())));
+ cv.notify_one();
+ }
+ return size;
+ });
+ ret.connection = clock::now() - start_connect;
+ fmt::println("Connected in {}", dht::print_duration(ret.connection));
+ std::vector<uint8_t> data(TX_SIZE, 'y');
+ std::error_code ec;
+ start_send = clock::now();
+ for (unsigned i = 0; i < TX_NUM; ++i) {
+ socket->write(data.data(), data.size(), ec);
+ if (ec)
+ fmt::println("error: {}", ec.message());
+ }
+ } else {
+ completed = true;
+ }
+ });
+ ret.success = cv.wait_for(lock, 60s, [&] { return completed or rx == TX_GOAL; });
+ std::this_thread::sleep_for(500ms);
+ return ret;
+}
+
+
+void
+bench()
+{
+
+ std::shared_ptr<Logger> logger;// = dht::log::getStdLogger();
+ auto factory = std::make_unique<IceTransportFactory>(logger);
+ auto ioContext = std::make_shared<asio::io_context>();
+ auto ioContextRunner = std::make_shared<std::thread>([context = ioContext]() {
+ try {
+ auto work = asio::make_work_guard(*context);
+ context->run();
+ } catch (const std::exception& ex) {
+ fmt::println(stderr, "Exception: {}", ex.what());
+ }
+ });
+
+ BenchResult total;
+ unsigned total_success = 0;
+ constexpr unsigned ITERATIONS = 20;
+ for (unsigned i = 0; i < ITERATIONS; ++i) {
+ fmt::println("Iteration {}", i);
+ auto res = runBench(ioContext, ioContextRunner, factory, logger);
+ if (res.success) {
+ total.connection += res.connection;
+ total.send += res.send;
+ total_success++;
+ }
+ }
+ fmt::println("Average connection time: {}", dht::print_duration(total.connection / total_success));
+ fmt::println("Average send time: {}", dht::print_duration(total.send / total_success));
+ fmt::println("Total success: {}", total_success);
+
+ std::this_thread::sleep_for(500ms);
+ ioContext->stop();
+ ioContextRunner->join();
+}
+
+}
+
+static void
+setSipLogLevel()
+{
+ int level = 0;
+ if (char* envvar = getenv("SIPLOGLEVEL")) {
+ // From 0 (min) to 6 (max)
+ level = std::clamp(std::stoi(envvar), 0, 6);
+ }
+
+ pj_log_set_level(level);
+ pj_log_set_log_func([](int level, const char* data, int /*len*/) {
+ });
+}
+
+int
+main(int argc, char** argv)
+{
+ setSipLogLevel();
+ dhtnet::bench();
+ return 0;
+}
\ No newline at end of file
diff --git a/tools/common.h b/tools/common.h
index 1181ee0..94d06d2 100644
--- a/tools/common.h
+++ b/tools/common.h
@@ -15,7 +15,10 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include <opendht/crypto.h>
-
+#include "connectionmanager.h"
+#include "multiplexed_socket.h"
+#include "ice_transport_factory.h"
+#include "certstore.h"
namespace dhtnet {
diff --git a/tools/dnc/main.cpp b/tools/dnc/main.cpp
index 929bcc1..278990b 100644
--- a/tools/dnc/main.cpp
+++ b/tools/dnc/main.cpp
@@ -113,15 +113,10 @@
static void
setSipLogLevel()
{
- char* envvar = getenv("SIPLOGLEVEL");
-
int level = 0;
-
- if (envvar != nullptr) {
- level = std::stoi(envvar);
-
+ if (char* envvar = getenv("SIPLOGLEVEL")) {
// From 0 (min) to 6 (max)
- level = std::max(0, std::min(level, 6));
+ level = std::clamp(std::stoi(envvar), 0, 6);
}
pj_log_set_level(level);