blob: db4b0cba4d369ed2ed9c501bb273dc09010a98e7 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
Adrien Béraudcb753622023-07-17 22:32:49 -04002 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
Adrien Béraud612b55b2023-05-29 10:42:04 -04003 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "multiplexed_socket.h"
18#include "peer_connection.h"
19#include "ice_transport.h"
20#include "certstore.h"
21
22#include <opendht/logger.h>
23#include <opendht/thread_pool.h>
24
25#include <asio/io_context.hpp>
26#include <asio/steady_timer.hpp>
27
28#include <deque>
29
30static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
31static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
32
33struct ChanneledMessage
34{
35 uint16_t channel;
36 std::vector<uint8_t> data;
37 MSGPACK_DEFINE(channel, data)
38};
39
40struct BeaconMsg
41{
42 bool p;
43 MSGPACK_DEFINE_MAP(p)
44};
45
46struct VersionMsg
47{
48 int v;
49 MSGPACK_DEFINE_MAP(v)
50};
51
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040052namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040053
54using clock = std::chrono::steady_clock;
55using time_point = clock::time_point;
56
57class MultiplexedSocket::Impl
58{
59public:
60 Impl(MultiplexedSocket& parent,
61 std::shared_ptr<asio::io_context> ctx,
62 const DeviceId& deviceId,
Adrien Béraud55133cc2023-10-15 11:55:20 -040063 std::unique_ptr<TlsSocketEndpoint> ep,
Adrien Béraud5636f7c2023-09-14 14:34:57 -040064 std::shared_ptr<dht::log::Logger> logger)
Adrien Béraud612b55b2023-05-29 10:42:04 -040065 : parent_(parent)
Adrien Béraud612b55b2023-05-29 10:42:04 -040066 , ctx_(std::move(ctx))
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040067 , deviceId(deviceId)
Adrien Béraud55133cc2023-10-15 11:55:20 -040068 , endpoint(std::move(ep))
69 , nextChannel_(endpoint->isInitiator() ? 0x0001u : 0x8000u)
Adrien Béraud612b55b2023-05-29 10:42:04 -040070 , eventLoopThread_ {[this] {
71 try {
72 eventLoop();
73 } catch (const std::exception& e) {
74 if (logger_)
75 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
76 shutdown();
77 }
78 }}
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040079 , beaconTimer_(*ctx_)
Adrien Béraud612b55b2023-05-29 10:42:04 -040080 {}
81
82 ~Impl() {}
83
84 void join()
85 {
86 if (!isShutdown_) {
87 if (endpoint)
88 endpoint->setOnStateChange({});
89 shutdown();
90 } else {
91 clearSockets();
92 }
93 if (eventLoopThread_.joinable())
94 eventLoopThread_.join();
95 }
96
97 void clearSockets()
98 {
99 decltype(sockets) socks;
100 {
101 std::lock_guard<std::mutex> lkSockets(socketsMutex);
102 socks = std::move(sockets);
103 }
104 for (auto& socket : socks) {
105 // Just trigger onShutdown() to make client know
106 // No need to write the EOF for the channel, the write will fail because endpoint is
107 // already shutdown
108 if (socket.second)
109 socket.second->stop();
110 }
111 }
112
113 void shutdown()
114 {
115 if (isShutdown_)
116 return;
117 stop.store(true);
118 isShutdown_ = true;
119 beaconTimer_.cancel();
120 if (onShutdown_)
121 onShutdown_();
122 if (endpoint) {
123 std::unique_lock<std::mutex> lk(writeMtx);
124 endpoint->shutdown();
125 }
126 clearSockets();
127 }
128
129 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
130 uint16_t channel,
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400131 bool isInitiator)
Adrien Béraud612b55b2023-05-29 10:42:04 -0400132 {
133 auto& channelSocket = sockets[channel];
134 if (not channelSocket)
135 channelSocket = std::make_shared<ChannelSocket>(
136 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
137 // Remove socket in another thread to avoid any lock
138 dht::ThreadPool::io().run([w, channel]() {
139 if (auto shared = w.lock()) {
140 shared->eraseChannel(channel);
141 }
142 });
143 });
144 else {
145 if (logger_)
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400146 logger_->warn("Received request for existing channel {}", channel);
147 return {};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400148 }
149 return channelSocket;
150 }
151
152 /**
153 * Handle packets on the TLS endpoint and parse RTP
154 */
155 void eventLoop();
156 /**
157 * Triggered when a new control packet is received
158 */
159 void handleControlPacket(std::vector<uint8_t>&& pkt);
160 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
161 bool handleProtocolMsg(const msgpack::object& o);
162 /**
163 * Triggered when a new packet on a channel is received
164 */
165 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
166 void onRequest(const std::string& name, uint16_t channel);
167 void onAccept(const std::string& name, uint16_t channel);
168
169 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
170 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
171
172 // Beacon
173 void sendBeacon(const std::chrono::milliseconds& timeout);
174 void handleBeaconRequest();
175 void handleBeaconResponse();
176 std::atomic_int beaconCounter_ {0};
177
178 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
179
180 msgpack::unpacker pac_ {};
181
182 MultiplexedSocket& parent_;
183
184 std::shared_ptr<Logger> logger_;
185 std::shared_ptr<asio::io_context> ctx_;
186
187 OnConnectionReadyCb onChannelReady_ {};
188 OnConnectionRequestCb onRequest_ {};
189 OnShutdownCb onShutdown_ {};
190
191 DeviceId deviceId {};
192 // Main socket
193 std::unique_ptr<TlsSocketEndpoint> endpoint {};
194
195 std::mutex socketsMutex {};
196 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
Adrien Béraud55133cc2023-10-15 11:55:20 -0400197 uint16_t nextChannel_;
Adrien Béraud612b55b2023-05-29 10:42:04 -0400198
199 // Main loop to parse incoming packets
200 std::atomic_bool stop {false};
201 std::thread eventLoopThread_ {};
202
203 std::atomic_bool isShutdown_ {false};
204
205 std::mutex writeMtx {};
206
207 time_point start_ {clock::now()};
208 //std::shared_ptr<Task> beaconTask_ {};
209 asio::steady_timer beaconTimer_;
210
211 // version related stuff
212 void sendVersion();
213 void onVersion(int version);
214 std::atomic_bool canSendBeacon_ {false};
215 std::atomic_bool answerBeacon_ {true};
216 int version_ {MULTIPLEXED_SOCKET_VERSION};
217 std::function<void(bool)> onBeaconCb_ {};
218 std::function<void(int)> onVersionCb_ {};
219};
220
221void
222MultiplexedSocket::Impl::eventLoop()
223{
224 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
225 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
226 if (logger_)
227 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
228 shutdown();
229 return false;
230 }
231 return true;
232 });
233 sendVersion();
234 std::error_code ec;
235 while (!stop) {
236 if (!endpoint) {
237 shutdown();
238 return;
239 }
240 pac_.reserve_buffer(IO_BUFFER_SIZE);
241 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
242 if (size < 0) {
243 if (ec && logger_)
244 logger_->error("Read error detected: {}", ec.message());
245 break;
246 }
247 if (size == 0) {
248 // We can close the socket
249 shutdown();
250 break;
251 }
252
253 pac_.buffer_consumed(size);
254 msgpack::object_handle oh;
255 while (pac_.next(oh) && !stop) {
256 try {
257 auto msg = oh.get().as<ChanneledMessage>();
258 if (msg.channel == CONTROL_CHANNEL)
259 handleControlPacket(std::move(msg.data));
260 else if (msg.channel == PROTOCOL_CHANNEL)
261 handleProtocolPacket(std::move(msg.data));
262 else
263 handleChannelPacket(msg.channel, std::move(msg.data));
264 } catch (const std::exception& e) {
265 if (logger_)
266 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
267 } catch (...) {
268 if (logger_)
269 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
270 }
271 }
272 }
273}
274
275void
276MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
277{
278 std::lock_guard<std::mutex> lkSockets(socketsMutex);
279 auto& socket = sockets[channel];
280 if (!socket) {
281 if (logger_)
282 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
283 return;
284 }
285
286 onChannelReady_(deviceId, socket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400287 socket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400288 // Due to the callbacks that can take some time, onAccept can arrive after
289 // receiving all the data. In this case, the socket should be removed here
290 // as handle by onChannelReady_
291 if (socket->isRemovable())
292 sockets.erase(channel);
293 else
294 socket->answered();
295}
296
297void
298MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
299{
300 if (!canSendBeacon_)
301 return;
302 beaconCounter_++;
303 if (logger_)
304 logger_->debug("Send beacon to peer {}", deviceId);
305
306 msgpack::sbuffer buffer(8);
307 msgpack::packer<msgpack::sbuffer> pk(&buffer);
308 pk.pack(BeaconMsg {true});
309 if (!writeProtocolMessage(buffer))
310 return;
311 beaconTimer_.expires_after(timeout);
312 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
313 if (ec == asio::error::operation_aborted)
314 return;
315 if (auto shared = w.lock()) {
316 if (shared->pimpl_->beaconCounter_ != 0) {
317 if (shared->pimpl_->logger_)
318 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
319 shared->shutdown();
320 }
321 }
322 });
323}
324
325void
326MultiplexedSocket::Impl::handleBeaconRequest()
327{
328 if (!answerBeacon_)
329 return;
330 // Run this on dedicated thread because some callbacks can take time
331 dht::ThreadPool::io().run([w = parent_.weak()]() {
332 if (auto shared = w.lock()) {
333 msgpack::sbuffer buffer(8);
334 msgpack::packer<msgpack::sbuffer> pk(&buffer);
335 pk.pack(BeaconMsg {false});
336 if (shared->pimpl_->logger_)
337 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
338 shared->pimpl_->writeProtocolMessage(buffer);
339 }
340 });
341}
342
343void
344MultiplexedSocket::Impl::handleBeaconResponse()
345{
346 if (logger_)
347 logger_->debug("Get beacon response from peer {}", deviceId);
348 beaconCounter_--;
349}
350
351bool
352MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
353{
354 std::error_code ec;
355 int wr = parent_.write(PROTOCOL_CHANNEL,
356 (const unsigned char*) buffer.data(),
357 buffer.size(),
358 ec);
359 return wr > 0;
360}
361
362void
363MultiplexedSocket::Impl::sendVersion()
364{
365 dht::ThreadPool::io().run([w = parent_.weak()]() {
366 if (auto shared = w.lock()) {
367 auto version = shared->pimpl_->version_;
368 msgpack::sbuffer buffer(8);
369 msgpack::packer<msgpack::sbuffer> pk(&buffer);
370 pk.pack(VersionMsg {version});
371 shared->pimpl_->writeProtocolMessage(buffer);
372 }
373 });
374}
375
376void
377MultiplexedSocket::Impl::onVersion(int version)
378{
379 // Check if version > 1
380 if (version >= 1) {
381 if (logger_)
382 logger_->debug("Peer {} supports beacon", deviceId);
383 canSendBeacon_ = true;
384 } else {
385 if (logger_)
386 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
387 deviceId,
388 version);
389 canSendBeacon_ = false;
390 }
391}
392
393void
394MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
395{
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400396 bool accept;
397 if (channel == CONTROL_CHANNEL || channel == PROTOCOL_CHANNEL) {
398 if (logger_)
399 logger_->warn("Channel {:d} is reserved, refusing request", channel);
400 accept = false;
401 } else
402 accept = onRequest_(endpoint->peerCertificate(), channel, name);
403
Adrien Béraud612b55b2023-05-29 10:42:04 -0400404 std::shared_ptr<ChannelSocket> channelSocket;
405 if (accept) {
406 std::lock_guard<std::mutex> lkSockets(socketsMutex);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400407 channelSocket = makeSocket(name, channel, false);
408 if (not channelSocket) {
409 if (logger_)
410 logger_->error("Channel {:d} already exists, refusing request", channel);
411 accept = false;
412 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400413 }
414
415 // Answer to ChannelRequest if accepted
416 ChannelRequest val;
417 val.channel = channel;
418 val.name = name;
419 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
420 msgpack::sbuffer buffer(512);
421 msgpack::pack(buffer, val);
422 std::error_code ec;
423 int wr = parent_.write(CONTROL_CHANNEL,
424 reinterpret_cast<const uint8_t*>(buffer.data()),
425 buffer.size(),
426 ec);
427 if (wr < 0) {
428 if (ec && logger_)
429 logger_->error("The write operation failed with error: {:s}", ec.message());
430 stop.store(true);
431 return;
432 }
433
434 if (accept) {
435 onChannelReady_(deviceId, channelSocket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400436 channelSocket->ready(true);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400437 if (channelSocket->isRemovable()) {
438 std::lock_guard<std::mutex> lkSockets(socketsMutex);
439 sockets.erase(channel);
440 } else
441 channelSocket->answered();
Adrien Béraud612b55b2023-05-29 10:42:04 -0400442 }
443}
444
445void
446MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
447{
448 // Run this on dedicated thread because some callbacks can take time
449 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
450 auto shared = w.lock();
451 if (!shared)
452 return;
453 auto& pimpl = *shared->pimpl_;
454 try {
455 size_t off = 0;
456 while (off != pkt.size()) {
457 msgpack::unpacked result;
458 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
459 auto object = result.get();
460 if (pimpl.handleProtocolMsg(object))
461 continue;
462 auto req = object.as<ChannelRequest>();
Adrien Béraud40371792023-10-17 14:39:16 -0400463 if (req.state == ChannelRequestState::REQUEST) {
464 pimpl.onRequest(req.name, req.channel);
465 }
466 else if (req.state == ChannelRequestState::ACCEPT) {
Adrien Béraud612b55b2023-05-29 10:42:04 -0400467 pimpl.onAccept(req.name, req.channel);
Adrien Béraud40371792023-10-17 14:39:16 -0400468 } else {
469 // DECLINE or unknown
Adrien Béraud612b55b2023-05-29 10:42:04 -0400470 std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
471 auto channel = pimpl.sockets.find(req.channel);
472 if (channel != pimpl.sockets.end()) {
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400473 channel->second->ready(false);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400474 channel->second->stop();
475 pimpl.sockets.erase(channel);
476 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400477 }
478 }
479 } catch (const std::exception& e) {
480 if (pimpl.logger_)
481 pimpl.logger_->error("Error on the control channel: {}", e.what());
482 }
483 });
484}
485
486void
487MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
488{
489 std::lock_guard<std::mutex> lkSockets(socketsMutex);
490 auto sockIt = sockets.find(channel);
491 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
492 if (pkt.size() == 0) {
493 sockIt->second->stop();
494 if (sockIt->second->isAnswered())
495 sockets.erase(sockIt);
496 else
497 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
498 // removed later.
499 } else {
500 sockIt->second->onRecv(std::move(pkt));
501 }
502 } else if (pkt.size() != 0) {
503 if (logger_)
504 logger_->warn("Non existing channel: {}", channel);
505 }
506}
507
508bool
509MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
510{
511 try {
512 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
513 auto key = o.via.map.ptr[0].key.as<std::string_view>();
514 if (key == "p") {
515 auto msg = o.as<BeaconMsg>();
516 if (msg.p)
517 handleBeaconRequest();
518 else
519 handleBeaconResponse();
520 if (onBeaconCb_)
521 onBeaconCb_(msg.p);
522 return true;
523 } else if (key == "v") {
524 auto msg = o.as<VersionMsg>();
525 onVersion(msg.v);
526 if (onVersionCb_)
527 onVersionCb_(msg.v);
528 return true;
529 } else {
530 if (logger_)
531 logger_->warn("Unknown message type");
532 }
533 }
534 } catch (const std::exception& e) {
535 if (logger_)
536 logger_->error("Error on the protocol channel: {}", e.what());
537 }
538 return false;
539}
540
541void
542MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
543{
544 // Run this on dedicated thread because some callbacks can take time
545 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
546 auto shared = w.lock();
547 if (!shared)
548 return;
549 try {
550 size_t off = 0;
551 while (off != pkt.size()) {
552 msgpack::unpacked result;
553 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
554 auto object = result.get();
555 if (shared->pimpl_->handleProtocolMsg(object))
556 return;
557 }
558 } catch (const std::exception& e) {
559 if (shared->pimpl_->logger_)
560 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
561 }
562 });
563}
564
565MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
Adrien Béraud5636f7c2023-09-14 14:34:57 -0400566 std::unique_ptr<TlsSocketEndpoint> endpoint, std::shared_ptr<dht::log::Logger> logger)
567 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint), logger))
Adrien Béraud612b55b2023-05-29 10:42:04 -0400568{}
569
570MultiplexedSocket::~MultiplexedSocket() {}
571
572std::shared_ptr<ChannelSocket>
573MultiplexedSocket::addChannel(const std::string& name)
574{
Adrien Béraud612b55b2023-05-29 10:42:04 -0400575 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
Adrien Béraud55133cc2023-10-15 11:55:20 -0400576 if (pimpl_->sockets.size() < UINT16_MAX)
577 for (unsigned i = 0; i < UINT16_MAX; ++i) {
578 auto c = pimpl_->nextChannel_++;
579 if (c == CONTROL_CHANNEL
580 || c == PROTOCOL_CHANNEL
581 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
582 continue;
583 return pimpl_->makeSocket(name, c, true);
584 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400585 return {};
586}
587
588DeviceId
589MultiplexedSocket::deviceId() const
590{
591 return pimpl_->deviceId;
592}
593
594void
595MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
596{
597 pimpl_->onChannelReady_ = std::move(cb);
598}
599
600void
601MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
602{
603 pimpl_->onRequest_ = std::move(cb);
604}
605
606bool
607MultiplexedSocket::isReliable() const
608{
609 return true;
610}
611
612bool
613MultiplexedSocket::isInitiator() const
614{
615 if (!pimpl_->endpoint) {
616 if (pimpl_->logger_)
617 pimpl_->logger_->warn("No endpoint found for socket");
618 return false;
619 }
620 return pimpl_->endpoint->isInitiator();
621}
622
623int
624MultiplexedSocket::maxPayload() const
625{
626 if (!pimpl_->endpoint) {
627 if (pimpl_->logger_)
628 pimpl_->logger_->warn("No endpoint found for socket");
629 return 0;
630 }
631 return pimpl_->endpoint->maxPayload();
632}
633
634std::size_t
635MultiplexedSocket::write(const uint16_t& channel,
636 const uint8_t* buf,
637 std::size_t len,
638 std::error_code& ec)
639{
640 assert(nullptr != buf);
641
642 if (pimpl_->isShutdown_) {
643 ec = std::make_error_code(std::errc::broken_pipe);
644 return -1;
645 }
646 if (len > UINT16_MAX) {
647 ec = std::make_error_code(std::errc::message_size);
648 return -1;
649 }
650 bool oneShot = len < 8192;
651 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
652 msgpack::packer<msgpack::sbuffer> pk(&buffer);
653 pk.pack_array(2);
654 pk.pack(channel);
655 pk.pack_bin(len);
656 if (oneShot)
657 pk.pack_bin_body((const char*) buf, len);
658
659 std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
660 if (!pimpl_->endpoint) {
661 if (pimpl_->logger_)
662 pimpl_->logger_->warn("No endpoint found for socket");
663 ec = std::make_error_code(std::errc::broken_pipe);
664 return -1;
665 }
666 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
667 if (not oneShot and res >= 0)
668 res = pimpl_->endpoint->write(buf, len, ec);
669 lk.unlock();
670 if (res < 0) {
671 if (ec && pimpl_->logger_)
672 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
673 shutdown();
674 }
675 return res;
676}
677
678void
679MultiplexedSocket::shutdown()
680{
681 pimpl_->shutdown();
682}
683
684void
685MultiplexedSocket::join()
686{
687 pimpl_->join();
688}
689
690void
691MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
692{
693 pimpl_->onShutdown_ = std::move(cb);
694 if (pimpl_->isShutdown_)
695 pimpl_->onShutdown_();
696}
697
698const std::shared_ptr<Logger>&
699MultiplexedSocket::logger()
700{
701 return pimpl_->logger_;
702}
703
704void
705MultiplexedSocket::monitor() const
706{
707 auto cert = peerCertificate();
708 if (!cert || !cert->issuer)
709 return;
710 auto now = clock::now();
711 if (!pimpl_->logger_)
712 return;
713 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
714 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
715 pimpl_->endpoint->monitor();
716 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
717 for (const auto& [_, channel] : pimpl_->sockets) {
718 if (channel)
719 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
720 fmt::ptr(channel.get()),
721 channel.use_count(),
722 channel->name(),
723 channel->isInitiator());
724 }
725}
726
727void
728MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
729{
730 pimpl_->sendBeacon(timeout);
731}
732
733std::shared_ptr<dht::crypto::Certificate>
734MultiplexedSocket::peerCertificate() const
735{
736 return pimpl_->endpoint->peerCertificate();
737}
738
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400739#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400740bool
741MultiplexedSocket::canSendBeacon() const
742{
743 return pimpl_->canSendBeacon_;
744}
745
746void
747MultiplexedSocket::answerToBeacon(bool value)
748{
749 pimpl_->answerBeacon_ = value;
750}
751
752void
753MultiplexedSocket::setVersion(int version)
754{
755 pimpl_->version_ = version;
756}
757
758void
759MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
760{
761 pimpl_->onBeaconCb_ = cb;
762}
763
764void
765MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
766{
767 pimpl_->onVersionCb_ = cb;
768}
769
770void
771MultiplexedSocket::sendVersion()
772{
773 pimpl_->sendVersion();
774}
775
Adrien Béraudac35e662023-07-19 09:37:29 -0400776#endif
777
Adrien Béraud612b55b2023-05-29 10:42:04 -0400778IpAddr
779MultiplexedSocket::getLocalAddress() const
780{
781 return pimpl_->endpoint->getLocalAddress();
782}
783
784IpAddr
785MultiplexedSocket::getRemoteAddress() const
786{
787 return pimpl_->endpoint->getRemoteAddress();
788}
789
Adrien Béraudafa8e282023-09-24 12:53:20 -0400790TlsSocketEndpoint*
791MultiplexedSocket::endpoint()
792{
793 return pimpl_->endpoint.get();
794}
795
Adrien Béraud612b55b2023-05-29 10:42:04 -0400796void
797MultiplexedSocket::eraseChannel(uint16_t channel)
798{
799 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
800 auto itSocket = pimpl_->sockets.find(channel);
801 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
802 pimpl_->sockets.erase(itSocket);
803}
804
805////////////////////////////////////////////////////////////////
806
807class ChannelSocket::Impl
808{
809public:
810 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
811 const std::string& name,
812 const uint16_t& channel,
813 bool isInitiator,
814 std::function<void()> rmFromMxSockCb)
815 : name(name)
816 , channel(channel)
817 , endpoint(std::move(endpoint))
818 , isInitiator_(isInitiator)
819 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
820 {}
821
822 ~Impl() {}
823
824 ChannelReadyCb readyCb_ {};
825 OnShutdownCb shutdownCb_ {};
826 std::atomic_bool isShutdown_ {false};
827 std::string name {};
828 uint16_t channel {};
829 std::weak_ptr<MultiplexedSocket> endpoint {};
830 bool isInitiator_ {false};
831 std::function<void()> rmFromMxSockCb_;
832
833 bool isAnswered_ {false};
834 bool isRemovable_ {false};
835
836 std::vector<uint8_t> buf {};
837 std::mutex mutex {};
838 std::condition_variable cv {};
839 GenericSocket<uint8_t>::RecvCb cb {};
840};
841
842ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
843 const DeviceId& deviceId,
844 const std::string& name,
845 const uint16_t& channel)
846 : pimpl_deviceId(deviceId)
847 , pimpl_name(name)
848 , pimpl_channel(channel)
849 , ioCtx_(*ctx)
850{}
851
852ChannelSocketTest::~ChannelSocketTest() {}
853
854void
855ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
856 const std::shared_ptr<ChannelSocketTest>& socket2)
857{
858 socket1->remote = socket2;
859 socket2->remote = socket1;
860}
861
862DeviceId
863ChannelSocketTest::deviceId() const
864{
865 return pimpl_deviceId;
866}
867
868std::string
869ChannelSocketTest::name() const
870{
871 return pimpl_name;
872}
873
874uint16_t
875ChannelSocketTest::channel() const
876{
877 return pimpl_channel;
878}
879
880void
881ChannelSocketTest::shutdown()
882{
883 {
884 std::unique_lock<std::mutex> lk {mutex};
885 if (!isShutdown_.exchange(true)) {
886 lk.unlock();
887 shutdownCb_();
888 }
889 cv.notify_all();
890 }
891
892 if (auto peer = remote.lock()) {
893 if (!peer->isShutdown_.exchange(true)) {
894 peer->shutdownCb_();
895 }
896 peer->cv.notify_all();
897 }
898}
899
900std::size_t
901ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
902{
903 std::size_t size = std::min(len, this->rx_buf.size());
904
905 for (std::size_t i = 0; i < size; ++i)
906 buf[i] = this->rx_buf[i];
907
908 if (size == this->rx_buf.size()) {
909 this->rx_buf.clear();
910 } else
911 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
912 return size;
913}
914
915std::size_t
916ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
917{
918 if (isShutdown_) {
919 ec = std::make_error_code(std::errc::broken_pipe);
920 return -1;
921 }
922 ec = {};
923 dht::ThreadPool::computation().run(
924 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
925 if (auto peer = r.lock())
926 peer->onRecv(std::move(data));
927 });
928 return len;
929}
930
931int
932ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
933{
934 std::unique_lock<std::mutex> lk {mutex};
935 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
936 return rx_buf.size();
937}
938
939void
940ChannelSocketTest::setOnRecv(RecvCb&& cb)
941{
942 std::lock_guard<std::mutex> lkSockets(mutex);
943 this->cb = std::move(cb);
944 if (!rx_buf.empty() && this->cb) {
945 this->cb(rx_buf.data(), rx_buf.size());
946 rx_buf.clear();
947 }
948}
949
950void
951ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
952{
953 std::lock_guard<std::mutex> lkSockets(mutex);
954 if (cb) {
955 cb(pkt.data(), pkt.size());
956 return;
957 }
958 rx_buf.insert(rx_buf.end(),
959 std::make_move_iterator(pkt.begin()),
960 std::make_move_iterator(pkt.end()));
961 cv.notify_all();
962}
963
964void
965ChannelSocketTest::onReady(ChannelReadyCb&& cb)
966{}
967
968void
969ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
970{
971 std::unique_lock<std::mutex> lk {mutex};
972 shutdownCb_ = std::move(cb);
973
974 if (isShutdown_) {
975 lk.unlock();
976 shutdownCb_();
977 }
978}
979
980ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
981 const std::string& name,
982 const uint16_t& channel,
983 bool isInitiator,
984 std::function<void()> rmFromMxSockCb)
985 : pimpl_ {
986 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
987{}
988
989ChannelSocket::~ChannelSocket() {}
990
991DeviceId
992ChannelSocket::deviceId() const
993{
994 if (auto ep = pimpl_->endpoint.lock()) {
995 return ep->deviceId();
996 }
997 return {};
998}
999
1000std::string
1001ChannelSocket::name() const
1002{
1003 return pimpl_->name;
1004}
1005
1006uint16_t
1007ChannelSocket::channel() const
1008{
1009 return pimpl_->channel;
1010}
1011
1012bool
1013ChannelSocket::isReliable() const
1014{
1015 if (auto ep = pimpl_->endpoint.lock()) {
1016 return ep->isReliable();
1017 }
1018 return false;
1019}
1020
1021bool
1022ChannelSocket::isInitiator() const
1023{
1024 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1025 // because a multiplexed socket can have sockets from accepted requests
1026 // or made via connectDevice(). Here, isInitiator_ return if the socket
1027 // is from connectDevice.
1028 return pimpl_->isInitiator_;
1029}
1030
1031int
1032ChannelSocket::maxPayload() const
1033{
1034 if (auto ep = pimpl_->endpoint.lock()) {
1035 return ep->maxPayload();
1036 }
1037 return -1;
1038}
1039
1040void
1041ChannelSocket::setOnRecv(RecvCb&& cb)
1042{
1043 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1044 pimpl_->cb = std::move(cb);
1045 if (!pimpl_->buf.empty() && pimpl_->cb) {
1046 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1047 pimpl_->buf.clear();
1048 }
1049}
1050
1051void
1052ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1053{
1054 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1055 if (pimpl_->cb) {
1056 pimpl_->cb(&pkt[0], pkt.size());
1057 return;
1058 }
1059 pimpl_->buf.insert(pimpl_->buf.end(),
1060 std::make_move_iterator(pkt.begin()),
1061 std::make_move_iterator(pkt.end()));
1062 pimpl_->cv.notify_all();
1063}
1064
Adrien Béraud6b6a5d32023-08-15 15:53:33 -04001065#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -04001066std::shared_ptr<MultiplexedSocket>
1067ChannelSocket::underlyingSocket() const
1068{
1069 if (auto mtx = pimpl_->endpoint.lock())
1070 return mtx;
1071 return {};
1072}
1073#endif
1074
1075void
1076ChannelSocket::answered()
1077{
1078 pimpl_->isAnswered_ = true;
1079}
1080
1081void
1082ChannelSocket::removable()
1083{
1084 pimpl_->isRemovable_ = true;
1085}
1086
1087bool
1088ChannelSocket::isRemovable() const
1089{
1090 return pimpl_->isRemovable_;
1091}
1092
1093bool
1094ChannelSocket::isAnswered() const
1095{
1096 return pimpl_->isAnswered_;
1097}
1098
1099void
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001100ChannelSocket::ready(bool accepted)
Adrien Béraud612b55b2023-05-29 10:42:04 -04001101{
1102 if (pimpl_->readyCb_)
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001103 pimpl_->readyCb_(accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001104}
1105
1106void
1107ChannelSocket::stop()
1108{
1109 if (pimpl_->isShutdown_)
1110 return;
1111 pimpl_->isShutdown_ = true;
1112 if (pimpl_->shutdownCb_)
1113 pimpl_->shutdownCb_();
1114 pimpl_->cv.notify_all();
1115 // stop() can be called by ChannelSocket::shutdown()
1116 // In this case, the eventLoop is not used, but MxSock
1117 // must remove the channel from its list (so that the
1118 // channel can be destroyed and its shared_ptr invalidated).
1119 if (pimpl_->rmFromMxSockCb_)
1120 pimpl_->rmFromMxSockCb_();
1121}
1122
1123void
1124ChannelSocket::shutdown()
1125{
1126 if (pimpl_->isShutdown_)
1127 return;
1128 stop();
1129 if (auto ep = pimpl_->endpoint.lock()) {
1130 std::error_code ec;
1131 const uint8_t dummy = '\0';
1132 ep->write(pimpl_->channel, &dummy, 0, ec);
1133 }
1134}
1135
1136std::size_t
1137ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1138{
1139 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1140 std::size_t size = std::min(len, pimpl_->buf.size());
1141
1142 for (std::size_t i = 0; i < size; ++i)
1143 outBuf[i] = pimpl_->buf[i];
1144
1145 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1146 return size;
1147}
1148
1149std::size_t
1150ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1151{
1152 if (pimpl_->isShutdown_) {
1153 ec = std::make_error_code(std::errc::broken_pipe);
1154 return -1;
1155 }
1156 if (auto ep = pimpl_->endpoint.lock()) {
1157 std::size_t sent = 0;
1158 do {
1159 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1160 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1161 if (ec) {
1162 if (ep->logger())
1163 ep->logger()->error("Error when writing on channel: {}", ec.message());
1164 return res;
1165 }
1166 sent += toSend;
1167 } while (sent < len);
1168 return sent;
1169 }
1170 ec = std::make_error_code(std::errc::broken_pipe);
1171 return -1;
1172}
1173
1174int
1175ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1176{
1177 std::unique_lock<std::mutex> lk {pimpl_->mutex};
1178 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1179 return pimpl_->buf.size();
1180}
1181
1182void
1183ChannelSocket::onShutdown(OnShutdownCb&& cb)
1184{
1185 pimpl_->shutdownCb_ = std::move(cb);
1186 if (pimpl_->isShutdown_) {
1187 pimpl_->shutdownCb_();
1188 }
1189}
1190
1191void
1192ChannelSocket::onReady(ChannelReadyCb&& cb)
1193{
1194 pimpl_->readyCb_ = std::move(cb);
1195}
1196
1197void
1198ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1199{
1200 if (auto ep = pimpl_->endpoint.lock()) {
1201 ep->sendBeacon(timeout);
1202 } else {
1203 shutdown();
1204 }
1205}
1206
1207std::shared_ptr<dht::crypto::Certificate>
1208ChannelSocket::peerCertificate() const
1209{
1210 if (auto ep = pimpl_->endpoint.lock())
1211 return ep->peerCertificate();
1212 return {};
1213}
1214
1215IpAddr
1216ChannelSocket::getLocalAddress() const
1217{
1218 if (auto ep = pimpl_->endpoint.lock())
1219 return ep->getLocalAddress();
1220 return {};
1221}
1222
1223IpAddr
1224ChannelSocket::getRemoteAddress() const
1225{
1226 if (auto ep = pimpl_->endpoint.lock())
1227 return ep->getRemoteAddress();
1228 return {};
1229}
1230
Amna31791e52023-08-03 12:40:57 -04001231std::vector<std::map<std::string, std::string>>
1232MultiplexedSocket::getChannelList() const
1233{
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001234 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
Amna31791e52023-08-03 12:40:57 -04001235 std::vector<std::map<std::string, std::string>> channelsList;
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001236 channelsList.reserve(pimpl_->sockets.size());
Amna31791e52023-08-03 12:40:57 -04001237 for (const auto& [_, channel] : pimpl_->sockets) {
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001238 channelsList.emplace_back(std::map<std::string, std::string> {
1239 {"id", fmt::format("{:x}", channel->channel())},
1240 {"name", channel->name()},
1241 });
Amna31791e52023-08-03 12:40:57 -04001242 }
Amna31791e52023-08-03 12:40:57 -04001243 return channelsList;
1244}
1245
Sébastien Blin464bdff2023-07-19 08:02:53 -04001246} // namespace dhtnet