blob: bf0948e6bcfc51f4ae624d49b452ecae97163de0 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
Adrien Béraudcb753622023-07-17 22:32:49 -04002 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
Adrien Béraud612b55b2023-05-29 10:42:04 -04003 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "multiplexed_socket.h"
18#include "peer_connection.h"
19#include "ice_transport.h"
20#include "certstore.h"
21
22#include <opendht/logger.h>
23#include <opendht/thread_pool.h>
24
25#include <asio/io_context.hpp>
26#include <asio/steady_timer.hpp>
27
28#include <deque>
29
30static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
31static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
32
33struct ChanneledMessage
34{
35 uint16_t channel;
36 std::vector<uint8_t> data;
37 MSGPACK_DEFINE(channel, data)
38};
39
40struct BeaconMsg
41{
42 bool p;
43 MSGPACK_DEFINE_MAP(p)
44};
45
46struct VersionMsg
47{
48 int v;
49 MSGPACK_DEFINE_MAP(v)
50};
51
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040052namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040053
54using clock = std::chrono::steady_clock;
55using time_point = clock::time_point;
56
57class MultiplexedSocket::Impl
58{
59public:
60 Impl(MultiplexedSocket& parent,
61 std::shared_ptr<asio::io_context> ctx,
62 const DeviceId& deviceId,
Adrien Béraud55133cc2023-10-15 11:55:20 -040063 std::unique_ptr<TlsSocketEndpoint> ep,
Adrien Béraud5636f7c2023-09-14 14:34:57 -040064 std::shared_ptr<dht::log::Logger> logger)
Adrien Béraud612b55b2023-05-29 10:42:04 -040065 : parent_(parent)
Adrien Béraud612b55b2023-05-29 10:42:04 -040066 , ctx_(std::move(ctx))
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040067 , deviceId(deviceId)
Adrien Béraud55133cc2023-10-15 11:55:20 -040068 , endpoint(std::move(ep))
69 , nextChannel_(endpoint->isInitiator() ? 0x0001u : 0x8000u)
Adrien Béraud612b55b2023-05-29 10:42:04 -040070 , eventLoopThread_ {[this] {
71 try {
72 eventLoop();
73 } catch (const std::exception& e) {
74 if (logger_)
75 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
76 shutdown();
77 }
78 }}
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040079 , beaconTimer_(*ctx_)
Adrien Béraud612b55b2023-05-29 10:42:04 -040080 {}
81
82 ~Impl() {}
83
84 void join()
85 {
86 if (!isShutdown_) {
87 if (endpoint)
88 endpoint->setOnStateChange({});
89 shutdown();
90 } else {
91 clearSockets();
92 }
93 if (eventLoopThread_.joinable())
94 eventLoopThread_.join();
95 }
96
97 void clearSockets()
98 {
99 decltype(sockets) socks;
100 {
101 std::lock_guard<std::mutex> lkSockets(socketsMutex);
102 socks = std::move(sockets);
103 }
104 for (auto& socket : socks) {
105 // Just trigger onShutdown() to make client know
106 // No need to write the EOF for the channel, the write will fail because endpoint is
107 // already shutdown
108 if (socket.second)
109 socket.second->stop();
110 }
111 }
112
113 void shutdown()
114 {
115 if (isShutdown_)
116 return;
117 stop.store(true);
118 isShutdown_ = true;
119 beaconTimer_.cancel();
120 if (onShutdown_)
121 onShutdown_();
122 if (endpoint) {
123 std::unique_lock<std::mutex> lk(writeMtx);
124 endpoint->shutdown();
125 }
126 clearSockets();
127 }
128
129 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
130 uint16_t channel,
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400131 bool isInitiator)
Adrien Béraud612b55b2023-05-29 10:42:04 -0400132 {
133 auto& channelSocket = sockets[channel];
134 if (not channelSocket)
135 channelSocket = std::make_shared<ChannelSocket>(
136 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
137 // Remove socket in another thread to avoid any lock
138 dht::ThreadPool::io().run([w, channel]() {
139 if (auto shared = w.lock()) {
140 shared->eraseChannel(channel);
141 }
142 });
143 });
144 else {
145 if (logger_)
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400146 logger_->warn("Received request for existing channel {}", channel);
147 return {};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400148 }
149 return channelSocket;
150 }
151
152 /**
153 * Handle packets on the TLS endpoint and parse RTP
154 */
155 void eventLoop();
156 /**
157 * Triggered when a new control packet is received
158 */
159 void handleControlPacket(std::vector<uint8_t>&& pkt);
160 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
161 bool handleProtocolMsg(const msgpack::object& o);
162 /**
163 * Triggered when a new packet on a channel is received
164 */
165 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
166 void onRequest(const std::string& name, uint16_t channel);
167 void onAccept(const std::string& name, uint16_t channel);
168
169 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
170 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
171
172 // Beacon
173 void sendBeacon(const std::chrono::milliseconds& timeout);
174 void handleBeaconRequest();
175 void handleBeaconResponse();
176 std::atomic_int beaconCounter_ {0};
177
178 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
179
180 msgpack::unpacker pac_ {};
181
182 MultiplexedSocket& parent_;
183
184 std::shared_ptr<Logger> logger_;
185 std::shared_ptr<asio::io_context> ctx_;
186
187 OnConnectionReadyCb onChannelReady_ {};
188 OnConnectionRequestCb onRequest_ {};
189 OnShutdownCb onShutdown_ {};
190
191 DeviceId deviceId {};
192 // Main socket
193 std::unique_ptr<TlsSocketEndpoint> endpoint {};
194
195 std::mutex socketsMutex {};
196 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
Adrien Béraud55133cc2023-10-15 11:55:20 -0400197 uint16_t nextChannel_;
Adrien Béraud612b55b2023-05-29 10:42:04 -0400198
199 // Main loop to parse incoming packets
200 std::atomic_bool stop {false};
201 std::thread eventLoopThread_ {};
202
203 std::atomic_bool isShutdown_ {false};
204
205 std::mutex writeMtx {};
206
207 time_point start_ {clock::now()};
208 //std::shared_ptr<Task> beaconTask_ {};
209 asio::steady_timer beaconTimer_;
210
211 // version related stuff
212 void sendVersion();
213 void onVersion(int version);
214 std::atomic_bool canSendBeacon_ {false};
215 std::atomic_bool answerBeacon_ {true};
216 int version_ {MULTIPLEXED_SOCKET_VERSION};
217 std::function<void(bool)> onBeaconCb_ {};
218 std::function<void(int)> onVersionCb_ {};
219};
220
221void
222MultiplexedSocket::Impl::eventLoop()
223{
224 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
225 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
226 if (logger_)
227 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
228 shutdown();
229 return false;
230 }
231 return true;
232 });
233 sendVersion();
234 std::error_code ec;
235 while (!stop) {
236 if (!endpoint) {
237 shutdown();
238 return;
239 }
240 pac_.reserve_buffer(IO_BUFFER_SIZE);
241 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
242 if (size < 0) {
243 if (ec && logger_)
244 logger_->error("Read error detected: {}", ec.message());
245 break;
246 }
247 if (size == 0) {
248 // We can close the socket
249 shutdown();
250 break;
251 }
252
253 pac_.buffer_consumed(size);
254 msgpack::object_handle oh;
255 while (pac_.next(oh) && !stop) {
256 try {
257 auto msg = oh.get().as<ChanneledMessage>();
258 if (msg.channel == CONTROL_CHANNEL)
259 handleControlPacket(std::move(msg.data));
260 else if (msg.channel == PROTOCOL_CHANNEL)
261 handleProtocolPacket(std::move(msg.data));
262 else
263 handleChannelPacket(msg.channel, std::move(msg.data));
264 } catch (const std::exception& e) {
265 if (logger_)
266 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
267 } catch (...) {
268 if (logger_)
269 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
270 }
271 }
272 }
273}
274
275void
276MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
277{
278 std::lock_guard<std::mutex> lkSockets(socketsMutex);
279 auto& socket = sockets[channel];
280 if (!socket) {
281 if (logger_)
282 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
283 return;
284 }
285
286 onChannelReady_(deviceId, socket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400287 socket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400288 // Due to the callbacks that can take some time, onAccept can arrive after
289 // receiving all the data. In this case, the socket should be removed here
290 // as handle by onChannelReady_
291 if (socket->isRemovable())
292 sockets.erase(channel);
293 else
294 socket->answered();
295}
296
297void
298MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
299{
300 if (!canSendBeacon_)
301 return;
302 beaconCounter_++;
303 if (logger_)
304 logger_->debug("Send beacon to peer {}", deviceId);
305
306 msgpack::sbuffer buffer(8);
307 msgpack::packer<msgpack::sbuffer> pk(&buffer);
308 pk.pack(BeaconMsg {true});
309 if (!writeProtocolMessage(buffer))
310 return;
311 beaconTimer_.expires_after(timeout);
312 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
313 if (ec == asio::error::operation_aborted)
314 return;
315 if (auto shared = w.lock()) {
316 if (shared->pimpl_->beaconCounter_ != 0) {
317 if (shared->pimpl_->logger_)
318 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
319 shared->shutdown();
320 }
321 }
322 });
323}
324
325void
326MultiplexedSocket::Impl::handleBeaconRequest()
327{
328 if (!answerBeacon_)
329 return;
330 // Run this on dedicated thread because some callbacks can take time
331 dht::ThreadPool::io().run([w = parent_.weak()]() {
332 if (auto shared = w.lock()) {
333 msgpack::sbuffer buffer(8);
334 msgpack::packer<msgpack::sbuffer> pk(&buffer);
335 pk.pack(BeaconMsg {false});
336 if (shared->pimpl_->logger_)
337 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
338 shared->pimpl_->writeProtocolMessage(buffer);
339 }
340 });
341}
342
343void
344MultiplexedSocket::Impl::handleBeaconResponse()
345{
346 if (logger_)
347 logger_->debug("Get beacon response from peer {}", deviceId);
348 beaconCounter_--;
349}
350
351bool
352MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
353{
354 std::error_code ec;
355 int wr = parent_.write(PROTOCOL_CHANNEL,
356 (const unsigned char*) buffer.data(),
357 buffer.size(),
358 ec);
359 return wr > 0;
360}
361
362void
363MultiplexedSocket::Impl::sendVersion()
364{
365 dht::ThreadPool::io().run([w = parent_.weak()]() {
366 if (auto shared = w.lock()) {
367 auto version = shared->pimpl_->version_;
368 msgpack::sbuffer buffer(8);
369 msgpack::packer<msgpack::sbuffer> pk(&buffer);
370 pk.pack(VersionMsg {version});
371 shared->pimpl_->writeProtocolMessage(buffer);
372 }
373 });
374}
375
376void
377MultiplexedSocket::Impl::onVersion(int version)
378{
379 // Check if version > 1
380 if (version >= 1) {
381 if (logger_)
382 logger_->debug("Peer {} supports beacon", deviceId);
383 canSendBeacon_ = true;
384 } else {
385 if (logger_)
386 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
387 deviceId,
388 version);
389 canSendBeacon_ = false;
390 }
391}
392
393void
394MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
395{
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400396 bool accept;
397 if (channel == CONTROL_CHANNEL || channel == PROTOCOL_CHANNEL) {
398 if (logger_)
399 logger_->warn("Channel {:d} is reserved, refusing request", channel);
400 accept = false;
401 } else
402 accept = onRequest_(endpoint->peerCertificate(), channel, name);
403
Adrien Béraud612b55b2023-05-29 10:42:04 -0400404 std::shared_ptr<ChannelSocket> channelSocket;
405 if (accept) {
406 std::lock_guard<std::mutex> lkSockets(socketsMutex);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400407 channelSocket = makeSocket(name, channel, false);
408 if (not channelSocket) {
409 if (logger_)
410 logger_->error("Channel {:d} already exists, refusing request", channel);
411 accept = false;
412 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400413 }
414
415 // Answer to ChannelRequest if accepted
416 ChannelRequest val;
417 val.channel = channel;
418 val.name = name;
419 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
420 msgpack::sbuffer buffer(512);
421 msgpack::pack(buffer, val);
422 std::error_code ec;
423 int wr = parent_.write(CONTROL_CHANNEL,
424 reinterpret_cast<const uint8_t*>(buffer.data()),
425 buffer.size(),
426 ec);
427 if (wr < 0) {
428 if (ec && logger_)
429 logger_->error("The write operation failed with error: {:s}", ec.message());
430 stop.store(true);
431 return;
432 }
433
434 if (accept) {
435 onChannelReady_(deviceId, channelSocket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400436 channelSocket->ready(true);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400437 if (channelSocket->isRemovable()) {
438 std::lock_guard<std::mutex> lkSockets(socketsMutex);
439 sockets.erase(channel);
440 } else
441 channelSocket->answered();
Adrien Béraud612b55b2023-05-29 10:42:04 -0400442 }
443}
444
445void
446MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
447{
Adrien Béraudcd7a7bd2023-10-15 12:54:30 -0400448 try {
449 size_t off = 0;
450 while (off != pkt.size()) {
451 msgpack::unpacked result;
452 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
453 auto object = result.get();
454 if (handleProtocolMsg(object))
455 continue;
456 auto req = object.as<ChannelRequest>();
457 if (req.state == ChannelRequestState::REQUEST) {
458 dht::ThreadPool::io().run([w = parent_.weak(), req = std::move(req)]() {
459 if (auto shared = w.lock())
460 shared->pimpl_->onRequest(req.name, req.channel);
461 });
462 }
463 else if (req.state == ChannelRequestState::ACCEPT) {
464 onAccept(req.name, req.channel);
465 } else {
466 // DECLINE or unknown
467 std::lock_guard<std::mutex> lkSockets(socketsMutex);
468 auto channel = sockets.find(req.channel);
469 if (channel != sockets.end()) {
470 channel->second->ready(false);
471 channel->second->stop();
472 sockets.erase(channel);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400473 }
474 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400475 }
Adrien Béraudcd7a7bd2023-10-15 12:54:30 -0400476 } catch (const std::exception& e) {
477 if (logger_)
478 logger_->error("Error on the control channel: {}", e.what());
479 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400480}
481
482void
483MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
484{
485 std::lock_guard<std::mutex> lkSockets(socketsMutex);
486 auto sockIt = sockets.find(channel);
487 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
488 if (pkt.size() == 0) {
489 sockIt->second->stop();
490 if (sockIt->second->isAnswered())
491 sockets.erase(sockIt);
492 else
493 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
494 // removed later.
495 } else {
496 sockIt->second->onRecv(std::move(pkt));
497 }
498 } else if (pkt.size() != 0) {
499 if (logger_)
500 logger_->warn("Non existing channel: {}", channel);
501 }
502}
503
504bool
505MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
506{
507 try {
508 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
509 auto key = o.via.map.ptr[0].key.as<std::string_view>();
510 if (key == "p") {
511 auto msg = o.as<BeaconMsg>();
512 if (msg.p)
513 handleBeaconRequest();
514 else
515 handleBeaconResponse();
516 if (onBeaconCb_)
517 onBeaconCb_(msg.p);
518 return true;
519 } else if (key == "v") {
520 auto msg = o.as<VersionMsg>();
521 onVersion(msg.v);
522 if (onVersionCb_)
523 onVersionCb_(msg.v);
524 return true;
525 } else {
526 if (logger_)
527 logger_->warn("Unknown message type");
528 }
529 }
530 } catch (const std::exception& e) {
531 if (logger_)
532 logger_->error("Error on the protocol channel: {}", e.what());
533 }
534 return false;
535}
536
537void
538MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
539{
540 // Run this on dedicated thread because some callbacks can take time
541 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
542 auto shared = w.lock();
543 if (!shared)
544 return;
545 try {
546 size_t off = 0;
547 while (off != pkt.size()) {
548 msgpack::unpacked result;
549 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
550 auto object = result.get();
551 if (shared->pimpl_->handleProtocolMsg(object))
552 return;
553 }
554 } catch (const std::exception& e) {
555 if (shared->pimpl_->logger_)
556 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
557 }
558 });
559}
560
561MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
Adrien Béraud5636f7c2023-09-14 14:34:57 -0400562 std::unique_ptr<TlsSocketEndpoint> endpoint, std::shared_ptr<dht::log::Logger> logger)
563 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint), logger))
Adrien Béraud612b55b2023-05-29 10:42:04 -0400564{}
565
566MultiplexedSocket::~MultiplexedSocket() {}
567
568std::shared_ptr<ChannelSocket>
569MultiplexedSocket::addChannel(const std::string& name)
570{
Adrien Béraud612b55b2023-05-29 10:42:04 -0400571 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
Adrien Béraud55133cc2023-10-15 11:55:20 -0400572 if (pimpl_->sockets.size() < UINT16_MAX)
573 for (unsigned i = 0; i < UINT16_MAX; ++i) {
574 auto c = pimpl_->nextChannel_++;
575 if (c == CONTROL_CHANNEL
576 || c == PROTOCOL_CHANNEL
577 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
578 continue;
579 return pimpl_->makeSocket(name, c, true);
580 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400581 return {};
582}
583
584DeviceId
585MultiplexedSocket::deviceId() const
586{
587 return pimpl_->deviceId;
588}
589
590void
591MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
592{
593 pimpl_->onChannelReady_ = std::move(cb);
594}
595
596void
597MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
598{
599 pimpl_->onRequest_ = std::move(cb);
600}
601
602bool
603MultiplexedSocket::isReliable() const
604{
605 return true;
606}
607
608bool
609MultiplexedSocket::isInitiator() const
610{
611 if (!pimpl_->endpoint) {
612 if (pimpl_->logger_)
613 pimpl_->logger_->warn("No endpoint found for socket");
614 return false;
615 }
616 return pimpl_->endpoint->isInitiator();
617}
618
619int
620MultiplexedSocket::maxPayload() const
621{
622 if (!pimpl_->endpoint) {
623 if (pimpl_->logger_)
624 pimpl_->logger_->warn("No endpoint found for socket");
625 return 0;
626 }
627 return pimpl_->endpoint->maxPayload();
628}
629
630std::size_t
631MultiplexedSocket::write(const uint16_t& channel,
632 const uint8_t* buf,
633 std::size_t len,
634 std::error_code& ec)
635{
636 assert(nullptr != buf);
637
638 if (pimpl_->isShutdown_) {
639 ec = std::make_error_code(std::errc::broken_pipe);
640 return -1;
641 }
642 if (len > UINT16_MAX) {
643 ec = std::make_error_code(std::errc::message_size);
644 return -1;
645 }
646 bool oneShot = len < 8192;
647 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
648 msgpack::packer<msgpack::sbuffer> pk(&buffer);
649 pk.pack_array(2);
650 pk.pack(channel);
651 pk.pack_bin(len);
652 if (oneShot)
653 pk.pack_bin_body((const char*) buf, len);
654
655 std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
656 if (!pimpl_->endpoint) {
657 if (pimpl_->logger_)
658 pimpl_->logger_->warn("No endpoint found for socket");
659 ec = std::make_error_code(std::errc::broken_pipe);
660 return -1;
661 }
662 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
663 if (not oneShot and res >= 0)
664 res = pimpl_->endpoint->write(buf, len, ec);
665 lk.unlock();
666 if (res < 0) {
667 if (ec && pimpl_->logger_)
668 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
669 shutdown();
670 }
671 return res;
672}
673
674void
675MultiplexedSocket::shutdown()
676{
677 pimpl_->shutdown();
678}
679
680void
681MultiplexedSocket::join()
682{
683 pimpl_->join();
684}
685
686void
687MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
688{
689 pimpl_->onShutdown_ = std::move(cb);
690 if (pimpl_->isShutdown_)
691 pimpl_->onShutdown_();
692}
693
694const std::shared_ptr<Logger>&
695MultiplexedSocket::logger()
696{
697 return pimpl_->logger_;
698}
699
700void
701MultiplexedSocket::monitor() const
702{
703 auto cert = peerCertificate();
704 if (!cert || !cert->issuer)
705 return;
706 auto now = clock::now();
707 if (!pimpl_->logger_)
708 return;
709 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
710 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
711 pimpl_->endpoint->monitor();
712 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
713 for (const auto& [_, channel] : pimpl_->sockets) {
714 if (channel)
715 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
716 fmt::ptr(channel.get()),
717 channel.use_count(),
718 channel->name(),
719 channel->isInitiator());
720 }
721}
722
723void
724MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
725{
726 pimpl_->sendBeacon(timeout);
727}
728
729std::shared_ptr<dht::crypto::Certificate>
730MultiplexedSocket::peerCertificate() const
731{
732 return pimpl_->endpoint->peerCertificate();
733}
734
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400735#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400736bool
737MultiplexedSocket::canSendBeacon() const
738{
739 return pimpl_->canSendBeacon_;
740}
741
742void
743MultiplexedSocket::answerToBeacon(bool value)
744{
745 pimpl_->answerBeacon_ = value;
746}
747
748void
749MultiplexedSocket::setVersion(int version)
750{
751 pimpl_->version_ = version;
752}
753
754void
755MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
756{
757 pimpl_->onBeaconCb_ = cb;
758}
759
760void
761MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
762{
763 pimpl_->onVersionCb_ = cb;
764}
765
766void
767MultiplexedSocket::sendVersion()
768{
769 pimpl_->sendVersion();
770}
771
Adrien Béraudac35e662023-07-19 09:37:29 -0400772#endif
773
Adrien Béraud612b55b2023-05-29 10:42:04 -0400774IpAddr
775MultiplexedSocket::getLocalAddress() const
776{
777 return pimpl_->endpoint->getLocalAddress();
778}
779
780IpAddr
781MultiplexedSocket::getRemoteAddress() const
782{
783 return pimpl_->endpoint->getRemoteAddress();
784}
785
Adrien Béraudafa8e282023-09-24 12:53:20 -0400786TlsSocketEndpoint*
787MultiplexedSocket::endpoint()
788{
789 return pimpl_->endpoint.get();
790}
791
Adrien Béraud612b55b2023-05-29 10:42:04 -0400792void
793MultiplexedSocket::eraseChannel(uint16_t channel)
794{
795 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
796 auto itSocket = pimpl_->sockets.find(channel);
797 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
798 pimpl_->sockets.erase(itSocket);
799}
800
801////////////////////////////////////////////////////////////////
802
803class ChannelSocket::Impl
804{
805public:
806 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
807 const std::string& name,
808 const uint16_t& channel,
809 bool isInitiator,
810 std::function<void()> rmFromMxSockCb)
811 : name(name)
812 , channel(channel)
813 , endpoint(std::move(endpoint))
814 , isInitiator_(isInitiator)
815 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
816 {}
817
818 ~Impl() {}
819
820 ChannelReadyCb readyCb_ {};
821 OnShutdownCb shutdownCb_ {};
822 std::atomic_bool isShutdown_ {false};
Adrien Béraud5fd233f2023-10-15 12:35:48 -0400823 const std::string name {};
824 const uint16_t channel {};
825 const std::weak_ptr<MultiplexedSocket> endpoint {};
826 const bool isInitiator_ {false};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400827 std::function<void()> rmFromMxSockCb_;
828
829 bool isAnswered_ {false};
830 bool isRemovable_ {false};
831
832 std::vector<uint8_t> buf {};
833 std::mutex mutex {};
834 std::condition_variable cv {};
835 GenericSocket<uint8_t>::RecvCb cb {};
836};
837
838ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
839 const DeviceId& deviceId,
840 const std::string& name,
841 const uint16_t& channel)
842 : pimpl_deviceId(deviceId)
843 , pimpl_name(name)
844 , pimpl_channel(channel)
845 , ioCtx_(*ctx)
846{}
847
848ChannelSocketTest::~ChannelSocketTest() {}
849
850void
851ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
852 const std::shared_ptr<ChannelSocketTest>& socket2)
853{
854 socket1->remote = socket2;
855 socket2->remote = socket1;
856}
857
858DeviceId
859ChannelSocketTest::deviceId() const
860{
861 return pimpl_deviceId;
862}
863
864std::string
865ChannelSocketTest::name() const
866{
867 return pimpl_name;
868}
869
870uint16_t
871ChannelSocketTest::channel() const
872{
873 return pimpl_channel;
874}
875
876void
877ChannelSocketTest::shutdown()
878{
879 {
880 std::unique_lock<std::mutex> lk {mutex};
881 if (!isShutdown_.exchange(true)) {
882 lk.unlock();
883 shutdownCb_();
884 }
885 cv.notify_all();
886 }
887
888 if (auto peer = remote.lock()) {
889 if (!peer->isShutdown_.exchange(true)) {
890 peer->shutdownCb_();
891 }
892 peer->cv.notify_all();
893 }
894}
895
896std::size_t
897ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
898{
899 std::size_t size = std::min(len, this->rx_buf.size());
900
901 for (std::size_t i = 0; i < size; ++i)
902 buf[i] = this->rx_buf[i];
903
904 if (size == this->rx_buf.size()) {
905 this->rx_buf.clear();
906 } else
907 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
908 return size;
909}
910
911std::size_t
912ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
913{
914 if (isShutdown_) {
915 ec = std::make_error_code(std::errc::broken_pipe);
916 return -1;
917 }
918 ec = {};
919 dht::ThreadPool::computation().run(
920 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
921 if (auto peer = r.lock())
922 peer->onRecv(std::move(data));
923 });
924 return len;
925}
926
927int
928ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
929{
930 std::unique_lock<std::mutex> lk {mutex};
931 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
932 return rx_buf.size();
933}
934
935void
936ChannelSocketTest::setOnRecv(RecvCb&& cb)
937{
938 std::lock_guard<std::mutex> lkSockets(mutex);
939 this->cb = std::move(cb);
940 if (!rx_buf.empty() && this->cb) {
941 this->cb(rx_buf.data(), rx_buf.size());
942 rx_buf.clear();
943 }
944}
945
946void
947ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
948{
949 std::lock_guard<std::mutex> lkSockets(mutex);
950 if (cb) {
951 cb(pkt.data(), pkt.size());
952 return;
953 }
954 rx_buf.insert(rx_buf.end(),
955 std::make_move_iterator(pkt.begin()),
956 std::make_move_iterator(pkt.end()));
957 cv.notify_all();
958}
959
960void
961ChannelSocketTest::onReady(ChannelReadyCb&& cb)
962{}
963
964void
965ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
966{
967 std::unique_lock<std::mutex> lk {mutex};
968 shutdownCb_ = std::move(cb);
969
970 if (isShutdown_) {
971 lk.unlock();
972 shutdownCb_();
973 }
974}
975
976ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
977 const std::string& name,
978 const uint16_t& channel,
979 bool isInitiator,
980 std::function<void()> rmFromMxSockCb)
981 : pimpl_ {
982 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
983{}
984
985ChannelSocket::~ChannelSocket() {}
986
987DeviceId
988ChannelSocket::deviceId() const
989{
990 if (auto ep = pimpl_->endpoint.lock()) {
991 return ep->deviceId();
992 }
993 return {};
994}
995
996std::string
997ChannelSocket::name() const
998{
999 return pimpl_->name;
1000}
1001
1002uint16_t
1003ChannelSocket::channel() const
1004{
1005 return pimpl_->channel;
1006}
1007
1008bool
1009ChannelSocket::isReliable() const
1010{
1011 if (auto ep = pimpl_->endpoint.lock()) {
1012 return ep->isReliable();
1013 }
1014 return false;
1015}
1016
1017bool
1018ChannelSocket::isInitiator() const
1019{
1020 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1021 // because a multiplexed socket can have sockets from accepted requests
1022 // or made via connectDevice(). Here, isInitiator_ return if the socket
1023 // is from connectDevice.
1024 return pimpl_->isInitiator_;
1025}
1026
1027int
1028ChannelSocket::maxPayload() const
1029{
1030 if (auto ep = pimpl_->endpoint.lock()) {
1031 return ep->maxPayload();
1032 }
1033 return -1;
1034}
1035
1036void
1037ChannelSocket::setOnRecv(RecvCb&& cb)
1038{
1039 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1040 pimpl_->cb = std::move(cb);
1041 if (!pimpl_->buf.empty() && pimpl_->cb) {
1042 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1043 pimpl_->buf.clear();
1044 }
1045}
1046
1047void
1048ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1049{
1050 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1051 if (pimpl_->cb) {
1052 pimpl_->cb(&pkt[0], pkt.size());
1053 return;
1054 }
1055 pimpl_->buf.insert(pimpl_->buf.end(),
1056 std::make_move_iterator(pkt.begin()),
1057 std::make_move_iterator(pkt.end()));
1058 pimpl_->cv.notify_all();
1059}
1060
Adrien Béraud6b6a5d32023-08-15 15:53:33 -04001061#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -04001062std::shared_ptr<MultiplexedSocket>
1063ChannelSocket::underlyingSocket() const
1064{
1065 if (auto mtx = pimpl_->endpoint.lock())
1066 return mtx;
1067 return {};
1068}
1069#endif
1070
1071void
1072ChannelSocket::answered()
1073{
1074 pimpl_->isAnswered_ = true;
1075}
1076
1077void
1078ChannelSocket::removable()
1079{
1080 pimpl_->isRemovable_ = true;
1081}
1082
1083bool
1084ChannelSocket::isRemovable() const
1085{
1086 return pimpl_->isRemovable_;
1087}
1088
1089bool
1090ChannelSocket::isAnswered() const
1091{
1092 return pimpl_->isAnswered_;
1093}
1094
1095void
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001096ChannelSocket::ready(bool accepted)
Adrien Béraud612b55b2023-05-29 10:42:04 -04001097{
1098 if (pimpl_->readyCb_)
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001099 pimpl_->readyCb_(accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001100}
1101
1102void
1103ChannelSocket::stop()
1104{
1105 if (pimpl_->isShutdown_)
1106 return;
1107 pimpl_->isShutdown_ = true;
1108 if (pimpl_->shutdownCb_)
1109 pimpl_->shutdownCb_();
1110 pimpl_->cv.notify_all();
1111 // stop() can be called by ChannelSocket::shutdown()
1112 // In this case, the eventLoop is not used, but MxSock
1113 // must remove the channel from its list (so that the
1114 // channel can be destroyed and its shared_ptr invalidated).
1115 if (pimpl_->rmFromMxSockCb_)
1116 pimpl_->rmFromMxSockCb_();
1117}
1118
1119void
1120ChannelSocket::shutdown()
1121{
1122 if (pimpl_->isShutdown_)
1123 return;
1124 stop();
1125 if (auto ep = pimpl_->endpoint.lock()) {
1126 std::error_code ec;
1127 const uint8_t dummy = '\0';
1128 ep->write(pimpl_->channel, &dummy, 0, ec);
1129 }
1130}
1131
1132std::size_t
1133ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1134{
1135 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1136 std::size_t size = std::min(len, pimpl_->buf.size());
1137
1138 for (std::size_t i = 0; i < size; ++i)
1139 outBuf[i] = pimpl_->buf[i];
1140
1141 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1142 return size;
1143}
1144
1145std::size_t
1146ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1147{
1148 if (pimpl_->isShutdown_) {
1149 ec = std::make_error_code(std::errc::broken_pipe);
1150 return -1;
1151 }
1152 if (auto ep = pimpl_->endpoint.lock()) {
1153 std::size_t sent = 0;
1154 do {
1155 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1156 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1157 if (ec) {
1158 if (ep->logger())
1159 ep->logger()->error("Error when writing on channel: {}", ec.message());
1160 return res;
1161 }
1162 sent += toSend;
1163 } while (sent < len);
1164 return sent;
1165 }
1166 ec = std::make_error_code(std::errc::broken_pipe);
1167 return -1;
1168}
1169
1170int
1171ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1172{
1173 std::unique_lock<std::mutex> lk {pimpl_->mutex};
1174 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1175 return pimpl_->buf.size();
1176}
1177
1178void
1179ChannelSocket::onShutdown(OnShutdownCb&& cb)
1180{
1181 pimpl_->shutdownCb_ = std::move(cb);
1182 if (pimpl_->isShutdown_) {
1183 pimpl_->shutdownCb_();
1184 }
1185}
1186
1187void
1188ChannelSocket::onReady(ChannelReadyCb&& cb)
1189{
1190 pimpl_->readyCb_ = std::move(cb);
1191}
1192
1193void
1194ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1195{
1196 if (auto ep = pimpl_->endpoint.lock()) {
1197 ep->sendBeacon(timeout);
1198 } else {
1199 shutdown();
1200 }
1201}
1202
1203std::shared_ptr<dht::crypto::Certificate>
1204ChannelSocket::peerCertificate() const
1205{
1206 if (auto ep = pimpl_->endpoint.lock())
1207 return ep->peerCertificate();
1208 return {};
1209}
1210
1211IpAddr
1212ChannelSocket::getLocalAddress() const
1213{
1214 if (auto ep = pimpl_->endpoint.lock())
1215 return ep->getLocalAddress();
1216 return {};
1217}
1218
1219IpAddr
1220ChannelSocket::getRemoteAddress() const
1221{
1222 if (auto ep = pimpl_->endpoint.lock())
1223 return ep->getRemoteAddress();
1224 return {};
1225}
1226
Amna31791e52023-08-03 12:40:57 -04001227std::vector<std::map<std::string, std::string>>
1228MultiplexedSocket::getChannelList() const
1229{
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001230 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
Amna31791e52023-08-03 12:40:57 -04001231 std::vector<std::map<std::string, std::string>> channelsList;
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001232 channelsList.reserve(pimpl_->sockets.size());
Amna31791e52023-08-03 12:40:57 -04001233 for (const auto& [_, channel] : pimpl_->sockets) {
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001234 channelsList.emplace_back(std::map<std::string, std::string> {
1235 {"id", fmt::format("{:x}", channel->channel())},
1236 {"name", channel->name()},
1237 });
Amna31791e52023-08-03 12:40:57 -04001238 }
Amna31791e52023-08-03 12:40:57 -04001239 return channelsList;
1240}
1241
Sébastien Blin464bdff2023-07-19 08:02:53 -04001242} // namespace dhtnet