blob: f95a1b0d66c9c5c3c0c5077e1f14b4a692e6c204 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
Adrien Béraudcb753622023-07-17 22:32:49 -04002 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
Adrien Béraud612b55b2023-05-29 10:42:04 -04003 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "multiplexed_socket.h"
18#include "peer_connection.h"
19#include "ice_transport.h"
20#include "certstore.h"
21
22#include <opendht/logger.h>
23#include <opendht/thread_pool.h>
24
25#include <asio/io_context.hpp>
26#include <asio/steady_timer.hpp>
27
28#include <deque>
29
30static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
31static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
32
33struct ChanneledMessage
34{
35 uint16_t channel;
36 std::vector<uint8_t> data;
37 MSGPACK_DEFINE(channel, data)
38};
39
40struct BeaconMsg
41{
42 bool p;
43 MSGPACK_DEFINE_MAP(p)
44};
45
46struct VersionMsg
47{
48 int v;
49 MSGPACK_DEFINE_MAP(v)
50};
51
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040052namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040053
54using clock = std::chrono::steady_clock;
55using time_point = clock::time_point;
56
57class MultiplexedSocket::Impl
58{
59public:
60 Impl(MultiplexedSocket& parent,
61 std::shared_ptr<asio::io_context> ctx,
62 const DeviceId& deviceId,
Adrien Béraud55133cc2023-10-15 11:55:20 -040063 std::unique_ptr<TlsSocketEndpoint> ep,
Adrien Béraud5636f7c2023-09-14 14:34:57 -040064 std::shared_ptr<dht::log::Logger> logger)
Adrien Béraud612b55b2023-05-29 10:42:04 -040065 : parent_(parent)
Adrien Béraud612b55b2023-05-29 10:42:04 -040066 , ctx_(std::move(ctx))
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040067 , deviceId(deviceId)
Adrien Béraud55133cc2023-10-15 11:55:20 -040068 , endpoint(std::move(ep))
69 , nextChannel_(endpoint->isInitiator() ? 0x0001u : 0x8000u)
Adrien Béraud18d344d2024-03-02 21:39:22 -050070 , logger_(std::move(logger))
Adrien Béraud612b55b2023-05-29 10:42:04 -040071 , eventLoopThread_ {[this] {
72 try {
73 eventLoop();
74 } catch (const std::exception& e) {
75 if (logger_)
76 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
77 shutdown();
78 }
79 }}
Adrien Béraudd78d1ac2023-08-25 10:43:33 -040080 , beaconTimer_(*ctx_)
Adrien Béraud612b55b2023-05-29 10:42:04 -040081 {}
82
83 ~Impl() {}
84
85 void join()
86 {
87 if (!isShutdown_) {
88 if (endpoint)
89 endpoint->setOnStateChange({});
90 shutdown();
91 } else {
92 clearSockets();
93 }
94 if (eventLoopThread_.joinable())
95 eventLoopThread_.join();
96 }
97
98 void clearSockets()
99 {
100 decltype(sockets) socks;
101 {
102 std::lock_guard<std::mutex> lkSockets(socketsMutex);
103 socks = std::move(sockets);
104 }
105 for (auto& socket : socks) {
106 // Just trigger onShutdown() to make client know
107 // No need to write the EOF for the channel, the write will fail because endpoint is
108 // already shutdown
109 if (socket.second)
110 socket.second->stop();
111 }
112 }
113
114 void shutdown()
115 {
116 if (isShutdown_)
117 return;
118 stop.store(true);
119 isShutdown_ = true;
120 beaconTimer_.cancel();
121 if (onShutdown_)
122 onShutdown_();
123 if (endpoint) {
124 std::unique_lock<std::mutex> lk(writeMtx);
125 endpoint->shutdown();
126 }
127 clearSockets();
128 }
129
130 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
131 uint16_t channel,
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400132 bool isInitiator)
Adrien Béraud612b55b2023-05-29 10:42:04 -0400133 {
134 auto& channelSocket = sockets[channel];
135 if (not channelSocket)
136 channelSocket = std::make_shared<ChannelSocket>(
137 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
138 // Remove socket in another thread to avoid any lock
139 dht::ThreadPool::io().run([w, channel]() {
140 if (auto shared = w.lock()) {
141 shared->eraseChannel(channel);
142 }
143 });
144 });
145 else {
146 if (logger_)
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400147 logger_->warn("Received request for existing channel {}", channel);
148 return {};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400149 }
150 return channelSocket;
151 }
152
153 /**
154 * Handle packets on the TLS endpoint and parse RTP
155 */
156 void eventLoop();
157 /**
158 * Triggered when a new control packet is received
159 */
160 void handleControlPacket(std::vector<uint8_t>&& pkt);
161 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
162 bool handleProtocolMsg(const msgpack::object& o);
163 /**
164 * Triggered when a new packet on a channel is received
165 */
166 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
167 void onRequest(const std::string& name, uint16_t channel);
168 void onAccept(const std::string& name, uint16_t channel);
169
170 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
171 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
172
173 // Beacon
174 void sendBeacon(const std::chrono::milliseconds& timeout);
175 void handleBeaconRequest();
176 void handleBeaconResponse();
177 std::atomic_int beaconCounter_ {0};
178
179 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
180
181 msgpack::unpacker pac_ {};
182
183 MultiplexedSocket& parent_;
184
185 std::shared_ptr<Logger> logger_;
186 std::shared_ptr<asio::io_context> ctx_;
187
188 OnConnectionReadyCb onChannelReady_ {};
189 OnConnectionRequestCb onRequest_ {};
190 OnShutdownCb onShutdown_ {};
191
192 DeviceId deviceId {};
193 // Main socket
194 std::unique_ptr<TlsSocketEndpoint> endpoint {};
195
196 std::mutex socketsMutex {};
197 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
Adrien Béraud55133cc2023-10-15 11:55:20 -0400198 uint16_t nextChannel_;
Adrien Béraud612b55b2023-05-29 10:42:04 -0400199
200 // Main loop to parse incoming packets
201 std::atomic_bool stop {false};
202 std::thread eventLoopThread_ {};
203
204 std::atomic_bool isShutdown_ {false};
205
206 std::mutex writeMtx {};
207
208 time_point start_ {clock::now()};
209 //std::shared_ptr<Task> beaconTask_ {};
210 asio::steady_timer beaconTimer_;
211
212 // version related stuff
213 void sendVersion();
214 void onVersion(int version);
215 std::atomic_bool canSendBeacon_ {false};
216 std::atomic_bool answerBeacon_ {true};
217 int version_ {MULTIPLEXED_SOCKET_VERSION};
218 std::function<void(bool)> onBeaconCb_ {};
219 std::function<void(int)> onVersionCb_ {};
220};
221
222void
223MultiplexedSocket::Impl::eventLoop()
224{
225 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
226 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
227 if (logger_)
228 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
229 shutdown();
230 return false;
231 }
232 return true;
233 });
234 sendVersion();
235 std::error_code ec;
236 while (!stop) {
237 if (!endpoint) {
238 shutdown();
239 return;
240 }
241 pac_.reserve_buffer(IO_BUFFER_SIZE);
242 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
243 if (size < 0) {
244 if (ec && logger_)
245 logger_->error("Read error detected: {}", ec.message());
246 break;
247 }
248 if (size == 0) {
249 // We can close the socket
250 shutdown();
251 break;
252 }
253
254 pac_.buffer_consumed(size);
255 msgpack::object_handle oh;
256 while (pac_.next(oh) && !stop) {
257 try {
258 auto msg = oh.get().as<ChanneledMessage>();
259 if (msg.channel == CONTROL_CHANNEL)
260 handleControlPacket(std::move(msg.data));
261 else if (msg.channel == PROTOCOL_CHANNEL)
262 handleProtocolPacket(std::move(msg.data));
263 else
264 handleChannelPacket(msg.channel, std::move(msg.data));
265 } catch (const std::exception& e) {
266 if (logger_)
267 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
268 } catch (...) {
269 if (logger_)
270 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
271 }
272 }
273 }
274}
275
276void
277MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
278{
279 std::lock_guard<std::mutex> lkSockets(socketsMutex);
280 auto& socket = sockets[channel];
281 if (!socket) {
282 if (logger_)
283 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
284 return;
285 }
286
287 onChannelReady_(deviceId, socket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400288 socket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400289 // Due to the callbacks that can take some time, onAccept can arrive after
290 // receiving all the data. In this case, the socket should be removed here
291 // as handle by onChannelReady_
292 if (socket->isRemovable())
293 sockets.erase(channel);
294 else
295 socket->answered();
296}
297
298void
299MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
300{
301 if (!canSendBeacon_)
302 return;
303 beaconCounter_++;
304 if (logger_)
305 logger_->debug("Send beacon to peer {}", deviceId);
306
307 msgpack::sbuffer buffer(8);
308 msgpack::packer<msgpack::sbuffer> pk(&buffer);
309 pk.pack(BeaconMsg {true});
310 if (!writeProtocolMessage(buffer))
311 return;
312 beaconTimer_.expires_after(timeout);
313 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
314 if (ec == asio::error::operation_aborted)
315 return;
316 if (auto shared = w.lock()) {
317 if (shared->pimpl_->beaconCounter_ != 0) {
318 if (shared->pimpl_->logger_)
319 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
320 shared->shutdown();
321 }
322 }
323 });
324}
325
326void
327MultiplexedSocket::Impl::handleBeaconRequest()
328{
329 if (!answerBeacon_)
330 return;
331 // Run this on dedicated thread because some callbacks can take time
332 dht::ThreadPool::io().run([w = parent_.weak()]() {
333 if (auto shared = w.lock()) {
334 msgpack::sbuffer buffer(8);
335 msgpack::packer<msgpack::sbuffer> pk(&buffer);
336 pk.pack(BeaconMsg {false});
337 if (shared->pimpl_->logger_)
338 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
339 shared->pimpl_->writeProtocolMessage(buffer);
340 }
341 });
342}
343
344void
345MultiplexedSocket::Impl::handleBeaconResponse()
346{
347 if (logger_)
348 logger_->debug("Get beacon response from peer {}", deviceId);
349 beaconCounter_--;
350}
351
352bool
353MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
354{
355 std::error_code ec;
356 int wr = parent_.write(PROTOCOL_CHANNEL,
357 (const unsigned char*) buffer.data(),
358 buffer.size(),
359 ec);
360 return wr > 0;
361}
362
363void
364MultiplexedSocket::Impl::sendVersion()
365{
366 dht::ThreadPool::io().run([w = parent_.weak()]() {
367 if (auto shared = w.lock()) {
368 auto version = shared->pimpl_->version_;
369 msgpack::sbuffer buffer(8);
370 msgpack::packer<msgpack::sbuffer> pk(&buffer);
371 pk.pack(VersionMsg {version});
372 shared->pimpl_->writeProtocolMessage(buffer);
373 }
374 });
375}
376
377void
378MultiplexedSocket::Impl::onVersion(int version)
379{
380 // Check if version > 1
381 if (version >= 1) {
382 if (logger_)
383 logger_->debug("Peer {} supports beacon", deviceId);
384 canSendBeacon_ = true;
385 } else {
386 if (logger_)
387 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
388 deviceId,
389 version);
390 canSendBeacon_ = false;
391 }
392}
393
394void
395MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
396{
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400397 bool accept;
398 if (channel == CONTROL_CHANNEL || channel == PROTOCOL_CHANNEL) {
399 if (logger_)
400 logger_->warn("Channel {:d} is reserved, refusing request", channel);
401 accept = false;
402 } else
403 accept = onRequest_(endpoint->peerCertificate(), channel, name);
404
Adrien Béraud612b55b2023-05-29 10:42:04 -0400405 std::shared_ptr<ChannelSocket> channelSocket;
406 if (accept) {
407 std::lock_guard<std::mutex> lkSockets(socketsMutex);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400408 channelSocket = makeSocket(name, channel, false);
409 if (not channelSocket) {
410 if (logger_)
411 logger_->error("Channel {:d} already exists, refusing request", channel);
412 accept = false;
413 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400414 }
415
416 // Answer to ChannelRequest if accepted
417 ChannelRequest val;
418 val.channel = channel;
419 val.name = name;
420 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
421 msgpack::sbuffer buffer(512);
422 msgpack::pack(buffer, val);
423 std::error_code ec;
424 int wr = parent_.write(CONTROL_CHANNEL,
425 reinterpret_cast<const uint8_t*>(buffer.data()),
426 buffer.size(),
427 ec);
428 if (wr < 0) {
429 if (ec && logger_)
430 logger_->error("The write operation failed with error: {:s}", ec.message());
431 stop.store(true);
432 return;
433 }
434
435 if (accept) {
436 onChannelReady_(deviceId, channelSocket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400437 channelSocket->ready(true);
Adrien Béraudf4b34e82023-10-16 12:02:23 -0400438 if (channelSocket->isRemovable()) {
439 std::lock_guard<std::mutex> lkSockets(socketsMutex);
440 sockets.erase(channel);
441 } else
442 channelSocket->answered();
Adrien Béraud612b55b2023-05-29 10:42:04 -0400443 }
444}
445
446void
447MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
448{
Adrien Béraudcd7a7bd2023-10-15 12:54:30 -0400449 try {
450 size_t off = 0;
451 while (off != pkt.size()) {
452 msgpack::unpacked result;
453 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
454 auto object = result.get();
455 if (handleProtocolMsg(object))
456 continue;
457 auto req = object.as<ChannelRequest>();
458 if (req.state == ChannelRequestState::REQUEST) {
459 dht::ThreadPool::io().run([w = parent_.weak(), req = std::move(req)]() {
460 if (auto shared = w.lock())
461 shared->pimpl_->onRequest(req.name, req.channel);
462 });
463 }
464 else if (req.state == ChannelRequestState::ACCEPT) {
465 onAccept(req.name, req.channel);
466 } else {
467 // DECLINE or unknown
468 std::lock_guard<std::mutex> lkSockets(socketsMutex);
469 auto channel = sockets.find(req.channel);
470 if (channel != sockets.end()) {
471 channel->second->ready(false);
472 channel->second->stop();
473 sockets.erase(channel);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400474 }
475 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400476 }
Adrien Béraudcd7a7bd2023-10-15 12:54:30 -0400477 } catch (const std::exception& e) {
478 if (logger_)
479 logger_->error("Error on the control channel: {}", e.what());
480 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400481}
482
483void
484MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
485{
486 std::lock_guard<std::mutex> lkSockets(socketsMutex);
487 auto sockIt = sockets.find(channel);
488 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
489 if (pkt.size() == 0) {
490 sockIt->second->stop();
491 if (sockIt->second->isAnswered())
492 sockets.erase(sockIt);
493 else
494 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
495 // removed later.
496 } else {
497 sockIt->second->onRecv(std::move(pkt));
498 }
499 } else if (pkt.size() != 0) {
500 if (logger_)
501 logger_->warn("Non existing channel: {}", channel);
502 }
503}
504
505bool
506MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
507{
508 try {
509 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
510 auto key = o.via.map.ptr[0].key.as<std::string_view>();
511 if (key == "p") {
512 auto msg = o.as<BeaconMsg>();
513 if (msg.p)
514 handleBeaconRequest();
515 else
516 handleBeaconResponse();
517 if (onBeaconCb_)
518 onBeaconCb_(msg.p);
519 return true;
520 } else if (key == "v") {
521 auto msg = o.as<VersionMsg>();
522 onVersion(msg.v);
523 if (onVersionCb_)
524 onVersionCb_(msg.v);
525 return true;
526 } else {
527 if (logger_)
528 logger_->warn("Unknown message type");
529 }
530 }
531 } catch (const std::exception& e) {
532 if (logger_)
533 logger_->error("Error on the protocol channel: {}", e.what());
534 }
535 return false;
536}
537
538void
539MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
540{
541 // Run this on dedicated thread because some callbacks can take time
542 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
543 auto shared = w.lock();
544 if (!shared)
545 return;
546 try {
547 size_t off = 0;
548 while (off != pkt.size()) {
549 msgpack::unpacked result;
550 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
551 auto object = result.get();
552 if (shared->pimpl_->handleProtocolMsg(object))
553 return;
554 }
555 } catch (const std::exception& e) {
556 if (shared->pimpl_->logger_)
557 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
558 }
559 });
560}
561
562MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
Adrien Béraud5636f7c2023-09-14 14:34:57 -0400563 std::unique_ptr<TlsSocketEndpoint> endpoint, std::shared_ptr<dht::log::Logger> logger)
564 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint), logger))
Adrien Béraud612b55b2023-05-29 10:42:04 -0400565{}
566
567MultiplexedSocket::~MultiplexedSocket() {}
568
569std::shared_ptr<ChannelSocket>
570MultiplexedSocket::addChannel(const std::string& name)
571{
Adrien Béraud612b55b2023-05-29 10:42:04 -0400572 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
Adrien Béraud55133cc2023-10-15 11:55:20 -0400573 if (pimpl_->sockets.size() < UINT16_MAX)
574 for (unsigned i = 0; i < UINT16_MAX; ++i) {
575 auto c = pimpl_->nextChannel_++;
576 if (c == CONTROL_CHANNEL
577 || c == PROTOCOL_CHANNEL
578 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
579 continue;
580 return pimpl_->makeSocket(name, c, true);
581 }
Adrien Béraud612b55b2023-05-29 10:42:04 -0400582 return {};
583}
584
585DeviceId
586MultiplexedSocket::deviceId() const
587{
588 return pimpl_->deviceId;
589}
590
591void
592MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
593{
594 pimpl_->onChannelReady_ = std::move(cb);
595}
596
597void
598MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
599{
600 pimpl_->onRequest_ = std::move(cb);
601}
602
603bool
604MultiplexedSocket::isReliable() const
605{
606 return true;
607}
608
609bool
610MultiplexedSocket::isInitiator() const
611{
612 if (!pimpl_->endpoint) {
613 if (pimpl_->logger_)
614 pimpl_->logger_->warn("No endpoint found for socket");
615 return false;
616 }
617 return pimpl_->endpoint->isInitiator();
618}
619
620int
621MultiplexedSocket::maxPayload() const
622{
623 if (!pimpl_->endpoint) {
624 if (pimpl_->logger_)
625 pimpl_->logger_->warn("No endpoint found for socket");
626 return 0;
627 }
628 return pimpl_->endpoint->maxPayload();
629}
630
631std::size_t
632MultiplexedSocket::write(const uint16_t& channel,
633 const uint8_t* buf,
634 std::size_t len,
635 std::error_code& ec)
636{
637 assert(nullptr != buf);
638
639 if (pimpl_->isShutdown_) {
640 ec = std::make_error_code(std::errc::broken_pipe);
641 return -1;
642 }
643 if (len > UINT16_MAX) {
644 ec = std::make_error_code(std::errc::message_size);
645 return -1;
646 }
647 bool oneShot = len < 8192;
648 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
649 msgpack::packer<msgpack::sbuffer> pk(&buffer);
650 pk.pack_array(2);
651 pk.pack(channel);
652 pk.pack_bin(len);
653 if (oneShot)
654 pk.pack_bin_body((const char*) buf, len);
655
656 std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
657 if (!pimpl_->endpoint) {
658 if (pimpl_->logger_)
659 pimpl_->logger_->warn("No endpoint found for socket");
660 ec = std::make_error_code(std::errc::broken_pipe);
661 return -1;
662 }
663 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
664 if (not oneShot and res >= 0)
665 res = pimpl_->endpoint->write(buf, len, ec);
666 lk.unlock();
667 if (res < 0) {
668 if (ec && pimpl_->logger_)
669 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
670 shutdown();
671 }
672 return res;
673}
674
675void
676MultiplexedSocket::shutdown()
677{
678 pimpl_->shutdown();
679}
680
681void
682MultiplexedSocket::join()
683{
684 pimpl_->join();
685}
686
687void
688MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
689{
690 pimpl_->onShutdown_ = std::move(cb);
691 if (pimpl_->isShutdown_)
692 pimpl_->onShutdown_();
693}
694
695const std::shared_ptr<Logger>&
696MultiplexedSocket::logger()
697{
698 return pimpl_->logger_;
699}
700
701void
702MultiplexedSocket::monitor() const
703{
704 auto cert = peerCertificate();
705 if (!cert || !cert->issuer)
706 return;
707 auto now = clock::now();
708 if (!pimpl_->logger_)
709 return;
710 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
711 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
712 pimpl_->endpoint->monitor();
713 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
714 for (const auto& [_, channel] : pimpl_->sockets) {
715 if (channel)
716 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
717 fmt::ptr(channel.get()),
718 channel.use_count(),
719 channel->name(),
720 channel->isInitiator());
721 }
722}
723
724void
725MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
726{
727 pimpl_->sendBeacon(timeout);
728}
729
730std::shared_ptr<dht::crypto::Certificate>
731MultiplexedSocket::peerCertificate() const
732{
733 return pimpl_->endpoint->peerCertificate();
734}
735
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400736#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400737bool
738MultiplexedSocket::canSendBeacon() const
739{
740 return pimpl_->canSendBeacon_;
741}
742
743void
744MultiplexedSocket::answerToBeacon(bool value)
745{
746 pimpl_->answerBeacon_ = value;
747}
748
749void
750MultiplexedSocket::setVersion(int version)
751{
752 pimpl_->version_ = version;
753}
754
755void
756MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
757{
758 pimpl_->onBeaconCb_ = cb;
759}
760
761void
762MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
763{
764 pimpl_->onVersionCb_ = cb;
765}
766
767void
768MultiplexedSocket::sendVersion()
769{
770 pimpl_->sendVersion();
771}
772
Adrien Béraudac35e662023-07-19 09:37:29 -0400773#endif
774
Adrien Béraud612b55b2023-05-29 10:42:04 -0400775IpAddr
776MultiplexedSocket::getLocalAddress() const
777{
778 return pimpl_->endpoint->getLocalAddress();
779}
780
781IpAddr
782MultiplexedSocket::getRemoteAddress() const
783{
784 return pimpl_->endpoint->getRemoteAddress();
785}
786
Adrien Béraudafa8e282023-09-24 12:53:20 -0400787TlsSocketEndpoint*
788MultiplexedSocket::endpoint()
789{
790 return pimpl_->endpoint.get();
791}
792
Adrien Béraud612b55b2023-05-29 10:42:04 -0400793void
794MultiplexedSocket::eraseChannel(uint16_t channel)
795{
796 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
797 auto itSocket = pimpl_->sockets.find(channel);
798 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
799 pimpl_->sockets.erase(itSocket);
800}
801
802////////////////////////////////////////////////////////////////
803
804class ChannelSocket::Impl
805{
806public:
807 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
808 const std::string& name,
809 const uint16_t& channel,
810 bool isInitiator,
811 std::function<void()> rmFromMxSockCb)
812 : name(name)
813 , channel(channel)
814 , endpoint(std::move(endpoint))
815 , isInitiator_(isInitiator)
816 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
817 {}
818
819 ~Impl() {}
820
821 ChannelReadyCb readyCb_ {};
822 OnShutdownCb shutdownCb_ {};
823 std::atomic_bool isShutdown_ {false};
Adrien Béraud5fd233f2023-10-15 12:35:48 -0400824 const std::string name {};
825 const uint16_t channel {};
826 const std::weak_ptr<MultiplexedSocket> endpoint {};
827 const bool isInitiator_ {false};
Adrien Béraud612b55b2023-05-29 10:42:04 -0400828 std::function<void()> rmFromMxSockCb_;
829
830 bool isAnswered_ {false};
831 bool isRemovable_ {false};
832
833 std::vector<uint8_t> buf {};
834 std::mutex mutex {};
835 std::condition_variable cv {};
836 GenericSocket<uint8_t>::RecvCb cb {};
837};
838
839ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
840 const DeviceId& deviceId,
841 const std::string& name,
842 const uint16_t& channel)
843 : pimpl_deviceId(deviceId)
844 , pimpl_name(name)
845 , pimpl_channel(channel)
846 , ioCtx_(*ctx)
847{}
848
849ChannelSocketTest::~ChannelSocketTest() {}
850
851void
852ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
853 const std::shared_ptr<ChannelSocketTest>& socket2)
854{
855 socket1->remote = socket2;
856 socket2->remote = socket1;
857}
858
859DeviceId
860ChannelSocketTest::deviceId() const
861{
862 return pimpl_deviceId;
863}
864
865std::string
866ChannelSocketTest::name() const
867{
868 return pimpl_name;
869}
870
871uint16_t
872ChannelSocketTest::channel() const
873{
874 return pimpl_channel;
875}
876
877void
878ChannelSocketTest::shutdown()
879{
880 {
881 std::unique_lock<std::mutex> lk {mutex};
882 if (!isShutdown_.exchange(true)) {
883 lk.unlock();
884 shutdownCb_();
885 }
886 cv.notify_all();
887 }
888
889 if (auto peer = remote.lock()) {
890 if (!peer->isShutdown_.exchange(true)) {
891 peer->shutdownCb_();
892 }
893 peer->cv.notify_all();
894 }
895}
896
897std::size_t
898ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
899{
900 std::size_t size = std::min(len, this->rx_buf.size());
901
902 for (std::size_t i = 0; i < size; ++i)
903 buf[i] = this->rx_buf[i];
904
905 if (size == this->rx_buf.size()) {
906 this->rx_buf.clear();
907 } else
908 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
909 return size;
910}
911
912std::size_t
913ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
914{
915 if (isShutdown_) {
916 ec = std::make_error_code(std::errc::broken_pipe);
917 return -1;
918 }
919 ec = {};
920 dht::ThreadPool::computation().run(
921 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
922 if (auto peer = r.lock())
923 peer->onRecv(std::move(data));
924 });
925 return len;
926}
927
928int
929ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
930{
931 std::unique_lock<std::mutex> lk {mutex};
932 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
933 return rx_buf.size();
934}
935
936void
937ChannelSocketTest::setOnRecv(RecvCb&& cb)
938{
939 std::lock_guard<std::mutex> lkSockets(mutex);
940 this->cb = std::move(cb);
941 if (!rx_buf.empty() && this->cb) {
942 this->cb(rx_buf.data(), rx_buf.size());
943 rx_buf.clear();
944 }
945}
946
947void
948ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
949{
950 std::lock_guard<std::mutex> lkSockets(mutex);
951 if (cb) {
952 cb(pkt.data(), pkt.size());
953 return;
954 }
955 rx_buf.insert(rx_buf.end(),
956 std::make_move_iterator(pkt.begin()),
957 std::make_move_iterator(pkt.end()));
958 cv.notify_all();
959}
960
961void
962ChannelSocketTest::onReady(ChannelReadyCb&& cb)
963{}
964
965void
966ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
967{
968 std::unique_lock<std::mutex> lk {mutex};
969 shutdownCb_ = std::move(cb);
970
971 if (isShutdown_) {
972 lk.unlock();
973 shutdownCb_();
974 }
975}
976
977ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
978 const std::string& name,
979 const uint16_t& channel,
980 bool isInitiator,
981 std::function<void()> rmFromMxSockCb)
982 : pimpl_ {
983 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
984{}
985
986ChannelSocket::~ChannelSocket() {}
987
988DeviceId
989ChannelSocket::deviceId() const
990{
991 if (auto ep = pimpl_->endpoint.lock()) {
992 return ep->deviceId();
993 }
994 return {};
995}
996
997std::string
998ChannelSocket::name() const
999{
1000 return pimpl_->name;
1001}
1002
1003uint16_t
1004ChannelSocket::channel() const
1005{
1006 return pimpl_->channel;
1007}
1008
1009bool
1010ChannelSocket::isReliable() const
1011{
1012 if (auto ep = pimpl_->endpoint.lock()) {
1013 return ep->isReliable();
1014 }
1015 return false;
1016}
1017
1018bool
1019ChannelSocket::isInitiator() const
1020{
1021 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1022 // because a multiplexed socket can have sockets from accepted requests
1023 // or made via connectDevice(). Here, isInitiator_ return if the socket
1024 // is from connectDevice.
1025 return pimpl_->isInitiator_;
1026}
1027
1028int
1029ChannelSocket::maxPayload() const
1030{
1031 if (auto ep = pimpl_->endpoint.lock()) {
1032 return ep->maxPayload();
1033 }
1034 return -1;
1035}
1036
1037void
1038ChannelSocket::setOnRecv(RecvCb&& cb)
1039{
1040 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1041 pimpl_->cb = std::move(cb);
1042 if (!pimpl_->buf.empty() && pimpl_->cb) {
1043 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1044 pimpl_->buf.clear();
1045 }
1046}
1047
1048void
1049ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1050{
1051 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1052 if (pimpl_->cb) {
1053 pimpl_->cb(&pkt[0], pkt.size());
1054 return;
1055 }
1056 pimpl_->buf.insert(pimpl_->buf.end(),
1057 std::make_move_iterator(pkt.begin()),
1058 std::make_move_iterator(pkt.end()));
1059 pimpl_->cv.notify_all();
1060}
1061
Adrien Béraud6b6a5d32023-08-15 15:53:33 -04001062#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -04001063std::shared_ptr<MultiplexedSocket>
1064ChannelSocket::underlyingSocket() const
1065{
1066 if (auto mtx = pimpl_->endpoint.lock())
1067 return mtx;
1068 return {};
1069}
1070#endif
1071
1072void
1073ChannelSocket::answered()
1074{
1075 pimpl_->isAnswered_ = true;
1076}
1077
1078void
1079ChannelSocket::removable()
1080{
1081 pimpl_->isRemovable_ = true;
1082}
1083
1084bool
1085ChannelSocket::isRemovable() const
1086{
1087 return pimpl_->isRemovable_;
1088}
1089
1090bool
1091ChannelSocket::isAnswered() const
1092{
1093 return pimpl_->isAnswered_;
1094}
1095
1096void
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001097ChannelSocket::ready(bool accepted)
Adrien Béraud612b55b2023-05-29 10:42:04 -04001098{
1099 if (pimpl_->readyCb_)
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001100 pimpl_->readyCb_(accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001101}
1102
1103void
1104ChannelSocket::stop()
1105{
1106 if (pimpl_->isShutdown_)
1107 return;
1108 pimpl_->isShutdown_ = true;
1109 if (pimpl_->shutdownCb_)
1110 pimpl_->shutdownCb_();
1111 pimpl_->cv.notify_all();
1112 // stop() can be called by ChannelSocket::shutdown()
1113 // In this case, the eventLoop is not used, but MxSock
1114 // must remove the channel from its list (so that the
1115 // channel can be destroyed and its shared_ptr invalidated).
1116 if (pimpl_->rmFromMxSockCb_)
1117 pimpl_->rmFromMxSockCb_();
1118}
1119
1120void
1121ChannelSocket::shutdown()
1122{
1123 if (pimpl_->isShutdown_)
1124 return;
1125 stop();
1126 if (auto ep = pimpl_->endpoint.lock()) {
1127 std::error_code ec;
1128 const uint8_t dummy = '\0';
1129 ep->write(pimpl_->channel, &dummy, 0, ec);
1130 }
1131}
1132
1133std::size_t
1134ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1135{
1136 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1137 std::size_t size = std::min(len, pimpl_->buf.size());
1138
1139 for (std::size_t i = 0; i < size; ++i)
1140 outBuf[i] = pimpl_->buf[i];
1141
1142 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1143 return size;
1144}
1145
1146std::size_t
1147ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1148{
1149 if (pimpl_->isShutdown_) {
1150 ec = std::make_error_code(std::errc::broken_pipe);
1151 return -1;
1152 }
1153 if (auto ep = pimpl_->endpoint.lock()) {
1154 std::size_t sent = 0;
1155 do {
1156 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1157 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1158 if (ec) {
1159 if (ep->logger())
1160 ep->logger()->error("Error when writing on channel: {}", ec.message());
1161 return res;
1162 }
1163 sent += toSend;
1164 } while (sent < len);
1165 return sent;
1166 }
1167 ec = std::make_error_code(std::errc::broken_pipe);
1168 return -1;
1169}
1170
1171int
1172ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1173{
1174 std::unique_lock<std::mutex> lk {pimpl_->mutex};
1175 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1176 return pimpl_->buf.size();
1177}
1178
1179void
1180ChannelSocket::onShutdown(OnShutdownCb&& cb)
1181{
1182 pimpl_->shutdownCb_ = std::move(cb);
1183 if (pimpl_->isShutdown_) {
1184 pimpl_->shutdownCb_();
1185 }
1186}
1187
1188void
1189ChannelSocket::onReady(ChannelReadyCb&& cb)
1190{
1191 pimpl_->readyCb_ = std::move(cb);
1192}
1193
1194void
1195ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1196{
1197 if (auto ep = pimpl_->endpoint.lock()) {
1198 ep->sendBeacon(timeout);
1199 } else {
1200 shutdown();
1201 }
1202}
1203
1204std::shared_ptr<dht::crypto::Certificate>
1205ChannelSocket::peerCertificate() const
1206{
1207 if (auto ep = pimpl_->endpoint.lock())
1208 return ep->peerCertificate();
1209 return {};
1210}
1211
1212IpAddr
1213ChannelSocket::getLocalAddress() const
1214{
1215 if (auto ep = pimpl_->endpoint.lock())
1216 return ep->getLocalAddress();
1217 return {};
1218}
1219
1220IpAddr
1221ChannelSocket::getRemoteAddress() const
1222{
1223 if (auto ep = pimpl_->endpoint.lock())
1224 return ep->getRemoteAddress();
1225 return {};
1226}
1227
Amna31791e52023-08-03 12:40:57 -04001228std::vector<std::map<std::string, std::string>>
1229MultiplexedSocket::getChannelList() const
1230{
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001231 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
Amna31791e52023-08-03 12:40:57 -04001232 std::vector<std::map<std::string, std::string>> channelsList;
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001233 channelsList.reserve(pimpl_->sockets.size());
Amna31791e52023-08-03 12:40:57 -04001234 for (const auto& [_, channel] : pimpl_->sockets) {
Adrien Béraudd8e666d2023-10-13 12:12:09 -04001235 channelsList.emplace_back(std::map<std::string, std::string> {
1236 {"id", fmt::format("{:x}", channel->channel())},
1237 {"name", channel->name()},
1238 });
Amna31791e52023-08-03 12:40:57 -04001239 }
Amna31791e52023-08-03 12:40:57 -04001240 return channelsList;
1241}
1242
Sébastien Blin464bdff2023-07-19 08:02:53 -04001243} // namespace dhtnet