blob: 24712560993184591b3d0ad39d37efc73b98cafc [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
Adrien Béraudcb753622023-07-17 22:32:49 -04002 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
Adrien Béraud612b55b2023-05-29 10:42:04 -04003 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "multiplexed_socket.h"
18#include "peer_connection.h"
19#include "ice_transport.h"
20#include "certstore.h"
21
22#include <opendht/logger.h>
23#include <opendht/thread_pool.h>
24
25#include <asio/io_context.hpp>
26#include <asio/steady_timer.hpp>
27
28#include <deque>
29
30static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
31static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
32
33struct ChanneledMessage
34{
35 uint16_t channel;
36 std::vector<uint8_t> data;
37 MSGPACK_DEFINE(channel, data)
38};
39
40struct BeaconMsg
41{
42 bool p;
43 MSGPACK_DEFINE_MAP(p)
44};
45
46struct VersionMsg
47{
48 int v;
49 MSGPACK_DEFINE_MAP(v)
50};
51
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040052namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040053
54using clock = std::chrono::steady_clock;
55using time_point = clock::time_point;
56
57class MultiplexedSocket::Impl
58{
59public:
60 Impl(MultiplexedSocket& parent,
61 std::shared_ptr<asio::io_context> ctx,
62 const DeviceId& deviceId,
Adrien Béraud55133cc2023-10-15 11:55:20 -040063 std::unique_ptr<TlsSocketEndpoint> ep,
Adrien Béraud5636f7c2023-09-14 14:34:57 -040064 std::shared_ptr<dht::log::Logger> logger)
Adrien Béraud612b55b2023-05-29 10:42:04 -040065 : parent_(parent)
Adrien Béraud612b55b2023-05-29 10:42:04 -040066 , ctx_(std::move(ctx))
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040067 , deviceId(deviceId)
Adrien Béraud55133cc2023-10-15 11:55:20 -040068 , endpoint(std::move(ep))
69 , nextChannel_(endpoint->isInitiator() ? 0x0001u : 0x8000u)
Adrien Béraud612b55b2023-05-29 10:42:04 -040070 , eventLoopThread_ {[this] {
71 try {
72 eventLoop();
73 } catch (const std::exception& e) {
74 if (logger_)
75 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
76 shutdown();
77 }
78 }}
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040079 , beaconTimer_(*ctx_)
Adrien Béraud612b55b2023-05-29 10:42:04 -040080 {}
81
82 ~Impl() {}
83
84 void join()
85 {
86 if (!isShutdown_) {
87 if (endpoint)
88 endpoint->setOnStateChange({});
89 shutdown();
90 } else {
91 clearSockets();
92 }
93 if (eventLoopThread_.joinable())
94 eventLoopThread_.join();
95 }
96
97 void clearSockets()
98 {
99 decltype(sockets) socks;
100 {
101 std::lock_guard<std::mutex> lkSockets(socketsMutex);
102 socks = std::move(sockets);
103 }
104 for (auto& socket : socks) {
105 // Just trigger onShutdown() to make client know
106 // No need to write the EOF for the channel, the write will fail because endpoint is
107 // already shutdown
108 if (socket.second)
109 socket.second->stop();
110 }
111 }
112
113 void shutdown()
114 {
115 if (isShutdown_)
116 return;
117 stop.store(true);
118 isShutdown_ = true;
119 beaconTimer_.cancel();
120 if (onShutdown_)
121 onShutdown_();
122 if (endpoint) {
123 std::unique_lock<std::mutex> lk(writeMtx);
124 endpoint->shutdown();
125 }
126 clearSockets();
127 }
128
129 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
130 uint16_t channel,
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400131 bool isInitiator)
Adrien Béraud612b55b2023-05-29 10:42:04 -0400132 {
133 auto& channelSocket = sockets[channel];
134 if (not channelSocket)
135 channelSocket = std::make_shared<ChannelSocket>(
136 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
137 // Remove socket in another thread to avoid any lock
138 dht::ThreadPool::io().run([w, channel]() {
139 if (auto shared = w.lock()) {
140 shared->eraseChannel(channel);
141 }
142 });
143 });
144 else {
145 if (logger_)
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400146 logger_->warn("Received request for existing channel {}", channel);
147 return {};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400148 }
149 return channelSocket;
150 }
151
152 /**
153 * Handle packets on the TLS endpoint and parse RTP
154 */
155 void eventLoop();
156 /**
157 * Triggered when a new control packet is received
158 */
159 void handleControlPacket(std::vector<uint8_t>&& pkt);
160 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
161 bool handleProtocolMsg(const msgpack::object& o);
162 /**
163 * Triggered when a new packet on a channel is received
164 */
165 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
166 void onRequest(const std::string& name, uint16_t channel);
167 void onAccept(const std::string& name, uint16_t channel);
168
169 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
170 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
171
172 // Beacon
173 void sendBeacon(const std::chrono::milliseconds& timeout);
174 void handleBeaconRequest();
175 void handleBeaconResponse();
176 std::atomic_int beaconCounter_ {0};
177
178 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
179
180 msgpack::unpacker pac_ {};
181
182 MultiplexedSocket& parent_;
183
184 std::shared_ptr<Logger> logger_;
185 std::shared_ptr<asio::io_context> ctx_;
186
187 OnConnectionReadyCb onChannelReady_ {};
188 OnConnectionRequestCb onRequest_ {};
189 OnShutdownCb onShutdown_ {};
190
191 DeviceId deviceId {};
192 // Main socket
193 std::unique_ptr<TlsSocketEndpoint> endpoint {};
194
195 std::mutex socketsMutex {};
196 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
Adrien Béraud55133cc2023-10-15 11:55:20 -0400197 uint16_t nextChannel_;
Adrien Béraud612b55b2023-05-29 10:42:04 -0400198
199 // Main loop to parse incoming packets
200 std::atomic_bool stop {false};
201 std::thread eventLoopThread_ {};
202
203 std::atomic_bool isShutdown_ {false};
204
205 std::mutex writeMtx {};
206
207 time_point start_ {clock::now()};
208 //std::shared_ptr<Task> beaconTask_ {};
209 asio::steady_timer beaconTimer_;
210
211 // version related stuff
212 void sendVersion();
213 void onVersion(int version);
214 std::atomic_bool canSendBeacon_ {false};
215 std::atomic_bool answerBeacon_ {true};
216 int version_ {MULTIPLEXED_SOCKET_VERSION};
217 std::function<void(bool)> onBeaconCb_ {};
218 std::function<void(int)> onVersionCb_ {};
219};
220
221void
222MultiplexedSocket::Impl::eventLoop()
223{
224 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
225 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
226 if (logger_)
227 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
228 shutdown();
229 return false;
230 }
231 return true;
232 });
233 sendVersion();
234 std::error_code ec;
235 while (!stop) {
236 if (!endpoint) {
237 shutdown();
238 return;
239 }
240 pac_.reserve_buffer(IO_BUFFER_SIZE);
241 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
242 if (size < 0) {
243 if (ec && logger_)
244 logger_->error("Read error detected: {}", ec.message());
245 break;
246 }
247 if (size == 0) {
248 // We can close the socket
249 shutdown();
250 break;
251 }
252
253 pac_.buffer_consumed(size);
254 msgpack::object_handle oh;
255 while (pac_.next(oh) && !stop) {
256 try {
257 auto msg = oh.get().as<ChanneledMessage>();
258 if (msg.channel == CONTROL_CHANNEL)
259 handleControlPacket(std::move(msg.data));
260 else if (msg.channel == PROTOCOL_CHANNEL)
261 handleProtocolPacket(std::move(msg.data));
262 else
263 handleChannelPacket(msg.channel, std::move(msg.data));
264 } catch (const std::exception& e) {
265 if (logger_)
266 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
267 } catch (...) {
268 if (logger_)
269 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
270 }
271 }
272 }
273}
274
275void
276MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
277{
278 std::lock_guard<std::mutex> lkSockets(socketsMutex);
279 auto& socket = sockets[channel];
280 if (!socket) {
281 if (logger_)
282 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
283 return;
284 }
285
286 onChannelReady_(deviceId, socket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400287 socket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400288 // Due to the callbacks that can take some time, onAccept can arrive after
289 // receiving all the data. In this case, the socket should be removed here
290 // as handle by onChannelReady_
291 if (socket->isRemovable())
292 sockets.erase(channel);
293 else
294 socket->answered();
295}
296
297void
298MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
299{
300 if (!canSendBeacon_)
301 return;
302 beaconCounter_++;
303 if (logger_)
304 logger_->debug("Send beacon to peer {}", deviceId);
305
306 msgpack::sbuffer buffer(8);
307 msgpack::packer<msgpack::sbuffer> pk(&buffer);
308 pk.pack(BeaconMsg {true});
309 if (!writeProtocolMessage(buffer))
310 return;
311 beaconTimer_.expires_after(timeout);
312 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
313 if (ec == asio::error::operation_aborted)
314 return;
315 if (auto shared = w.lock()) {
316 if (shared->pimpl_->beaconCounter_ != 0) {
317 if (shared->pimpl_->logger_)
318 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
319 shared->shutdown();
320 }
321 }
322 });
323}
324
325void
326MultiplexedSocket::Impl::handleBeaconRequest()
327{
328 if (!answerBeacon_)
329 return;
330 // Run this on dedicated thread because some callbacks can take time
331 dht::ThreadPool::io().run([w = parent_.weak()]() {
332 if (auto shared = w.lock()) {
333 msgpack::sbuffer buffer(8);
334 msgpack::packer<msgpack::sbuffer> pk(&buffer);
335 pk.pack(BeaconMsg {false});
336 if (shared->pimpl_->logger_)
337 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
338 shared->pimpl_->writeProtocolMessage(buffer);
339 }
340 });
341}
342
343void
344MultiplexedSocket::Impl::handleBeaconResponse()
345{
346 if (logger_)
347 logger_->debug("Get beacon response from peer {}", deviceId);
348 beaconCounter_--;
349}
350
351bool
352MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
353{
354 std::error_code ec;
355 int wr = parent_.write(PROTOCOL_CHANNEL,
356 (const unsigned char*) buffer.data(),
357 buffer.size(),
358 ec);
359 return wr > 0;
360}
361
362void
363MultiplexedSocket::Impl::sendVersion()
364{
365 dht::ThreadPool::io().run([w = parent_.weak()]() {
366 if (auto shared = w.lock()) {
367 auto version = shared->pimpl_->version_;
368 msgpack::sbuffer buffer(8);
369 msgpack::packer<msgpack::sbuffer> pk(&buffer);
370 pk.pack(VersionMsg {version});
371 shared->pimpl_->writeProtocolMessage(buffer);
372 }
373 });
374}
375
376void
377MultiplexedSocket::Impl::onVersion(int version)
378{
379 // Check if version > 1
380 if (version >= 1) {
381 if (logger_)
382 logger_->debug("Peer {} supports beacon", deviceId);
383 canSendBeacon_ = true;
384 } else {
385 if (logger_)
386 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
387 deviceId,
388 version);
389 canSendBeacon_ = false;
390 }
391}
392
393void
394MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
395{
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400396 bool accept;
397 if (channel == CONTROL_CHANNEL || channel == PROTOCOL_CHANNEL) {
398 if (logger_)
399 logger_->warn("Channel {:d} is reserved, refusing request", channel);
400 accept = false;
401 } else
402 accept = onRequest_(endpoint->peerCertificate(), channel, name);
403
Adrien Béraud612b55b2023-05-29 10:42:04 -0400404 std::shared_ptr<ChannelSocket> channelSocket;
405 if (accept) {
406 std::lock_guard<std::mutex> lkSockets(socketsMutex);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400407 channelSocket = makeSocket(name, channel, false);
408 if (not channelSocket) {
409 if (logger_)
410 logger_->error("Channel {:d} already exists, refusing request", channel);
411 accept = false;
412 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400413 }
414
415 // Answer to ChannelRequest if accepted
416 ChannelRequest val;
417 val.channel = channel;
418 val.name = name;
419 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
420 msgpack::sbuffer buffer(512);
421 msgpack::pack(buffer, val);
422 std::error_code ec;
423 int wr = parent_.write(CONTROL_CHANNEL,
424 reinterpret_cast<const uint8_t*>(buffer.data()),
425 buffer.size(),
426 ec);
427 if (wr < 0) {
428 if (ec && logger_)
429 logger_->error("The write operation failed with error: {:s}", ec.message());
430 stop.store(true);
431 return;
432 }
433
434 if (accept) {
435 onChannelReady_(deviceId, channelSocket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400436 channelSocket->ready(true);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400437 if (channelSocket->isRemovable()) {
438 std::lock_guard<std::mutex> lkSockets(socketsMutex);
439 sockets.erase(channel);
440 } else
441 channelSocket->answered();
Adrien Béraud612b55b2023-05-29 10:42:04 -0400442 }
443}
444
445void
446MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
447{
448 // Run this on dedicated thread because some callbacks can take time
449 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
450 auto shared = w.lock();
451 if (!shared)
452 return;
453 auto& pimpl = *shared->pimpl_;
454 try {
455 size_t off = 0;
456 while (off != pkt.size()) {
457 msgpack::unpacked result;
458 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
459 auto object = result.get();
460 if (pimpl.handleProtocolMsg(object))
461 continue;
462 auto req = object.as<ChannelRequest>();
463 if (req.state == ChannelRequestState::ACCEPT) {
464 pimpl.onAccept(req.name, req.channel);
465 } else if (req.state == ChannelRequestState::DECLINE) {
466 std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
467 auto channel = pimpl.sockets.find(req.channel);
468 if (channel != pimpl.sockets.end()) {
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400469 channel->second->ready(false);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400470 channel->second->stop();
471 pimpl.sockets.erase(channel);
472 }
473 } else if (pimpl.onRequest_) {
474 pimpl.onRequest(req.name, req.channel);
475 }
476 }
477 } catch (const std::exception& e) {
478 if (pimpl.logger_)
479 pimpl.logger_->error("Error on the control channel: {}", e.what());
480 }
481 });
482}
483
484void
485MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
486{
487 std::lock_guard<std::mutex> lkSockets(socketsMutex);
488 auto sockIt = sockets.find(channel);
489 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
490 if (pkt.size() == 0) {
491 sockIt->second->stop();
492 if (sockIt->second->isAnswered())
493 sockets.erase(sockIt);
494 else
495 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
496 // removed later.
497 } else {
498 sockIt->second->onRecv(std::move(pkt));
499 }
500 } else if (pkt.size() != 0) {
501 if (logger_)
502 logger_->warn("Non existing channel: {}", channel);
503 }
504}
505
506bool
507MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
508{
509 try {
510 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
511 auto key = o.via.map.ptr[0].key.as<std::string_view>();
512 if (key == "p") {
513 auto msg = o.as<BeaconMsg>();
514 if (msg.p)
515 handleBeaconRequest();
516 else
517 handleBeaconResponse();
518 if (onBeaconCb_)
519 onBeaconCb_(msg.p);
520 return true;
521 } else if (key == "v") {
522 auto msg = o.as<VersionMsg>();
523 onVersion(msg.v);
524 if (onVersionCb_)
525 onVersionCb_(msg.v);
526 return true;
527 } else {
528 if (logger_)
529 logger_->warn("Unknown message type");
530 }
531 }
532 } catch (const std::exception& e) {
533 if (logger_)
534 logger_->error("Error on the protocol channel: {}", e.what());
535 }
536 return false;
537}
538
539void
540MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
541{
542 // Run this on dedicated thread because some callbacks can take time
543 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
544 auto shared = w.lock();
545 if (!shared)
546 return;
547 try {
548 size_t off = 0;
549 while (off != pkt.size()) {
550 msgpack::unpacked result;
551 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
552 auto object = result.get();
553 if (shared->pimpl_->handleProtocolMsg(object))
554 return;
555 }
556 } catch (const std::exception& e) {
557 if (shared->pimpl_->logger_)
558 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
559 }
560 });
561}
562
563MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
Adrien Béraud5636f7c2023-09-14 14:34:57 -0400564 std::unique_ptr<TlsSocketEndpoint> endpoint, std::shared_ptr<dht::log::Logger> logger)
565 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint), logger))
Adrien Béraud612b55b2023-05-29 10:42:04 -0400566{}
567
568MultiplexedSocket::~MultiplexedSocket() {}
569
570std::shared_ptr<ChannelSocket>
571MultiplexedSocket::addChannel(const std::string& name)
572{
Adrien Béraud612b55b2023-05-29 10:42:04 -0400573 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
Adrien Béraud55133cc2023-10-15 11:55:20 -0400574 if (pimpl_->sockets.size() < UINT16_MAX)
575 for (unsigned i = 0; i < UINT16_MAX; ++i) {
576 auto c = pimpl_->nextChannel_++;
577 if (c == CONTROL_CHANNEL
578 || c == PROTOCOL_CHANNEL
579 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
580 continue;
581 return pimpl_->makeSocket(name, c, true);
582 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400583 return {};
584}
585
586DeviceId
587MultiplexedSocket::deviceId() const
588{
589 return pimpl_->deviceId;
590}
591
592void
593MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
594{
595 pimpl_->onChannelReady_ = std::move(cb);
596}
597
598void
599MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
600{
601 pimpl_->onRequest_ = std::move(cb);
602}
603
604bool
605MultiplexedSocket::isReliable() const
606{
607 return true;
608}
609
610bool
611MultiplexedSocket::isInitiator() const
612{
613 if (!pimpl_->endpoint) {
614 if (pimpl_->logger_)
615 pimpl_->logger_->warn("No endpoint found for socket");
616 return false;
617 }
618 return pimpl_->endpoint->isInitiator();
619}
620
621int
622MultiplexedSocket::maxPayload() const
623{
624 if (!pimpl_->endpoint) {
625 if (pimpl_->logger_)
626 pimpl_->logger_->warn("No endpoint found for socket");
627 return 0;
628 }
629 return pimpl_->endpoint->maxPayload();
630}
631
632std::size_t
633MultiplexedSocket::write(const uint16_t& channel,
634 const uint8_t* buf,
635 std::size_t len,
636 std::error_code& ec)
637{
638 assert(nullptr != buf);
639
640 if (pimpl_->isShutdown_) {
641 ec = std::make_error_code(std::errc::broken_pipe);
642 return -1;
643 }
644 if (len > UINT16_MAX) {
645 ec = std::make_error_code(std::errc::message_size);
646 return -1;
647 }
648 bool oneShot = len < 8192;
649 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
650 msgpack::packer<msgpack::sbuffer> pk(&buffer);
651 pk.pack_array(2);
652 pk.pack(channel);
653 pk.pack_bin(len);
654 if (oneShot)
655 pk.pack_bin_body((const char*) buf, len);
656
657 std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
658 if (!pimpl_->endpoint) {
659 if (pimpl_->logger_)
660 pimpl_->logger_->warn("No endpoint found for socket");
661 ec = std::make_error_code(std::errc::broken_pipe);
662 return -1;
663 }
664 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
665 if (not oneShot and res >= 0)
666 res = pimpl_->endpoint->write(buf, len, ec);
667 lk.unlock();
668 if (res < 0) {
669 if (ec && pimpl_->logger_)
670 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
671 shutdown();
672 }
673 return res;
674}
675
676void
677MultiplexedSocket::shutdown()
678{
679 pimpl_->shutdown();
680}
681
682void
683MultiplexedSocket::join()
684{
685 pimpl_->join();
686}
687
688void
689MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
690{
691 pimpl_->onShutdown_ = std::move(cb);
692 if (pimpl_->isShutdown_)
693 pimpl_->onShutdown_();
694}
695
696const std::shared_ptr<Logger>&
697MultiplexedSocket::logger()
698{
699 return pimpl_->logger_;
700}
701
702void
703MultiplexedSocket::monitor() const
704{
705 auto cert = peerCertificate();
706 if (!cert || !cert->issuer)
707 return;
708 auto now = clock::now();
709 if (!pimpl_->logger_)
710 return;
711 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
712 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
713 pimpl_->endpoint->monitor();
714 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
715 for (const auto& [_, channel] : pimpl_->sockets) {
716 if (channel)
717 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
718 fmt::ptr(channel.get()),
719 channel.use_count(),
720 channel->name(),
721 channel->isInitiator());
722 }
723}
724
725void
726MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
727{
728 pimpl_->sendBeacon(timeout);
729}
730
731std::shared_ptr<dht::crypto::Certificate>
732MultiplexedSocket::peerCertificate() const
733{
734 return pimpl_->endpoint->peerCertificate();
735}
736
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400737#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400738bool
739MultiplexedSocket::canSendBeacon() const
740{
741 return pimpl_->canSendBeacon_;
742}
743
744void
745MultiplexedSocket::answerToBeacon(bool value)
746{
747 pimpl_->answerBeacon_ = value;
748}
749
750void
751MultiplexedSocket::setVersion(int version)
752{
753 pimpl_->version_ = version;
754}
755
756void
757MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
758{
759 pimpl_->onBeaconCb_ = cb;
760}
761
762void
763MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
764{
765 pimpl_->onVersionCb_ = cb;
766}
767
768void
769MultiplexedSocket::sendVersion()
770{
771 pimpl_->sendVersion();
772}
773
Adrien Béraudac35e662023-07-19 09:37:29 -0400774#endif
775
Adrien Béraud612b55b2023-05-29 10:42:04 -0400776IpAddr
777MultiplexedSocket::getLocalAddress() const
778{
779 return pimpl_->endpoint->getLocalAddress();
780}
781
782IpAddr
783MultiplexedSocket::getRemoteAddress() const
784{
785 return pimpl_->endpoint->getRemoteAddress();
786}
787
Adrien Béraudafa8e282023-09-24 12:53:20 -0400788TlsSocketEndpoint*
789MultiplexedSocket::endpoint()
790{
791 return pimpl_->endpoint.get();
792}
793
Adrien Béraud612b55b2023-05-29 10:42:04 -0400794void
795MultiplexedSocket::eraseChannel(uint16_t channel)
796{
797 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
798 auto itSocket = pimpl_->sockets.find(channel);
799 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
800 pimpl_->sockets.erase(itSocket);
801}
802
803////////////////////////////////////////////////////////////////
804
805class ChannelSocket::Impl
806{
807public:
808 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
809 const std::string& name,
810 const uint16_t& channel,
811 bool isInitiator,
812 std::function<void()> rmFromMxSockCb)
813 : name(name)
814 , channel(channel)
815 , endpoint(std::move(endpoint))
816 , isInitiator_(isInitiator)
817 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
818 {}
819
820 ~Impl() {}
821
822 ChannelReadyCb readyCb_ {};
823 OnShutdownCb shutdownCb_ {};
824 std::atomic_bool isShutdown_ {false};
825 std::string name {};
826 uint16_t channel {};
827 std::weak_ptr<MultiplexedSocket> endpoint {};
828 bool isInitiator_ {false};
829 std::function<void()> rmFromMxSockCb_;
830
831 bool isAnswered_ {false};
832 bool isRemovable_ {false};
833
834 std::vector<uint8_t> buf {};
835 std::mutex mutex {};
836 std::condition_variable cv {};
837 GenericSocket<uint8_t>::RecvCb cb {};
838};
839
840ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
841 const DeviceId& deviceId,
842 const std::string& name,
843 const uint16_t& channel)
844 : pimpl_deviceId(deviceId)
845 , pimpl_name(name)
846 , pimpl_channel(channel)
847 , ioCtx_(*ctx)
848{}
849
850ChannelSocketTest::~ChannelSocketTest() {}
851
852void
853ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
854 const std::shared_ptr<ChannelSocketTest>& socket2)
855{
856 socket1->remote = socket2;
857 socket2->remote = socket1;
858}
859
860DeviceId
861ChannelSocketTest::deviceId() const
862{
863 return pimpl_deviceId;
864}
865
866std::string
867ChannelSocketTest::name() const
868{
869 return pimpl_name;
870}
871
872uint16_t
873ChannelSocketTest::channel() const
874{
875 return pimpl_channel;
876}
877
878void
879ChannelSocketTest::shutdown()
880{
881 {
882 std::unique_lock<std::mutex> lk {mutex};
883 if (!isShutdown_.exchange(true)) {
884 lk.unlock();
885 shutdownCb_();
886 }
887 cv.notify_all();
888 }
889
890 if (auto peer = remote.lock()) {
891 if (!peer->isShutdown_.exchange(true)) {
892 peer->shutdownCb_();
893 }
894 peer->cv.notify_all();
895 }
896}
897
898std::size_t
899ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
900{
901 std::size_t size = std::min(len, this->rx_buf.size());
902
903 for (std::size_t i = 0; i < size; ++i)
904 buf[i] = this->rx_buf[i];
905
906 if (size == this->rx_buf.size()) {
907 this->rx_buf.clear();
908 } else
909 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
910 return size;
911}
912
913std::size_t
914ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
915{
916 if (isShutdown_) {
917 ec = std::make_error_code(std::errc::broken_pipe);
918 return -1;
919 }
920 ec = {};
921 dht::ThreadPool::computation().run(
922 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
923 if (auto peer = r.lock())
924 peer->onRecv(std::move(data));
925 });
926 return len;
927}
928
929int
930ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
931{
932 std::unique_lock<std::mutex> lk {mutex};
933 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
934 return rx_buf.size();
935}
936
937void
938ChannelSocketTest::setOnRecv(RecvCb&& cb)
939{
940 std::lock_guard<std::mutex> lkSockets(mutex);
941 this->cb = std::move(cb);
942 if (!rx_buf.empty() && this->cb) {
943 this->cb(rx_buf.data(), rx_buf.size());
944 rx_buf.clear();
945 }
946}
947
948void
949ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
950{
951 std::lock_guard<std::mutex> lkSockets(mutex);
952 if (cb) {
953 cb(pkt.data(), pkt.size());
954 return;
955 }
956 rx_buf.insert(rx_buf.end(),
957 std::make_move_iterator(pkt.begin()),
958 std::make_move_iterator(pkt.end()));
959 cv.notify_all();
960}
961
962void
963ChannelSocketTest::onReady(ChannelReadyCb&& cb)
964{}
965
966void
967ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
968{
969 std::unique_lock<std::mutex> lk {mutex};
970 shutdownCb_ = std::move(cb);
971
972 if (isShutdown_) {
973 lk.unlock();
974 shutdownCb_();
975 }
976}
977
978ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
979 const std::string& name,
980 const uint16_t& channel,
981 bool isInitiator,
982 std::function<void()> rmFromMxSockCb)
983 : pimpl_ {
984 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
985{}
986
987ChannelSocket::~ChannelSocket() {}
988
989DeviceId
990ChannelSocket::deviceId() const
991{
992 if (auto ep = pimpl_->endpoint.lock()) {
993 return ep->deviceId();
994 }
995 return {};
996}
997
998std::string
999ChannelSocket::name() const
1000{
1001 return pimpl_->name;
1002}
1003
1004uint16_t
1005ChannelSocket::channel() const
1006{
1007 return pimpl_->channel;
1008}
1009
1010bool
1011ChannelSocket::isReliable() const
1012{
1013 if (auto ep = pimpl_->endpoint.lock()) {
1014 return ep->isReliable();
1015 }
1016 return false;
1017}
1018
1019bool
1020ChannelSocket::isInitiator() const
1021{
1022 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1023 // because a multiplexed socket can have sockets from accepted requests
1024 // or made via connectDevice(). Here, isInitiator_ return if the socket
1025 // is from connectDevice.
1026 return pimpl_->isInitiator_;
1027}
1028
1029int
1030ChannelSocket::maxPayload() const
1031{
1032 if (auto ep = pimpl_->endpoint.lock()) {
1033 return ep->maxPayload();
1034 }
1035 return -1;
1036}
1037
1038void
1039ChannelSocket::setOnRecv(RecvCb&& cb)
1040{
1041 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1042 pimpl_->cb = std::move(cb);
1043 if (!pimpl_->buf.empty() && pimpl_->cb) {
1044 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1045 pimpl_->buf.clear();
1046 }
1047}
1048
1049void
1050ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1051{
1052 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1053 if (pimpl_->cb) {
1054 pimpl_->cb(&pkt[0], pkt.size());
1055 return;
1056 }
1057 pimpl_->buf.insert(pimpl_->buf.end(),
1058 std::make_move_iterator(pkt.begin()),
1059 std::make_move_iterator(pkt.end()));
1060 pimpl_->cv.notify_all();
1061}
1062
Adrien Béraud6b6a5d32023-08-15 15:53:33 -04001063#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -04001064std::shared_ptr<MultiplexedSocket>
1065ChannelSocket::underlyingSocket() const
1066{
1067 if (auto mtx = pimpl_->endpoint.lock())
1068 return mtx;
1069 return {};
1070}
1071#endif
1072
1073void
1074ChannelSocket::answered()
1075{
1076 pimpl_->isAnswered_ = true;
1077}
1078
1079void
1080ChannelSocket::removable()
1081{
1082 pimpl_->isRemovable_ = true;
1083}
1084
1085bool
1086ChannelSocket::isRemovable() const
1087{
1088 return pimpl_->isRemovable_;
1089}
1090
1091bool
1092ChannelSocket::isAnswered() const
1093{
1094 return pimpl_->isAnswered_;
1095}
1096
1097void
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001098ChannelSocket::ready(bool accepted)
Adrien Béraud612b55b2023-05-29 10:42:04 -04001099{
1100 if (pimpl_->readyCb_)
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001101 pimpl_->readyCb_(accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001102}
1103
1104void
1105ChannelSocket::stop()
1106{
1107 if (pimpl_->isShutdown_)
1108 return;
1109 pimpl_->isShutdown_ = true;
1110 if (pimpl_->shutdownCb_)
1111 pimpl_->shutdownCb_();
1112 pimpl_->cv.notify_all();
1113 // stop() can be called by ChannelSocket::shutdown()
1114 // In this case, the eventLoop is not used, but MxSock
1115 // must remove the channel from its list (so that the
1116 // channel can be destroyed and its shared_ptr invalidated).
1117 if (pimpl_->rmFromMxSockCb_)
1118 pimpl_->rmFromMxSockCb_();
1119}
1120
1121void
1122ChannelSocket::shutdown()
1123{
1124 if (pimpl_->isShutdown_)
1125 return;
1126 stop();
1127 if (auto ep = pimpl_->endpoint.lock()) {
1128 std::error_code ec;
1129 const uint8_t dummy = '\0';
1130 ep->write(pimpl_->channel, &dummy, 0, ec);
1131 }
1132}
1133
1134std::size_t
1135ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1136{
1137 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1138 std::size_t size = std::min(len, pimpl_->buf.size());
1139
1140 for (std::size_t i = 0; i < size; ++i)
1141 outBuf[i] = pimpl_->buf[i];
1142
1143 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1144 return size;
1145}
1146
1147std::size_t
1148ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1149{
1150 if (pimpl_->isShutdown_) {
1151 ec = std::make_error_code(std::errc::broken_pipe);
1152 return -1;
1153 }
1154 if (auto ep = pimpl_->endpoint.lock()) {
1155 std::size_t sent = 0;
1156 do {
1157 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1158 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1159 if (ec) {
1160 if (ep->logger())
1161 ep->logger()->error("Error when writing on channel: {}", ec.message());
1162 return res;
1163 }
1164 sent += toSend;
1165 } while (sent < len);
1166 return sent;
1167 }
1168 ec = std::make_error_code(std::errc::broken_pipe);
1169 return -1;
1170}
1171
1172int
1173ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1174{
1175 std::unique_lock<std::mutex> lk {pimpl_->mutex};
1176 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1177 return pimpl_->buf.size();
1178}
1179
1180void
1181ChannelSocket::onShutdown(OnShutdownCb&& cb)
1182{
1183 pimpl_->shutdownCb_ = std::move(cb);
1184 if (pimpl_->isShutdown_) {
1185 pimpl_->shutdownCb_();
1186 }
1187}
1188
1189void
1190ChannelSocket::onReady(ChannelReadyCb&& cb)
1191{
1192 pimpl_->readyCb_ = std::move(cb);
1193}
1194
1195void
1196ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1197{
1198 if (auto ep = pimpl_->endpoint.lock()) {
1199 ep->sendBeacon(timeout);
1200 } else {
1201 shutdown();
1202 }
1203}
1204
1205std::shared_ptr<dht::crypto::Certificate>
1206ChannelSocket::peerCertificate() const
1207{
1208 if (auto ep = pimpl_->endpoint.lock())
1209 return ep->peerCertificate();
1210 return {};
1211}
1212
1213IpAddr
1214ChannelSocket::getLocalAddress() const
1215{
1216 if (auto ep = pimpl_->endpoint.lock())
1217 return ep->getLocalAddress();
1218 return {};
1219}
1220
1221IpAddr
1222ChannelSocket::getRemoteAddress() const
1223{
1224 if (auto ep = pimpl_->endpoint.lock())
1225 return ep->getRemoteAddress();
1226 return {};
1227}
1228
Amna31791e52023-08-03 12:40:57 -04001229std::vector<std::map<std::string, std::string>>
1230MultiplexedSocket::getChannelList() const
1231{
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001232 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
Amna31791e52023-08-03 12:40:57 -04001233 std::vector<std::map<std::string, std::string>> channelsList;
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001234 channelsList.reserve(pimpl_->sockets.size());
Amna31791e52023-08-03 12:40:57 -04001235 for (const auto& [_, channel] : pimpl_->sockets) {
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001236 channelsList.emplace_back(std::map<std::string, std::string> {
1237 {"id", fmt::format("{:x}", channel->channel())},
1238 {"name", channel->name()},
1239 });
Amna31791e52023-08-03 12:40:57 -04001240 }
Amna31791e52023-08-03 12:40:57 -04001241 return channelsList;
1242}
1243
Sébastien Blin464bdff2023-07-19 08:02:53 -04001244} // namespace dhtnet