blob: 81a777cdf02e7ee564a437b8a6f0ce9fc80dde48 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
Adrien Béraudcb753622023-07-17 22:32:49 -04002 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
Adrien Béraud612b55b2023-05-29 10:42:04 -04003 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "multiplexed_socket.h"
18#include "peer_connection.h"
19#include "ice_transport.h"
20#include "certstore.h"
21
22#include <opendht/logger.h>
23#include <opendht/thread_pool.h>
24
25#include <asio/io_context.hpp>
26#include <asio/steady_timer.hpp>
27
28#include <deque>
29
30static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
31static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
32
33struct ChanneledMessage
34{
35 uint16_t channel;
36 std::vector<uint8_t> data;
37 MSGPACK_DEFINE(channel, data)
38};
39
40struct BeaconMsg
41{
42 bool p;
43 MSGPACK_DEFINE_MAP(p)
44};
45
46struct VersionMsg
47{
48 int v;
49 MSGPACK_DEFINE_MAP(v)
50};
51
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040052namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040053
54using clock = std::chrono::steady_clock;
55using time_point = clock::time_point;
56
57class MultiplexedSocket::Impl
58{
59public:
60 Impl(MultiplexedSocket& parent,
61 std::shared_ptr<asio::io_context> ctx,
62 const DeviceId& deviceId,
Adrien Béraud55133cc2023-10-15 11:55:20 -040063 std::unique_ptr<TlsSocketEndpoint> ep,
Adrien Béraud5636f7c2023-09-14 14:34:57 -040064 std::shared_ptr<dht::log::Logger> logger)
Adrien Béraud612b55b2023-05-29 10:42:04 -040065 : parent_(parent)
Adrien Béraud612b55b2023-05-29 10:42:04 -040066 , ctx_(std::move(ctx))
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040067 , deviceId(deviceId)
Adrien Béraud55133cc2023-10-15 11:55:20 -040068 , endpoint(std::move(ep))
69 , nextChannel_(endpoint->isInitiator() ? 0x0001u : 0x8000u)
Adrien Béraud612b55b2023-05-29 10:42:04 -040070 , eventLoopThread_ {[this] {
71 try {
72 eventLoop();
73 } catch (const std::exception& e) {
74 if (logger_)
75 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
76 shutdown();
77 }
78 }}
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040079 , beaconTimer_(*ctx_)
Adrien Béraud612b55b2023-05-29 10:42:04 -040080 {}
81
82 ~Impl() {}
83
84 void join()
85 {
86 if (!isShutdown_) {
87 if (endpoint)
88 endpoint->setOnStateChange({});
89 shutdown();
90 } else {
91 clearSockets();
92 }
93 if (eventLoopThread_.joinable())
94 eventLoopThread_.join();
95 }
96
97 void clearSockets()
98 {
99 decltype(sockets) socks;
100 {
101 std::lock_guard<std::mutex> lkSockets(socketsMutex);
102 socks = std::move(sockets);
103 }
104 for (auto& socket : socks) {
105 // Just trigger onShutdown() to make client know
106 // No need to write the EOF for the channel, the write will fail because endpoint is
107 // already shutdown
108 if (socket.second)
109 socket.second->stop();
110 }
111 }
112
113 void shutdown()
114 {
115 if (isShutdown_)
116 return;
117 stop.store(true);
118 isShutdown_ = true;
119 beaconTimer_.cancel();
120 if (onShutdown_)
121 onShutdown_();
122 if (endpoint) {
123 std::unique_lock<std::mutex> lk(writeMtx);
124 endpoint->shutdown();
125 }
126 clearSockets();
127 }
128
129 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
130 uint16_t channel,
131 bool isInitiator = false)
132 {
133 auto& channelSocket = sockets[channel];
134 if (not channelSocket)
135 channelSocket = std::make_shared<ChannelSocket>(
136 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
137 // Remove socket in another thread to avoid any lock
138 dht::ThreadPool::io().run([w, channel]() {
139 if (auto shared = w.lock()) {
140 shared->eraseChannel(channel);
141 }
142 });
143 });
144 else {
145 if (logger_)
146 logger_->warn("A channel is already present on that socket, accepting "
147 "the request will close the previous one {}", name);
148 }
149 return channelSocket;
150 }
151
152 /**
153 * Handle packets on the TLS endpoint and parse RTP
154 */
155 void eventLoop();
156 /**
157 * Triggered when a new control packet is received
158 */
159 void handleControlPacket(std::vector<uint8_t>&& pkt);
160 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
161 bool handleProtocolMsg(const msgpack::object& o);
162 /**
163 * Triggered when a new packet on a channel is received
164 */
165 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
166 void onRequest(const std::string& name, uint16_t channel);
167 void onAccept(const std::string& name, uint16_t channel);
168
169 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
170 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
171
172 // Beacon
173 void sendBeacon(const std::chrono::milliseconds& timeout);
174 void handleBeaconRequest();
175 void handleBeaconResponse();
176 std::atomic_int beaconCounter_ {0};
177
178 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
179
180 msgpack::unpacker pac_ {};
181
182 MultiplexedSocket& parent_;
183
184 std::shared_ptr<Logger> logger_;
185 std::shared_ptr<asio::io_context> ctx_;
186
187 OnConnectionReadyCb onChannelReady_ {};
188 OnConnectionRequestCb onRequest_ {};
189 OnShutdownCb onShutdown_ {};
190
191 DeviceId deviceId {};
192 // Main socket
193 std::unique_ptr<TlsSocketEndpoint> endpoint {};
194
195 std::mutex socketsMutex {};
196 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
Adrien Béraud55133cc2023-10-15 11:55:20 -0400197 uint16_t nextChannel_;
Adrien Béraud612b55b2023-05-29 10:42:04 -0400198
199 // Main loop to parse incoming packets
200 std::atomic_bool stop {false};
201 std::thread eventLoopThread_ {};
202
203 std::atomic_bool isShutdown_ {false};
204
205 std::mutex writeMtx {};
206
207 time_point start_ {clock::now()};
208 //std::shared_ptr<Task> beaconTask_ {};
209 asio::steady_timer beaconTimer_;
210
211 // version related stuff
212 void sendVersion();
213 void onVersion(int version);
214 std::atomic_bool canSendBeacon_ {false};
215 std::atomic_bool answerBeacon_ {true};
216 int version_ {MULTIPLEXED_SOCKET_VERSION};
217 std::function<void(bool)> onBeaconCb_ {};
218 std::function<void(int)> onVersionCb_ {};
219};
220
221void
222MultiplexedSocket::Impl::eventLoop()
223{
224 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
225 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
226 if (logger_)
227 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
228 shutdown();
229 return false;
230 }
231 return true;
232 });
233 sendVersion();
234 std::error_code ec;
235 while (!stop) {
236 if (!endpoint) {
237 shutdown();
238 return;
239 }
240 pac_.reserve_buffer(IO_BUFFER_SIZE);
241 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
242 if (size < 0) {
243 if (ec && logger_)
244 logger_->error("Read error detected: {}", ec.message());
245 break;
246 }
247 if (size == 0) {
248 // We can close the socket
249 shutdown();
250 break;
251 }
252
253 pac_.buffer_consumed(size);
254 msgpack::object_handle oh;
255 while (pac_.next(oh) && !stop) {
256 try {
257 auto msg = oh.get().as<ChanneledMessage>();
258 if (msg.channel == CONTROL_CHANNEL)
259 handleControlPacket(std::move(msg.data));
260 else if (msg.channel == PROTOCOL_CHANNEL)
261 handleProtocolPacket(std::move(msg.data));
262 else
263 handleChannelPacket(msg.channel, std::move(msg.data));
264 } catch (const std::exception& e) {
265 if (logger_)
266 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
267 } catch (...) {
268 if (logger_)
269 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
270 }
271 }
272 }
273}
274
275void
276MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
277{
278 std::lock_guard<std::mutex> lkSockets(socketsMutex);
279 auto& socket = sockets[channel];
280 if (!socket) {
281 if (logger_)
282 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
283 return;
284 }
285
286 onChannelReady_(deviceId, socket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400287 socket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400288 // Due to the callbacks that can take some time, onAccept can arrive after
289 // receiving all the data. In this case, the socket should be removed here
290 // as handle by onChannelReady_
291 if (socket->isRemovable())
292 sockets.erase(channel);
293 else
294 socket->answered();
295}
296
297void
298MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
299{
300 if (!canSendBeacon_)
301 return;
302 beaconCounter_++;
303 if (logger_)
304 logger_->debug("Send beacon to peer {}", deviceId);
305
306 msgpack::sbuffer buffer(8);
307 msgpack::packer<msgpack::sbuffer> pk(&buffer);
308 pk.pack(BeaconMsg {true});
309 if (!writeProtocolMessage(buffer))
310 return;
311 beaconTimer_.expires_after(timeout);
312 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
313 if (ec == asio::error::operation_aborted)
314 return;
315 if (auto shared = w.lock()) {
316 if (shared->pimpl_->beaconCounter_ != 0) {
317 if (shared->pimpl_->logger_)
318 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
319 shared->shutdown();
320 }
321 }
322 });
323}
324
325void
326MultiplexedSocket::Impl::handleBeaconRequest()
327{
328 if (!answerBeacon_)
329 return;
330 // Run this on dedicated thread because some callbacks can take time
331 dht::ThreadPool::io().run([w = parent_.weak()]() {
332 if (auto shared = w.lock()) {
333 msgpack::sbuffer buffer(8);
334 msgpack::packer<msgpack::sbuffer> pk(&buffer);
335 pk.pack(BeaconMsg {false});
336 if (shared->pimpl_->logger_)
337 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
338 shared->pimpl_->writeProtocolMessage(buffer);
339 }
340 });
341}
342
343void
344MultiplexedSocket::Impl::handleBeaconResponse()
345{
346 if (logger_)
347 logger_->debug("Get beacon response from peer {}", deviceId);
348 beaconCounter_--;
349}
350
351bool
352MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
353{
354 std::error_code ec;
355 int wr = parent_.write(PROTOCOL_CHANNEL,
356 (const unsigned char*) buffer.data(),
357 buffer.size(),
358 ec);
359 return wr > 0;
360}
361
362void
363MultiplexedSocket::Impl::sendVersion()
364{
365 dht::ThreadPool::io().run([w = parent_.weak()]() {
366 if (auto shared = w.lock()) {
367 auto version = shared->pimpl_->version_;
368 msgpack::sbuffer buffer(8);
369 msgpack::packer<msgpack::sbuffer> pk(&buffer);
370 pk.pack(VersionMsg {version});
371 shared->pimpl_->writeProtocolMessage(buffer);
372 }
373 });
374}
375
376void
377MultiplexedSocket::Impl::onVersion(int version)
378{
379 // Check if version > 1
380 if (version >= 1) {
381 if (logger_)
382 logger_->debug("Peer {} supports beacon", deviceId);
383 canSendBeacon_ = true;
384 } else {
385 if (logger_)
386 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
387 deviceId,
388 version);
389 canSendBeacon_ = false;
390 }
391}
392
393void
394MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
395{
396 auto accept = onRequest_(endpoint->peerCertificate(), channel, name);
397 std::shared_ptr<ChannelSocket> channelSocket;
398 if (accept) {
399 std::lock_guard<std::mutex> lkSockets(socketsMutex);
400 channelSocket = makeSocket(name, channel);
401 }
402
403 // Answer to ChannelRequest if accepted
404 ChannelRequest val;
405 val.channel = channel;
406 val.name = name;
407 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
408 msgpack::sbuffer buffer(512);
409 msgpack::pack(buffer, val);
410 std::error_code ec;
411 int wr = parent_.write(CONTROL_CHANNEL,
412 reinterpret_cast<const uint8_t*>(buffer.data()),
413 buffer.size(),
414 ec);
415 if (wr < 0) {
416 if (ec && logger_)
417 logger_->error("The write operation failed with error: {:s}", ec.message());
418 stop.store(true);
419 return;
420 }
421
422 if (accept) {
423 onChannelReady_(deviceId, channelSocket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400424 channelSocket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400425 }
426}
427
428void
429MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
430{
431 // Run this on dedicated thread because some callbacks can take time
432 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
433 auto shared = w.lock();
434 if (!shared)
435 return;
436 auto& pimpl = *shared->pimpl_;
437 try {
438 size_t off = 0;
439 while (off != pkt.size()) {
440 msgpack::unpacked result;
441 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
442 auto object = result.get();
443 if (pimpl.handleProtocolMsg(object))
444 continue;
445 auto req = object.as<ChannelRequest>();
446 if (req.state == ChannelRequestState::ACCEPT) {
447 pimpl.onAccept(req.name, req.channel);
448 } else if (req.state == ChannelRequestState::DECLINE) {
449 std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
450 auto channel = pimpl.sockets.find(req.channel);
451 if (channel != pimpl.sockets.end()) {
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400452 channel->second->ready(false);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400453 channel->second->stop();
454 pimpl.sockets.erase(channel);
455 }
456 } else if (pimpl.onRequest_) {
457 pimpl.onRequest(req.name, req.channel);
458 }
459 }
460 } catch (const std::exception& e) {
461 if (pimpl.logger_)
462 pimpl.logger_->error("Error on the control channel: {}", e.what());
463 }
464 });
465}
466
467void
468MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
469{
470 std::lock_guard<std::mutex> lkSockets(socketsMutex);
471 auto sockIt = sockets.find(channel);
472 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
473 if (pkt.size() == 0) {
474 sockIt->second->stop();
475 if (sockIt->second->isAnswered())
476 sockets.erase(sockIt);
477 else
478 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
479 // removed later.
480 } else {
481 sockIt->second->onRecv(std::move(pkt));
482 }
483 } else if (pkt.size() != 0) {
484 if (logger_)
485 logger_->warn("Non existing channel: {}", channel);
486 }
487}
488
489bool
490MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
491{
492 try {
493 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
494 auto key = o.via.map.ptr[0].key.as<std::string_view>();
495 if (key == "p") {
496 auto msg = o.as<BeaconMsg>();
497 if (msg.p)
498 handleBeaconRequest();
499 else
500 handleBeaconResponse();
501 if (onBeaconCb_)
502 onBeaconCb_(msg.p);
503 return true;
504 } else if (key == "v") {
505 auto msg = o.as<VersionMsg>();
506 onVersion(msg.v);
507 if (onVersionCb_)
508 onVersionCb_(msg.v);
509 return true;
510 } else {
511 if (logger_)
512 logger_->warn("Unknown message type");
513 }
514 }
515 } catch (const std::exception& e) {
516 if (logger_)
517 logger_->error("Error on the protocol channel: {}", e.what());
518 }
519 return false;
520}
521
522void
523MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
524{
525 // Run this on dedicated thread because some callbacks can take time
526 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
527 auto shared = w.lock();
528 if (!shared)
529 return;
530 try {
531 size_t off = 0;
532 while (off != pkt.size()) {
533 msgpack::unpacked result;
534 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
535 auto object = result.get();
536 if (shared->pimpl_->handleProtocolMsg(object))
537 return;
538 }
539 } catch (const std::exception& e) {
540 if (shared->pimpl_->logger_)
541 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
542 }
543 });
544}
545
546MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
Adrien Béraud5636f7c2023-09-14 14:34:57 -0400547 std::unique_ptr<TlsSocketEndpoint> endpoint, std::shared_ptr<dht::log::Logger> logger)
548 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint), logger))
Adrien Béraud612b55b2023-05-29 10:42:04 -0400549{}
550
551MultiplexedSocket::~MultiplexedSocket() {}
552
553std::shared_ptr<ChannelSocket>
554MultiplexedSocket::addChannel(const std::string& name)
555{
Adrien Béraud612b55b2023-05-29 10:42:04 -0400556 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
Adrien Béraud55133cc2023-10-15 11:55:20 -0400557 if (pimpl_->sockets.size() < UINT16_MAX)
558 for (unsigned i = 0; i < UINT16_MAX; ++i) {
559 auto c = pimpl_->nextChannel_++;
560 if (c == CONTROL_CHANNEL
561 || c == PROTOCOL_CHANNEL
562 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
563 continue;
564 return pimpl_->makeSocket(name, c, true);
565 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400566 return {};
567}
568
569DeviceId
570MultiplexedSocket::deviceId() const
571{
572 return pimpl_->deviceId;
573}
574
575void
576MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
577{
578 pimpl_->onChannelReady_ = std::move(cb);
579}
580
581void
582MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
583{
584 pimpl_->onRequest_ = std::move(cb);
585}
586
587bool
588MultiplexedSocket::isReliable() const
589{
590 return true;
591}
592
593bool
594MultiplexedSocket::isInitiator() const
595{
596 if (!pimpl_->endpoint) {
597 if (pimpl_->logger_)
598 pimpl_->logger_->warn("No endpoint found for socket");
599 return false;
600 }
601 return pimpl_->endpoint->isInitiator();
602}
603
604int
605MultiplexedSocket::maxPayload() const
606{
607 if (!pimpl_->endpoint) {
608 if (pimpl_->logger_)
609 pimpl_->logger_->warn("No endpoint found for socket");
610 return 0;
611 }
612 return pimpl_->endpoint->maxPayload();
613}
614
615std::size_t
616MultiplexedSocket::write(const uint16_t& channel,
617 const uint8_t* buf,
618 std::size_t len,
619 std::error_code& ec)
620{
621 assert(nullptr != buf);
622
623 if (pimpl_->isShutdown_) {
624 ec = std::make_error_code(std::errc::broken_pipe);
625 return -1;
626 }
627 if (len > UINT16_MAX) {
628 ec = std::make_error_code(std::errc::message_size);
629 return -1;
630 }
631 bool oneShot = len < 8192;
632 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
633 msgpack::packer<msgpack::sbuffer> pk(&buffer);
634 pk.pack_array(2);
635 pk.pack(channel);
636 pk.pack_bin(len);
637 if (oneShot)
638 pk.pack_bin_body((const char*) buf, len);
639
640 std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
641 if (!pimpl_->endpoint) {
642 if (pimpl_->logger_)
643 pimpl_->logger_->warn("No endpoint found for socket");
644 ec = std::make_error_code(std::errc::broken_pipe);
645 return -1;
646 }
647 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
648 if (not oneShot and res >= 0)
649 res = pimpl_->endpoint->write(buf, len, ec);
650 lk.unlock();
651 if (res < 0) {
652 if (ec && pimpl_->logger_)
653 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
654 shutdown();
655 }
656 return res;
657}
658
659void
660MultiplexedSocket::shutdown()
661{
662 pimpl_->shutdown();
663}
664
665void
666MultiplexedSocket::join()
667{
668 pimpl_->join();
669}
670
671void
672MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
673{
674 pimpl_->onShutdown_ = std::move(cb);
675 if (pimpl_->isShutdown_)
676 pimpl_->onShutdown_();
677}
678
679const std::shared_ptr<Logger>&
680MultiplexedSocket::logger()
681{
682 return pimpl_->logger_;
683}
684
685void
686MultiplexedSocket::monitor() const
687{
688 auto cert = peerCertificate();
689 if (!cert || !cert->issuer)
690 return;
691 auto now = clock::now();
692 if (!pimpl_->logger_)
693 return;
694 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
695 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
696 pimpl_->endpoint->monitor();
697 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
698 for (const auto& [_, channel] : pimpl_->sockets) {
699 if (channel)
700 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
701 fmt::ptr(channel.get()),
702 channel.use_count(),
703 channel->name(),
704 channel->isInitiator());
705 }
706}
707
708void
709MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
710{
711 pimpl_->sendBeacon(timeout);
712}
713
714std::shared_ptr<dht::crypto::Certificate>
715MultiplexedSocket::peerCertificate() const
716{
717 return pimpl_->endpoint->peerCertificate();
718}
719
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400720#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400721bool
722MultiplexedSocket::canSendBeacon() const
723{
724 return pimpl_->canSendBeacon_;
725}
726
727void
728MultiplexedSocket::answerToBeacon(bool value)
729{
730 pimpl_->answerBeacon_ = value;
731}
732
733void
734MultiplexedSocket::setVersion(int version)
735{
736 pimpl_->version_ = version;
737}
738
739void
740MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
741{
742 pimpl_->onBeaconCb_ = cb;
743}
744
745void
746MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
747{
748 pimpl_->onVersionCb_ = cb;
749}
750
751void
752MultiplexedSocket::sendVersion()
753{
754 pimpl_->sendVersion();
755}
756
Adrien Béraudac35e662023-07-19 09:37:29 -0400757#endif
758
Adrien Béraud612b55b2023-05-29 10:42:04 -0400759IpAddr
760MultiplexedSocket::getLocalAddress() const
761{
762 return pimpl_->endpoint->getLocalAddress();
763}
764
765IpAddr
766MultiplexedSocket::getRemoteAddress() const
767{
768 return pimpl_->endpoint->getRemoteAddress();
769}
770
Adrien Béraudafa8e282023-09-24 12:53:20 -0400771TlsSocketEndpoint*
772MultiplexedSocket::endpoint()
773{
774 return pimpl_->endpoint.get();
775}
776
Adrien Béraud612b55b2023-05-29 10:42:04 -0400777void
778MultiplexedSocket::eraseChannel(uint16_t channel)
779{
780 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
781 auto itSocket = pimpl_->sockets.find(channel);
782 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
783 pimpl_->sockets.erase(itSocket);
784}
785
786////////////////////////////////////////////////////////////////
787
788class ChannelSocket::Impl
789{
790public:
791 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
792 const std::string& name,
793 const uint16_t& channel,
794 bool isInitiator,
795 std::function<void()> rmFromMxSockCb)
796 : name(name)
797 , channel(channel)
798 , endpoint(std::move(endpoint))
799 , isInitiator_(isInitiator)
800 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
801 {}
802
803 ~Impl() {}
804
805 ChannelReadyCb readyCb_ {};
806 OnShutdownCb shutdownCb_ {};
807 std::atomic_bool isShutdown_ {false};
808 std::string name {};
809 uint16_t channel {};
810 std::weak_ptr<MultiplexedSocket> endpoint {};
811 bool isInitiator_ {false};
812 std::function<void()> rmFromMxSockCb_;
813
814 bool isAnswered_ {false};
815 bool isRemovable_ {false};
816
817 std::vector<uint8_t> buf {};
818 std::mutex mutex {};
819 std::condition_variable cv {};
820 GenericSocket<uint8_t>::RecvCb cb {};
821};
822
823ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
824 const DeviceId& deviceId,
825 const std::string& name,
826 const uint16_t& channel)
827 : pimpl_deviceId(deviceId)
828 , pimpl_name(name)
829 , pimpl_channel(channel)
830 , ioCtx_(*ctx)
831{}
832
833ChannelSocketTest::~ChannelSocketTest() {}
834
835void
836ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
837 const std::shared_ptr<ChannelSocketTest>& socket2)
838{
839 socket1->remote = socket2;
840 socket2->remote = socket1;
841}
842
843DeviceId
844ChannelSocketTest::deviceId() const
845{
846 return pimpl_deviceId;
847}
848
849std::string
850ChannelSocketTest::name() const
851{
852 return pimpl_name;
853}
854
855uint16_t
856ChannelSocketTest::channel() const
857{
858 return pimpl_channel;
859}
860
861void
862ChannelSocketTest::shutdown()
863{
864 {
865 std::unique_lock<std::mutex> lk {mutex};
866 if (!isShutdown_.exchange(true)) {
867 lk.unlock();
868 shutdownCb_();
869 }
870 cv.notify_all();
871 }
872
873 if (auto peer = remote.lock()) {
874 if (!peer->isShutdown_.exchange(true)) {
875 peer->shutdownCb_();
876 }
877 peer->cv.notify_all();
878 }
879}
880
881std::size_t
882ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
883{
884 std::size_t size = std::min(len, this->rx_buf.size());
885
886 for (std::size_t i = 0; i < size; ++i)
887 buf[i] = this->rx_buf[i];
888
889 if (size == this->rx_buf.size()) {
890 this->rx_buf.clear();
891 } else
892 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
893 return size;
894}
895
896std::size_t
897ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
898{
899 if (isShutdown_) {
900 ec = std::make_error_code(std::errc::broken_pipe);
901 return -1;
902 }
903 ec = {};
904 dht::ThreadPool::computation().run(
905 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
906 if (auto peer = r.lock())
907 peer->onRecv(std::move(data));
908 });
909 return len;
910}
911
912int
913ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
914{
915 std::unique_lock<std::mutex> lk {mutex};
916 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
917 return rx_buf.size();
918}
919
920void
921ChannelSocketTest::setOnRecv(RecvCb&& cb)
922{
923 std::lock_guard<std::mutex> lkSockets(mutex);
924 this->cb = std::move(cb);
925 if (!rx_buf.empty() && this->cb) {
926 this->cb(rx_buf.data(), rx_buf.size());
927 rx_buf.clear();
928 }
929}
930
931void
932ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
933{
934 std::lock_guard<std::mutex> lkSockets(mutex);
935 if (cb) {
936 cb(pkt.data(), pkt.size());
937 return;
938 }
939 rx_buf.insert(rx_buf.end(),
940 std::make_move_iterator(pkt.begin()),
941 std::make_move_iterator(pkt.end()));
942 cv.notify_all();
943}
944
945void
946ChannelSocketTest::onReady(ChannelReadyCb&& cb)
947{}
948
949void
950ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
951{
952 std::unique_lock<std::mutex> lk {mutex};
953 shutdownCb_ = std::move(cb);
954
955 if (isShutdown_) {
956 lk.unlock();
957 shutdownCb_();
958 }
959}
960
961ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
962 const std::string& name,
963 const uint16_t& channel,
964 bool isInitiator,
965 std::function<void()> rmFromMxSockCb)
966 : pimpl_ {
967 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
968{}
969
970ChannelSocket::~ChannelSocket() {}
971
972DeviceId
973ChannelSocket::deviceId() const
974{
975 if (auto ep = pimpl_->endpoint.lock()) {
976 return ep->deviceId();
977 }
978 return {};
979}
980
981std::string
982ChannelSocket::name() const
983{
984 return pimpl_->name;
985}
986
987uint16_t
988ChannelSocket::channel() const
989{
990 return pimpl_->channel;
991}
992
993bool
994ChannelSocket::isReliable() const
995{
996 if (auto ep = pimpl_->endpoint.lock()) {
997 return ep->isReliable();
998 }
999 return false;
1000}
1001
1002bool
1003ChannelSocket::isInitiator() const
1004{
1005 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1006 // because a multiplexed socket can have sockets from accepted requests
1007 // or made via connectDevice(). Here, isInitiator_ return if the socket
1008 // is from connectDevice.
1009 return pimpl_->isInitiator_;
1010}
1011
1012int
1013ChannelSocket::maxPayload() const
1014{
1015 if (auto ep = pimpl_->endpoint.lock()) {
1016 return ep->maxPayload();
1017 }
1018 return -1;
1019}
1020
1021void
1022ChannelSocket::setOnRecv(RecvCb&& cb)
1023{
1024 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1025 pimpl_->cb = std::move(cb);
1026 if (!pimpl_->buf.empty() && pimpl_->cb) {
1027 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1028 pimpl_->buf.clear();
1029 }
1030}
1031
1032void
1033ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1034{
1035 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1036 if (pimpl_->cb) {
1037 pimpl_->cb(&pkt[0], pkt.size());
1038 return;
1039 }
1040 pimpl_->buf.insert(pimpl_->buf.end(),
1041 std::make_move_iterator(pkt.begin()),
1042 std::make_move_iterator(pkt.end()));
1043 pimpl_->cv.notify_all();
1044}
1045
Adrien Béraud6b6a5d32023-08-15 15:53:33 -04001046#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -04001047std::shared_ptr<MultiplexedSocket>
1048ChannelSocket::underlyingSocket() const
1049{
1050 if (auto mtx = pimpl_->endpoint.lock())
1051 return mtx;
1052 return {};
1053}
1054#endif
1055
1056void
1057ChannelSocket::answered()
1058{
1059 pimpl_->isAnswered_ = true;
1060}
1061
1062void
1063ChannelSocket::removable()
1064{
1065 pimpl_->isRemovable_ = true;
1066}
1067
1068bool
1069ChannelSocket::isRemovable() const
1070{
1071 return pimpl_->isRemovable_;
1072}
1073
1074bool
1075ChannelSocket::isAnswered() const
1076{
1077 return pimpl_->isAnswered_;
1078}
1079
1080void
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001081ChannelSocket::ready(bool accepted)
Adrien Béraud612b55b2023-05-29 10:42:04 -04001082{
1083 if (pimpl_->readyCb_)
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001084 pimpl_->readyCb_(accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001085}
1086
1087void
1088ChannelSocket::stop()
1089{
1090 if (pimpl_->isShutdown_)
1091 return;
1092 pimpl_->isShutdown_ = true;
1093 if (pimpl_->shutdownCb_)
1094 pimpl_->shutdownCb_();
1095 pimpl_->cv.notify_all();
1096 // stop() can be called by ChannelSocket::shutdown()
1097 // In this case, the eventLoop is not used, but MxSock
1098 // must remove the channel from its list (so that the
1099 // channel can be destroyed and its shared_ptr invalidated).
1100 if (pimpl_->rmFromMxSockCb_)
1101 pimpl_->rmFromMxSockCb_();
1102}
1103
1104void
1105ChannelSocket::shutdown()
1106{
1107 if (pimpl_->isShutdown_)
1108 return;
1109 stop();
1110 if (auto ep = pimpl_->endpoint.lock()) {
1111 std::error_code ec;
1112 const uint8_t dummy = '\0';
1113 ep->write(pimpl_->channel, &dummy, 0, ec);
1114 }
1115}
1116
1117std::size_t
1118ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1119{
1120 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1121 std::size_t size = std::min(len, pimpl_->buf.size());
1122
1123 for (std::size_t i = 0; i < size; ++i)
1124 outBuf[i] = pimpl_->buf[i];
1125
1126 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1127 return size;
1128}
1129
1130std::size_t
1131ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1132{
1133 if (pimpl_->isShutdown_) {
1134 ec = std::make_error_code(std::errc::broken_pipe);
1135 return -1;
1136 }
1137 if (auto ep = pimpl_->endpoint.lock()) {
1138 std::size_t sent = 0;
1139 do {
1140 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1141 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1142 if (ec) {
1143 if (ep->logger())
1144 ep->logger()->error("Error when writing on channel: {}", ec.message());
1145 return res;
1146 }
1147 sent += toSend;
1148 } while (sent < len);
1149 return sent;
1150 }
1151 ec = std::make_error_code(std::errc::broken_pipe);
1152 return -1;
1153}
1154
1155int
1156ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1157{
1158 std::unique_lock<std::mutex> lk {pimpl_->mutex};
1159 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1160 return pimpl_->buf.size();
1161}
1162
1163void
1164ChannelSocket::onShutdown(OnShutdownCb&& cb)
1165{
1166 pimpl_->shutdownCb_ = std::move(cb);
1167 if (pimpl_->isShutdown_) {
1168 pimpl_->shutdownCb_();
1169 }
1170}
1171
1172void
1173ChannelSocket::onReady(ChannelReadyCb&& cb)
1174{
1175 pimpl_->readyCb_ = std::move(cb);
1176}
1177
1178void
1179ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1180{
1181 if (auto ep = pimpl_->endpoint.lock()) {
1182 ep->sendBeacon(timeout);
1183 } else {
1184 shutdown();
1185 }
1186}
1187
1188std::shared_ptr<dht::crypto::Certificate>
1189ChannelSocket::peerCertificate() const
1190{
1191 if (auto ep = pimpl_->endpoint.lock())
1192 return ep->peerCertificate();
1193 return {};
1194}
1195
1196IpAddr
1197ChannelSocket::getLocalAddress() const
1198{
1199 if (auto ep = pimpl_->endpoint.lock())
1200 return ep->getLocalAddress();
1201 return {};
1202}
1203
1204IpAddr
1205ChannelSocket::getRemoteAddress() const
1206{
1207 if (auto ep = pimpl_->endpoint.lock())
1208 return ep->getRemoteAddress();
1209 return {};
1210}
1211
Amna31791e52023-08-03 12:40:57 -04001212std::vector<std::map<std::string, std::string>>
1213MultiplexedSocket::getChannelList() const
1214{
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001215 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
Amna31791e52023-08-03 12:40:57 -04001216 std::vector<std::map<std::string, std::string>> channelsList;
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001217 channelsList.reserve(pimpl_->sockets.size());
Amna31791e52023-08-03 12:40:57 -04001218 for (const auto& [_, channel] : pimpl_->sockets) {
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001219 channelsList.emplace_back(std::map<std::string, std::string> {
1220 {"id", fmt::format("{:x}", channel->channel())},
1221 {"name", channel->name()},
1222 });
Amna31791e52023-08-03 12:40:57 -04001223 }
Amna31791e52023-08-03 12:40:57 -04001224 return channelsList;
1225}
1226
Sébastien Blin464bdff2023-07-19 08:02:53 -04001227} // namespace dhtnet