blob: 97c9e1cd707137741f5430279f48a87319121b12 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
Adrien Béraudcb753622023-07-17 22:32:49 -04002 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
Adrien Béraud612b55b2023-05-29 10:42:04 -04003 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "multiplexed_socket.h"
18#include "peer_connection.h"
19#include "ice_transport.h"
20#include "certstore.h"
21
22#include <opendht/logger.h>
23#include <opendht/thread_pool.h>
24
25#include <asio/io_context.hpp>
26#include <asio/steady_timer.hpp>
27
28#include <deque>
29
30static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
31static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
32
33struct ChanneledMessage
34{
35 uint16_t channel;
36 std::vector<uint8_t> data;
37 MSGPACK_DEFINE(channel, data)
38};
39
40struct BeaconMsg
41{
42 bool p;
43 MSGPACK_DEFINE_MAP(p)
44};
45
46struct VersionMsg
47{
48 int v;
49 MSGPACK_DEFINE_MAP(v)
50};
51
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040052namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040053
54using clock = std::chrono::steady_clock;
55using time_point = clock::time_point;
56
57class MultiplexedSocket::Impl
58{
59public:
60 Impl(MultiplexedSocket& parent,
61 std::shared_ptr<asio::io_context> ctx,
62 const DeviceId& deviceId,
63 std::unique_ptr<TlsSocketEndpoint> endpoint)
64 : parent_(parent)
65 , deviceId(deviceId)
66 , ctx_(std::move(ctx))
67 , beaconTimer_(*ctx_)
68 , endpoint(std::move(endpoint))
69 , eventLoopThread_ {[this] {
70 try {
71 eventLoop();
72 } catch (const std::exception& e) {
73 if (logger_)
74 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
75 shutdown();
76 }
77 }}
78 {}
79
80 ~Impl() {}
81
82 void join()
83 {
84 if (!isShutdown_) {
85 if (endpoint)
86 endpoint->setOnStateChange({});
87 shutdown();
88 } else {
89 clearSockets();
90 }
91 if (eventLoopThread_.joinable())
92 eventLoopThread_.join();
93 }
94
95 void clearSockets()
96 {
97 decltype(sockets) socks;
98 {
99 std::lock_guard<std::mutex> lkSockets(socketsMutex);
100 socks = std::move(sockets);
101 }
102 for (auto& socket : socks) {
103 // Just trigger onShutdown() to make client know
104 // No need to write the EOF for the channel, the write will fail because endpoint is
105 // already shutdown
106 if (socket.second)
107 socket.second->stop();
108 }
109 }
110
111 void shutdown()
112 {
113 if (isShutdown_)
114 return;
115 stop.store(true);
116 isShutdown_ = true;
117 beaconTimer_.cancel();
118 if (onShutdown_)
119 onShutdown_();
120 if (endpoint) {
121 std::unique_lock<std::mutex> lk(writeMtx);
122 endpoint->shutdown();
123 }
124 clearSockets();
125 }
126
127 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
128 uint16_t channel,
129 bool isInitiator = false)
130 {
131 auto& channelSocket = sockets[channel];
132 if (not channelSocket)
133 channelSocket = std::make_shared<ChannelSocket>(
134 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
135 // Remove socket in another thread to avoid any lock
136 dht::ThreadPool::io().run([w, channel]() {
137 if (auto shared = w.lock()) {
138 shared->eraseChannel(channel);
139 }
140 });
141 });
142 else {
143 if (logger_)
144 logger_->warn("A channel is already present on that socket, accepting "
145 "the request will close the previous one {}", name);
146 }
147 return channelSocket;
148 }
149
150 /**
151 * Handle packets on the TLS endpoint and parse RTP
152 */
153 void eventLoop();
154 /**
155 * Triggered when a new control packet is received
156 */
157 void handleControlPacket(std::vector<uint8_t>&& pkt);
158 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
159 bool handleProtocolMsg(const msgpack::object& o);
160 /**
161 * Triggered when a new packet on a channel is received
162 */
163 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
164 void onRequest(const std::string& name, uint16_t channel);
165 void onAccept(const std::string& name, uint16_t channel);
166
167 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
168 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
169
170 // Beacon
171 void sendBeacon(const std::chrono::milliseconds& timeout);
172 void handleBeaconRequest();
173 void handleBeaconResponse();
174 std::atomic_int beaconCounter_ {0};
175
176 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
177
178 msgpack::unpacker pac_ {};
179
180 MultiplexedSocket& parent_;
181
182 std::shared_ptr<Logger> logger_;
183 std::shared_ptr<asio::io_context> ctx_;
184
185 OnConnectionReadyCb onChannelReady_ {};
186 OnConnectionRequestCb onRequest_ {};
187 OnShutdownCb onShutdown_ {};
188
189 DeviceId deviceId {};
190 // Main socket
191 std::unique_ptr<TlsSocketEndpoint> endpoint {};
192
193 std::mutex socketsMutex {};
194 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
195
196 // Main loop to parse incoming packets
197 std::atomic_bool stop {false};
198 std::thread eventLoopThread_ {};
199
200 std::atomic_bool isShutdown_ {false};
201
202 std::mutex writeMtx {};
203
204 time_point start_ {clock::now()};
205 //std::shared_ptr<Task> beaconTask_ {};
206 asio::steady_timer beaconTimer_;
207
208 // version related stuff
209 void sendVersion();
210 void onVersion(int version);
211 std::atomic_bool canSendBeacon_ {false};
212 std::atomic_bool answerBeacon_ {true};
213 int version_ {MULTIPLEXED_SOCKET_VERSION};
214 std::function<void(bool)> onBeaconCb_ {};
215 std::function<void(int)> onVersionCb_ {};
216};
217
218void
219MultiplexedSocket::Impl::eventLoop()
220{
221 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
222 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
223 if (logger_)
224 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
225 shutdown();
226 return false;
227 }
228 return true;
229 });
230 sendVersion();
231 std::error_code ec;
232 while (!stop) {
233 if (!endpoint) {
234 shutdown();
235 return;
236 }
237 pac_.reserve_buffer(IO_BUFFER_SIZE);
238 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
239 if (size < 0) {
240 if (ec && logger_)
241 logger_->error("Read error detected: {}", ec.message());
242 break;
243 }
244 if (size == 0) {
245 // We can close the socket
246 shutdown();
247 break;
248 }
249
250 pac_.buffer_consumed(size);
251 msgpack::object_handle oh;
252 while (pac_.next(oh) && !stop) {
253 try {
254 auto msg = oh.get().as<ChanneledMessage>();
255 if (msg.channel == CONTROL_CHANNEL)
256 handleControlPacket(std::move(msg.data));
257 else if (msg.channel == PROTOCOL_CHANNEL)
258 handleProtocolPacket(std::move(msg.data));
259 else
260 handleChannelPacket(msg.channel, std::move(msg.data));
261 } catch (const std::exception& e) {
262 if (logger_)
263 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
264 } catch (...) {
265 if (logger_)
266 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
267 }
268 }
269 }
270}
271
272void
273MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
274{
275 std::lock_guard<std::mutex> lkSockets(socketsMutex);
276 auto& socket = sockets[channel];
277 if (!socket) {
278 if (logger_)
279 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
280 return;
281 }
282
283 onChannelReady_(deviceId, socket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400284 socket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400285 // Due to the callbacks that can take some time, onAccept can arrive after
286 // receiving all the data. In this case, the socket should be removed here
287 // as handle by onChannelReady_
288 if (socket->isRemovable())
289 sockets.erase(channel);
290 else
291 socket->answered();
292}
293
294void
295MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
296{
297 if (!canSendBeacon_)
298 return;
299 beaconCounter_++;
300 if (logger_)
301 logger_->debug("Send beacon to peer {}", deviceId);
302
303 msgpack::sbuffer buffer(8);
304 msgpack::packer<msgpack::sbuffer> pk(&buffer);
305 pk.pack(BeaconMsg {true});
306 if (!writeProtocolMessage(buffer))
307 return;
308 beaconTimer_.expires_after(timeout);
309 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
310 if (ec == asio::error::operation_aborted)
311 return;
312 if (auto shared = w.lock()) {
313 if (shared->pimpl_->beaconCounter_ != 0) {
314 if (shared->pimpl_->logger_)
315 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
316 shared->shutdown();
317 }
318 }
319 });
320}
321
322void
323MultiplexedSocket::Impl::handleBeaconRequest()
324{
325 if (!answerBeacon_)
326 return;
327 // Run this on dedicated thread because some callbacks can take time
328 dht::ThreadPool::io().run([w = parent_.weak()]() {
329 if (auto shared = w.lock()) {
330 msgpack::sbuffer buffer(8);
331 msgpack::packer<msgpack::sbuffer> pk(&buffer);
332 pk.pack(BeaconMsg {false});
333 if (shared->pimpl_->logger_)
334 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
335 shared->pimpl_->writeProtocolMessage(buffer);
336 }
337 });
338}
339
340void
341MultiplexedSocket::Impl::handleBeaconResponse()
342{
343 if (logger_)
344 logger_->debug("Get beacon response from peer {}", deviceId);
345 beaconCounter_--;
346}
347
348bool
349MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
350{
351 std::error_code ec;
352 int wr = parent_.write(PROTOCOL_CHANNEL,
353 (const unsigned char*) buffer.data(),
354 buffer.size(),
355 ec);
356 return wr > 0;
357}
358
359void
360MultiplexedSocket::Impl::sendVersion()
361{
362 dht::ThreadPool::io().run([w = parent_.weak()]() {
363 if (auto shared = w.lock()) {
364 auto version = shared->pimpl_->version_;
365 msgpack::sbuffer buffer(8);
366 msgpack::packer<msgpack::sbuffer> pk(&buffer);
367 pk.pack(VersionMsg {version});
368 shared->pimpl_->writeProtocolMessage(buffer);
369 }
370 });
371}
372
373void
374MultiplexedSocket::Impl::onVersion(int version)
375{
376 // Check if version > 1
377 if (version >= 1) {
378 if (logger_)
379 logger_->debug("Peer {} supports beacon", deviceId);
380 canSendBeacon_ = true;
381 } else {
382 if (logger_)
383 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
384 deviceId,
385 version);
386 canSendBeacon_ = false;
387 }
388}
389
390void
391MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
392{
393 auto accept = onRequest_(endpoint->peerCertificate(), channel, name);
394 std::shared_ptr<ChannelSocket> channelSocket;
395 if (accept) {
396 std::lock_guard<std::mutex> lkSockets(socketsMutex);
397 channelSocket = makeSocket(name, channel);
398 }
399
400 // Answer to ChannelRequest if accepted
401 ChannelRequest val;
402 val.channel = channel;
403 val.name = name;
404 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
405 msgpack::sbuffer buffer(512);
406 msgpack::pack(buffer, val);
407 std::error_code ec;
408 int wr = parent_.write(CONTROL_CHANNEL,
409 reinterpret_cast<const uint8_t*>(buffer.data()),
410 buffer.size(),
411 ec);
412 if (wr < 0) {
413 if (ec && logger_)
414 logger_->error("The write operation failed with error: {:s}", ec.message());
415 stop.store(true);
416 return;
417 }
418
419 if (accept) {
420 onChannelReady_(deviceId, channelSocket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400421 channelSocket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400422 }
423}
424
425void
426MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
427{
428 // Run this on dedicated thread because some callbacks can take time
429 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
430 auto shared = w.lock();
431 if (!shared)
432 return;
433 auto& pimpl = *shared->pimpl_;
434 try {
435 size_t off = 0;
436 while (off != pkt.size()) {
437 msgpack::unpacked result;
438 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
439 auto object = result.get();
440 if (pimpl.handleProtocolMsg(object))
441 continue;
442 auto req = object.as<ChannelRequest>();
443 if (req.state == ChannelRequestState::ACCEPT) {
444 pimpl.onAccept(req.name, req.channel);
445 } else if (req.state == ChannelRequestState::DECLINE) {
446 std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
447 auto channel = pimpl.sockets.find(req.channel);
448 if (channel != pimpl.sockets.end()) {
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400449 channel->second->ready(false);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400450 channel->second->stop();
451 pimpl.sockets.erase(channel);
452 }
453 } else if (pimpl.onRequest_) {
454 pimpl.onRequest(req.name, req.channel);
455 }
456 }
457 } catch (const std::exception& e) {
458 if (pimpl.logger_)
459 pimpl.logger_->error("Error on the control channel: {}", e.what());
460 }
461 });
462}
463
464void
465MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
466{
467 std::lock_guard<std::mutex> lkSockets(socketsMutex);
468 auto sockIt = sockets.find(channel);
469 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
470 if (pkt.size() == 0) {
471 sockIt->second->stop();
472 if (sockIt->second->isAnswered())
473 sockets.erase(sockIt);
474 else
475 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
476 // removed later.
477 } else {
478 sockIt->second->onRecv(std::move(pkt));
479 }
480 } else if (pkt.size() != 0) {
481 if (logger_)
482 logger_->warn("Non existing channel: {}", channel);
483 }
484}
485
486bool
487MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
488{
489 try {
490 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
491 auto key = o.via.map.ptr[0].key.as<std::string_view>();
492 if (key == "p") {
493 auto msg = o.as<BeaconMsg>();
494 if (msg.p)
495 handleBeaconRequest();
496 else
497 handleBeaconResponse();
498 if (onBeaconCb_)
499 onBeaconCb_(msg.p);
500 return true;
501 } else if (key == "v") {
502 auto msg = o.as<VersionMsg>();
503 onVersion(msg.v);
504 if (onVersionCb_)
505 onVersionCb_(msg.v);
506 return true;
507 } else {
508 if (logger_)
509 logger_->warn("Unknown message type");
510 }
511 }
512 } catch (const std::exception& e) {
513 if (logger_)
514 logger_->error("Error on the protocol channel: {}", e.what());
515 }
516 return false;
517}
518
519void
520MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
521{
522 // Run this on dedicated thread because some callbacks can take time
523 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
524 auto shared = w.lock();
525 if (!shared)
526 return;
527 try {
528 size_t off = 0;
529 while (off != pkt.size()) {
530 msgpack::unpacked result;
531 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
532 auto object = result.get();
533 if (shared->pimpl_->handleProtocolMsg(object))
534 return;
535 }
536 } catch (const std::exception& e) {
537 if (shared->pimpl_->logger_)
538 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
539 }
540 });
541}
542
543MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
544 std::unique_ptr<TlsSocketEndpoint> endpoint)
545 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint)))
546{}
547
548MultiplexedSocket::~MultiplexedSocket() {}
549
550std::shared_ptr<ChannelSocket>
551MultiplexedSocket::addChannel(const std::string& name)
552{
553 // Note: because both sides can request the same channel number at the same time
554 // it's better to use a random channel number instead of just incrementing the request.
555 thread_local dht::crypto::random_device rd;
556 std::uniform_int_distribution<uint16_t> dist;
557 auto offset = dist(rd);
558 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
559 for (int i = 1; i < UINT16_MAX; ++i) {
560 auto c = (offset + i) % UINT16_MAX;
561 if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL
562 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
563 continue;
564 auto channel = pimpl_->makeSocket(name, c, true);
565 return channel;
566 }
567 return {};
568}
569
570DeviceId
571MultiplexedSocket::deviceId() const
572{
573 return pimpl_->deviceId;
574}
575
576void
577MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
578{
579 pimpl_->onChannelReady_ = std::move(cb);
580}
581
582void
583MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
584{
585 pimpl_->onRequest_ = std::move(cb);
586}
587
588bool
589MultiplexedSocket::isReliable() const
590{
591 return true;
592}
593
594bool
595MultiplexedSocket::isInitiator() const
596{
597 if (!pimpl_->endpoint) {
598 if (pimpl_->logger_)
599 pimpl_->logger_->warn("No endpoint found for socket");
600 return false;
601 }
602 return pimpl_->endpoint->isInitiator();
603}
604
605int
606MultiplexedSocket::maxPayload() const
607{
608 if (!pimpl_->endpoint) {
609 if (pimpl_->logger_)
610 pimpl_->logger_->warn("No endpoint found for socket");
611 return 0;
612 }
613 return pimpl_->endpoint->maxPayload();
614}
615
616std::size_t
617MultiplexedSocket::write(const uint16_t& channel,
618 const uint8_t* buf,
619 std::size_t len,
620 std::error_code& ec)
621{
622 assert(nullptr != buf);
623
624 if (pimpl_->isShutdown_) {
625 ec = std::make_error_code(std::errc::broken_pipe);
626 return -1;
627 }
628 if (len > UINT16_MAX) {
629 ec = std::make_error_code(std::errc::message_size);
630 return -1;
631 }
632 bool oneShot = len < 8192;
633 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
634 msgpack::packer<msgpack::sbuffer> pk(&buffer);
635 pk.pack_array(2);
636 pk.pack(channel);
637 pk.pack_bin(len);
638 if (oneShot)
639 pk.pack_bin_body((const char*) buf, len);
640
641 std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
642 if (!pimpl_->endpoint) {
643 if (pimpl_->logger_)
644 pimpl_->logger_->warn("No endpoint found for socket");
645 ec = std::make_error_code(std::errc::broken_pipe);
646 return -1;
647 }
648 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
649 if (not oneShot and res >= 0)
650 res = pimpl_->endpoint->write(buf, len, ec);
651 lk.unlock();
652 if (res < 0) {
653 if (ec && pimpl_->logger_)
654 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
655 shutdown();
656 }
657 return res;
658}
659
660void
661MultiplexedSocket::shutdown()
662{
663 pimpl_->shutdown();
664}
665
666void
667MultiplexedSocket::join()
668{
669 pimpl_->join();
670}
671
672void
673MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
674{
675 pimpl_->onShutdown_ = std::move(cb);
676 if (pimpl_->isShutdown_)
677 pimpl_->onShutdown_();
678}
679
680const std::shared_ptr<Logger>&
681MultiplexedSocket::logger()
682{
683 return pimpl_->logger_;
684}
685
686void
687MultiplexedSocket::monitor() const
688{
689 auto cert = peerCertificate();
690 if (!cert || !cert->issuer)
691 return;
692 auto now = clock::now();
693 if (!pimpl_->logger_)
694 return;
695 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
696 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
697 pimpl_->endpoint->monitor();
698 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
699 for (const auto& [_, channel] : pimpl_->sockets) {
700 if (channel)
701 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
702 fmt::ptr(channel.get()),
703 channel.use_count(),
704 channel->name(),
705 channel->isInitiator());
706 }
707}
708
709void
710MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
711{
712 pimpl_->sendBeacon(timeout);
713}
714
715std::shared_ptr<dht::crypto::Certificate>
716MultiplexedSocket::peerCertificate() const
717{
718 return pimpl_->endpoint->peerCertificate();
719}
720
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400721#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400722bool
723MultiplexedSocket::canSendBeacon() const
724{
725 return pimpl_->canSendBeacon_;
726}
727
728void
729MultiplexedSocket::answerToBeacon(bool value)
730{
731 pimpl_->answerBeacon_ = value;
732}
733
734void
735MultiplexedSocket::setVersion(int version)
736{
737 pimpl_->version_ = version;
738}
739
740void
741MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
742{
743 pimpl_->onBeaconCb_ = cb;
744}
745
746void
747MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
748{
749 pimpl_->onVersionCb_ = cb;
750}
751
752void
753MultiplexedSocket::sendVersion()
754{
755 pimpl_->sendVersion();
756}
757
Adrien Béraudac35e662023-07-19 09:37:29 -0400758#endif
759
Adrien Béraud612b55b2023-05-29 10:42:04 -0400760IpAddr
761MultiplexedSocket::getLocalAddress() const
762{
763 return pimpl_->endpoint->getLocalAddress();
764}
765
766IpAddr
767MultiplexedSocket::getRemoteAddress() const
768{
769 return pimpl_->endpoint->getRemoteAddress();
770}
771
Adrien Béraud612b55b2023-05-29 10:42:04 -0400772void
773MultiplexedSocket::eraseChannel(uint16_t channel)
774{
775 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
776 auto itSocket = pimpl_->sockets.find(channel);
777 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
778 pimpl_->sockets.erase(itSocket);
779}
780
781////////////////////////////////////////////////////////////////
782
783class ChannelSocket::Impl
784{
785public:
786 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
787 const std::string& name,
788 const uint16_t& channel,
789 bool isInitiator,
790 std::function<void()> rmFromMxSockCb)
791 : name(name)
792 , channel(channel)
793 , endpoint(std::move(endpoint))
794 , isInitiator_(isInitiator)
795 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
796 {}
797
798 ~Impl() {}
799
800 ChannelReadyCb readyCb_ {};
801 OnShutdownCb shutdownCb_ {};
802 std::atomic_bool isShutdown_ {false};
803 std::string name {};
804 uint16_t channel {};
805 std::weak_ptr<MultiplexedSocket> endpoint {};
806 bool isInitiator_ {false};
807 std::function<void()> rmFromMxSockCb_;
808
809 bool isAnswered_ {false};
810 bool isRemovable_ {false};
811
812 std::vector<uint8_t> buf {};
813 std::mutex mutex {};
814 std::condition_variable cv {};
815 GenericSocket<uint8_t>::RecvCb cb {};
816};
817
818ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
819 const DeviceId& deviceId,
820 const std::string& name,
821 const uint16_t& channel)
822 : pimpl_deviceId(deviceId)
823 , pimpl_name(name)
824 , pimpl_channel(channel)
825 , ioCtx_(*ctx)
826{}
827
828ChannelSocketTest::~ChannelSocketTest() {}
829
830void
831ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
832 const std::shared_ptr<ChannelSocketTest>& socket2)
833{
834 socket1->remote = socket2;
835 socket2->remote = socket1;
836}
837
838DeviceId
839ChannelSocketTest::deviceId() const
840{
841 return pimpl_deviceId;
842}
843
844std::string
845ChannelSocketTest::name() const
846{
847 return pimpl_name;
848}
849
850uint16_t
851ChannelSocketTest::channel() const
852{
853 return pimpl_channel;
854}
855
856void
857ChannelSocketTest::shutdown()
858{
859 {
860 std::unique_lock<std::mutex> lk {mutex};
861 if (!isShutdown_.exchange(true)) {
862 lk.unlock();
863 shutdownCb_();
864 }
865 cv.notify_all();
866 }
867
868 if (auto peer = remote.lock()) {
869 if (!peer->isShutdown_.exchange(true)) {
870 peer->shutdownCb_();
871 }
872 peer->cv.notify_all();
873 }
874}
875
876std::size_t
877ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
878{
879 std::size_t size = std::min(len, this->rx_buf.size());
880
881 for (std::size_t i = 0; i < size; ++i)
882 buf[i] = this->rx_buf[i];
883
884 if (size == this->rx_buf.size()) {
885 this->rx_buf.clear();
886 } else
887 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
888 return size;
889}
890
891std::size_t
892ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
893{
894 if (isShutdown_) {
895 ec = std::make_error_code(std::errc::broken_pipe);
896 return -1;
897 }
898 ec = {};
899 dht::ThreadPool::computation().run(
900 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
901 if (auto peer = r.lock())
902 peer->onRecv(std::move(data));
903 });
904 return len;
905}
906
907int
908ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
909{
910 std::unique_lock<std::mutex> lk {mutex};
911 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
912 return rx_buf.size();
913}
914
915void
916ChannelSocketTest::setOnRecv(RecvCb&& cb)
917{
918 std::lock_guard<std::mutex> lkSockets(mutex);
919 this->cb = std::move(cb);
920 if (!rx_buf.empty() && this->cb) {
921 this->cb(rx_buf.data(), rx_buf.size());
922 rx_buf.clear();
923 }
924}
925
926void
927ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
928{
929 std::lock_guard<std::mutex> lkSockets(mutex);
930 if (cb) {
931 cb(pkt.data(), pkt.size());
932 return;
933 }
934 rx_buf.insert(rx_buf.end(),
935 std::make_move_iterator(pkt.begin()),
936 std::make_move_iterator(pkt.end()));
937 cv.notify_all();
938}
939
940void
941ChannelSocketTest::onReady(ChannelReadyCb&& cb)
942{}
943
944void
945ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
946{
947 std::unique_lock<std::mutex> lk {mutex};
948 shutdownCb_ = std::move(cb);
949
950 if (isShutdown_) {
951 lk.unlock();
952 shutdownCb_();
953 }
954}
955
956ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
957 const std::string& name,
958 const uint16_t& channel,
959 bool isInitiator,
960 std::function<void()> rmFromMxSockCb)
961 : pimpl_ {
962 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
963{}
964
965ChannelSocket::~ChannelSocket() {}
966
967DeviceId
968ChannelSocket::deviceId() const
969{
970 if (auto ep = pimpl_->endpoint.lock()) {
971 return ep->deviceId();
972 }
973 return {};
974}
975
976std::string
977ChannelSocket::name() const
978{
979 return pimpl_->name;
980}
981
982uint16_t
983ChannelSocket::channel() const
984{
985 return pimpl_->channel;
986}
987
988bool
989ChannelSocket::isReliable() const
990{
991 if (auto ep = pimpl_->endpoint.lock()) {
992 return ep->isReliable();
993 }
994 return false;
995}
996
997bool
998ChannelSocket::isInitiator() const
999{
1000 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1001 // because a multiplexed socket can have sockets from accepted requests
1002 // or made via connectDevice(). Here, isInitiator_ return if the socket
1003 // is from connectDevice.
1004 return pimpl_->isInitiator_;
1005}
1006
1007int
1008ChannelSocket::maxPayload() const
1009{
1010 if (auto ep = pimpl_->endpoint.lock()) {
1011 return ep->maxPayload();
1012 }
1013 return -1;
1014}
1015
1016void
1017ChannelSocket::setOnRecv(RecvCb&& cb)
1018{
1019 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1020 pimpl_->cb = std::move(cb);
1021 if (!pimpl_->buf.empty() && pimpl_->cb) {
1022 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1023 pimpl_->buf.clear();
1024 }
1025}
1026
1027void
1028ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1029{
1030 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1031 if (pimpl_->cb) {
1032 pimpl_->cb(&pkt[0], pkt.size());
1033 return;
1034 }
1035 pimpl_->buf.insert(pimpl_->buf.end(),
1036 std::make_move_iterator(pkt.begin()),
1037 std::make_move_iterator(pkt.end()));
1038 pimpl_->cv.notify_all();
1039}
1040
Adrien Béraud6b6a5d32023-08-15 15:53:33 -04001041#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -04001042std::shared_ptr<MultiplexedSocket>
1043ChannelSocket::underlyingSocket() const
1044{
1045 if (auto mtx = pimpl_->endpoint.lock())
1046 return mtx;
1047 return {};
1048}
1049#endif
1050
1051void
1052ChannelSocket::answered()
1053{
1054 pimpl_->isAnswered_ = true;
1055}
1056
1057void
1058ChannelSocket::removable()
1059{
1060 pimpl_->isRemovable_ = true;
1061}
1062
1063bool
1064ChannelSocket::isRemovable() const
1065{
1066 return pimpl_->isRemovable_;
1067}
1068
1069bool
1070ChannelSocket::isAnswered() const
1071{
1072 return pimpl_->isAnswered_;
1073}
1074
1075void
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001076ChannelSocket::ready(bool accepted)
Adrien Béraud612b55b2023-05-29 10:42:04 -04001077{
1078 if (pimpl_->readyCb_)
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001079 pimpl_->readyCb_(accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001080}
1081
1082void
1083ChannelSocket::stop()
1084{
1085 if (pimpl_->isShutdown_)
1086 return;
1087 pimpl_->isShutdown_ = true;
1088 if (pimpl_->shutdownCb_)
1089 pimpl_->shutdownCb_();
1090 pimpl_->cv.notify_all();
1091 // stop() can be called by ChannelSocket::shutdown()
1092 // In this case, the eventLoop is not used, but MxSock
1093 // must remove the channel from its list (so that the
1094 // channel can be destroyed and its shared_ptr invalidated).
1095 if (pimpl_->rmFromMxSockCb_)
1096 pimpl_->rmFromMxSockCb_();
1097}
1098
1099void
1100ChannelSocket::shutdown()
1101{
1102 if (pimpl_->isShutdown_)
1103 return;
1104 stop();
1105 if (auto ep = pimpl_->endpoint.lock()) {
1106 std::error_code ec;
1107 const uint8_t dummy = '\0';
1108 ep->write(pimpl_->channel, &dummy, 0, ec);
1109 }
1110}
1111
1112std::size_t
1113ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1114{
1115 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1116 std::size_t size = std::min(len, pimpl_->buf.size());
1117
1118 for (std::size_t i = 0; i < size; ++i)
1119 outBuf[i] = pimpl_->buf[i];
1120
1121 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1122 return size;
1123}
1124
1125std::size_t
1126ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1127{
1128 if (pimpl_->isShutdown_) {
1129 ec = std::make_error_code(std::errc::broken_pipe);
1130 return -1;
1131 }
1132 if (auto ep = pimpl_->endpoint.lock()) {
1133 std::size_t sent = 0;
1134 do {
1135 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1136 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1137 if (ec) {
1138 if (ep->logger())
1139 ep->logger()->error("Error when writing on channel: {}", ec.message());
1140 return res;
1141 }
1142 sent += toSend;
1143 } while (sent < len);
1144 return sent;
1145 }
1146 ec = std::make_error_code(std::errc::broken_pipe);
1147 return -1;
1148}
1149
1150int
1151ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1152{
1153 std::unique_lock<std::mutex> lk {pimpl_->mutex};
1154 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1155 return pimpl_->buf.size();
1156}
1157
1158void
1159ChannelSocket::onShutdown(OnShutdownCb&& cb)
1160{
1161 pimpl_->shutdownCb_ = std::move(cb);
1162 if (pimpl_->isShutdown_) {
1163 pimpl_->shutdownCb_();
1164 }
1165}
1166
1167void
1168ChannelSocket::onReady(ChannelReadyCb&& cb)
1169{
1170 pimpl_->readyCb_ = std::move(cb);
1171}
1172
1173void
1174ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1175{
1176 if (auto ep = pimpl_->endpoint.lock()) {
1177 ep->sendBeacon(timeout);
1178 } else {
1179 shutdown();
1180 }
1181}
1182
1183std::shared_ptr<dht::crypto::Certificate>
1184ChannelSocket::peerCertificate() const
1185{
1186 if (auto ep = pimpl_->endpoint.lock())
1187 return ep->peerCertificate();
1188 return {};
1189}
1190
1191IpAddr
1192ChannelSocket::getLocalAddress() const
1193{
1194 if (auto ep = pimpl_->endpoint.lock())
1195 return ep->getLocalAddress();
1196 return {};
1197}
1198
1199IpAddr
1200ChannelSocket::getRemoteAddress() const
1201{
1202 if (auto ep = pimpl_->endpoint.lock())
1203 return ep->getRemoteAddress();
1204 return {};
1205}
1206
Amna31791e52023-08-03 12:40:57 -04001207std::vector<std::map<std::string, std::string>>
1208MultiplexedSocket::getChannelList() const
1209{
1210 std::vector<std::map<std::string, std::string>> channelsList;
1211
1212 for (const auto& [_, channel] : pimpl_->sockets) {
1213 if (channel) {
1214 std::map<std::string, std::string> channelMap;
1215 channelMap["channel"] = std::to_string(channel->channel());
1216 channelMap["channelName"]= channel->name();
1217 channelsList.emplace_back(std::move(channelMap));
1218 }
1219 }
1220
1221 return channelsList;
1222}
1223
Sébastien Blin464bdff2023-07-19 08:02:53 -04001224} // namespace dhtnet