MultiplexedSocket: use incrementing channel number
'Birthday paradox' implies that if the two peers create channels
at the same time, there is a non-zero chance of collision even for
small amount of channels, because channel ids are 16 bits:
* >50% probability of collision with N=256
* >1% probability of collision with N=40
* 0.2% probability of collision with N=16
For this reason, use incrementing channel numbers,
but start at 0x8000 for the server and at 1 for the client.
This method guarantees 0-collision for
at least the first 2^15-1 channels opened each side,
and reduces the probability of collision
after that.
Change-Id: I913cf2962a1fc577a6d24b9184a25d75d2473574
diff --git a/src/multiplexed_socket.cpp b/src/multiplexed_socket.cpp
index b34bf07..81a777c 100644
--- a/src/multiplexed_socket.cpp
+++ b/src/multiplexed_socket.cpp
@@ -60,12 +60,13 @@
Impl(MultiplexedSocket& parent,
std::shared_ptr<asio::io_context> ctx,
const DeviceId& deviceId,
- std::unique_ptr<TlsSocketEndpoint> endpoint,
+ std::unique_ptr<TlsSocketEndpoint> ep,
std::shared_ptr<dht::log::Logger> logger)
: parent_(parent)
, ctx_(std::move(ctx))
, deviceId(deviceId)
- , endpoint(std::move(endpoint))
+ , endpoint(std::move(ep))
+ , nextChannel_(endpoint->isInitiator() ? 0x0001u : 0x8000u)
, eventLoopThread_ {[this] {
try {
eventLoop();
@@ -193,6 +194,7 @@
std::mutex socketsMutex {};
std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
+ uint16_t nextChannel_;
// Main loop to parse incoming packets
std::atomic_bool stop {false};
@@ -551,20 +553,16 @@
std::shared_ptr<ChannelSocket>
MultiplexedSocket::addChannel(const std::string& name)
{
- // Note: because both sides can request the same channel number at the same time
- // it's better to use a random channel number instead of just incrementing the request.
- thread_local dht::crypto::random_device rd;
- std::uniform_int_distribution<uint16_t> dist;
- auto offset = dist(rd);
std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
- for (int i = 1; i < UINT16_MAX; ++i) {
- auto c = (offset + i) % UINT16_MAX;
- if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL
- || pimpl_->sockets.find(c) != pimpl_->sockets.end())
- continue;
- auto channel = pimpl_->makeSocket(name, c, true);
- return channel;
- }
+ if (pimpl_->sockets.size() < UINT16_MAX)
+ for (unsigned i = 0; i < UINT16_MAX; ++i) {
+ auto c = pimpl_->nextChannel_++;
+ if (c == CONTROL_CHANNEL
+ || c == PROTOCOL_CHANNEL
+ || pimpl_->sockets.find(c) != pimpl_->sockets.end())
+ continue;
+ return pimpl_->makeSocket(name, c, true);
+ }
return {};
}