blob: 910e31e87bcf9dfa043bfb710f3b6c291b004558 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
Adrien Béraudcb753622023-07-17 22:32:49 -04002 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
Adrien Béraud612b55b2023-05-29 10:42:04 -04003 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "multiplexed_socket.h"
18#include "peer_connection.h"
19#include "ice_transport.h"
20#include "certstore.h"
21
22#include <opendht/logger.h>
23#include <opendht/thread_pool.h>
24
25#include <asio/io_context.hpp>
26#include <asio/steady_timer.hpp>
27
28#include <deque>
29
30static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
31static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
32
33struct ChanneledMessage
34{
35 uint16_t channel;
36 std::vector<uint8_t> data;
37 MSGPACK_DEFINE(channel, data)
38};
39
40struct BeaconMsg
41{
42 bool p;
43 MSGPACK_DEFINE_MAP(p)
44};
45
46struct VersionMsg
47{
48 int v;
49 MSGPACK_DEFINE_MAP(v)
50};
51
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040052namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040053
54using clock = std::chrono::steady_clock;
55using time_point = clock::time_point;
56
57class MultiplexedSocket::Impl
58{
59public:
60 Impl(MultiplexedSocket& parent,
61 std::shared_ptr<asio::io_context> ctx,
62 const DeviceId& deviceId,
Adrien Béraud55133cc2023-10-15 11:55:20 -040063 std::unique_ptr<TlsSocketEndpoint> ep,
Adrien Béraud5636f7c2023-09-14 14:34:57 -040064 std::shared_ptr<dht::log::Logger> logger)
Adrien Béraud612b55b2023-05-29 10:42:04 -040065 : parent_(parent)
Adrien Béraud612b55b2023-05-29 10:42:04 -040066 , ctx_(std::move(ctx))
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040067 , deviceId(deviceId)
Adrien Béraud55133cc2023-10-15 11:55:20 -040068 , endpoint(std::move(ep))
69 , nextChannel_(endpoint->isInitiator() ? 0x0001u : 0x8000u)
Adrien Béraud18d344d2024-03-02 21:39:22 -050070 , logger_(std::move(logger))
Adrien Béraud612b55b2023-05-29 10:42:04 -040071 , eventLoopThread_ {[this] {
72 try {
73 eventLoop();
74 } catch (const std::exception& e) {
75 if (logger_)
76 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
77 shutdown();
78 }
79 }}
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040080 , beaconTimer_(*ctx_)
Adrien Béraud612b55b2023-05-29 10:42:04 -040081 {}
82
83 ~Impl() {}
84
85 void join()
86 {
87 if (!isShutdown_) {
88 if (endpoint)
89 endpoint->setOnStateChange({});
90 shutdown();
91 } else {
92 clearSockets();
93 }
94 if (eventLoopThread_.joinable())
95 eventLoopThread_.join();
96 }
97
98 void clearSockets()
99 {
100 decltype(sockets) socks;
101 {
Adrien Béraud024c46f2024-03-02 23:53:18 -0500102 std::lock_guard lkSockets(socketsMutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400103 socks = std::move(sockets);
104 }
105 for (auto& socket : socks) {
106 // Just trigger onShutdown() to make client know
107 // No need to write the EOF for the channel, the write will fail because endpoint is
108 // already shutdown
109 if (socket.second)
110 socket.second->stop();
111 }
112 }
113
114 void shutdown()
115 {
116 if (isShutdown_)
117 return;
118 stop.store(true);
119 isShutdown_ = true;
120 beaconTimer_.cancel();
121 if (onShutdown_)
122 onShutdown_();
123 if (endpoint) {
Adrien Béraud024c46f2024-03-02 23:53:18 -0500124 std::unique_lock lk(writeMtx);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400125 endpoint->shutdown();
126 }
127 clearSockets();
128 }
129
Adrien Béraud1aaaa962024-09-26 12:41:58 -0400130 bool isRunning() const {
131 return !isShutdown_ && !stop;
132 }
133
Adrien Béraud612b55b2023-05-29 10:42:04 -0400134 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
135 uint16_t channel,
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400136 bool isInitiator)
Adrien Béraud612b55b2023-05-29 10:42:04 -0400137 {
138 auto& channelSocket = sockets[channel];
139 if (not channelSocket)
140 channelSocket = std::make_shared<ChannelSocket>(
141 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
142 // Remove socket in another thread to avoid any lock
143 dht::ThreadPool::io().run([w, channel]() {
144 if (auto shared = w.lock()) {
145 shared->eraseChannel(channel);
146 }
147 });
148 });
149 else {
150 if (logger_)
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400151 logger_->warn("Received request for existing channel {}", channel);
152 return {};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400153 }
154 return channelSocket;
155 }
156
157 /**
158 * Handle packets on the TLS endpoint and parse RTP
159 */
160 void eventLoop();
161 /**
162 * Triggered when a new control packet is received
163 */
164 void handleControlPacket(std::vector<uint8_t>&& pkt);
165 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
166 bool handleProtocolMsg(const msgpack::object& o);
167 /**
168 * Triggered when a new packet on a channel is received
169 */
170 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
171 void onRequest(const std::string& name, uint16_t channel);
172 void onAccept(const std::string& name, uint16_t channel);
173
174 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
175 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
176
177 // Beacon
178 void sendBeacon(const std::chrono::milliseconds& timeout);
179 void handleBeaconRequest();
180 void handleBeaconResponse();
181 std::atomic_int beaconCounter_ {0};
182
183 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
184
185 msgpack::unpacker pac_ {};
186
187 MultiplexedSocket& parent_;
188
189 std::shared_ptr<Logger> logger_;
190 std::shared_ptr<asio::io_context> ctx_;
191
192 OnConnectionReadyCb onChannelReady_ {};
193 OnConnectionRequestCb onRequest_ {};
194 OnShutdownCb onShutdown_ {};
195
196 DeviceId deviceId {};
197 // Main socket
198 std::unique_ptr<TlsSocketEndpoint> endpoint {};
199
200 std::mutex socketsMutex {};
201 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
Adrien Béraud55133cc2023-10-15 11:55:20 -0400202 uint16_t nextChannel_;
Adrien Béraud612b55b2023-05-29 10:42:04 -0400203
204 // Main loop to parse incoming packets
205 std::atomic_bool stop {false};
206 std::thread eventLoopThread_ {};
207
208 std::atomic_bool isShutdown_ {false};
209
210 std::mutex writeMtx {};
211
212 time_point start_ {clock::now()};
213 //std::shared_ptr<Task> beaconTask_ {};
214 asio::steady_timer beaconTimer_;
215
216 // version related stuff
217 void sendVersion();
218 void onVersion(int version);
219 std::atomic_bool canSendBeacon_ {false};
220 std::atomic_bool answerBeacon_ {true};
221 int version_ {MULTIPLEXED_SOCKET_VERSION};
222 std::function<void(bool)> onBeaconCb_ {};
223 std::function<void(int)> onVersionCb_ {};
224};
225
226void
227MultiplexedSocket::Impl::eventLoop()
228{
229 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
230 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
231 if (logger_)
232 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
233 shutdown();
234 return false;
235 }
236 return true;
237 });
238 sendVersion();
239 std::error_code ec;
240 while (!stop) {
241 if (!endpoint) {
242 shutdown();
243 return;
244 }
245 pac_.reserve_buffer(IO_BUFFER_SIZE);
246 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
247 if (size < 0) {
248 if (ec && logger_)
249 logger_->error("Read error detected: {}", ec.message());
250 break;
251 }
252 if (size == 0) {
253 // We can close the socket
254 shutdown();
255 break;
256 }
257
258 pac_.buffer_consumed(size);
259 msgpack::object_handle oh;
260 while (pac_.next(oh) && !stop) {
261 try {
262 auto msg = oh.get().as<ChanneledMessage>();
263 if (msg.channel == CONTROL_CHANNEL)
264 handleControlPacket(std::move(msg.data));
265 else if (msg.channel == PROTOCOL_CHANNEL)
266 handleProtocolPacket(std::move(msg.data));
267 else
268 handleChannelPacket(msg.channel, std::move(msg.data));
269 } catch (const std::exception& e) {
270 if (logger_)
271 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
272 } catch (...) {
273 if (logger_)
274 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
275 }
276 }
277 }
278}
279
280void
281MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
282{
Adrien Béraud024c46f2024-03-02 23:53:18 -0500283 std::lock_guard lkSockets(socketsMutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400284 auto& socket = sockets[channel];
285 if (!socket) {
286 if (logger_)
287 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
288 return;
289 }
290
291 onChannelReady_(deviceId, socket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400292 socket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400293 // Due to the callbacks that can take some time, onAccept can arrive after
294 // receiving all the data. In this case, the socket should be removed here
295 // as handle by onChannelReady_
296 if (socket->isRemovable())
297 sockets.erase(channel);
298 else
299 socket->answered();
300}
301
302void
303MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
304{
305 if (!canSendBeacon_)
306 return;
307 beaconCounter_++;
308 if (logger_)
309 logger_->debug("Send beacon to peer {}", deviceId);
310
311 msgpack::sbuffer buffer(8);
312 msgpack::packer<msgpack::sbuffer> pk(&buffer);
313 pk.pack(BeaconMsg {true});
314 if (!writeProtocolMessage(buffer))
315 return;
316 beaconTimer_.expires_after(timeout);
317 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
318 if (ec == asio::error::operation_aborted)
319 return;
320 if (auto shared = w.lock()) {
321 if (shared->pimpl_->beaconCounter_ != 0) {
322 if (shared->pimpl_->logger_)
323 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
324 shared->shutdown();
325 }
326 }
327 });
328}
329
330void
331MultiplexedSocket::Impl::handleBeaconRequest()
332{
333 if (!answerBeacon_)
334 return;
335 // Run this on dedicated thread because some callbacks can take time
336 dht::ThreadPool::io().run([w = parent_.weak()]() {
337 if (auto shared = w.lock()) {
338 msgpack::sbuffer buffer(8);
339 msgpack::packer<msgpack::sbuffer> pk(&buffer);
340 pk.pack(BeaconMsg {false});
341 if (shared->pimpl_->logger_)
342 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
343 shared->pimpl_->writeProtocolMessage(buffer);
344 }
345 });
346}
347
348void
349MultiplexedSocket::Impl::handleBeaconResponse()
350{
351 if (logger_)
352 logger_->debug("Get beacon response from peer {}", deviceId);
353 beaconCounter_--;
354}
355
356bool
357MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
358{
359 std::error_code ec;
360 int wr = parent_.write(PROTOCOL_CHANNEL,
361 (const unsigned char*) buffer.data(),
362 buffer.size(),
363 ec);
364 return wr > 0;
365}
366
367void
368MultiplexedSocket::Impl::sendVersion()
369{
370 dht::ThreadPool::io().run([w = parent_.weak()]() {
371 if (auto shared = w.lock()) {
372 auto version = shared->pimpl_->version_;
373 msgpack::sbuffer buffer(8);
374 msgpack::packer<msgpack::sbuffer> pk(&buffer);
375 pk.pack(VersionMsg {version});
376 shared->pimpl_->writeProtocolMessage(buffer);
377 }
378 });
379}
380
381void
382MultiplexedSocket::Impl::onVersion(int version)
383{
384 // Check if version > 1
385 if (version >= 1) {
386 if (logger_)
387 logger_->debug("Peer {} supports beacon", deviceId);
388 canSendBeacon_ = true;
389 } else {
390 if (logger_)
391 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
392 deviceId,
393 version);
394 canSendBeacon_ = false;
395 }
396}
397
398void
399MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
400{
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400401 bool accept;
402 if (channel == CONTROL_CHANNEL || channel == PROTOCOL_CHANNEL) {
403 if (logger_)
404 logger_->warn("Channel {:d} is reserved, refusing request", channel);
405 accept = false;
406 } else
407 accept = onRequest_(endpoint->peerCertificate(), channel, name);
408
Adrien Béraud612b55b2023-05-29 10:42:04 -0400409 std::shared_ptr<ChannelSocket> channelSocket;
410 if (accept) {
Adrien Béraud024c46f2024-03-02 23:53:18 -0500411 std::lock_guard lkSockets(socketsMutex);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400412 channelSocket = makeSocket(name, channel, false);
413 if (not channelSocket) {
414 if (logger_)
415 logger_->error("Channel {:d} already exists, refusing request", channel);
416 accept = false;
417 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400418 }
419
420 // Answer to ChannelRequest if accepted
421 ChannelRequest val;
422 val.channel = channel;
423 val.name = name;
424 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
425 msgpack::sbuffer buffer(512);
426 msgpack::pack(buffer, val);
427 std::error_code ec;
428 int wr = parent_.write(CONTROL_CHANNEL,
429 reinterpret_cast<const uint8_t*>(buffer.data()),
430 buffer.size(),
431 ec);
432 if (wr < 0) {
433 if (ec && logger_)
434 logger_->error("The write operation failed with error: {:s}", ec.message());
435 stop.store(true);
436 return;
437 }
438
439 if (accept) {
440 onChannelReady_(deviceId, channelSocket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400441 channelSocket->ready(true);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400442 if (channelSocket->isRemovable()) {
Adrien Béraud024c46f2024-03-02 23:53:18 -0500443 std::lock_guard lkSockets(socketsMutex);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400444 sockets.erase(channel);
445 } else
446 channelSocket->answered();
Adrien Béraud612b55b2023-05-29 10:42:04 -0400447 }
448}
449
450void
451MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
452{
Adrien Béraudcd7a7bd2023-10-15 12:54:30 -0400453 try {
454 size_t off = 0;
455 while (off != pkt.size()) {
456 msgpack::unpacked result;
457 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
458 auto object = result.get();
459 if (handleProtocolMsg(object))
460 continue;
461 auto req = object.as<ChannelRequest>();
462 if (req.state == ChannelRequestState::REQUEST) {
463 dht::ThreadPool::io().run([w = parent_.weak(), req = std::move(req)]() {
464 if (auto shared = w.lock())
465 shared->pimpl_->onRequest(req.name, req.channel);
466 });
467 }
468 else if (req.state == ChannelRequestState::ACCEPT) {
469 onAccept(req.name, req.channel);
470 } else {
471 // DECLINE or unknown
Adrien Béraud024c46f2024-03-02 23:53:18 -0500472 std::lock_guard lkSockets(socketsMutex);
Adrien Béraudcd7a7bd2023-10-15 12:54:30 -0400473 auto channel = sockets.find(req.channel);
474 if (channel != sockets.end()) {
475 channel->second->ready(false);
476 channel->second->stop();
477 sockets.erase(channel);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400478 }
479 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400480 }
Adrien Béraudcd7a7bd2023-10-15 12:54:30 -0400481 } catch (const std::exception& e) {
482 if (logger_)
483 logger_->error("Error on the control channel: {}", e.what());
484 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400485}
486
487void
488MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
489{
Adrien Béraud024c46f2024-03-02 23:53:18 -0500490 std::lock_guard lkSockets(socketsMutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400491 auto sockIt = sockets.find(channel);
492 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
493 if (pkt.size() == 0) {
494 sockIt->second->stop();
495 if (sockIt->second->isAnswered())
496 sockets.erase(sockIt);
497 else
498 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
499 // removed later.
500 } else {
501 sockIt->second->onRecv(std::move(pkt));
502 }
503 } else if (pkt.size() != 0) {
504 if (logger_)
505 logger_->warn("Non existing channel: {}", channel);
506 }
507}
508
509bool
510MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
511{
512 try {
513 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
514 auto key = o.via.map.ptr[0].key.as<std::string_view>();
515 if (key == "p") {
516 auto msg = o.as<BeaconMsg>();
517 if (msg.p)
518 handleBeaconRequest();
519 else
520 handleBeaconResponse();
521 if (onBeaconCb_)
522 onBeaconCb_(msg.p);
523 return true;
524 } else if (key == "v") {
525 auto msg = o.as<VersionMsg>();
526 onVersion(msg.v);
527 if (onVersionCb_)
528 onVersionCb_(msg.v);
529 return true;
530 } else {
531 if (logger_)
532 logger_->warn("Unknown message type");
533 }
534 }
535 } catch (const std::exception& e) {
536 if (logger_)
537 logger_->error("Error on the protocol channel: {}", e.what());
538 }
539 return false;
540}
541
542void
543MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
544{
545 // Run this on dedicated thread because some callbacks can take time
546 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
547 auto shared = w.lock();
548 if (!shared)
549 return;
550 try {
551 size_t off = 0;
552 while (off != pkt.size()) {
553 msgpack::unpacked result;
554 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
555 auto object = result.get();
556 if (shared->pimpl_->handleProtocolMsg(object))
557 return;
558 }
559 } catch (const std::exception& e) {
560 if (shared->pimpl_->logger_)
561 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
562 }
563 });
564}
565
566MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
Adrien Béraud5636f7c2023-09-14 14:34:57 -0400567 std::unique_ptr<TlsSocketEndpoint> endpoint, std::shared_ptr<dht::log::Logger> logger)
568 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint), logger))
Adrien Béraud612b55b2023-05-29 10:42:04 -0400569{}
570
571MultiplexedSocket::~MultiplexedSocket() {}
572
573std::shared_ptr<ChannelSocket>
574MultiplexedSocket::addChannel(const std::string& name)
575{
Adrien Béraud024c46f2024-03-02 23:53:18 -0500576 std::lock_guard lk(pimpl_->socketsMutex);
Adrien Béraud55133cc2023-10-15 11:55:20 -0400577 if (pimpl_->sockets.size() < UINT16_MAX)
578 for (unsigned i = 0; i < UINT16_MAX; ++i) {
579 auto c = pimpl_->nextChannel_++;
580 if (c == CONTROL_CHANNEL
581 || c == PROTOCOL_CHANNEL
582 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
583 continue;
584 return pimpl_->makeSocket(name, c, true);
585 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400586 return {};
587}
588
589DeviceId
590MultiplexedSocket::deviceId() const
591{
592 return pimpl_->deviceId;
593}
594
595void
596MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
597{
598 pimpl_->onChannelReady_ = std::move(cb);
599}
600
601void
602MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
603{
604 pimpl_->onRequest_ = std::move(cb);
605}
606
607bool
608MultiplexedSocket::isReliable() const
609{
610 return true;
611}
612
613bool
614MultiplexedSocket::isInitiator() const
615{
616 if (!pimpl_->endpoint) {
617 if (pimpl_->logger_)
618 pimpl_->logger_->warn("No endpoint found for socket");
619 return false;
620 }
621 return pimpl_->endpoint->isInitiator();
622}
623
624int
625MultiplexedSocket::maxPayload() const
626{
627 if (!pimpl_->endpoint) {
628 if (pimpl_->logger_)
629 pimpl_->logger_->warn("No endpoint found for socket");
630 return 0;
631 }
632 return pimpl_->endpoint->maxPayload();
633}
634
635std::size_t
636MultiplexedSocket::write(const uint16_t& channel,
637 const uint8_t* buf,
638 std::size_t len,
639 std::error_code& ec)
640{
641 assert(nullptr != buf);
642
643 if (pimpl_->isShutdown_) {
644 ec = std::make_error_code(std::errc::broken_pipe);
645 return -1;
646 }
647 if (len > UINT16_MAX) {
648 ec = std::make_error_code(std::errc::message_size);
649 return -1;
650 }
651 bool oneShot = len < 8192;
652 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
653 msgpack::packer<msgpack::sbuffer> pk(&buffer);
654 pk.pack_array(2);
655 pk.pack(channel);
656 pk.pack_bin(len);
657 if (oneShot)
658 pk.pack_bin_body((const char*) buf, len);
659
Adrien Béraud024c46f2024-03-02 23:53:18 -0500660 std::unique_lock lk(pimpl_->writeMtx);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400661 if (!pimpl_->endpoint) {
662 if (pimpl_->logger_)
663 pimpl_->logger_->warn("No endpoint found for socket");
664 ec = std::make_error_code(std::errc::broken_pipe);
665 return -1;
666 }
667 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
668 if (not oneShot and res >= 0)
669 res = pimpl_->endpoint->write(buf, len, ec);
670 lk.unlock();
671 if (res < 0) {
672 if (ec && pimpl_->logger_)
673 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
674 shutdown();
675 }
676 return res;
677}
678
679void
680MultiplexedSocket::shutdown()
681{
682 pimpl_->shutdown();
683}
684
Adrien Béraud1aaaa962024-09-26 12:41:58 -0400685bool
686MultiplexedSocket::isRunning() const
687{
688 return pimpl_->isRunning();
689}
690
Adrien Béraud612b55b2023-05-29 10:42:04 -0400691void
692MultiplexedSocket::join()
693{
694 pimpl_->join();
695}
696
697void
698MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
699{
700 pimpl_->onShutdown_ = std::move(cb);
701 if (pimpl_->isShutdown_)
702 pimpl_->onShutdown_();
703}
704
705const std::shared_ptr<Logger>&
706MultiplexedSocket::logger()
707{
708 return pimpl_->logger_;
709}
710
711void
712MultiplexedSocket::monitor() const
713{
714 auto cert = peerCertificate();
715 if (!cert || !cert->issuer)
716 return;
717 auto now = clock::now();
718 if (!pimpl_->logger_)
719 return;
720 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
721 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
722 pimpl_->endpoint->monitor();
Adrien Béraud024c46f2024-03-02 23:53:18 -0500723 std::lock_guard lk(pimpl_->socketsMutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400724 for (const auto& [_, channel] : pimpl_->sockets) {
725 if (channel)
726 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
727 fmt::ptr(channel.get()),
728 channel.use_count(),
729 channel->name(),
730 channel->isInitiator());
731 }
732}
733
734void
735MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
736{
737 pimpl_->sendBeacon(timeout);
738}
739
740std::shared_ptr<dht::crypto::Certificate>
741MultiplexedSocket::peerCertificate() const
742{
743 return pimpl_->endpoint->peerCertificate();
744}
745
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400746#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400747bool
748MultiplexedSocket::canSendBeacon() const
749{
750 return pimpl_->canSendBeacon_;
751}
752
753void
754MultiplexedSocket::answerToBeacon(bool value)
755{
756 pimpl_->answerBeacon_ = value;
757}
758
759void
760MultiplexedSocket::setVersion(int version)
761{
762 pimpl_->version_ = version;
763}
764
765void
766MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
767{
768 pimpl_->onBeaconCb_ = cb;
769}
770
771void
772MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
773{
774 pimpl_->onVersionCb_ = cb;
775}
776
777void
778MultiplexedSocket::sendVersion()
779{
780 pimpl_->sendVersion();
781}
782
Adrien Béraudac35e662023-07-19 09:37:29 -0400783#endif
784
Adrien Béraud612b55b2023-05-29 10:42:04 -0400785IpAddr
786MultiplexedSocket::getLocalAddress() const
787{
788 return pimpl_->endpoint->getLocalAddress();
789}
790
791IpAddr
792MultiplexedSocket::getRemoteAddress() const
793{
794 return pimpl_->endpoint->getRemoteAddress();
795}
796
Adrien Béraudafa8e282023-09-24 12:53:20 -0400797TlsSocketEndpoint*
798MultiplexedSocket::endpoint()
799{
800 return pimpl_->endpoint.get();
801}
802
Adrien Béraud612b55b2023-05-29 10:42:04 -0400803void
804MultiplexedSocket::eraseChannel(uint16_t channel)
805{
Adrien Béraud024c46f2024-03-02 23:53:18 -0500806 std::lock_guard lkSockets(pimpl_->socketsMutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400807 auto itSocket = pimpl_->sockets.find(channel);
808 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
809 pimpl_->sockets.erase(itSocket);
810}
811
812////////////////////////////////////////////////////////////////
813
814class ChannelSocket::Impl
815{
816public:
817 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
818 const std::string& name,
819 const uint16_t& channel,
820 bool isInitiator,
821 std::function<void()> rmFromMxSockCb)
822 : name(name)
823 , channel(channel)
824 , endpoint(std::move(endpoint))
825 , isInitiator_(isInitiator)
826 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
827 {}
828
829 ~Impl() {}
830
831 ChannelReadyCb readyCb_ {};
832 OnShutdownCb shutdownCb_ {};
833 std::atomic_bool isShutdown_ {false};
Adrien Béraud5fd233f2023-10-15 12:35:48 -0400834 const std::string name {};
835 const uint16_t channel {};
836 const std::weak_ptr<MultiplexedSocket> endpoint {};
837 const bool isInitiator_ {false};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400838 std::function<void()> rmFromMxSockCb_;
839
840 bool isAnswered_ {false};
841 bool isRemovable_ {false};
842
843 std::vector<uint8_t> buf {};
844 std::mutex mutex {};
845 std::condition_variable cv {};
846 GenericSocket<uint8_t>::RecvCb cb {};
847};
848
849ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
850 const DeviceId& deviceId,
851 const std::string& name,
852 const uint16_t& channel)
853 : pimpl_deviceId(deviceId)
854 , pimpl_name(name)
855 , pimpl_channel(channel)
856 , ioCtx_(*ctx)
857{}
858
859ChannelSocketTest::~ChannelSocketTest() {}
860
861void
862ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
863 const std::shared_ptr<ChannelSocketTest>& socket2)
864{
865 socket1->remote = socket2;
866 socket2->remote = socket1;
867}
868
869DeviceId
870ChannelSocketTest::deviceId() const
871{
872 return pimpl_deviceId;
873}
874
875std::string
876ChannelSocketTest::name() const
877{
878 return pimpl_name;
879}
880
881uint16_t
882ChannelSocketTest::channel() const
883{
884 return pimpl_channel;
885}
886
887void
888ChannelSocketTest::shutdown()
889{
890 {
Adrien Béraud024c46f2024-03-02 23:53:18 -0500891 std::unique_lock lk {mutex};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400892 if (!isShutdown_.exchange(true)) {
893 lk.unlock();
894 shutdownCb_();
895 }
896 cv.notify_all();
897 }
898
899 if (auto peer = remote.lock()) {
900 if (!peer->isShutdown_.exchange(true)) {
901 peer->shutdownCb_();
902 }
903 peer->cv.notify_all();
904 }
905}
906
907std::size_t
908ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
909{
910 std::size_t size = std::min(len, this->rx_buf.size());
911
912 for (std::size_t i = 0; i < size; ++i)
913 buf[i] = this->rx_buf[i];
914
915 if (size == this->rx_buf.size()) {
916 this->rx_buf.clear();
917 } else
918 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
919 return size;
920}
921
922std::size_t
923ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
924{
925 if (isShutdown_) {
926 ec = std::make_error_code(std::errc::broken_pipe);
927 return -1;
928 }
929 ec = {};
930 dht::ThreadPool::computation().run(
931 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
932 if (auto peer = r.lock())
933 peer->onRecv(std::move(data));
934 });
935 return len;
936}
937
938int
939ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
940{
Adrien Béraud024c46f2024-03-02 23:53:18 -0500941 std::unique_lock lk {mutex};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400942 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
943 return rx_buf.size();
944}
945
946void
947ChannelSocketTest::setOnRecv(RecvCb&& cb)
948{
Adrien Béraud024c46f2024-03-02 23:53:18 -0500949 std::lock_guard lkSockets(mutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400950 this->cb = std::move(cb);
951 if (!rx_buf.empty() && this->cb) {
952 this->cb(rx_buf.data(), rx_buf.size());
953 rx_buf.clear();
954 }
955}
956
957void
958ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
959{
Adrien Béraud024c46f2024-03-02 23:53:18 -0500960 std::lock_guard lkSockets(mutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400961 if (cb) {
962 cb(pkt.data(), pkt.size());
963 return;
964 }
965 rx_buf.insert(rx_buf.end(),
966 std::make_move_iterator(pkt.begin()),
967 std::make_move_iterator(pkt.end()));
968 cv.notify_all();
969}
970
971void
972ChannelSocketTest::onReady(ChannelReadyCb&& cb)
973{}
974
975void
976ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
977{
Adrien Béraud024c46f2024-03-02 23:53:18 -0500978 std::unique_lock lk {mutex};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400979 shutdownCb_ = std::move(cb);
980
981 if (isShutdown_) {
982 lk.unlock();
983 shutdownCb_();
984 }
985}
986
987ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
988 const std::string& name,
989 const uint16_t& channel,
990 bool isInitiator,
991 std::function<void()> rmFromMxSockCb)
992 : pimpl_ {
993 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
994{}
995
996ChannelSocket::~ChannelSocket() {}
997
998DeviceId
999ChannelSocket::deviceId() const
1000{
1001 if (auto ep = pimpl_->endpoint.lock()) {
1002 return ep->deviceId();
1003 }
1004 return {};
1005}
1006
1007std::string
1008ChannelSocket::name() const
1009{
1010 return pimpl_->name;
1011}
1012
1013uint16_t
1014ChannelSocket::channel() const
1015{
1016 return pimpl_->channel;
1017}
1018
1019bool
1020ChannelSocket::isReliable() const
1021{
1022 if (auto ep = pimpl_->endpoint.lock()) {
1023 return ep->isReliable();
1024 }
1025 return false;
1026}
1027
1028bool
1029ChannelSocket::isInitiator() const
1030{
1031 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1032 // because a multiplexed socket can have sockets from accepted requests
1033 // or made via connectDevice(). Here, isInitiator_ return if the socket
1034 // is from connectDevice.
1035 return pimpl_->isInitiator_;
1036}
1037
1038int
1039ChannelSocket::maxPayload() const
1040{
1041 if (auto ep = pimpl_->endpoint.lock()) {
1042 return ep->maxPayload();
1043 }
1044 return -1;
1045}
1046
1047void
1048ChannelSocket::setOnRecv(RecvCb&& cb)
1049{
Adrien Béraud024c46f2024-03-02 23:53:18 -05001050 std::lock_guard lkSockets(pimpl_->mutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001051 pimpl_->cb = std::move(cb);
1052 if (!pimpl_->buf.empty() && pimpl_->cb) {
1053 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1054 pimpl_->buf.clear();
1055 }
1056}
1057
1058void
1059ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1060{
Adrien Béraud024c46f2024-03-02 23:53:18 -05001061 std::lock_guard lkSockets(pimpl_->mutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001062 if (pimpl_->cb) {
1063 pimpl_->cb(&pkt[0], pkt.size());
1064 return;
1065 }
1066 pimpl_->buf.insert(pimpl_->buf.end(),
1067 std::make_move_iterator(pkt.begin()),
1068 std::make_move_iterator(pkt.end()));
1069 pimpl_->cv.notify_all();
1070}
1071
Adrien Béraud6b6a5d32023-08-15 15:53:33 -04001072#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -04001073std::shared_ptr<MultiplexedSocket>
1074ChannelSocket::underlyingSocket() const
1075{
1076 if (auto mtx = pimpl_->endpoint.lock())
1077 return mtx;
1078 return {};
1079}
1080#endif
1081
1082void
1083ChannelSocket::answered()
1084{
1085 pimpl_->isAnswered_ = true;
1086}
1087
1088void
1089ChannelSocket::removable()
1090{
1091 pimpl_->isRemovable_ = true;
1092}
1093
1094bool
1095ChannelSocket::isRemovable() const
1096{
1097 return pimpl_->isRemovable_;
1098}
1099
1100bool
1101ChannelSocket::isAnswered() const
1102{
1103 return pimpl_->isAnswered_;
1104}
1105
1106void
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001107ChannelSocket::ready(bool accepted)
Adrien Béraud612b55b2023-05-29 10:42:04 -04001108{
1109 if (pimpl_->readyCb_)
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001110 pimpl_->readyCb_(accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001111}
1112
1113void
1114ChannelSocket::stop()
1115{
1116 if (pimpl_->isShutdown_)
1117 return;
1118 pimpl_->isShutdown_ = true;
1119 if (pimpl_->shutdownCb_)
1120 pimpl_->shutdownCb_();
1121 pimpl_->cv.notify_all();
1122 // stop() can be called by ChannelSocket::shutdown()
1123 // In this case, the eventLoop is not used, but MxSock
1124 // must remove the channel from its list (so that the
1125 // channel can be destroyed and its shared_ptr invalidated).
1126 if (pimpl_->rmFromMxSockCb_)
1127 pimpl_->rmFromMxSockCb_();
1128}
1129
1130void
1131ChannelSocket::shutdown()
1132{
1133 if (pimpl_->isShutdown_)
1134 return;
1135 stop();
1136 if (auto ep = pimpl_->endpoint.lock()) {
1137 std::error_code ec;
1138 const uint8_t dummy = '\0';
1139 ep->write(pimpl_->channel, &dummy, 0, ec);
1140 }
1141}
1142
1143std::size_t
1144ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1145{
Adrien Béraud024c46f2024-03-02 23:53:18 -05001146 std::lock_guard lkSockets(pimpl_->mutex);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001147 std::size_t size = std::min(len, pimpl_->buf.size());
1148
1149 for (std::size_t i = 0; i < size; ++i)
1150 outBuf[i] = pimpl_->buf[i];
1151
1152 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1153 return size;
1154}
1155
1156std::size_t
1157ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1158{
1159 if (pimpl_->isShutdown_) {
1160 ec = std::make_error_code(std::errc::broken_pipe);
1161 return -1;
1162 }
1163 if (auto ep = pimpl_->endpoint.lock()) {
1164 std::size_t sent = 0;
1165 do {
1166 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1167 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1168 if (ec) {
1169 if (ep->logger())
1170 ep->logger()->error("Error when writing on channel: {}", ec.message());
1171 return res;
1172 }
1173 sent += toSend;
1174 } while (sent < len);
1175 return sent;
1176 }
1177 ec = std::make_error_code(std::errc::broken_pipe);
1178 return -1;
1179}
1180
1181int
1182ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1183{
Adrien Béraud024c46f2024-03-02 23:53:18 -05001184 std::unique_lock lk {pimpl_->mutex};
Adrien Béraud612b55b2023-05-29 10:42:04 -04001185 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1186 return pimpl_->buf.size();
1187}
1188
1189void
1190ChannelSocket::onShutdown(OnShutdownCb&& cb)
1191{
1192 pimpl_->shutdownCb_ = std::move(cb);
1193 if (pimpl_->isShutdown_) {
1194 pimpl_->shutdownCb_();
1195 }
1196}
1197
1198void
1199ChannelSocket::onReady(ChannelReadyCb&& cb)
1200{
1201 pimpl_->readyCb_ = std::move(cb);
1202}
1203
1204void
1205ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1206{
1207 if (auto ep = pimpl_->endpoint.lock()) {
1208 ep->sendBeacon(timeout);
1209 } else {
1210 shutdown();
1211 }
1212}
1213
1214std::shared_ptr<dht::crypto::Certificate>
1215ChannelSocket::peerCertificate() const
1216{
1217 if (auto ep = pimpl_->endpoint.lock())
1218 return ep->peerCertificate();
1219 return {};
1220}
1221
1222IpAddr
1223ChannelSocket::getLocalAddress() const
1224{
1225 if (auto ep = pimpl_->endpoint.lock())
1226 return ep->getLocalAddress();
1227 return {};
1228}
1229
1230IpAddr
1231ChannelSocket::getRemoteAddress() const
1232{
1233 if (auto ep = pimpl_->endpoint.lock())
1234 return ep->getRemoteAddress();
1235 return {};
1236}
1237
Amna31791e52023-08-03 12:40:57 -04001238std::vector<std::map<std::string, std::string>>
1239MultiplexedSocket::getChannelList() const
1240{
Adrien Béraud024c46f2024-03-02 23:53:18 -05001241 std::lock_guard lkSockets(pimpl_->socketsMutex);
Amna31791e52023-08-03 12:40:57 -04001242 std::vector<std::map<std::string, std::string>> channelsList;
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001243 channelsList.reserve(pimpl_->sockets.size());
Amna31791e52023-08-03 12:40:57 -04001244 for (const auto& [_, channel] : pimpl_->sockets) {
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001245 channelsList.emplace_back(std::map<std::string, std::string> {
1246 {"id", fmt::format("{:x}", channel->channel())},
1247 {"name", channel->name()},
1248 });
Amna31791e52023-08-03 12:40:57 -04001249 }
Amna31791e52023-08-03 12:40:57 -04001250 return channelsList;
1251}
1252
Sébastien Blin464bdff2023-07-19 08:02:53 -04001253} // namespace dhtnet