blob: b34bf07f17bb9346a4715d7c76ecb2e53a295534 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
Adrien Béraudcb753622023-07-17 22:32:49 -04002 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
Adrien Béraud612b55b2023-05-29 10:42:04 -04003 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "multiplexed_socket.h"
18#include "peer_connection.h"
19#include "ice_transport.h"
20#include "certstore.h"
21
22#include <opendht/logger.h>
23#include <opendht/thread_pool.h>
24
25#include <asio/io_context.hpp>
26#include <asio/steady_timer.hpp>
27
28#include <deque>
29
30static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
31static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
32
33struct ChanneledMessage
34{
35 uint16_t channel;
36 std::vector<uint8_t> data;
37 MSGPACK_DEFINE(channel, data)
38};
39
40struct BeaconMsg
41{
42 bool p;
43 MSGPACK_DEFINE_MAP(p)
44};
45
46struct VersionMsg
47{
48 int v;
49 MSGPACK_DEFINE_MAP(v)
50};
51
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040052namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040053
54using clock = std::chrono::steady_clock;
55using time_point = clock::time_point;
56
57class MultiplexedSocket::Impl
58{
59public:
60 Impl(MultiplexedSocket& parent,
61 std::shared_ptr<asio::io_context> ctx,
62 const DeviceId& deviceId,
Adrien Béraud5636f7c2023-09-14 14:34:57 -040063 std::unique_ptr<TlsSocketEndpoint> endpoint,
64 std::shared_ptr<dht::log::Logger> logger)
Adrien Béraud612b55b2023-05-29 10:42:04 -040065 : parent_(parent)
Adrien Béraud612b55b2023-05-29 10:42:04 -040066 , ctx_(std::move(ctx))
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040067 , deviceId(deviceId)
Adrien Béraud612b55b2023-05-29 10:42:04 -040068 , endpoint(std::move(endpoint))
69 , eventLoopThread_ {[this] {
70 try {
71 eventLoop();
72 } catch (const std::exception& e) {
73 if (logger_)
74 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
75 shutdown();
76 }
77 }}
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040078 , beaconTimer_(*ctx_)
Adrien Béraud612b55b2023-05-29 10:42:04 -040079 {}
80
81 ~Impl() {}
82
83 void join()
84 {
85 if (!isShutdown_) {
86 if (endpoint)
87 endpoint->setOnStateChange({});
88 shutdown();
89 } else {
90 clearSockets();
91 }
92 if (eventLoopThread_.joinable())
93 eventLoopThread_.join();
94 }
95
96 void clearSockets()
97 {
98 decltype(sockets) socks;
99 {
100 std::lock_guard<std::mutex> lkSockets(socketsMutex);
101 socks = std::move(sockets);
102 }
103 for (auto& socket : socks) {
104 // Just trigger onShutdown() to make client know
105 // No need to write the EOF for the channel, the write will fail because endpoint is
106 // already shutdown
107 if (socket.second)
108 socket.second->stop();
109 }
110 }
111
112 void shutdown()
113 {
114 if (isShutdown_)
115 return;
116 stop.store(true);
117 isShutdown_ = true;
118 beaconTimer_.cancel();
119 if (onShutdown_)
120 onShutdown_();
121 if (endpoint) {
122 std::unique_lock<std::mutex> lk(writeMtx);
123 endpoint->shutdown();
124 }
125 clearSockets();
126 }
127
128 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
129 uint16_t channel,
130 bool isInitiator = false)
131 {
132 auto& channelSocket = sockets[channel];
133 if (not channelSocket)
134 channelSocket = std::make_shared<ChannelSocket>(
135 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
136 // Remove socket in another thread to avoid any lock
137 dht::ThreadPool::io().run([w, channel]() {
138 if (auto shared = w.lock()) {
139 shared->eraseChannel(channel);
140 }
141 });
142 });
143 else {
144 if (logger_)
145 logger_->warn("A channel is already present on that socket, accepting "
146 "the request will close the previous one {}", name);
147 }
148 return channelSocket;
149 }
150
151 /**
152 * Handle packets on the TLS endpoint and parse RTP
153 */
154 void eventLoop();
155 /**
156 * Triggered when a new control packet is received
157 */
158 void handleControlPacket(std::vector<uint8_t>&& pkt);
159 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
160 bool handleProtocolMsg(const msgpack::object& o);
161 /**
162 * Triggered when a new packet on a channel is received
163 */
164 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
165 void onRequest(const std::string& name, uint16_t channel);
166 void onAccept(const std::string& name, uint16_t channel);
167
168 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
169 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
170
171 // Beacon
172 void sendBeacon(const std::chrono::milliseconds& timeout);
173 void handleBeaconRequest();
174 void handleBeaconResponse();
175 std::atomic_int beaconCounter_ {0};
176
177 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
178
179 msgpack::unpacker pac_ {};
180
181 MultiplexedSocket& parent_;
182
183 std::shared_ptr<Logger> logger_;
184 std::shared_ptr<asio::io_context> ctx_;
185
186 OnConnectionReadyCb onChannelReady_ {};
187 OnConnectionRequestCb onRequest_ {};
188 OnShutdownCb onShutdown_ {};
189
190 DeviceId deviceId {};
191 // Main socket
192 std::unique_ptr<TlsSocketEndpoint> endpoint {};
193
194 std::mutex socketsMutex {};
195 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
196
197 // Main loop to parse incoming packets
198 std::atomic_bool stop {false};
199 std::thread eventLoopThread_ {};
200
201 std::atomic_bool isShutdown_ {false};
202
203 std::mutex writeMtx {};
204
205 time_point start_ {clock::now()};
206 //std::shared_ptr<Task> beaconTask_ {};
207 asio::steady_timer beaconTimer_;
208
209 // version related stuff
210 void sendVersion();
211 void onVersion(int version);
212 std::atomic_bool canSendBeacon_ {false};
213 std::atomic_bool answerBeacon_ {true};
214 int version_ {MULTIPLEXED_SOCKET_VERSION};
215 std::function<void(bool)> onBeaconCb_ {};
216 std::function<void(int)> onVersionCb_ {};
217};
218
219void
220MultiplexedSocket::Impl::eventLoop()
221{
222 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
223 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
224 if (logger_)
225 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
226 shutdown();
227 return false;
228 }
229 return true;
230 });
231 sendVersion();
232 std::error_code ec;
233 while (!stop) {
234 if (!endpoint) {
235 shutdown();
236 return;
237 }
238 pac_.reserve_buffer(IO_BUFFER_SIZE);
239 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
240 if (size < 0) {
241 if (ec && logger_)
242 logger_->error("Read error detected: {}", ec.message());
243 break;
244 }
245 if (size == 0) {
246 // We can close the socket
247 shutdown();
248 break;
249 }
250
251 pac_.buffer_consumed(size);
252 msgpack::object_handle oh;
253 while (pac_.next(oh) && !stop) {
254 try {
255 auto msg = oh.get().as<ChanneledMessage>();
256 if (msg.channel == CONTROL_CHANNEL)
257 handleControlPacket(std::move(msg.data));
258 else if (msg.channel == PROTOCOL_CHANNEL)
259 handleProtocolPacket(std::move(msg.data));
260 else
261 handleChannelPacket(msg.channel, std::move(msg.data));
262 } catch (const std::exception& e) {
263 if (logger_)
264 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
265 } catch (...) {
266 if (logger_)
267 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
268 }
269 }
270 }
271}
272
273void
274MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
275{
276 std::lock_guard<std::mutex> lkSockets(socketsMutex);
277 auto& socket = sockets[channel];
278 if (!socket) {
279 if (logger_)
280 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
281 return;
282 }
283
284 onChannelReady_(deviceId, socket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400285 socket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400286 // Due to the callbacks that can take some time, onAccept can arrive after
287 // receiving all the data. In this case, the socket should be removed here
288 // as handle by onChannelReady_
289 if (socket->isRemovable())
290 sockets.erase(channel);
291 else
292 socket->answered();
293}
294
295void
296MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
297{
298 if (!canSendBeacon_)
299 return;
300 beaconCounter_++;
301 if (logger_)
302 logger_->debug("Send beacon to peer {}", deviceId);
303
304 msgpack::sbuffer buffer(8);
305 msgpack::packer<msgpack::sbuffer> pk(&buffer);
306 pk.pack(BeaconMsg {true});
307 if (!writeProtocolMessage(buffer))
308 return;
309 beaconTimer_.expires_after(timeout);
310 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
311 if (ec == asio::error::operation_aborted)
312 return;
313 if (auto shared = w.lock()) {
314 if (shared->pimpl_->beaconCounter_ != 0) {
315 if (shared->pimpl_->logger_)
316 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
317 shared->shutdown();
318 }
319 }
320 });
321}
322
323void
324MultiplexedSocket::Impl::handleBeaconRequest()
325{
326 if (!answerBeacon_)
327 return;
328 // Run this on dedicated thread because some callbacks can take time
329 dht::ThreadPool::io().run([w = parent_.weak()]() {
330 if (auto shared = w.lock()) {
331 msgpack::sbuffer buffer(8);
332 msgpack::packer<msgpack::sbuffer> pk(&buffer);
333 pk.pack(BeaconMsg {false});
334 if (shared->pimpl_->logger_)
335 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
336 shared->pimpl_->writeProtocolMessage(buffer);
337 }
338 });
339}
340
341void
342MultiplexedSocket::Impl::handleBeaconResponse()
343{
344 if (logger_)
345 logger_->debug("Get beacon response from peer {}", deviceId);
346 beaconCounter_--;
347}
348
349bool
350MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
351{
352 std::error_code ec;
353 int wr = parent_.write(PROTOCOL_CHANNEL,
354 (const unsigned char*) buffer.data(),
355 buffer.size(),
356 ec);
357 return wr > 0;
358}
359
360void
361MultiplexedSocket::Impl::sendVersion()
362{
363 dht::ThreadPool::io().run([w = parent_.weak()]() {
364 if (auto shared = w.lock()) {
365 auto version = shared->pimpl_->version_;
366 msgpack::sbuffer buffer(8);
367 msgpack::packer<msgpack::sbuffer> pk(&buffer);
368 pk.pack(VersionMsg {version});
369 shared->pimpl_->writeProtocolMessage(buffer);
370 }
371 });
372}
373
374void
375MultiplexedSocket::Impl::onVersion(int version)
376{
377 // Check if version > 1
378 if (version >= 1) {
379 if (logger_)
380 logger_->debug("Peer {} supports beacon", deviceId);
381 canSendBeacon_ = true;
382 } else {
383 if (logger_)
384 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
385 deviceId,
386 version);
387 canSendBeacon_ = false;
388 }
389}
390
391void
392MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
393{
394 auto accept = onRequest_(endpoint->peerCertificate(), channel, name);
395 std::shared_ptr<ChannelSocket> channelSocket;
396 if (accept) {
397 std::lock_guard<std::mutex> lkSockets(socketsMutex);
398 channelSocket = makeSocket(name, channel);
399 }
400
401 // Answer to ChannelRequest if accepted
402 ChannelRequest val;
403 val.channel = channel;
404 val.name = name;
405 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
406 msgpack::sbuffer buffer(512);
407 msgpack::pack(buffer, val);
408 std::error_code ec;
409 int wr = parent_.write(CONTROL_CHANNEL,
410 reinterpret_cast<const uint8_t*>(buffer.data()),
411 buffer.size(),
412 ec);
413 if (wr < 0) {
414 if (ec && logger_)
415 logger_->error("The write operation failed with error: {:s}", ec.message());
416 stop.store(true);
417 return;
418 }
419
420 if (accept) {
421 onChannelReady_(deviceId, channelSocket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400422 channelSocket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400423 }
424}
425
426void
427MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
428{
429 // Run this on dedicated thread because some callbacks can take time
430 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
431 auto shared = w.lock();
432 if (!shared)
433 return;
434 auto& pimpl = *shared->pimpl_;
435 try {
436 size_t off = 0;
437 while (off != pkt.size()) {
438 msgpack::unpacked result;
439 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
440 auto object = result.get();
441 if (pimpl.handleProtocolMsg(object))
442 continue;
443 auto req = object.as<ChannelRequest>();
444 if (req.state == ChannelRequestState::ACCEPT) {
445 pimpl.onAccept(req.name, req.channel);
446 } else if (req.state == ChannelRequestState::DECLINE) {
447 std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
448 auto channel = pimpl.sockets.find(req.channel);
449 if (channel != pimpl.sockets.end()) {
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400450 channel->second->ready(false);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400451 channel->second->stop();
452 pimpl.sockets.erase(channel);
453 }
454 } else if (pimpl.onRequest_) {
455 pimpl.onRequest(req.name, req.channel);
456 }
457 }
458 } catch (const std::exception& e) {
459 if (pimpl.logger_)
460 pimpl.logger_->error("Error on the control channel: {}", e.what());
461 }
462 });
463}
464
465void
466MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
467{
468 std::lock_guard<std::mutex> lkSockets(socketsMutex);
469 auto sockIt = sockets.find(channel);
470 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
471 if (pkt.size() == 0) {
472 sockIt->second->stop();
473 if (sockIt->second->isAnswered())
474 sockets.erase(sockIt);
475 else
476 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
477 // removed later.
478 } else {
479 sockIt->second->onRecv(std::move(pkt));
480 }
481 } else if (pkt.size() != 0) {
482 if (logger_)
483 logger_->warn("Non existing channel: {}", channel);
484 }
485}
486
487bool
488MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
489{
490 try {
491 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
492 auto key = o.via.map.ptr[0].key.as<std::string_view>();
493 if (key == "p") {
494 auto msg = o.as<BeaconMsg>();
495 if (msg.p)
496 handleBeaconRequest();
497 else
498 handleBeaconResponse();
499 if (onBeaconCb_)
500 onBeaconCb_(msg.p);
501 return true;
502 } else if (key == "v") {
503 auto msg = o.as<VersionMsg>();
504 onVersion(msg.v);
505 if (onVersionCb_)
506 onVersionCb_(msg.v);
507 return true;
508 } else {
509 if (logger_)
510 logger_->warn("Unknown message type");
511 }
512 }
513 } catch (const std::exception& e) {
514 if (logger_)
515 logger_->error("Error on the protocol channel: {}", e.what());
516 }
517 return false;
518}
519
520void
521MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
522{
523 // Run this on dedicated thread because some callbacks can take time
524 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
525 auto shared = w.lock();
526 if (!shared)
527 return;
528 try {
529 size_t off = 0;
530 while (off != pkt.size()) {
531 msgpack::unpacked result;
532 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
533 auto object = result.get();
534 if (shared->pimpl_->handleProtocolMsg(object))
535 return;
536 }
537 } catch (const std::exception& e) {
538 if (shared->pimpl_->logger_)
539 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
540 }
541 });
542}
543
544MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
Adrien Béraud5636f7c2023-09-14 14:34:57 -0400545 std::unique_ptr<TlsSocketEndpoint> endpoint, std::shared_ptr<dht::log::Logger> logger)
546 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint), logger))
Adrien Béraud612b55b2023-05-29 10:42:04 -0400547{}
548
549MultiplexedSocket::~MultiplexedSocket() {}
550
551std::shared_ptr<ChannelSocket>
552MultiplexedSocket::addChannel(const std::string& name)
553{
554 // Note: because both sides can request the same channel number at the same time
555 // it's better to use a random channel number instead of just incrementing the request.
556 thread_local dht::crypto::random_device rd;
557 std::uniform_int_distribution<uint16_t> dist;
558 auto offset = dist(rd);
559 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
560 for (int i = 1; i < UINT16_MAX; ++i) {
561 auto c = (offset + i) % UINT16_MAX;
562 if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL
563 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
564 continue;
565 auto channel = pimpl_->makeSocket(name, c, true);
566 return channel;
567 }
568 return {};
569}
570
571DeviceId
572MultiplexedSocket::deviceId() const
573{
574 return pimpl_->deviceId;
575}
576
577void
578MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
579{
580 pimpl_->onChannelReady_ = std::move(cb);
581}
582
583void
584MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
585{
586 pimpl_->onRequest_ = std::move(cb);
587}
588
589bool
590MultiplexedSocket::isReliable() const
591{
592 return true;
593}
594
595bool
596MultiplexedSocket::isInitiator() const
597{
598 if (!pimpl_->endpoint) {
599 if (pimpl_->logger_)
600 pimpl_->logger_->warn("No endpoint found for socket");
601 return false;
602 }
603 return pimpl_->endpoint->isInitiator();
604}
605
606int
607MultiplexedSocket::maxPayload() const
608{
609 if (!pimpl_->endpoint) {
610 if (pimpl_->logger_)
611 pimpl_->logger_->warn("No endpoint found for socket");
612 return 0;
613 }
614 return pimpl_->endpoint->maxPayload();
615}
616
617std::size_t
618MultiplexedSocket::write(const uint16_t& channel,
619 const uint8_t* buf,
620 std::size_t len,
621 std::error_code& ec)
622{
623 assert(nullptr != buf);
624
625 if (pimpl_->isShutdown_) {
626 ec = std::make_error_code(std::errc::broken_pipe);
627 return -1;
628 }
629 if (len > UINT16_MAX) {
630 ec = std::make_error_code(std::errc::message_size);
631 return -1;
632 }
633 bool oneShot = len < 8192;
634 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
635 msgpack::packer<msgpack::sbuffer> pk(&buffer);
636 pk.pack_array(2);
637 pk.pack(channel);
638 pk.pack_bin(len);
639 if (oneShot)
640 pk.pack_bin_body((const char*) buf, len);
641
642 std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
643 if (!pimpl_->endpoint) {
644 if (pimpl_->logger_)
645 pimpl_->logger_->warn("No endpoint found for socket");
646 ec = std::make_error_code(std::errc::broken_pipe);
647 return -1;
648 }
649 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
650 if (not oneShot and res >= 0)
651 res = pimpl_->endpoint->write(buf, len, ec);
652 lk.unlock();
653 if (res < 0) {
654 if (ec && pimpl_->logger_)
655 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
656 shutdown();
657 }
658 return res;
659}
660
661void
662MultiplexedSocket::shutdown()
663{
664 pimpl_->shutdown();
665}
666
667void
668MultiplexedSocket::join()
669{
670 pimpl_->join();
671}
672
673void
674MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
675{
676 pimpl_->onShutdown_ = std::move(cb);
677 if (pimpl_->isShutdown_)
678 pimpl_->onShutdown_();
679}
680
681const std::shared_ptr<Logger>&
682MultiplexedSocket::logger()
683{
684 return pimpl_->logger_;
685}
686
687void
688MultiplexedSocket::monitor() const
689{
690 auto cert = peerCertificate();
691 if (!cert || !cert->issuer)
692 return;
693 auto now = clock::now();
694 if (!pimpl_->logger_)
695 return;
696 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
697 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
698 pimpl_->endpoint->monitor();
699 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
700 for (const auto& [_, channel] : pimpl_->sockets) {
701 if (channel)
702 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
703 fmt::ptr(channel.get()),
704 channel.use_count(),
705 channel->name(),
706 channel->isInitiator());
707 }
708}
709
710void
711MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
712{
713 pimpl_->sendBeacon(timeout);
714}
715
716std::shared_ptr<dht::crypto::Certificate>
717MultiplexedSocket::peerCertificate() const
718{
719 return pimpl_->endpoint->peerCertificate();
720}
721
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400722#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400723bool
724MultiplexedSocket::canSendBeacon() const
725{
726 return pimpl_->canSendBeacon_;
727}
728
729void
730MultiplexedSocket::answerToBeacon(bool value)
731{
732 pimpl_->answerBeacon_ = value;
733}
734
735void
736MultiplexedSocket::setVersion(int version)
737{
738 pimpl_->version_ = version;
739}
740
741void
742MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
743{
744 pimpl_->onBeaconCb_ = cb;
745}
746
747void
748MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
749{
750 pimpl_->onVersionCb_ = cb;
751}
752
753void
754MultiplexedSocket::sendVersion()
755{
756 pimpl_->sendVersion();
757}
758
Adrien Béraudac35e662023-07-19 09:37:29 -0400759#endif
760
Adrien Béraud612b55b2023-05-29 10:42:04 -0400761IpAddr
762MultiplexedSocket::getLocalAddress() const
763{
764 return pimpl_->endpoint->getLocalAddress();
765}
766
767IpAddr
768MultiplexedSocket::getRemoteAddress() const
769{
770 return pimpl_->endpoint->getRemoteAddress();
771}
772
Adrien Béraudafa8e282023-09-24 12:53:20 -0400773TlsSocketEndpoint*
774MultiplexedSocket::endpoint()
775{
776 return pimpl_->endpoint.get();
777}
778
Adrien Béraud612b55b2023-05-29 10:42:04 -0400779void
780MultiplexedSocket::eraseChannel(uint16_t channel)
781{
782 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
783 auto itSocket = pimpl_->sockets.find(channel);
784 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
785 pimpl_->sockets.erase(itSocket);
786}
787
788////////////////////////////////////////////////////////////////
789
790class ChannelSocket::Impl
791{
792public:
793 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
794 const std::string& name,
795 const uint16_t& channel,
796 bool isInitiator,
797 std::function<void()> rmFromMxSockCb)
798 : name(name)
799 , channel(channel)
800 , endpoint(std::move(endpoint))
801 , isInitiator_(isInitiator)
802 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
803 {}
804
805 ~Impl() {}
806
807 ChannelReadyCb readyCb_ {};
808 OnShutdownCb shutdownCb_ {};
809 std::atomic_bool isShutdown_ {false};
810 std::string name {};
811 uint16_t channel {};
812 std::weak_ptr<MultiplexedSocket> endpoint {};
813 bool isInitiator_ {false};
814 std::function<void()> rmFromMxSockCb_;
815
816 bool isAnswered_ {false};
817 bool isRemovable_ {false};
818
819 std::vector<uint8_t> buf {};
820 std::mutex mutex {};
821 std::condition_variable cv {};
822 GenericSocket<uint8_t>::RecvCb cb {};
823};
824
825ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
826 const DeviceId& deviceId,
827 const std::string& name,
828 const uint16_t& channel)
829 : pimpl_deviceId(deviceId)
830 , pimpl_name(name)
831 , pimpl_channel(channel)
832 , ioCtx_(*ctx)
833{}
834
835ChannelSocketTest::~ChannelSocketTest() {}
836
837void
838ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
839 const std::shared_ptr<ChannelSocketTest>& socket2)
840{
841 socket1->remote = socket2;
842 socket2->remote = socket1;
843}
844
845DeviceId
846ChannelSocketTest::deviceId() const
847{
848 return pimpl_deviceId;
849}
850
851std::string
852ChannelSocketTest::name() const
853{
854 return pimpl_name;
855}
856
857uint16_t
858ChannelSocketTest::channel() const
859{
860 return pimpl_channel;
861}
862
863void
864ChannelSocketTest::shutdown()
865{
866 {
867 std::unique_lock<std::mutex> lk {mutex};
868 if (!isShutdown_.exchange(true)) {
869 lk.unlock();
870 shutdownCb_();
871 }
872 cv.notify_all();
873 }
874
875 if (auto peer = remote.lock()) {
876 if (!peer->isShutdown_.exchange(true)) {
877 peer->shutdownCb_();
878 }
879 peer->cv.notify_all();
880 }
881}
882
883std::size_t
884ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
885{
886 std::size_t size = std::min(len, this->rx_buf.size());
887
888 for (std::size_t i = 0; i < size; ++i)
889 buf[i] = this->rx_buf[i];
890
891 if (size == this->rx_buf.size()) {
892 this->rx_buf.clear();
893 } else
894 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
895 return size;
896}
897
898std::size_t
899ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
900{
901 if (isShutdown_) {
902 ec = std::make_error_code(std::errc::broken_pipe);
903 return -1;
904 }
905 ec = {};
906 dht::ThreadPool::computation().run(
907 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
908 if (auto peer = r.lock())
909 peer->onRecv(std::move(data));
910 });
911 return len;
912}
913
914int
915ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
916{
917 std::unique_lock<std::mutex> lk {mutex};
918 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
919 return rx_buf.size();
920}
921
922void
923ChannelSocketTest::setOnRecv(RecvCb&& cb)
924{
925 std::lock_guard<std::mutex> lkSockets(mutex);
926 this->cb = std::move(cb);
927 if (!rx_buf.empty() && this->cb) {
928 this->cb(rx_buf.data(), rx_buf.size());
929 rx_buf.clear();
930 }
931}
932
933void
934ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
935{
936 std::lock_guard<std::mutex> lkSockets(mutex);
937 if (cb) {
938 cb(pkt.data(), pkt.size());
939 return;
940 }
941 rx_buf.insert(rx_buf.end(),
942 std::make_move_iterator(pkt.begin()),
943 std::make_move_iterator(pkt.end()));
944 cv.notify_all();
945}
946
947void
948ChannelSocketTest::onReady(ChannelReadyCb&& cb)
949{}
950
951void
952ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
953{
954 std::unique_lock<std::mutex> lk {mutex};
955 shutdownCb_ = std::move(cb);
956
957 if (isShutdown_) {
958 lk.unlock();
959 shutdownCb_();
960 }
961}
962
963ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
964 const std::string& name,
965 const uint16_t& channel,
966 bool isInitiator,
967 std::function<void()> rmFromMxSockCb)
968 : pimpl_ {
969 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
970{}
971
972ChannelSocket::~ChannelSocket() {}
973
974DeviceId
975ChannelSocket::deviceId() const
976{
977 if (auto ep = pimpl_->endpoint.lock()) {
978 return ep->deviceId();
979 }
980 return {};
981}
982
983std::string
984ChannelSocket::name() const
985{
986 return pimpl_->name;
987}
988
989uint16_t
990ChannelSocket::channel() const
991{
992 return pimpl_->channel;
993}
994
995bool
996ChannelSocket::isReliable() const
997{
998 if (auto ep = pimpl_->endpoint.lock()) {
999 return ep->isReliable();
1000 }
1001 return false;
1002}
1003
1004bool
1005ChannelSocket::isInitiator() const
1006{
1007 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1008 // because a multiplexed socket can have sockets from accepted requests
1009 // or made via connectDevice(). Here, isInitiator_ return if the socket
1010 // is from connectDevice.
1011 return pimpl_->isInitiator_;
1012}
1013
1014int
1015ChannelSocket::maxPayload() const
1016{
1017 if (auto ep = pimpl_->endpoint.lock()) {
1018 return ep->maxPayload();
1019 }
1020 return -1;
1021}
1022
1023void
1024ChannelSocket::setOnRecv(RecvCb&& cb)
1025{
1026 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1027 pimpl_->cb = std::move(cb);
1028 if (!pimpl_->buf.empty() && pimpl_->cb) {
1029 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1030 pimpl_->buf.clear();
1031 }
1032}
1033
1034void
1035ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1036{
1037 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1038 if (pimpl_->cb) {
1039 pimpl_->cb(&pkt[0], pkt.size());
1040 return;
1041 }
1042 pimpl_->buf.insert(pimpl_->buf.end(),
1043 std::make_move_iterator(pkt.begin()),
1044 std::make_move_iterator(pkt.end()));
1045 pimpl_->cv.notify_all();
1046}
1047
Adrien Béraud6b6a5d32023-08-15 15:53:33 -04001048#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -04001049std::shared_ptr<MultiplexedSocket>
1050ChannelSocket::underlyingSocket() const
1051{
1052 if (auto mtx = pimpl_->endpoint.lock())
1053 return mtx;
1054 return {};
1055}
1056#endif
1057
1058void
1059ChannelSocket::answered()
1060{
1061 pimpl_->isAnswered_ = true;
1062}
1063
1064void
1065ChannelSocket::removable()
1066{
1067 pimpl_->isRemovable_ = true;
1068}
1069
1070bool
1071ChannelSocket::isRemovable() const
1072{
1073 return pimpl_->isRemovable_;
1074}
1075
1076bool
1077ChannelSocket::isAnswered() const
1078{
1079 return pimpl_->isAnswered_;
1080}
1081
1082void
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001083ChannelSocket::ready(bool accepted)
Adrien Béraud612b55b2023-05-29 10:42:04 -04001084{
1085 if (pimpl_->readyCb_)
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001086 pimpl_->readyCb_(accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001087}
1088
1089void
1090ChannelSocket::stop()
1091{
1092 if (pimpl_->isShutdown_)
1093 return;
1094 pimpl_->isShutdown_ = true;
1095 if (pimpl_->shutdownCb_)
1096 pimpl_->shutdownCb_();
1097 pimpl_->cv.notify_all();
1098 // stop() can be called by ChannelSocket::shutdown()
1099 // In this case, the eventLoop is not used, but MxSock
1100 // must remove the channel from its list (so that the
1101 // channel can be destroyed and its shared_ptr invalidated).
1102 if (pimpl_->rmFromMxSockCb_)
1103 pimpl_->rmFromMxSockCb_();
1104}
1105
1106void
1107ChannelSocket::shutdown()
1108{
1109 if (pimpl_->isShutdown_)
1110 return;
1111 stop();
1112 if (auto ep = pimpl_->endpoint.lock()) {
1113 std::error_code ec;
1114 const uint8_t dummy = '\0';
1115 ep->write(pimpl_->channel, &dummy, 0, ec);
1116 }
1117}
1118
1119std::size_t
1120ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1121{
1122 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1123 std::size_t size = std::min(len, pimpl_->buf.size());
1124
1125 for (std::size_t i = 0; i < size; ++i)
1126 outBuf[i] = pimpl_->buf[i];
1127
1128 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1129 return size;
1130}
1131
1132std::size_t
1133ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1134{
1135 if (pimpl_->isShutdown_) {
1136 ec = std::make_error_code(std::errc::broken_pipe);
1137 return -1;
1138 }
1139 if (auto ep = pimpl_->endpoint.lock()) {
1140 std::size_t sent = 0;
1141 do {
1142 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1143 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1144 if (ec) {
1145 if (ep->logger())
1146 ep->logger()->error("Error when writing on channel: {}", ec.message());
1147 return res;
1148 }
1149 sent += toSend;
1150 } while (sent < len);
1151 return sent;
1152 }
1153 ec = std::make_error_code(std::errc::broken_pipe);
1154 return -1;
1155}
1156
1157int
1158ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1159{
1160 std::unique_lock<std::mutex> lk {pimpl_->mutex};
1161 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1162 return pimpl_->buf.size();
1163}
1164
1165void
1166ChannelSocket::onShutdown(OnShutdownCb&& cb)
1167{
1168 pimpl_->shutdownCb_ = std::move(cb);
1169 if (pimpl_->isShutdown_) {
1170 pimpl_->shutdownCb_();
1171 }
1172}
1173
1174void
1175ChannelSocket::onReady(ChannelReadyCb&& cb)
1176{
1177 pimpl_->readyCb_ = std::move(cb);
1178}
1179
1180void
1181ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1182{
1183 if (auto ep = pimpl_->endpoint.lock()) {
1184 ep->sendBeacon(timeout);
1185 } else {
1186 shutdown();
1187 }
1188}
1189
1190std::shared_ptr<dht::crypto::Certificate>
1191ChannelSocket::peerCertificate() const
1192{
1193 if (auto ep = pimpl_->endpoint.lock())
1194 return ep->peerCertificate();
1195 return {};
1196}
1197
1198IpAddr
1199ChannelSocket::getLocalAddress() const
1200{
1201 if (auto ep = pimpl_->endpoint.lock())
1202 return ep->getLocalAddress();
1203 return {};
1204}
1205
1206IpAddr
1207ChannelSocket::getRemoteAddress() const
1208{
1209 if (auto ep = pimpl_->endpoint.lock())
1210 return ep->getRemoteAddress();
1211 return {};
1212}
1213
Amna31791e52023-08-03 12:40:57 -04001214std::vector<std::map<std::string, std::string>>
1215MultiplexedSocket::getChannelList() const
1216{
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001217 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
Amna31791e52023-08-03 12:40:57 -04001218 std::vector<std::map<std::string, std::string>> channelsList;
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001219 channelsList.reserve(pimpl_->sockets.size());
Amna31791e52023-08-03 12:40:57 -04001220 for (const auto& [_, channel] : pimpl_->sockets) {
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001221 channelsList.emplace_back(std::map<std::string, std::string> {
1222 {"id", fmt::format("{:x}", channel->channel())},
1223 {"name", channel->name()},
1224 });
Amna31791e52023-08-03 12:40:57 -04001225 }
Amna31791e52023-08-03 12:40:57 -04001226 return channelsList;
1227}
1228
Sébastien Blin464bdff2023-07-19 08:02:53 -04001229} // namespace dhtnet