ConnectionManager: refactor to reduce contention
* merge info, waiting and pending maps into DeviceInfo,
improving consistency of various operations
* DeviceInfo has its own mutex,
and operations keep a weak pointer to it,
reducing contention
* DeviceInfoSet encapsulate atomic management of DeviceInfos,
enforcing proper locking order
Change-Id: I12c107782600355a34d460e7802d92ae9d590993
diff --git a/src/connectionmanager.cpp b/src/connectionmanager.cpp
index cf78c54..0304c21 100644
--- a/src/connectionmanager.cpp
+++ b/src/connectionmanager.cpp
@@ -40,14 +40,14 @@
static constexpr uint64_t ID_MAX_VAL = 9007199254740992;
using ValueIdDist = std::uniform_int_distribution<dht::Value::Id>;
-using CallbackId = std::pair<dhtnet::DeviceId, dht::Value::Id>;
+
std::string
callbackIdToString(const dhtnet::DeviceId& did, const dht::Value::Id& vid)
{
return fmt::format("{} {}", did.to_view(), vid);
}
-CallbackId parseCallbackId(std::string_view ci)
+std::pair<dhtnet::DeviceId, dht::Value::Id> parseCallbackId(std::string_view ci)
{
auto sep = ci.find(' ');
std::string_view deviceIdString = ci.substr(0, sep);
@@ -55,8 +55,7 @@
dhtnet::DeviceId deviceId(deviceIdString);
dht::Value::Id vid = std::stoul(std::string(vidString), nullptr, 10);
-
- return CallbackId(deviceId, vid);
+ return {deviceId, vid};
}
std::shared_ptr<ConnectionManager::Config>
@@ -101,12 +100,252 @@
// Used to store currently non ready TLS Socket
std::unique_ptr<TlsSocketEndpoint> tls_ {nullptr};
std::shared_ptr<MultiplexedSocket> socket_ {};
- std::set<CallbackId> cbIds_ {};
+ std::set<dht::Value::Id> cbIds_ {};
std::function<void(bool)> onConnected_;
std::unique_ptr<asio::steady_timer> waitForAnswer_ {};
+
+ void shutdown() {
+ std::lock_guard<std::mutex> lk(mutex_);
+ if (tls_)
+ tls_->shutdown();
+ if (socket_)
+ socket_->shutdown();
+ if (waitForAnswer_)
+ waitForAnswer_->cancel();
+ if (ice_) {
+ dht::ThreadPool::io().run(
+ [ice = std::shared_ptr<IceTransport>(std::move(ice_))] {});
+ }
+ }
+
+ std::map<std::string, std::string>
+ getInfo(const DeviceId& deviceId, dht::Value::Id valueId, tls::CertificateStore& certStore) const
+ {
+ std::map<std::string, std::string> connectionInfo;
+ connectionInfo["id"] = callbackIdToString(deviceId, valueId);
+ connectionInfo["device"] = deviceId.toString();
+ auto cert = tls_ ? tls_->peerCertificate() : (socket_ ? socket_->peerCertificate() : nullptr);
+ if (not cert)
+ cert = certStore.getCertificate(deviceId.toString());
+ if (cert) {
+ connectionInfo["peer"] = cert->issuer->getId().toString();
+ }
+ if (socket_) {
+ connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Connected));
+ connectionInfo["remoteAddress"] = socket_->getRemoteAddress();
+ } else if (tls_) {
+ connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::TLS));
+ connectionInfo["remoteAddress"] = tls_->getRemoteAddress();
+ } else if(ice_) {
+ connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::ICE));
+ connectionInfo["remoteAddress"] = ice_->getRemoteAddress(ICE_COMP_ID_SIP_TRANSPORT);
+ }
+ return connectionInfo;
+ }
};
+struct PendingCb {
+ std::string name;
+ ConnectCallback cb;
+};
+
+struct DeviceInfo {
+ const DeviceId deviceId;
+ mutable std::mutex mtx_ {};
+ std::map<dht::Value::Id, std::shared_ptr<ConnectionInfo>> info;
+ std::map<dht::Value::Id, PendingCb> connecting;
+ std::map<dht::Value::Id, PendingCb> waiting;
+ DeviceInfo(DeviceId id) : deviceId {id} {}
+
+ inline bool isConnecting() const {
+ return !connecting.empty() || !waiting.empty();
+ }
+
+ inline bool empty() const {
+ return info.empty() && connecting.empty() && waiting.empty();
+ }
+
+ dht::Value::Id newId(std::mt19937_64& rand) const {
+ ValueIdDist dist(1, ID_MAX_VAL);
+ dht::Value::Id id;
+ do {
+ id = dist(rand);
+ } while (info.find(id) != info.end()
+ || connecting.find(id) != connecting.end()
+ || waiting.find(id) != waiting.end());
+ return id;
+ }
+
+ std::shared_ptr<ConnectionInfo> getConnectedInfo() const {
+ for (auto& [id, ci] : info) {
+ if (ci->socket_)
+ return ci;
+ }
+ return {};
+ }
+
+ std::vector<PendingCb> extractPendingOperations(dht::Value::Id vid, const std::shared_ptr<ChannelSocket>& sock, bool accepted = true)
+ {
+ std::vector<PendingCb> ret;
+ if (vid == 0) {
+ // Extract all pending callbacks
+ ret.reserve(connecting.size() + waiting.size());
+ for (auto& [vid, cb] : connecting)
+ ret.emplace_back(std::move(cb));
+ connecting.clear();
+ for (auto& [vid, cb] : waiting)
+ ret.emplace_back(std::move(cb));
+ waiting.clear();
+ } else if (auto n = waiting.extract(vid)) {
+ // If it's a waiting operation, just move it
+ ret.emplace_back(std::move(n.mapped()));
+ } else if (auto n = connecting.extract(vid)) {
+ ret.emplace_back(std::move(n.mapped()));
+ // If sock is nullptr, execute if it's the last connecting operation
+ // If accepted is false, it means that underlying socket is ok, but channel is declined
+ if (!sock && connecting.empty() && accepted) {
+ for (auto& [vid, cb] : waiting)
+ ret.emplace_back(std::move(cb));
+ waiting.clear();
+ for (auto& [vid, cb] : connecting)
+ ret.emplace_back(std::move(cb));
+ connecting.clear();
+ }
+ }
+ return ret;
+ }
+
+ std::vector<std::shared_ptr<ConnectionInfo>> extractUnusedConnections() {
+ std::vector<std::shared_ptr<ConnectionInfo>> unused {};
+ for (auto& [id, info] : info)
+ unused.emplace_back(std::move(info));
+ info.clear();
+ return unused;
+ }
+
+ void executePendingOperations(std::unique_lock<std::mutex>& lock, dht::Value::Id vid, const std::shared_ptr<ChannelSocket>& sock, bool accepted = true) {
+ auto ops = extractPendingOperations(vid, sock, accepted);
+ lock.unlock();
+ for (auto& cb : ops)
+ cb.cb(sock, deviceId);
+ }
+ void executePendingOperations(dht::Value::Id vid, const std::shared_ptr<ChannelSocket>& sock, bool accepted = true) {
+ std::unique_lock<std::mutex> lock(mtx_);
+ executePendingOperations(lock, vid, sock, accepted);
+ }
+
+ std::map<dht::Value::Id, std::string> getPendingIds() const {
+ std::map<dht::Value::Id, std::string> ret;
+ for (const auto& [id, pc]: connecting)
+ ret[id] = pc.name;
+ for (const auto& [id, pc]: waiting)
+ ret[id] = pc.name;
+ return ret;
+ }
+
+ std::vector<std::map<std::string, std::string>>
+ getConnectionList(tls::CertificateStore& certStore) const {
+ std::lock_guard<std::mutex> lk(mtx_);
+ std::vector<std::map<std::string, std::string>> ret;
+ ret.reserve(info.size());
+ for (auto& [id, ci] : info) {
+ std::lock_guard<std::mutex> lk(ci->mutex_);
+ ret.emplace_back(ci->getInfo(deviceId, id, certStore));
+ }
+ auto cert = certStore.getCertificate(deviceId.toString());
+ for (const auto& [vid, ci] : connecting) {
+ ret.emplace_back(std::map<std::string, std::string> {
+ {"id", callbackIdToString(deviceId, vid)},
+ {"status", std::to_string(static_cast<int>(ConnectionStatus::Connecting))},
+ {"device", deviceId.toString()},
+ {"peer", cert ? cert->issuer->getId().toString() : ""}
+ });
+ }
+ for (const auto& [vid, ci] : waiting) {
+ ret.emplace_back(std::map<std::string, std::string> {
+ {"id", callbackIdToString(deviceId, vid)},
+ {"status", std::to_string(static_cast<int>(ConnectionStatus::Waiting))},
+ {"device", deviceId.toString()},
+ {"peer", cert ? cert->issuer->getId().toString() : ""}
+ });
+ }
+ return ret;
+ }
+};
+
+class DeviceInfoSet {
+public:
+ std::shared_ptr<DeviceInfo> getDeviceInfo(const DeviceId& deviceId) {
+ std::lock_guard<std::mutex> lk(mtx_);
+ auto it = infos_.find(deviceId);
+ if (it != infos_.end())
+ return it->second;
+ return {};
+ }
+
+ std::vector<std::shared_ptr<DeviceInfo>> getDeviceInfos() {
+ std::vector<std::shared_ptr<DeviceInfo>> deviceInfos;
+ std::lock_guard<std::mutex> lk(mtx_);
+ deviceInfos.reserve(infos_.size());
+ for (auto& [deviceId, info] : infos_)
+ deviceInfos.emplace_back(info);
+ return deviceInfos;
+ }
+
+ std::shared_ptr<DeviceInfo> createDeviceInfo(const DeviceId& deviceId) {
+ std::lock_guard<std::mutex> lk(mtx_);
+ auto& info = infos_[deviceId];
+ if (!info)
+ info = std::make_shared<DeviceInfo>(deviceId);
+ return info;
+ }
+
+ bool removeDeviceInfo(const DeviceId& deviceId) {
+ std::lock_guard<std::mutex> lk(mtx_);
+ return infos_.erase(deviceId) != 0;
+ }
+
+ std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId, const dht::Value::Id& id) {
+ if (auto info = getDeviceInfo(deviceId)) {
+ std::lock_guard<std::mutex> lk(info->mtx_);
+ auto it = info->info.find(id);
+ if (it != info->info.end())
+ return it->second;
+ }
+ return {};
+ }
+
+ std::vector<std::shared_ptr<ConnectionInfo>> getConnectedInfos() {
+ auto deviceInfos = getDeviceInfos();
+ std::vector<std::shared_ptr<ConnectionInfo>> ret;
+ ret.reserve(deviceInfos.size());
+ for (auto& info : deviceInfos) {
+ std::lock_guard<std::mutex> lk(info->mtx_);
+ for (auto& [id, ci] : info->info) {
+ if (ci->socket_)
+ ret.emplace_back(ci);
+ }
+ }
+ return ret;
+ }
+ std::vector<std::shared_ptr<DeviceInfo>> shutdown() {
+ std::vector<std::shared_ptr<DeviceInfo>> ret;
+ std::lock_guard<std::mutex> lk(mtx_);
+ ret.reserve(infos_.size());
+ for (auto& [deviceId, info] : infos_) {
+ ret.emplace_back(std::move(info));
+ }
+ infos_.clear();
+ return ret;
+ }
+
+private:
+ std::mutex mtx_ {};
+ std::map<DeviceId, std::shared_ptr<DeviceInfo>> infos_ {};
+};
+
+
/**
* returns whether or not UPnP is enabled and active_
* ie: if it is able to make port mappings
@@ -124,7 +363,7 @@
public:
explicit Impl(std::shared_ptr<ConnectionManager::Config> config_)
: config_ {std::move(createConfig(config_))}
- , rand {dht::crypto::getSeededRandomEngine<std::mt19937_64>()}
+ , rand_ {dht::crypto::getSeededRandomEngine<std::mt19937_64>()}
{
loadTreatedMessages();
if(!config_->ioContext) {
@@ -151,61 +390,40 @@
std::shared_ptr<dht::DhtRunner> dht() { return config_->dht; }
const dht::crypto::Identity& identity() const { return config_->id; }
- void removeUnusedConnections(const DeviceId& deviceId = {})
+ void shutdown()
{
- std::vector<std::shared_ptr<ConnectionInfo>> unused {};
-
- {
- std::lock_guard<std::mutex> lk(infosMtx_);
- for (auto it = infos_.begin(); it != infos_.end();) {
- auto& [key, info] = *it;
- if (info && (!deviceId || key.first == deviceId)) {
- unused.emplace_back(std::move(info));
- it = infos_.erase(it);
- } else {
- ++it;
- }
- }
+ if (isDestroying_.exchange(true))
+ return;
+ std::vector<std::shared_ptr<ConnectionInfo>> unused;
+ std::vector<std::pair<DeviceId, std::vector<PendingCb>>> pending;
+ for (auto& dinfo: infos_.shutdown()) {
+ std::lock_guard<std::mutex> lk(dinfo->mtx_);
+ auto p = dinfo->extractPendingOperations(0, nullptr, false);
+ if (!p.empty())
+ pending.emplace_back(dinfo->deviceId, std::move(p));
+ auto uc = dinfo->extractUnusedConnections();
+ unused.insert(unused.end(), std::make_move_iterator(uc.begin()), std::make_move_iterator(uc.end()));
}
- for (auto& info: unused) {
- if (info->tls_)
- info->tls_->shutdown();
- if (info->socket_)
- info->socket_->shutdown();
- if (info->waitForAnswer_)
- info->waitForAnswer_->cancel();
- }
+ for (auto& info: unused)
+ info->shutdown();
+ for (auto& op: pending)
+ for (auto& cb: op.second)
+ cb.cb(nullptr, op.first);
if (!unused.empty())
dht::ThreadPool::io().run([infos = std::move(unused)]() mutable {
infos.clear();
});
}
- void shutdown()
- {
- if (isDestroying_.exchange(true))
- return;
- decltype(pendingOperations_) po;
- {
- std::lock_guard<std::mutex> lk(connectCbsMtx_);
- po = std::move(pendingOperations_);
- }
- for (auto& [deviceId, pcbs] : po) {
- for (auto& [id, pending] : pcbs.connecting)
- pending.cb(nullptr, deviceId);
- for (auto& [id, pending] : pcbs.waiting)
- pending.cb(nullptr, deviceId);
- }
-
- removeUnusedConnections();
- }
-
- void connectDeviceStartIce(const std::shared_ptr<dht::crypto::PublicKey>& devicePk,
+ void connectDeviceStartIce(const std::shared_ptr<ConnectionInfo>& info,
+ const std::shared_ptr<dht::crypto::PublicKey>& devicePk,
const dht::Value::Id& vid,
const std::string& connType,
std::function<void(bool)> onConnected);
- void onResponse(const asio::error_code& ec, const DeviceId& deviceId, const dht::Value::Id& vid);
- bool connectDeviceOnNegoDone(const DeviceId& deviceId,
+ void onResponse(const asio::error_code& ec, const std::weak_ptr<ConnectionInfo>& info, const DeviceId& deviceId, const dht::Value::Id& vid);
+ bool connectDeviceOnNegoDone(const std::weak_ptr<DeviceInfo>& dinfo,
+ const std::shared_ptr<ConnectionInfo>& info,
+ const DeviceId& deviceId,
const std::string& name,
const dht::Value::Id& vid,
const std::shared_ptr<dht::crypto::Certificate>& cert);
@@ -235,9 +453,9 @@
* @param vid channel's id
* @param deviceId to identify the linked ConnectCallback
*/
- void sendChannelRequest(std::shared_ptr<MultiplexedSocket>& sock,
+ void sendChannelRequest(const std::weak_ptr<DeviceInfo>& dinfo,
+ const std::shared_ptr<MultiplexedSocket>& sock,
const std::string& name,
- const DeviceId& deviceId,
const dht::Value::Id& vid);
/**
* Triggered when a PeerConnectionRequest comes from the DHT
@@ -245,15 +463,29 @@
void answerTo(IceTransport& ice,
const dht::Value::Id& id,
const std::shared_ptr<dht::crypto::PublicKey>& fromPk);
- bool onRequestStartIce(const PeerConnectionRequest& req);
- bool onRequestOnNegoDone(const PeerConnectionRequest& req);
+ bool onRequestStartIce(const std::shared_ptr<ConnectionInfo>& info, const PeerConnectionRequest& req);
+ bool onRequestOnNegoDone(const std::weak_ptr<DeviceInfo>& dinfo, const std::shared_ptr<ConnectionInfo>& info, const PeerConnectionRequest& req);
void onDhtPeerRequest(const PeerConnectionRequest& req,
const std::shared_ptr<dht::crypto::Certificate>& cert);
+ /**
+ * Triggered when a new TLS socket is ready to use
+ * @param ok If succeed
+ * @param deviceId Related device
+ * @param vid vid of the connection request
+ * @param name non empty if TLS was created by connectDevice()
+ */
+ void onTlsNegotiationDone(const std::shared_ptr<DeviceInfo>& dinfo,
+ const std::shared_ptr<ConnectionInfo>& info,
+ bool ok,
+ const DeviceId& deviceId,
+ const dht::Value::Id& vid,
+ const std::string& name = "");
- void addNewMultiplexedSocket(const CallbackId& id, const std::shared_ptr<ConnectionInfo>& info);
+ void addNewMultiplexedSocket(const std::weak_ptr<DeviceInfo>& dinfo, const DeviceId& deviceId, const dht::Value::Id& vid, const std::shared_ptr<ConnectionInfo>& info);
void onPeerResponse(const PeerConnectionRequest& req);
void onDhtConnected(const dht::crypto::PublicKey& devicePk);
+
const std::shared_future<tls::DhParams> dhParams() const;
tls::CertificateStore& certStore() const { return *config_->certStore; }
@@ -327,162 +559,31 @@
*/
bool getUPnPActive() const;
- /**
- * Triggered when a new TLS socket is ready to use
- * @param ok If succeed
- * @param deviceId Related device
- * @param vid vid of the connection request
- * @param name non empty if TLS was created by connectDevice()
- */
- void onTlsNegotiationDone(bool ok,
- const DeviceId& deviceId,
- const dht::Value::Id& vid,
- const std::string& name = "");
-
std::shared_ptr<ConnectionManager::Config> config_;
std::unique_ptr<std::thread> ioContextRunner_;
- mutable std::mt19937_64 rand;
+ mutable std::mutex randMtx_;
+ mutable std::mt19937_64 rand_;
iOSConnectedCallback iOSConnectedCb_ {};
- std::mutex infosMtx_ {};
- // Note: Someone can ask multiple sockets, so to avoid any race condition,
- // each device can have multiple multiplexed sockets.
- std::map<CallbackId, std::shared_ptr<ConnectionInfo>> infos_ {};
-
- std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId, const dht::Value::Id& id)
- {
- std::lock_guard<std::mutex> lk(infosMtx_);
- auto it = infos_.find({deviceId, id});
- if (it != infos_.end())
- return it->second;
- return {};
- }
-
- std::shared_ptr<ConnectionInfo> getConnectedInfo(const DeviceId& deviceId)
- {
- std::lock_guard<std::mutex> lk(infosMtx_);
- auto it = std::find_if(infos_.begin(), infos_.end(), [&](const auto& item) {
- auto& [key, value] = item;
- return key.first == deviceId && value && value->socket_;
- });
- if (it != infos_.end())
- return it->second;
- return {};
- }
+ DeviceInfoSet infos_ {};
ChannelRequestCallback channelReqCb_ {};
ConnectionReadyCallback connReadyCb_ {};
onICERequestCallback iceReqCb_ {};
-
- /**
- * Stores callback from connectDevice
- * @note: each device needs a vector because several connectDevice can
- * be done in parallel and we only want one socket
- */
- std::mutex connectCbsMtx_ {};
-
-
- struct PendingCb
- {
- std::string name;
- ConnectCallback cb;
- };
- struct PendingOperations {
- std::map<dht::Value::Id, PendingCb> connecting;
- std::map<dht::Value::Id, PendingCb> waiting;
- };
-
- std::map<DeviceId, PendingOperations> pendingOperations_ {};
-
- void executePendingOperations(const DeviceId& deviceId, const dht::Value::Id& vid, const std::shared_ptr<ChannelSocket>& sock, bool accepted = true)
- {
- std::vector<PendingCb> ret;
- std::unique_lock<std::mutex> lk(connectCbsMtx_);
- auto it = pendingOperations_.find(deviceId);
- if (it == pendingOperations_.end())
- return;
- auto& pendingOperations = it->second;
- if (vid == 0) {
- // Extract all pending callbacks
- for (auto& [vid, cb] : pendingOperations.connecting)
- ret.emplace_back(std::move(cb));
- pendingOperations.connecting.clear();
- for (auto& [vid, cb] : pendingOperations.waiting)
- ret.emplace_back(std::move(cb));
- pendingOperations.waiting.clear();
- } else if (auto n = pendingOperations.waiting.extract(vid)) {
- // If it's a waiting operation, just move it
- ret.emplace_back(std::move(n.mapped()));
- } else if (auto n = pendingOperations.connecting.extract(vid)) {
- ret.emplace_back(std::move(n.mapped()));
- // If sock is nullptr, execute if it's the last connecting operation
- // If accepted is false, it means that underlying socket is ok, but channel is declined
- if (!sock && pendingOperations.connecting.empty() && accepted) {
- for (auto& [vid, cb] : pendingOperations.waiting)
- ret.emplace_back(std::move(cb));
- pendingOperations.waiting.clear();
- for (auto& [vid, cb] : pendingOperations.connecting)
- ret.emplace_back(std::move(cb));
- pendingOperations.connecting.clear();
- }
- }
- if (pendingOperations.waiting.empty() && pendingOperations.connecting.empty())
- pendingOperations_.erase(it);
- lk.unlock();
- for (auto& cb : ret)
- cb.cb(sock, deviceId);
- }
-
- std::map<dht::Value::Id, std::string> getPendingIds(const DeviceId& deviceId, const dht::Value::Id vid = 0)
- {
- std::map<dht::Value::Id, std::string> ret;
- std::lock_guard<std::mutex> lk(connectCbsMtx_);
- auto it = pendingOperations_.find(deviceId);
- if (it == pendingOperations_.end())
- return ret;
- auto& pendingOp = it->second;
- for (const auto& [id, pc]: pendingOp.connecting) {
- if (vid == 0 || id == vid)
- ret[id] = pc.name;
- }
- for (const auto& [id, pc]: pendingOp.waiting) {
- if (vid == 0 || id == vid)
- ret[id] = pc.name;
- }
- return ret;
- }
-
- std::shared_ptr<ConnectionManager::Impl> shared()
- {
- return std::static_pointer_cast<ConnectionManager::Impl>(shared_from_this());
- }
- std::shared_ptr<ConnectionManager::Impl const> shared() const
- {
- return std::static_pointer_cast<ConnectionManager::Impl const>(shared_from_this());
- }
- std::weak_ptr<ConnectionManager::Impl> weak()
- {
- return std::static_pointer_cast<ConnectionManager::Impl>(shared_from_this());
- }
- std::weak_ptr<ConnectionManager::Impl const> weak() const
- {
- return std::static_pointer_cast<ConnectionManager::Impl const>(shared_from_this());
- }
-
std::atomic_bool isDestroying_ {false};
};
void
ConnectionManager::Impl::connectDeviceStartIce(
+ const std::shared_ptr<ConnectionInfo>& info,
const std::shared_ptr<dht::crypto::PublicKey>& devicePk,
const dht::Value::Id& vid,
const std::string& connType,
std::function<void(bool)> onConnected)
{
auto deviceId = devicePk->getLongId();
- auto info = getInfo(deviceId, vid);
if (!info) {
onConnected(false);
return;
@@ -542,17 +643,18 @@
std::chrono::steady_clock::now()
+ DHT_MSG_TIMEOUT);
info->waitForAnswer_->async_wait(
- std::bind(&ConnectionManager::Impl::onResponse, this, std::placeholders::_1, deviceId, vid));
+ std::bind(&ConnectionManager::Impl::onResponse, this, std::placeholders::_1, info, deviceId, vid));
}
void
ConnectionManager::Impl::onResponse(const asio::error_code& ec,
+ const std::weak_ptr<ConnectionInfo>& winfo,
const DeviceId& deviceId,
const dht::Value::Id& vid)
{
if (ec == asio::error::operation_aborted)
return;
- auto info = getInfo(deviceId, vid);
+ auto info = winfo.lock();
if (!info)
return;
@@ -587,12 +689,13 @@
bool
ConnectionManager::Impl::connectDeviceOnNegoDone(
+ const std::weak_ptr<DeviceInfo>& dinfo,
+ const std::shared_ptr<ConnectionInfo>& info,
const DeviceId& deviceId,
const std::string& name,
const dht::Value::Id& vid,
const std::shared_ptr<dht::crypto::Certificate>& cert)
{
- auto info = getInfo(deviceId, vid);
if (!info)
return false;
@@ -625,10 +728,10 @@
*cert);
info->tls_->setOnReady(
- [w = weak(), deviceId = std::move(deviceId), vid = std::move(vid), name = std::move(name)](
+ [w = weak_from_this(), dinfo, winfo=std::weak_ptr(info), deviceId = std::move(deviceId), vid = std::move(vid), name = std::move(name)](
bool ok) {
if (auto shared = w.lock())
- shared->onTlsNegotiationDone(ok, deviceId, vid, name);
+ shared->onTlsNegotiationDone(dinfo.lock(), winfo.lock(), ok, deviceId, vid, name);
});
return true;
}
@@ -650,7 +753,7 @@
return;
}
findCertificate(deviceId,
- [w = weak(),
+ [w = weak_from_this(),
deviceId,
name,
cb = std::move(cb),
@@ -695,7 +798,7 @@
return;
}
findCertificate(deviceId,
- [w = weak(),
+ [w = weak_from_this(),
deviceId,
name,
cb = std::move(cb),
@@ -734,7 +837,7 @@
const std::string& connType)
{
// Avoid dht operation in a DHT callback to avoid deadlocks
- dht::ThreadPool::computation().run([w = weak(),
+ dht::ThreadPool::computation().run([w = weak_from_this(),
name = std::move(name),
cert = std::move(cert),
cb = std::move(cb),
@@ -748,40 +851,35 @@
cb(nullptr, deviceId);
return;
}
+ auto di = sthis->infos_.createDeviceInfo(deviceId);
+ std::unique_lock<std::mutex> lk(di->mtx_);
+
dht::Value::Id vid;
- auto isConnectingToDevice = false;
{
- std::lock_guard<std::mutex> lk(sthis->connectCbsMtx_);
- vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand);
- auto pendingsIt = sthis->pendingOperations_.find(deviceId);
- if (pendingsIt != sthis->pendingOperations_.end()) {
- const auto& pendings = pendingsIt->second;
- while (pendings.connecting.find(vid) != pendings.connecting.end()
- || pendings.waiting.find(vid) != pendings.waiting.end()) {
- vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand);
- }
- }
- // Check if already connecting
- isConnectingToDevice = pendingsIt != sthis->pendingOperations_.end();
- // Save current request for sendChannelRequest.
- // Note: do not return here, cause we can be in a state where first
- // socket is negotiated and first channel is pending
- // so return only after we checked the info
- if (isConnectingToDevice && !forceNewSocket)
- pendingsIt->second.waiting[vid] = PendingCb {name, std::move(cb)};
- else
- sthis->pendingOperations_[deviceId].connecting[vid] = PendingCb {name, std::move(cb)};
+ std::lock_guard<std::mutex> lkr(sthis->randMtx_);
+ vid = di->newId(sthis->rand_);
}
+ // Check if already connecting
+ auto isConnectingToDevice = di->isConnecting();
+ // Note: we can be in a state where first
+ // socket is negotiated and first channel is pending
+ // so return only after we checked the info
+ if (isConnectingToDevice && !forceNewSocket)
+ di->waiting[vid] = PendingCb {name, std::move(cb)};
+ else
+ di->connecting[vid] = PendingCb {name, std::move(cb)};
+
// Check if already negotiated
- CallbackId cbId(deviceId, vid);
- if (auto info = sthis->getConnectedInfo(deviceId)) {
- std::lock_guard<std::mutex> lk(info->mutex_);
- if (info->socket_) {
+ if (auto info = di->getConnectedInfo()) {
+ std::unique_lock<std::mutex> lkc(info->mutex_);
+ if (auto sock = info->socket_) {
+ info->cbIds_.emplace(vid);
+ lkc.unlock();
+ lk.unlock();
if (sthis->config_->logger)
sthis->config_->logger->debug("[device {}] Peer already connected. Add a new channel", deviceId);
- info->cbIds_.emplace(cbId);
- sthis->sendChannelRequest(info->socket_, name, deviceId, vid);
+ sthis->sendChannelRequest(di, sock, name, vid);
return;
}
}
@@ -793,18 +891,24 @@
}
if (noNewSocket) {
// If no new socket is specified, we don't try to generate a new socket
- sthis->executePendingOperations(deviceId, vid, nullptr);
+ di->executePendingOperations(lk, vid, nullptr);
return;
}
// Note: used when the ice negotiation fails to erase
// all stored structures.
- auto eraseInfo = [w, cbId] {
- if (auto shared = w.lock()) {
- // If no new socket is specified, we don't try to generate a new socket
- shared->executePendingOperations(cbId.first, cbId.second, nullptr);
- std::lock_guard<std::mutex> lk(shared->infosMtx_);
- shared->infos_.erase(cbId);
+ auto eraseInfo = [w, diw=std::weak_ptr(di), vid] {
+ if (auto di = diw.lock()) {
+ std::unique_lock<std::mutex> lk(di->mtx_);
+ di->info.erase(vid);
+ auto ops = di->extractPendingOperations(vid, nullptr);
+ if (di->empty()) {
+ if (auto shared = w.lock())
+ shared->infos_.removeDeviceInfo(di->deviceId);
+ }
+ lk.unlock();
+ for (const auto& op: ops)
+ op.cb(nullptr, di->deviceId);
}
};
@@ -812,6 +916,7 @@
sthis->getIceOptions([w,
deviceId = std::move(deviceId),
devicePk = std::move(devicePk),
+ diw=std::weak_ptr(di),
name = std::move(name),
cert = std::move(cert),
vid,
@@ -822,17 +927,22 @@
dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
return;
}
+ auto info = std::make_shared<ConnectionInfo>();
+ auto winfo = std::weak_ptr(info);
ice_config.tcpEnable = true;
ice_config.onInitDone = [w,
devicePk = std::move(devicePk),
name = std::move(name),
cert = std::move(cert),
+ diw,
+ winfo = std::weak_ptr(info),
vid,
connType,
eraseInfo](bool ok) {
dht::ThreadPool::io().run([w = std::move(w),
devicePk = std::move(devicePk),
- vid = std::move(vid),
+ vid,
+ winfo,
eraseInfo,
connType, ok] {
auto sthis = w.lock();
@@ -842,7 +952,7 @@
eraseInfo();
return;
}
- sthis->connectDeviceStartIce(devicePk, vid, connType, [=](bool ok) {
+ sthis->connectDeviceStartIce(winfo.lock(), devicePk, vid, connType, [=](bool ok) {
if (!ok) {
dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
}
@@ -853,27 +963,30 @@
deviceId,
name,
cert = std::move(cert),
+ diw,
+ winfo = std::weak_ptr(info),
vid,
eraseInfo](bool ok) {
dht::ThreadPool::io().run([w = std::move(w),
deviceId = std::move(deviceId),
name = std::move(name),
cert = std::move(cert),
+ diw = std::move(diw),
+ winfo = std::move(winfo),
vid = std::move(vid),
eraseInfo = std::move(eraseInfo),
ok] {
auto sthis = w.lock();
if (!ok && sthis && sthis->config_->logger)
sthis->config_->logger->error("[device {}] ICE negotiation failed.", deviceId);
- if (!sthis || !ok || !sthis->connectDeviceOnNegoDone(deviceId, name, vid, cert))
+ if (!sthis || !ok || !sthis->connectDeviceOnNegoDone(diw, winfo.lock(), deviceId, name, vid, cert))
eraseInfo();
});
};
- auto info = std::make_shared<ConnectionInfo>();
- {
- std::lock_guard<std::mutex> lk(sthis->infosMtx_);
- sthis->infos_[{deviceId, vid}] = info;
+ if (auto di = diw.lock()) {
+ std::lock_guard<std::mutex> lk(di->mtx_);
+ di->info[vid] = info;
}
std::unique_lock<std::mutex> lk {info->mutex_};
ice_config.master = false;
@@ -903,23 +1016,20 @@
}
void
-ConnectionManager::Impl::sendChannelRequest(std::shared_ptr<MultiplexedSocket>& sock,
+ConnectionManager::Impl::sendChannelRequest(const std::weak_ptr<DeviceInfo>& dinfo,
+ const std::shared_ptr<MultiplexedSocket>& sock,
const std::string& name,
- const DeviceId& deviceId,
const dht::Value::Id& vid)
{
auto channelSock = sock->addChannel(name);
- channelSock->onShutdown([name, deviceId, vid, w = weak()] {
- auto shared = w.lock();
- if (auto shared = w.lock())
- shared->executePendingOperations(deviceId, vid, nullptr);
+ channelSock->onShutdown([dinfo, name, vid] {
+ if (auto info = dinfo.lock())
+ info->executePendingOperations(vid, nullptr);
});
channelSock->onReady(
- [wSock = std::weak_ptr<ChannelSocket>(channelSock), name, deviceId, vid, w = weak()](bool accepted) {
- auto shared = w.lock();
- auto channelSock = wSock.lock();
- if (shared)
- shared->executePendingOperations(deviceId, vid, accepted ? channelSock : nullptr, accepted);
+ [dinfo, wSock = std::weak_ptr(channelSock), name, vid](bool accepted) {
+ if (auto info = dinfo.lock())
+ info->executePendingOperations(vid, accepted ? wSock.lock() : nullptr, accepted);
});
ChannelRequest val;
@@ -937,7 +1047,7 @@
if (res < 0) {
// TODO check if we should handle errors here
if (config_->logger)
- config_->logger->error("[device {}] sendChannelRequest failed - error: {}", deviceId, ec.message());
+ config_->logger->error("sendChannelRequest failed - error: {}", ec.message());
}
}
@@ -945,7 +1055,7 @@
ConnectionManager::Impl::onPeerResponse(const PeerConnectionRequest& req)
{
auto device = req.owner->getLongId();
- if (auto info = getInfo(device, req.id)) {
+ if (auto info = infos_.getInfo(device, req.id)) {
if (config_->logger)
config_->logger->debug("[device {}] New response received", device);
std::lock_guard<std::mutex> lk {info->mutex_};
@@ -955,6 +1065,7 @@
info->waitForAnswer_->async_wait(std::bind(&ConnectionManager::Impl::onResponse,
this,
std::placeholders::_1,
+ std::weak_ptr(info),
device,
req.id));
} else {
@@ -970,7 +1081,7 @@
return;
dht()->listen<PeerConnectionRequest>(
dht::InfoHash::get(PeerConnectionRequest::key_prefix + devicePk.getId().toString()),
- [w = weak()](PeerConnectionRequest&& req) {
+ [w = weak_from_this()](PeerConnectionRequest&& req) {
auto shared = w.lock();
if (!shared)
return false;
@@ -1018,7 +1129,9 @@
}
void
-ConnectionManager::Impl::onTlsNegotiationDone(bool ok,
+ConnectionManager::Impl::onTlsNegotiationDone(const std::shared_ptr<DeviceInfo>& dinfo,
+ const std::shared_ptr<ConnectionInfo>& info,
+ bool ok,
const DeviceId& deviceId,
const dht::Value::Id& vid,
const std::string& name)
@@ -1044,7 +1157,7 @@
deviceId,
name,
vid);
- executePendingOperations(deviceId, vid, nullptr);
+ dinfo->executePendingOperations(vid, nullptr);
}
} else {
// The socket is ready, store it
@@ -1061,17 +1174,19 @@
vid);
}
- auto info = getInfo(deviceId, vid);
- addNewMultiplexedSocket({deviceId, vid}, info);
+ // Note: do not remove pending there it's done in sendChannelRequest
+ std::unique_lock<std::mutex> lk2 {dinfo->mtx_};
+ auto pendingIds = dinfo->getPendingIds();
+ lk2.unlock();
+ std::unique_lock<std::mutex> lk {info->mutex_};
+ addNewMultiplexedSocket(dinfo, deviceId, vid, info);
// Finally, open the channel and launch pending callbacks
- if (info->socket_) {
- // Note: do not remove pending there it's done in sendChannelRequest
- for (const auto& [id, name] : getPendingIds(deviceId)) {
- if (config_->logger)
- config_->logger->debug("[device {}] Send request on TLS socket for channel {}",
- deviceId, name);
- sendChannelRequest(info->socket_, name, deviceId, id);
- }
+ lk.unlock();
+ for (const auto& [id, name]: pendingIds) {
+ if (config_->logger)
+ config_->logger->debug("[device {}] Send request on TLS socket for channel {}",
+ deviceId, name);
+ sendChannelRequest(dinfo, info->socket_, name, id);
}
}
}
@@ -1113,13 +1228,12 @@
}
bool
-ConnectionManager::Impl::onRequestStartIce(const PeerConnectionRequest& req)
+ConnectionManager::Impl::onRequestStartIce(const std::shared_ptr<ConnectionInfo>& info, const PeerConnectionRequest& req)
{
- auto deviceId = req.owner->getLongId();
- auto info = getInfo(deviceId, req.id);
if (!info)
return false;
+ auto deviceId = req.owner->getLongId();
std::unique_lock<std::mutex> lk {info->mutex_};
auto& ice = info->ice_;
if (!ice) {
@@ -1144,13 +1258,12 @@
}
bool
-ConnectionManager::Impl::onRequestOnNegoDone(const PeerConnectionRequest& req)
+ConnectionManager::Impl::onRequestOnNegoDone(const std::weak_ptr<DeviceInfo>& dinfo, const std::shared_ptr<ConnectionInfo>& info, const PeerConnectionRequest& req)
{
- auto deviceId = req.owner->getLongId();
- auto info = getInfo(deviceId, req.id);
if (!info)
return false;
+ auto deviceId = req.owner->getLongId();
std::unique_lock<std::mutex> lk {info->mutex_};
auto& ice = info->ice_;
if (!ice) {
@@ -1176,7 +1289,7 @@
config_->ioContext,
identity(),
dhParams(),
- [ph, deviceId, w=weak(), l=config_->logger](const dht::crypto::Certificate& cert) {
+ [ph, deviceId, w=weak_from_this(), l=config_->logger](const dht::crypto::Certificate& cert) {
auto shared = w.lock();
if (!shared)
return false;
@@ -1194,9 +1307,9 @@
});
info->tls_->setOnReady(
- [w = weak(), deviceId = std::move(deviceId), vid = std::move(req.id)](bool ok) {
+ [w = weak_from_this(), dinfo, winfo=std::weak_ptr(info), deviceId = std::move(deviceId), vid = std::move(req.id)](bool ok) {
if (auto shared = w.lock())
- shared->onTlsNegotiationDone(ok, deviceId, vid);
+ shared->onTlsNegotiationDone(dinfo.lock(), winfo.lock(), ok, deviceId, vid);
});
return true;
}
@@ -1215,25 +1328,41 @@
}
// Because the connection is accepted, create an ICE socket.
- getIceOptions([w = weak(), req, deviceId](auto&& ice_config) {
+ getIceOptions([w = weak_from_this(), req, deviceId](auto&& ice_config) {
auto shared = w.lock();
if (!shared)
return;
+
+ auto di = shared->infos_.createDeviceInfo(deviceId);
+ auto info = std::make_shared<ConnectionInfo>();
+ auto wdi = std::weak_ptr(di);
+ auto winfo = std::weak_ptr(info);
+
// Note: used when the ice negotiation fails to erase
// all stored structures.
- auto eraseInfo = [w, id = req.id, deviceId] {
- if (auto shared = w.lock()) {
- // If no new socket is specified, we don't try to generate a new socket
- shared->executePendingOperations(deviceId, id, nullptr);
- if (shared->connReadyCb_)
- shared->connReadyCb_(deviceId, "", nullptr);
- std::lock_guard<std::mutex> lk(shared->infosMtx_);
- shared->infos_.erase({deviceId, id});
+ auto eraseInfo = [w, wdi, id = req.id] {
+ auto shared = w.lock();
+ if (auto di = wdi.lock()) {
+ std::unique_lock<std::mutex> lk(di->mtx_);
+ di->info.erase(id);
+ auto ops = di->extractPendingOperations(id, nullptr);
+ if (di->empty()) {
+ if (shared)
+ shared->infos_.removeDeviceInfo(di->deviceId);
+ }
+ lk.unlock();
+ for (const auto& op: ops)
+ op.cb(nullptr, di->deviceId);
+ if (shared && shared->connReadyCb_)
+ shared->connReadyCb_(di->deviceId, "", nullptr);
}
};
+ ice_config.master = true;
+ ice_config.streamsCount = 1;
+ ice_config.compCountPerStream = 1; // TCP
ice_config.tcpEnable = true;
- ice_config.onInitDone = [w, req, eraseInfo](bool ok) {
+ ice_config.onInitDone = [w, winfo, req, eraseInfo](bool ok) {
auto shared = w.lock();
if (!shared)
return;
@@ -1245,16 +1374,15 @@
}
dht::ThreadPool::io().run(
- [w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] {
- auto shared = w.lock();
- if (!shared)
- return;
- if (!shared->onRequestStartIce(req))
- eraseInfo();
+ [w = std::move(w), winfo = std::move(winfo), req = std::move(req), eraseInfo = std::move(eraseInfo)] {
+ if (auto shared = w.lock()) {
+ if (!shared->onRequestStartIce(winfo.lock(), req))
+ eraseInfo();
+ }
});
};
- ice_config.onNegoDone = [w, req, eraseInfo](bool ok) {
+ ice_config.onNegoDone = [w, wdi, winfo, req, eraseInfo](bool ok) {
auto shared = w.lock();
if (!shared)
return;
@@ -1266,25 +1394,22 @@
}
dht::ThreadPool::io().run(
- [w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] {
+ [w = std::move(w), wdi = std::move(wdi), winfo = std::move(winfo), req = std::move(req), eraseInfo = std::move(eraseInfo)] {
if (auto shared = w.lock())
- if (!shared->onRequestOnNegoDone(req))
+ if (!shared->onRequestOnNegoDone(wdi.lock(), winfo.lock(), req))
eraseInfo();
});
};
// Negotiate a new ICE socket
- auto info = std::make_shared<ConnectionInfo>();
{
- std::lock_guard<std::mutex> lk(shared->infosMtx_);
- shared->infos_[{deviceId, req.id}] = info;
+ std::lock_guard<std::mutex> lk(di->mtx_);
+ di->info[req.id] = info;
}
+
if (shared->config_->logger)
shared->config_->logger->debug("[device {}] Accepting connection", deviceId);
std::unique_lock<std::mutex> lk {info->mutex_};
- ice_config.streamsCount = 1;
- ice_config.compCountPerStream = 1; // TCP
- ice_config.master = true;
info->ice_ = shared->config_->factory->createUTransport("");
if (not info->ice_) {
if (shared->config_->logger)
@@ -1307,16 +1432,16 @@
}
void
-ConnectionManager::Impl::addNewMultiplexedSocket(const CallbackId& id, const std::shared_ptr<ConnectionInfo>& info)
+ConnectionManager::Impl::addNewMultiplexedSocket(const std::weak_ptr<DeviceInfo>& dinfo, const DeviceId& deviceId, const dht::Value::Id& vid, const std::shared_ptr<ConnectionInfo>& info)
{
- info->socket_ = std::make_shared<MultiplexedSocket>(config_->ioContext, id.first, std::move(info->tls_), config_->logger);
+ info->socket_ = std::make_shared<MultiplexedSocket>(config_->ioContext, deviceId, std::move(info->tls_), config_->logger);
info->socket_->setOnReady(
- [w = weak()](const DeviceId& deviceId, const std::shared_ptr<ChannelSocket>& socket) {
+ [w = weak_from_this()](const DeviceId& deviceId, const std::shared_ptr<ChannelSocket>& socket) {
if (auto sthis = w.lock())
if (sthis->connReadyCb_)
sthis->connReadyCb_(deviceId, socket->name(), socket);
});
- info->socket_->setOnRequest([w = weak()](const std::shared_ptr<dht::crypto::Certificate>& peer,
+ info->socket_->setOnRequest([w = weak_from_this()](const std::shared_ptr<dht::crypto::Certificate>& peer,
const uint16_t&,
const std::string& name) {
if (auto sthis = w.lock())
@@ -1324,26 +1449,34 @@
return sthis->channelReqCb_(peer, name);
return false;
});
- info->socket_->onShutdown([w = weak(), deviceId=id.first, vid=id.second]() {
+ info->socket_->onShutdown([dinfo, wi=std::weak_ptr(info), vid]() {
// Cancel current outgoing connections
- dht::ThreadPool::io().run([w, deviceId, vid] {
- auto sthis = w.lock();
- if (!sthis)
- return;
-
- std::set<CallbackId> ids;
- if (auto info = sthis->getInfo(deviceId, vid)) {
+ dht::ThreadPool::io().run([dinfo, wi, vid] {
+ std::set<dht::Value::Id> ids;
+ if (auto info = wi.lock()) {
std::lock_guard<std::mutex> lk(info->mutex_);
if (info->socket_) {
ids = std::move(info->cbIds_);
info->socket_->shutdown();
}
}
- for (const auto& cbId : ids)
- sthis->executePendingOperations(cbId.first, cbId.second, nullptr);
-
- std::lock_guard<std::mutex> lk(sthis->infosMtx_);
- sthis->infos_.erase({deviceId, vid});
+ if (auto deviceInfo = dinfo.lock()) {
+ std::shared_ptr<ConnectionInfo> info;
+ std::vector<PendingCb> ops;
+ std::unique_lock<std::mutex> lk(deviceInfo->mtx_);
+ auto it = deviceInfo->info.find(vid);
+ if (it != deviceInfo->info.end()) {
+ info = std::move(it->second);
+ deviceInfo->info.erase(it);
+ }
+ for (const auto& cbId : ids) {
+ auto po = deviceInfo->extractPendingOperations(cbId, nullptr);
+ ops.insert(ops.end(), po.begin(), po.end());
+ }
+ lk.unlock();
+ for (auto& op : ops)
+ op.cb(nullptr, deviceInfo->deviceId);
+ }
});
});
}
@@ -1409,7 +1542,7 @@
void
ConnectionManager::Impl::saveTreatedMessages() const
{
- dht::ThreadPool::io().run([w = weak()]() {
+ dht::ThreadPool::io().run([w = weak_from_this()]() {
if (auto sthis = w.lock()) {
auto& this_ = *sthis;
std::lock_guard<std::mutex> lock(this_.messageMutex_);
@@ -1474,7 +1607,7 @@
void
ConnectionManager::Impl::storeActiveIpAddress(std::function<void()>&& cb)
{
- dht()->getPublicAddress([w=weak(), cb = std::move(cb)](std::vector<dht::SockAddr>&& results) {
+ dht()->getPublicAddress([w=weak_from_this(), cb = std::move(cb)](std::vector<dht::SockAddr>&& results) {
auto shared = w.lock();
if (!shared)
return;
@@ -1703,52 +1836,50 @@
bool
ConnectionManager::isConnecting(const DeviceId& deviceId, const std::string& name) const
{
- auto pending = pimpl_->getPendingIds(deviceId);
- return std::find_if(pending.begin(), pending.end(), [&](auto p) { return p.second == name; })
- != pending.end();
+ if (auto dinfo = pimpl_->infos_.getDeviceInfo(deviceId)) {
+ std::unique_lock<std::mutex> lk {dinfo->mtx_};
+ auto pending = dinfo->getPendingIds();
+ lk.unlock();
+ return std::find_if(pending.begin(), pending.end(), [&](const auto& p) { return p.second == name; })
+ != pending.end();
+ }
+ return false;
}
void
ConnectionManager::closeConnectionsWith(const std::string& peerUri)
{
- std::vector<std::shared_ptr<ConnectionInfo>> connInfos;
- std::set<DeviceId> peersDevices;
- {
- std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
- for (auto iter = pimpl_->infos_.begin(); iter != pimpl_->infos_.end();) {
- auto const& [key, value] = *iter;
- std::unique_lock<std::mutex> lkv {value->mutex_};
- auto deviceId = key.first;
- auto tls = value->tls_ ? value->tls_.get() : (value->socket_ ? value->socket_->endpoint() : nullptr);
+ std::vector<std::shared_ptr<DeviceInfo>> dInfos;
+ for (const auto& dinfo: pimpl_->infos_.getDeviceInfos()) {
+ std::unique_lock<std::mutex> lk(dinfo->mtx_);
+ bool isPeer = false;
+ for (auto const& [id, cinfo]: dinfo->info) {
+ std::lock_guard<std::mutex> lkv {cinfo->mutex_};
+ auto tls = cinfo->tls_ ? cinfo->tls_.get() : (cinfo->socket_ ? cinfo->socket_->endpoint() : nullptr);
auto cert = tls ? tls->peerCertificate() : nullptr;
if (not cert)
- cert = pimpl_->certStore().getCertificate(deviceId.toString());
+ cert = pimpl_->certStore().getCertificate(dinfo->deviceId.toString());
if (cert && cert->issuer && peerUri == cert->issuer->getId().toString()) {
- connInfos.emplace_back(value);
- peersDevices.emplace(deviceId);
- lkv.unlock();
- iter = pimpl_->infos_.erase(iter);
- } else {
- iter++;
+ isPeer = true;
+ break;
}
}
+ lk.unlock();
+ if (isPeer) {
+ dInfos.emplace_back(std::move(dinfo));
+ }
}
// Stop connections to all peers devices
- for (const auto& deviceId : peersDevices) {
- pimpl_->executePendingOperations(deviceId, 0, nullptr);
- // This will close the TLS Session
- pimpl_->removeUnusedConnections(deviceId);
- }
- for (auto& info : connInfos) {
- if (info->socket_)
- info->socket_->shutdown();
- if (info->waitForAnswer_)
- info->waitForAnswer_->cancel();
- if (info->ice_) {
- std::unique_lock<std::mutex> lk {info->mutex_};
- dht::ThreadPool::io().run(
- [ice = std::shared_ptr<IceTransport>(std::move(info->ice_))] {});
- }
+ for (const auto& dinfo : dInfos) {
+ std::unique_lock<std::mutex> lk {dinfo->mtx_};
+ auto unused = dinfo->extractUnusedConnections();
+ auto pending = dinfo->extractPendingOperations(0, nullptr);
+ pimpl_->infos_.removeDeviceInfo(dinfo->deviceId);
+ lk.unlock();
+ for (auto& op : unused)
+ op->shutdown();
+ for (auto& op : pending)
+ op.cb(nullptr, dinfo->deviceId);
}
}
@@ -1785,19 +1916,18 @@
std::size_t
ConnectionManager::activeSockets() const
{
- std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
- return pimpl_->infos_.size();
+ return pimpl_->infos_.getConnectedInfos().size();
}
void
ConnectionManager::monitor() const
{
- std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
auto logger = pimpl_->config_->logger;
if (!logger)
return;
logger->debug("ConnectionManager current status:");
- for (const auto& [_, ci] : pimpl_->infos_) {
+ for (const auto& ci : pimpl_->infos_.getConnectedInfos()) {
+ std::lock_guard<std::mutex> lk(ci->mutex_);
if (ci->socket_)
ci->socket_->monitor();
}
@@ -1807,8 +1937,8 @@
void
ConnectionManager::connectivityChanged()
{
- std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
- for (const auto& [_, ci] : pimpl_->infos_) {
+ for (const auto& ci : pimpl_->infos_.getConnectedInfos()) {
+ std::lock_guard<std::mutex> lk(ci->mutex_);
if (ci->socket_)
ci->socket_->sendBeacon();
}
@@ -1854,71 +1984,14 @@
ConnectionManager::getConnectionList(const DeviceId& device) const
{
std::vector<std::map<std::string, std::string>> connectionsList;
- std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
-
- for (const auto& [key, ci] : pimpl_->infos_) {
- if (device && key.first != device)
- continue;
- std::map<std::string, std::string> connectionInfo;
- connectionInfo["id"] = callbackIdToString(key.first, key.second);
- connectionInfo["device"] = key.first.toString();
- if (ci->tls_) {
- if (auto cert = ci->tls_->peerCertificate()) {
- connectionInfo["peer"] = cert->issuer->getId().toString();
- }
- }
- if (ci->socket_) {
- connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Connected));
- } else if (ci->tls_) {
- connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::TLS));
- } else if(ci->ice_)
- {
- connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::ICE));
- }
- if (ci->tls_) {
- std::string remoteAddress = ci->tls_->getRemoteAddress();
- std::string remoteAddressIp = remoteAddress.substr(0, remoteAddress.find(':'));
- std::string remoteAddressPort = remoteAddress.substr(remoteAddress.find(':') + 1);
- connectionInfo["remoteAdress"] = remoteAddressIp;
- connectionInfo["remotePort"] = remoteAddressPort;
- }
- connectionsList.emplace_back(std::move(connectionInfo));
- }
-
if (device) {
- auto it = pimpl_->pendingOperations_.find(device);
- if (it != pimpl_->pendingOperations_.end()) {
- const auto& po = it->second;
- for (const auto& [vid, ci] : po.connecting) {
- std::map<std::string, std::string> connectionInfo;
- connectionInfo["id"] = callbackIdToString(device, vid);
- connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Connecting));
- connectionsList.emplace_back(std::move(connectionInfo));
- }
-
- for (const auto& [vid, ci] : po.waiting) {
- std::map<std::string, std::string> connectionInfo;
- connectionInfo["id"] = callbackIdToString(device, vid);
- connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Waiting));
- connectionsList.emplace_back(std::move(connectionInfo));
- }
+ if (auto deviceInfo = pimpl_->infos_.getDeviceInfo(device)) {
+ connectionsList = deviceInfo->getConnectionList(pimpl_->certStore());
}
- }
- else {
- for (const auto& [key, po] : pimpl_->pendingOperations_) {
- for (const auto& [vid, ci] : po.connecting) {
- std::map<std::string, std::string> connectionInfo;
- connectionInfo["id"] = callbackIdToString(device, vid);
- connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Connecting));
- connectionsList.emplace_back(std::move(connectionInfo));
- }
-
- for (const auto& [vid, ci] : po.waiting) {
- std::map<std::string, std::string> connectionInfo;
- connectionInfo["id"] = callbackIdToString(device, vid);
- connectionInfo["status"] = std::to_string(static_cast<int>(ConnectionStatus::Waiting));
- connectionsList.emplace_back(std::move(connectionInfo));
- }
+ } else {
+ for (const auto& deviceInfo : pimpl_->infos_.getDeviceInfos()) {
+ auto cl = deviceInfo->getConnectionList(pimpl_->certStore());
+ connectionsList.insert(connectionsList.end(), std::make_move_iterator(cl.begin()), std::make_move_iterator(cl.end()));
}
}
return connectionsList;
@@ -1927,13 +2000,13 @@
std::vector<std::map<std::string, std::string>>
ConnectionManager::getChannelList(const std::string& connectionId) const
{
- std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
- CallbackId cbid = parseCallbackId(connectionId);
- if (pimpl_->infos_.count(cbid) > 0) {
- return pimpl_->infos_[cbid]->socket_->getChannelList();
- } else {
- return {};
+ auto [deviceId, valueId] = parseCallbackId(connectionId);
+ if (auto info = pimpl_->infos_.getInfo(deviceId, valueId)) {
+ std::lock_guard<std::mutex> lk(info->mutex_);
+ if (info->socket_)
+ return info->socket_->getChannelList();
}
+ return {};
}
} // namespace dhtnet