blob: 2079df5e0249c9f9d2985d95ee21b4105ef2e5f5 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
2 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
3 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
17#pragma once
18
19#include "ip_utils.h"
20#include "generic_io.h"
21
22#include <opendht/default_types.h>
23#include <condition_variable>
24
25#include <cstdint>
26
27namespace asio {
28class io_context;
29}
30
31namespace dht {
32namespace log {
Adrien Béraud9132a812023-07-21 11:20:40 -040033struct Logger;
Adrien Béraud612b55b2023-05-29 10:42:04 -040034}
35}
36
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040037namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040038
39using Logger = dht::log::Logger;
40class IceTransport;
41class ChannelSocket;
42class TlsSocketEndpoint;
43
44using DeviceId = dht::PkId;
45using OnConnectionRequestCb
46 = std::function<bool(const std::shared_ptr<dht::crypto::Certificate>& /* peer */,
47 const uint16_t& /* id */,
48 const std::string& /* name */)>;
49using OnConnectionReadyCb
50 = std::function<void(const DeviceId& /* deviceId */, const std::shared_ptr<ChannelSocket>&)>;
Adrien Béraudc5b971d2023-06-13 19:41:25 -040051using ChannelReadyCb = std::function<void(bool)>;
Adrien Béraud612b55b2023-05-29 10:42:04 -040052using OnShutdownCb = std::function<void(void)>;
53
54static constexpr auto SEND_BEACON_TIMEOUT = std::chrono::milliseconds(3000);
55static constexpr uint16_t CONTROL_CHANNEL {0};
56static constexpr uint16_t PROTOCOL_CHANNEL {0xffff};
57
58enum class ChannelRequestState {
59 REQUEST,
60 ACCEPT,
61 DECLINE,
62};
63
64/**
65 * That msgpack structure is used to request a new channel (id, name)
66 * Transmitted over the TLS socket
67 */
68struct ChannelRequest
69{
70 std::string name {};
71 uint16_t channel {0};
72 ChannelRequestState state {ChannelRequestState::REQUEST};
73 MSGPACK_DEFINE(name, channel, state)
74};
75
76/**
77 * A socket divided in channels over a TLS session
78 */
79class MultiplexedSocket : public std::enable_shared_from_this<MultiplexedSocket>
80{
81public:
Adrien Béraud5636f7c2023-09-14 14:34:57 -040082 MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, std::unique_ptr<TlsSocketEndpoint> endpoint, std::shared_ptr<dht::log::Logger> logger = {});
Adrien Béraud612b55b2023-05-29 10:42:04 -040083 ~MultiplexedSocket();
84 std::shared_ptr<ChannelSocket> addChannel(const std::string& name);
85
86 std::shared_ptr<MultiplexedSocket> shared()
87 {
88 return std::static_pointer_cast<MultiplexedSocket>(shared_from_this());
89 }
90 std::shared_ptr<MultiplexedSocket const> shared() const
91 {
92 return std::static_pointer_cast<MultiplexedSocket const>(shared_from_this());
93 }
94 std::weak_ptr<MultiplexedSocket> weak()
95 {
96 return std::static_pointer_cast<MultiplexedSocket>(shared_from_this());
97 }
98 std::weak_ptr<MultiplexedSocket const> weak() const
99 {
100 return std::static_pointer_cast<MultiplexedSocket const>(shared_from_this());
101 }
102
103 DeviceId deviceId() const;
104 bool isReliable() const;
105 bool isInitiator() const;
106 int maxPayload() const;
107
108 /**
109 * Will be triggered when a new channel is ready
110 */
111 void setOnReady(OnConnectionReadyCb&& cb);
112 /**
113 * Will be triggered when the peer asks for a new channel
114 */
115 void setOnRequest(OnConnectionRequestCb&& cb);
116
117 std::size_t write(const uint16_t& channel,
118 const uint8_t* buf,
119 std::size_t len,
120 std::error_code& ec);
121
122 /**
123 * This will close all channels and send a TLS EOF on the main socket.
124 */
125 void shutdown();
126
127 /**
128 * This will wait that eventLoop is stopped and stop it if necessary
129 */
130 void join();
131
132 /**
133 * Will trigger that callback when shutdown() is called
134 */
135 void onShutdown(OnShutdownCb&& cb);
136
137 /**
138 * Get informations from socket (channels opened)
139 */
140 void monitor() const;
141
142 const std::shared_ptr<Logger>& logger();
143
144 /**
Amna31791e52023-08-03 12:40:57 -0400145 * Get the list of channels
146 */
147 std::vector<std::map<std::string, std::string>> getChannelList() const;
148
149 /**
Adrien Béraud612b55b2023-05-29 10:42:04 -0400150 * Send a beacon on the socket and close if no response come
151 * @param timeout
152 */
153 void sendBeacon(const std::chrono::milliseconds& timeout = SEND_BEACON_TIMEOUT);
154
155 /**
156 * Get peer's certificate
157 */
158 std::shared_ptr<dht::crypto::Certificate> peerCertificate() const;
159
160 IpAddr getLocalAddress() const;
161 IpAddr getRemoteAddress() const;
162
163 void eraseChannel(uint16_t channel);
164
Adrien Béraudafa8e282023-09-24 12:53:20 -0400165 TlsSocketEndpoint* endpoint();
166
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400167#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400168 /**
169 * Check if we can send beacon on the socket
170 */
171 bool canSendBeacon() const;
172
173 /**
174 * Decide if yes or not we answer to beacon
175 * @param value New value
176 */
177 void answerToBeacon(bool value);
178
179 /**
180 * Change version sent to the peer
181 */
182 void setVersion(int version);
183
184 /**
185 * Set a callback to detect beacon messages
186 */
187 void setOnBeaconCb(const std::function<void(bool)>& cb);
188
189 /**
190 * Set a callback to detect version messages
191 */
192 void setOnVersionCb(const std::function<void(int)>& cb);
193
194 /**
195 * Send the version
196 */
197 void sendVersion();
198#endif
199
200private:
201 class Impl;
202 std::unique_ptr<Impl> pimpl_;
203};
204
205class ChannelSocketInterface : public GenericSocket<uint8_t>
206{
207public:
208 using SocketType = GenericSocket<uint8_t>;
209
210 virtual DeviceId deviceId() const = 0;
211 virtual std::string name() const = 0;
212 virtual uint16_t channel() const = 0;
213 /**
214 * Triggered when a specific channel is ready
215 * Used by ConnectionManager::connectDevice()
216 */
217 virtual void onReady(ChannelReadyCb&& cb) = 0;
218 /**
219 * Will trigger that callback when shutdown() is called
220 */
221 virtual void onShutdown(OnShutdownCb&& cb) = 0;
222
223 virtual void onRecv(std::vector<uint8_t>&& pkt) = 0;
224};
225
226class ChannelSocketTest : public ChannelSocketInterface
227{
228public:
229 ChannelSocketTest(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, const std::string& name, const uint16_t& channel);
230 ~ChannelSocketTest();
231
232 static void link(const std::shared_ptr<ChannelSocketTest>& socket1,
233 const std::shared_ptr<ChannelSocketTest>& socket2);
234
235 DeviceId deviceId() const override;
236 std::string name() const override;
237 uint16_t channel() const override;
238
239 bool isReliable() const override { return true; };
240 bool isInitiator() const override { return true; };
241 int maxPayload() const override { return 0; };
242
243 void shutdown() override;
244
245 std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override;
246 std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override;
247 int waitForData(std::chrono::milliseconds timeout, std::error_code&) const override;
248 void setOnRecv(RecvCb&&) override;
249 void onRecv(std::vector<uint8_t>&& pkt) override;
250
251 /**
252 * Triggered when a specific channel is ready
253 * Used by ConnectionManager::connectDevice()
254 */
255 void onReady(ChannelReadyCb&& cb) override;
256 /**
257 * Will trigger that callback when shutdown() is called
258 */
259 void onShutdown(OnShutdownCb&& cb) override;
260
261 std::vector<uint8_t> rx_buf {};
262 mutable std::mutex mutex {};
263 mutable std::condition_variable cv {};
264 GenericSocket<uint8_t>::RecvCb cb {};
265
266private:
267 const DeviceId pimpl_deviceId;
268 const std::string pimpl_name;
269 const uint16_t pimpl_channel;
270 asio::io_context& ioCtx_;
271 std::weak_ptr<ChannelSocketTest> remote;
272 OnShutdownCb shutdownCb_ {[&] {
273 }};
274 std::atomic_bool isShutdown_ {false};
275};
276
277/**
278 * Represents a channel of the multiplexed socket (channel, name)
279 */
280class ChannelSocket : public ChannelSocketInterface
281{
282public:
283 ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
284 const std::string& name,
285 const uint16_t& channel,
286 bool isInitiator = false,
287 std::function<void()> rmFromMxSockCb = {});
288 ~ChannelSocket();
289
290 DeviceId deviceId() const override;
291 std::string name() const override;
292 uint16_t channel() const override;
293 bool isReliable() const override;
294 bool isInitiator() const override;
295 int maxPayload() const override;
296 /**
297 * Like shutdown, but don't send any packet on the socket.
298 * Used by Multiplexed Socket when the TLS endpoint is already shutting down
299 */
300 void stop();
301
302 /**
303 * This will send an empty buffer as a packet (equivalent to EOF)
304 * Will trigger onShutdown's callback
305 */
306 void shutdown() override;
307
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400308 void ready(bool accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400309 /**
310 * Triggered when a specific channel is ready
311 * Used by ConnectionManager::connectDevice()
312 */
313 void onReady(ChannelReadyCb&& cb) override;
314 /**
315 * Will trigger that callback when shutdown() is called
316 */
317 void onShutdown(OnShutdownCb&& cb) override;
318
319 std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override;
320 /**
321 * @note len should be < UINT8_MAX, else you will get ec = EMSGSIZE
322 */
323 std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override;
324 int waitForData(std::chrono::milliseconds timeout, std::error_code&) const override;
325
326 /**
327 * set a callback when receiving data
328 * @note: this callback should take a little time and not block
329 * but you can move it in a thread
330 */
331 void setOnRecv(RecvCb&&) override;
332
333 void onRecv(std::vector<uint8_t>&& pkt) override;
334
335 /**
336 * Send a beacon on the socket and close if no response come
337 * @param timeout
338 */
339 void sendBeacon(const std::chrono::milliseconds& timeout = SEND_BEACON_TIMEOUT);
340
341 /**
342 * Get peer's certificate
343 */
344 std::shared_ptr<dht::crypto::Certificate> peerCertificate() const;
345
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400346#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400347 std::shared_ptr<MultiplexedSocket> underlyingSocket() const;
348#endif
349
350 // Note: When a channel is accepted, it can receives data ASAP and when finished will be removed
351 // however, onAccept is it's own thread due to the callbacks. In this case, the channel must be
352 // deleted in the onAccept.
353 void answered();
354 bool isAnswered() const;
355 void removable();
356 bool isRemovable() const;
357
358 IpAddr getLocalAddress() const;
359 IpAddr getRemoteAddress() const;
360
361private:
362 class Impl;
363 std::unique_ptr<Impl> pimpl_;
364};
365
Sébastien Blin464bdff2023-07-19 08:02:53 -0400366} // namespace dhtnet
Adrien Béraud612b55b2023-05-29 10:42:04 -0400367
Adrien Béraud1ae60aa2023-07-07 09:55:09 -0400368MSGPACK_ADD_ENUM(dhtnet::ChannelRequestState);