blob: 5abf7427ed0ce70bc27af87784851f6eca0967ff [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
2 * Copyright (C) 2019-2023 Savoir-faire Linux Inc.
3 * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation; either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program. If not, see <https://www.gnu.org/licenses/>.
17 */
18
19#include "multiplexed_socket.h"
20#include "peer_connection.h"
21#include "ice_transport.h"
22#include "certstore.h"
23
24#include <opendht/logger.h>
25#include <opendht/thread_pool.h>
26
27#include <asio/io_context.hpp>
28#include <asio/steady_timer.hpp>
29
30#include <deque>
31
32static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
33static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
34
35struct ChanneledMessage
36{
37 uint16_t channel;
38 std::vector<uint8_t> data;
39 MSGPACK_DEFINE(channel, data)
40};
41
42struct BeaconMsg
43{
44 bool p;
45 MSGPACK_DEFINE_MAP(p)
46};
47
48struct VersionMsg
49{
50 int v;
51 MSGPACK_DEFINE_MAP(v)
52};
53
54namespace jami {
55
56using clock = std::chrono::steady_clock;
57using time_point = clock::time_point;
58
59class MultiplexedSocket::Impl
60{
61public:
62 Impl(MultiplexedSocket& parent,
63 std::shared_ptr<asio::io_context> ctx,
64 const DeviceId& deviceId,
65 std::unique_ptr<TlsSocketEndpoint> endpoint)
66 : parent_(parent)
67 , deviceId(deviceId)
68 , ctx_(std::move(ctx))
69 , beaconTimer_(*ctx_)
70 , endpoint(std::move(endpoint))
71 , eventLoopThread_ {[this] {
72 try {
73 eventLoop();
74 } catch (const std::exception& e) {
75 if (logger_)
76 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
77 shutdown();
78 }
79 }}
80 {}
81
82 ~Impl() {}
83
84 void join()
85 {
86 if (!isShutdown_) {
87 if (endpoint)
88 endpoint->setOnStateChange({});
89 shutdown();
90 } else {
91 clearSockets();
92 }
93 if (eventLoopThread_.joinable())
94 eventLoopThread_.join();
95 }
96
97 void clearSockets()
98 {
99 decltype(sockets) socks;
100 {
101 std::lock_guard<std::mutex> lkSockets(socketsMutex);
102 socks = std::move(sockets);
103 }
104 for (auto& socket : socks) {
105 // Just trigger onShutdown() to make client know
106 // No need to write the EOF for the channel, the write will fail because endpoint is
107 // already shutdown
108 if (socket.second)
109 socket.second->stop();
110 }
111 }
112
113 void shutdown()
114 {
115 if (isShutdown_)
116 return;
117 stop.store(true);
118 isShutdown_ = true;
119 beaconTimer_.cancel();
120 if (onShutdown_)
121 onShutdown_();
122 if (endpoint) {
123 std::unique_lock<std::mutex> lk(writeMtx);
124 endpoint->shutdown();
125 }
126 clearSockets();
127 }
128
129 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
130 uint16_t channel,
131 bool isInitiator = false)
132 {
133 auto& channelSocket = sockets[channel];
134 if (not channelSocket)
135 channelSocket = std::make_shared<ChannelSocket>(
136 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
137 // Remove socket in another thread to avoid any lock
138 dht::ThreadPool::io().run([w, channel]() {
139 if (auto shared = w.lock()) {
140 shared->eraseChannel(channel);
141 }
142 });
143 });
144 else {
145 if (logger_)
146 logger_->warn("A channel is already present on that socket, accepting "
147 "the request will close the previous one {}", name);
148 }
149 return channelSocket;
150 }
151
152 /**
153 * Handle packets on the TLS endpoint and parse RTP
154 */
155 void eventLoop();
156 /**
157 * Triggered when a new control packet is received
158 */
159 void handleControlPacket(std::vector<uint8_t>&& pkt);
160 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
161 bool handleProtocolMsg(const msgpack::object& o);
162 /**
163 * Triggered when a new packet on a channel is received
164 */
165 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
166 void onRequest(const std::string& name, uint16_t channel);
167 void onAccept(const std::string& name, uint16_t channel);
168
169 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
170 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
171
172 // Beacon
173 void sendBeacon(const std::chrono::milliseconds& timeout);
174 void handleBeaconRequest();
175 void handleBeaconResponse();
176 std::atomic_int beaconCounter_ {0};
177
178 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
179
180 msgpack::unpacker pac_ {};
181
182 MultiplexedSocket& parent_;
183
184 std::shared_ptr<Logger> logger_;
185 std::shared_ptr<asio::io_context> ctx_;
186
187 OnConnectionReadyCb onChannelReady_ {};
188 OnConnectionRequestCb onRequest_ {};
189 OnShutdownCb onShutdown_ {};
190
191 DeviceId deviceId {};
192 // Main socket
193 std::unique_ptr<TlsSocketEndpoint> endpoint {};
194
195 std::mutex socketsMutex {};
196 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
197
198 // Main loop to parse incoming packets
199 std::atomic_bool stop {false};
200 std::thread eventLoopThread_ {};
201
202 std::atomic_bool isShutdown_ {false};
203
204 std::mutex writeMtx {};
205
206 time_point start_ {clock::now()};
207 //std::shared_ptr<Task> beaconTask_ {};
208 asio::steady_timer beaconTimer_;
209
210 // version related stuff
211 void sendVersion();
212 void onVersion(int version);
213 std::atomic_bool canSendBeacon_ {false};
214 std::atomic_bool answerBeacon_ {true};
215 int version_ {MULTIPLEXED_SOCKET_VERSION};
216 std::function<void(bool)> onBeaconCb_ {};
217 std::function<void(int)> onVersionCb_ {};
218};
219
220void
221MultiplexedSocket::Impl::eventLoop()
222{
223 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
224 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
225 if (logger_)
226 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
227 shutdown();
228 return false;
229 }
230 return true;
231 });
232 sendVersion();
233 std::error_code ec;
234 while (!stop) {
235 if (!endpoint) {
236 shutdown();
237 return;
238 }
239 pac_.reserve_buffer(IO_BUFFER_SIZE);
240 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
241 if (size < 0) {
242 if (ec && logger_)
243 logger_->error("Read error detected: {}", ec.message());
244 break;
245 }
246 if (size == 0) {
247 // We can close the socket
248 shutdown();
249 break;
250 }
251
252 pac_.buffer_consumed(size);
253 msgpack::object_handle oh;
254 while (pac_.next(oh) && !stop) {
255 try {
256 auto msg = oh.get().as<ChanneledMessage>();
257 if (msg.channel == CONTROL_CHANNEL)
258 handleControlPacket(std::move(msg.data));
259 else if (msg.channel == PROTOCOL_CHANNEL)
260 handleProtocolPacket(std::move(msg.data));
261 else
262 handleChannelPacket(msg.channel, std::move(msg.data));
263 } catch (const std::exception& e) {
264 if (logger_)
265 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
266 } catch (...) {
267 if (logger_)
268 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
269 }
270 }
271 }
272}
273
274void
275MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
276{
277 std::lock_guard<std::mutex> lkSockets(socketsMutex);
278 auto& socket = sockets[channel];
279 if (!socket) {
280 if (logger_)
281 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
282 return;
283 }
284
285 onChannelReady_(deviceId, socket);
286 socket->ready();
287 // Due to the callbacks that can take some time, onAccept can arrive after
288 // receiving all the data. In this case, the socket should be removed here
289 // as handle by onChannelReady_
290 if (socket->isRemovable())
291 sockets.erase(channel);
292 else
293 socket->answered();
294}
295
296void
297MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
298{
299 if (!canSendBeacon_)
300 return;
301 beaconCounter_++;
302 if (logger_)
303 logger_->debug("Send beacon to peer {}", deviceId);
304
305 msgpack::sbuffer buffer(8);
306 msgpack::packer<msgpack::sbuffer> pk(&buffer);
307 pk.pack(BeaconMsg {true});
308 if (!writeProtocolMessage(buffer))
309 return;
310 beaconTimer_.expires_after(timeout);
311 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
312 if (ec == asio::error::operation_aborted)
313 return;
314 if (auto shared = w.lock()) {
315 if (shared->pimpl_->beaconCounter_ != 0) {
316 if (shared->pimpl_->logger_)
317 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
318 shared->shutdown();
319 }
320 }
321 });
322}
323
324void
325MultiplexedSocket::Impl::handleBeaconRequest()
326{
327 if (!answerBeacon_)
328 return;
329 // Run this on dedicated thread because some callbacks can take time
330 dht::ThreadPool::io().run([w = parent_.weak()]() {
331 if (auto shared = w.lock()) {
332 msgpack::sbuffer buffer(8);
333 msgpack::packer<msgpack::sbuffer> pk(&buffer);
334 pk.pack(BeaconMsg {false});
335 if (shared->pimpl_->logger_)
336 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
337 shared->pimpl_->writeProtocolMessage(buffer);
338 }
339 });
340}
341
342void
343MultiplexedSocket::Impl::handleBeaconResponse()
344{
345 if (logger_)
346 logger_->debug("Get beacon response from peer {}", deviceId);
347 beaconCounter_--;
348}
349
350bool
351MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
352{
353 std::error_code ec;
354 int wr = parent_.write(PROTOCOL_CHANNEL,
355 (const unsigned char*) buffer.data(),
356 buffer.size(),
357 ec);
358 return wr > 0;
359}
360
361void
362MultiplexedSocket::Impl::sendVersion()
363{
364 dht::ThreadPool::io().run([w = parent_.weak()]() {
365 if (auto shared = w.lock()) {
366 auto version = shared->pimpl_->version_;
367 msgpack::sbuffer buffer(8);
368 msgpack::packer<msgpack::sbuffer> pk(&buffer);
369 pk.pack(VersionMsg {version});
370 shared->pimpl_->writeProtocolMessage(buffer);
371 }
372 });
373}
374
375void
376MultiplexedSocket::Impl::onVersion(int version)
377{
378 // Check if version > 1
379 if (version >= 1) {
380 if (logger_)
381 logger_->debug("Peer {} supports beacon", deviceId);
382 canSendBeacon_ = true;
383 } else {
384 if (logger_)
385 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
386 deviceId,
387 version);
388 canSendBeacon_ = false;
389 }
390}
391
392void
393MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
394{
395 auto accept = onRequest_(endpoint->peerCertificate(), channel, name);
396 std::shared_ptr<ChannelSocket> channelSocket;
397 if (accept) {
398 std::lock_guard<std::mutex> lkSockets(socketsMutex);
399 channelSocket = makeSocket(name, channel);
400 }
401
402 // Answer to ChannelRequest if accepted
403 ChannelRequest val;
404 val.channel = channel;
405 val.name = name;
406 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
407 msgpack::sbuffer buffer(512);
408 msgpack::pack(buffer, val);
409 std::error_code ec;
410 int wr = parent_.write(CONTROL_CHANNEL,
411 reinterpret_cast<const uint8_t*>(buffer.data()),
412 buffer.size(),
413 ec);
414 if (wr < 0) {
415 if (ec && logger_)
416 logger_->error("The write operation failed with error: {:s}", ec.message());
417 stop.store(true);
418 return;
419 }
420
421 if (accept) {
422 onChannelReady_(deviceId, channelSocket);
423 channelSocket->ready();
424 }
425}
426
427void
428MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
429{
430 // Run this on dedicated thread because some callbacks can take time
431 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
432 auto shared = w.lock();
433 if (!shared)
434 return;
435 auto& pimpl = *shared->pimpl_;
436 try {
437 size_t off = 0;
438 while (off != pkt.size()) {
439 msgpack::unpacked result;
440 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
441 auto object = result.get();
442 if (pimpl.handleProtocolMsg(object))
443 continue;
444 auto req = object.as<ChannelRequest>();
445 if (req.state == ChannelRequestState::ACCEPT) {
446 pimpl.onAccept(req.name, req.channel);
447 } else if (req.state == ChannelRequestState::DECLINE) {
448 std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
449 auto channel = pimpl.sockets.find(req.channel);
450 if (channel != pimpl.sockets.end()) {
451 channel->second->stop();
452 pimpl.sockets.erase(channel);
453 }
454 } else if (pimpl.onRequest_) {
455 pimpl.onRequest(req.name, req.channel);
456 }
457 }
458 } catch (const std::exception& e) {
459 if (pimpl.logger_)
460 pimpl.logger_->error("Error on the control channel: {}", e.what());
461 }
462 });
463}
464
465void
466MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
467{
468 std::lock_guard<std::mutex> lkSockets(socketsMutex);
469 auto sockIt = sockets.find(channel);
470 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
471 if (pkt.size() == 0) {
472 sockIt->second->stop();
473 if (sockIt->second->isAnswered())
474 sockets.erase(sockIt);
475 else
476 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
477 // removed later.
478 } else {
479 sockIt->second->onRecv(std::move(pkt));
480 }
481 } else if (pkt.size() != 0) {
482 if (logger_)
483 logger_->warn("Non existing channel: {}", channel);
484 }
485}
486
487bool
488MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
489{
490 try {
491 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
492 auto key = o.via.map.ptr[0].key.as<std::string_view>();
493 if (key == "p") {
494 auto msg = o.as<BeaconMsg>();
495 if (msg.p)
496 handleBeaconRequest();
497 else
498 handleBeaconResponse();
499 if (onBeaconCb_)
500 onBeaconCb_(msg.p);
501 return true;
502 } else if (key == "v") {
503 auto msg = o.as<VersionMsg>();
504 onVersion(msg.v);
505 if (onVersionCb_)
506 onVersionCb_(msg.v);
507 return true;
508 } else {
509 if (logger_)
510 logger_->warn("Unknown message type");
511 }
512 }
513 } catch (const std::exception& e) {
514 if (logger_)
515 logger_->error("Error on the protocol channel: {}", e.what());
516 }
517 return false;
518}
519
520void
521MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
522{
523 // Run this on dedicated thread because some callbacks can take time
524 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
525 auto shared = w.lock();
526 if (!shared)
527 return;
528 try {
529 size_t off = 0;
530 while (off != pkt.size()) {
531 msgpack::unpacked result;
532 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
533 auto object = result.get();
534 if (shared->pimpl_->handleProtocolMsg(object))
535 return;
536 }
537 } catch (const std::exception& e) {
538 if (shared->pimpl_->logger_)
539 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
540 }
541 });
542}
543
544MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
545 std::unique_ptr<TlsSocketEndpoint> endpoint)
546 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint)))
547{}
548
549MultiplexedSocket::~MultiplexedSocket() {}
550
551std::shared_ptr<ChannelSocket>
552MultiplexedSocket::addChannel(const std::string& name)
553{
554 // Note: because both sides can request the same channel number at the same time
555 // it's better to use a random channel number instead of just incrementing the request.
556 thread_local dht::crypto::random_device rd;
557 std::uniform_int_distribution<uint16_t> dist;
558 auto offset = dist(rd);
559 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
560 for (int i = 1; i < UINT16_MAX; ++i) {
561 auto c = (offset + i) % UINT16_MAX;
562 if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL
563 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
564 continue;
565 auto channel = pimpl_->makeSocket(name, c, true);
566 return channel;
567 }
568 return {};
569}
570
571DeviceId
572MultiplexedSocket::deviceId() const
573{
574 return pimpl_->deviceId;
575}
576
577void
578MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
579{
580 pimpl_->onChannelReady_ = std::move(cb);
581}
582
583void
584MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
585{
586 pimpl_->onRequest_ = std::move(cb);
587}
588
589bool
590MultiplexedSocket::isReliable() const
591{
592 return true;
593}
594
595bool
596MultiplexedSocket::isInitiator() const
597{
598 if (!pimpl_->endpoint) {
599 if (pimpl_->logger_)
600 pimpl_->logger_->warn("No endpoint found for socket");
601 return false;
602 }
603 return pimpl_->endpoint->isInitiator();
604}
605
606int
607MultiplexedSocket::maxPayload() const
608{
609 if (!pimpl_->endpoint) {
610 if (pimpl_->logger_)
611 pimpl_->logger_->warn("No endpoint found for socket");
612 return 0;
613 }
614 return pimpl_->endpoint->maxPayload();
615}
616
617std::size_t
618MultiplexedSocket::write(const uint16_t& channel,
619 const uint8_t* buf,
620 std::size_t len,
621 std::error_code& ec)
622{
623 assert(nullptr != buf);
624
625 if (pimpl_->isShutdown_) {
626 ec = std::make_error_code(std::errc::broken_pipe);
627 return -1;
628 }
629 if (len > UINT16_MAX) {
630 ec = std::make_error_code(std::errc::message_size);
631 return -1;
632 }
633 bool oneShot = len < 8192;
634 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
635 msgpack::packer<msgpack::sbuffer> pk(&buffer);
636 pk.pack_array(2);
637 pk.pack(channel);
638 pk.pack_bin(len);
639 if (oneShot)
640 pk.pack_bin_body((const char*) buf, len);
641
642 std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
643 if (!pimpl_->endpoint) {
644 if (pimpl_->logger_)
645 pimpl_->logger_->warn("No endpoint found for socket");
646 ec = std::make_error_code(std::errc::broken_pipe);
647 return -1;
648 }
649 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
650 if (not oneShot and res >= 0)
651 res = pimpl_->endpoint->write(buf, len, ec);
652 lk.unlock();
653 if (res < 0) {
654 if (ec && pimpl_->logger_)
655 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
656 shutdown();
657 }
658 return res;
659}
660
661void
662MultiplexedSocket::shutdown()
663{
664 pimpl_->shutdown();
665}
666
667void
668MultiplexedSocket::join()
669{
670 pimpl_->join();
671}
672
673void
674MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
675{
676 pimpl_->onShutdown_ = std::move(cb);
677 if (pimpl_->isShutdown_)
678 pimpl_->onShutdown_();
679}
680
681const std::shared_ptr<Logger>&
682MultiplexedSocket::logger()
683{
684 return pimpl_->logger_;
685}
686
687void
688MultiplexedSocket::monitor() const
689{
690 auto cert = peerCertificate();
691 if (!cert || !cert->issuer)
692 return;
693 auto now = clock::now();
694 if (!pimpl_->logger_)
695 return;
696 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
697 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
698 pimpl_->endpoint->monitor();
699 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
700 for (const auto& [_, channel] : pimpl_->sockets) {
701 if (channel)
702 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
703 fmt::ptr(channel.get()),
704 channel.use_count(),
705 channel->name(),
706 channel->isInitiator());
707 }
708}
709
710void
711MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
712{
713 pimpl_->sendBeacon(timeout);
714}
715
716std::shared_ptr<dht::crypto::Certificate>
717MultiplexedSocket::peerCertificate() const
718{
719 return pimpl_->endpoint->peerCertificate();
720}
721
722#ifdef LIBJAMI_TESTABLE
723bool
724MultiplexedSocket::canSendBeacon() const
725{
726 return pimpl_->canSendBeacon_;
727}
728
729void
730MultiplexedSocket::answerToBeacon(bool value)
731{
732 pimpl_->answerBeacon_ = value;
733}
734
735void
736MultiplexedSocket::setVersion(int version)
737{
738 pimpl_->version_ = version;
739}
740
741void
742MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
743{
744 pimpl_->onBeaconCb_ = cb;
745}
746
747void
748MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
749{
750 pimpl_->onVersionCb_ = cb;
751}
752
753void
754MultiplexedSocket::sendVersion()
755{
756 pimpl_->sendVersion();
757}
758
759IpAddr
760MultiplexedSocket::getLocalAddress() const
761{
762 return pimpl_->endpoint->getLocalAddress();
763}
764
765IpAddr
766MultiplexedSocket::getRemoteAddress() const
767{
768 return pimpl_->endpoint->getRemoteAddress();
769}
770
771#endif
772
773void
774MultiplexedSocket::eraseChannel(uint16_t channel)
775{
776 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
777 auto itSocket = pimpl_->sockets.find(channel);
778 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
779 pimpl_->sockets.erase(itSocket);
780}
781
782////////////////////////////////////////////////////////////////
783
784class ChannelSocket::Impl
785{
786public:
787 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
788 const std::string& name,
789 const uint16_t& channel,
790 bool isInitiator,
791 std::function<void()> rmFromMxSockCb)
792 : name(name)
793 , channel(channel)
794 , endpoint(std::move(endpoint))
795 , isInitiator_(isInitiator)
796 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
797 {}
798
799 ~Impl() {}
800
801 ChannelReadyCb readyCb_ {};
802 OnShutdownCb shutdownCb_ {};
803 std::atomic_bool isShutdown_ {false};
804 std::string name {};
805 uint16_t channel {};
806 std::weak_ptr<MultiplexedSocket> endpoint {};
807 bool isInitiator_ {false};
808 std::function<void()> rmFromMxSockCb_;
809
810 bool isAnswered_ {false};
811 bool isRemovable_ {false};
812
813 std::vector<uint8_t> buf {};
814 std::mutex mutex {};
815 std::condition_variable cv {};
816 GenericSocket<uint8_t>::RecvCb cb {};
817};
818
819ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
820 const DeviceId& deviceId,
821 const std::string& name,
822 const uint16_t& channel)
823 : pimpl_deviceId(deviceId)
824 , pimpl_name(name)
825 , pimpl_channel(channel)
826 , ioCtx_(*ctx)
827{}
828
829ChannelSocketTest::~ChannelSocketTest() {}
830
831void
832ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
833 const std::shared_ptr<ChannelSocketTest>& socket2)
834{
835 socket1->remote = socket2;
836 socket2->remote = socket1;
837}
838
839DeviceId
840ChannelSocketTest::deviceId() const
841{
842 return pimpl_deviceId;
843}
844
845std::string
846ChannelSocketTest::name() const
847{
848 return pimpl_name;
849}
850
851uint16_t
852ChannelSocketTest::channel() const
853{
854 return pimpl_channel;
855}
856
857void
858ChannelSocketTest::shutdown()
859{
860 {
861 std::unique_lock<std::mutex> lk {mutex};
862 if (!isShutdown_.exchange(true)) {
863 lk.unlock();
864 shutdownCb_();
865 }
866 cv.notify_all();
867 }
868
869 if (auto peer = remote.lock()) {
870 if (!peer->isShutdown_.exchange(true)) {
871 peer->shutdownCb_();
872 }
873 peer->cv.notify_all();
874 }
875}
876
877std::size_t
878ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
879{
880 std::size_t size = std::min(len, this->rx_buf.size());
881
882 for (std::size_t i = 0; i < size; ++i)
883 buf[i] = this->rx_buf[i];
884
885 if (size == this->rx_buf.size()) {
886 this->rx_buf.clear();
887 } else
888 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
889 return size;
890}
891
892std::size_t
893ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
894{
895 if (isShutdown_) {
896 ec = std::make_error_code(std::errc::broken_pipe);
897 return -1;
898 }
899 ec = {};
900 dht::ThreadPool::computation().run(
901 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
902 if (auto peer = r.lock())
903 peer->onRecv(std::move(data));
904 });
905 return len;
906}
907
908int
909ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
910{
911 std::unique_lock<std::mutex> lk {mutex};
912 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
913 return rx_buf.size();
914}
915
916void
917ChannelSocketTest::setOnRecv(RecvCb&& cb)
918{
919 std::lock_guard<std::mutex> lkSockets(mutex);
920 this->cb = std::move(cb);
921 if (!rx_buf.empty() && this->cb) {
922 this->cb(rx_buf.data(), rx_buf.size());
923 rx_buf.clear();
924 }
925}
926
927void
928ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
929{
930 std::lock_guard<std::mutex> lkSockets(mutex);
931 if (cb) {
932 cb(pkt.data(), pkt.size());
933 return;
934 }
935 rx_buf.insert(rx_buf.end(),
936 std::make_move_iterator(pkt.begin()),
937 std::make_move_iterator(pkt.end()));
938 cv.notify_all();
939}
940
941void
942ChannelSocketTest::onReady(ChannelReadyCb&& cb)
943{}
944
945void
946ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
947{
948 std::unique_lock<std::mutex> lk {mutex};
949 shutdownCb_ = std::move(cb);
950
951 if (isShutdown_) {
952 lk.unlock();
953 shutdownCb_();
954 }
955}
956
957ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
958 const std::string& name,
959 const uint16_t& channel,
960 bool isInitiator,
961 std::function<void()> rmFromMxSockCb)
962 : pimpl_ {
963 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
964{}
965
966ChannelSocket::~ChannelSocket() {}
967
968DeviceId
969ChannelSocket::deviceId() const
970{
971 if (auto ep = pimpl_->endpoint.lock()) {
972 return ep->deviceId();
973 }
974 return {};
975}
976
977std::string
978ChannelSocket::name() const
979{
980 return pimpl_->name;
981}
982
983uint16_t
984ChannelSocket::channel() const
985{
986 return pimpl_->channel;
987}
988
989bool
990ChannelSocket::isReliable() const
991{
992 if (auto ep = pimpl_->endpoint.lock()) {
993 return ep->isReliable();
994 }
995 return false;
996}
997
998bool
999ChannelSocket::isInitiator() const
1000{
1001 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1002 // because a multiplexed socket can have sockets from accepted requests
1003 // or made via connectDevice(). Here, isInitiator_ return if the socket
1004 // is from connectDevice.
1005 return pimpl_->isInitiator_;
1006}
1007
1008int
1009ChannelSocket::maxPayload() const
1010{
1011 if (auto ep = pimpl_->endpoint.lock()) {
1012 return ep->maxPayload();
1013 }
1014 return -1;
1015}
1016
1017void
1018ChannelSocket::setOnRecv(RecvCb&& cb)
1019{
1020 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1021 pimpl_->cb = std::move(cb);
1022 if (!pimpl_->buf.empty() && pimpl_->cb) {
1023 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1024 pimpl_->buf.clear();
1025 }
1026}
1027
1028void
1029ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1030{
1031 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1032 if (pimpl_->cb) {
1033 pimpl_->cb(&pkt[0], pkt.size());
1034 return;
1035 }
1036 pimpl_->buf.insert(pimpl_->buf.end(),
1037 std::make_move_iterator(pkt.begin()),
1038 std::make_move_iterator(pkt.end()));
1039 pimpl_->cv.notify_all();
1040}
1041
1042#ifdef LIBJAMI_TESTABLE
1043std::shared_ptr<MultiplexedSocket>
1044ChannelSocket::underlyingSocket() const
1045{
1046 if (auto mtx = pimpl_->endpoint.lock())
1047 return mtx;
1048 return {};
1049}
1050#endif
1051
1052void
1053ChannelSocket::answered()
1054{
1055 pimpl_->isAnswered_ = true;
1056}
1057
1058void
1059ChannelSocket::removable()
1060{
1061 pimpl_->isRemovable_ = true;
1062}
1063
1064bool
1065ChannelSocket::isRemovable() const
1066{
1067 return pimpl_->isRemovable_;
1068}
1069
1070bool
1071ChannelSocket::isAnswered() const
1072{
1073 return pimpl_->isAnswered_;
1074}
1075
1076void
1077ChannelSocket::ready()
1078{
1079 if (pimpl_->readyCb_)
1080 pimpl_->readyCb_();
1081}
1082
1083void
1084ChannelSocket::stop()
1085{
1086 if (pimpl_->isShutdown_)
1087 return;
1088 pimpl_->isShutdown_ = true;
1089 if (pimpl_->shutdownCb_)
1090 pimpl_->shutdownCb_();
1091 pimpl_->cv.notify_all();
1092 // stop() can be called by ChannelSocket::shutdown()
1093 // In this case, the eventLoop is not used, but MxSock
1094 // must remove the channel from its list (so that the
1095 // channel can be destroyed and its shared_ptr invalidated).
1096 if (pimpl_->rmFromMxSockCb_)
1097 pimpl_->rmFromMxSockCb_();
1098}
1099
1100void
1101ChannelSocket::shutdown()
1102{
1103 if (pimpl_->isShutdown_)
1104 return;
1105 stop();
1106 if (auto ep = pimpl_->endpoint.lock()) {
1107 std::error_code ec;
1108 const uint8_t dummy = '\0';
1109 ep->write(pimpl_->channel, &dummy, 0, ec);
1110 }
1111}
1112
1113std::size_t
1114ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1115{
1116 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1117 std::size_t size = std::min(len, pimpl_->buf.size());
1118
1119 for (std::size_t i = 0; i < size; ++i)
1120 outBuf[i] = pimpl_->buf[i];
1121
1122 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1123 return size;
1124}
1125
1126std::size_t
1127ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1128{
1129 if (pimpl_->isShutdown_) {
1130 ec = std::make_error_code(std::errc::broken_pipe);
1131 return -1;
1132 }
1133 if (auto ep = pimpl_->endpoint.lock()) {
1134 std::size_t sent = 0;
1135 do {
1136 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1137 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1138 if (ec) {
1139 if (ep->logger())
1140 ep->logger()->error("Error when writing on channel: {}", ec.message());
1141 return res;
1142 }
1143 sent += toSend;
1144 } while (sent < len);
1145 return sent;
1146 }
1147 ec = std::make_error_code(std::errc::broken_pipe);
1148 return -1;
1149}
1150
1151int
1152ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1153{
1154 std::unique_lock<std::mutex> lk {pimpl_->mutex};
1155 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1156 return pimpl_->buf.size();
1157}
1158
1159void
1160ChannelSocket::onShutdown(OnShutdownCb&& cb)
1161{
1162 pimpl_->shutdownCb_ = std::move(cb);
1163 if (pimpl_->isShutdown_) {
1164 pimpl_->shutdownCb_();
1165 }
1166}
1167
1168void
1169ChannelSocket::onReady(ChannelReadyCb&& cb)
1170{
1171 pimpl_->readyCb_ = std::move(cb);
1172}
1173
1174void
1175ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1176{
1177 if (auto ep = pimpl_->endpoint.lock()) {
1178 ep->sendBeacon(timeout);
1179 } else {
1180 shutdown();
1181 }
1182}
1183
1184std::shared_ptr<dht::crypto::Certificate>
1185ChannelSocket::peerCertificate() const
1186{
1187 if (auto ep = pimpl_->endpoint.lock())
1188 return ep->peerCertificate();
1189 return {};
1190}
1191
1192IpAddr
1193ChannelSocket::getLocalAddress() const
1194{
1195 if (auto ep = pimpl_->endpoint.lock())
1196 return ep->getLocalAddress();
1197 return {};
1198}
1199
1200IpAddr
1201ChannelSocket::getRemoteAddress() const
1202{
1203 if (auto ep = pimpl_->endpoint.lock())
1204 return ep->getRemoteAddress();
1205 return {};
1206}
1207
1208} // namespace jami