blob: a090a8c9a03830dcf18b27af054c7c8a5f44102e [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
2 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
3 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <https://www.gnu.org/licenses/>.
16 */
17#pragma once
18
19#include "ip_utils.h"
20#include "generic_io.h"
21
22#include <opendht/default_types.h>
23#include <condition_variable>
24
25#include <cstdint>
26
27namespace asio {
28class io_context;
29}
30
31namespace dht {
32namespace log {
Adrien Béraud9132a812023-07-21 11:20:40 -040033struct Logger;
Adrien Béraud612b55b2023-05-29 10:42:04 -040034}
35}
36
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040037namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040038
39using Logger = dht::log::Logger;
40class IceTransport;
41class ChannelSocket;
42class TlsSocketEndpoint;
43
44using DeviceId = dht::PkId;
45using OnConnectionRequestCb
46 = std::function<bool(const std::shared_ptr<dht::crypto::Certificate>& /* peer */,
47 const uint16_t& /* id */,
48 const std::string& /* name */)>;
49using OnConnectionReadyCb
50 = std::function<void(const DeviceId& /* deviceId */, const std::shared_ptr<ChannelSocket>&)>;
Adrien Béraudc5b971d2023-06-13 19:41:25 -040051using ChannelReadyCb = std::function<void(bool)>;
Adrien Béraud612b55b2023-05-29 10:42:04 -040052using OnShutdownCb = std::function<void(void)>;
53
54static constexpr auto SEND_BEACON_TIMEOUT = std::chrono::milliseconds(3000);
55static constexpr uint16_t CONTROL_CHANNEL {0};
56static constexpr uint16_t PROTOCOL_CHANNEL {0xffff};
57
58enum class ChannelRequestState {
59 REQUEST,
60 ACCEPT,
61 DECLINE,
62};
63
64/**
65 * That msgpack structure is used to request a new channel (id, name)
66 * Transmitted over the TLS socket
67 */
68struct ChannelRequest
69{
70 std::string name {};
71 uint16_t channel {0};
72 ChannelRequestState state {ChannelRequestState::REQUEST};
73 MSGPACK_DEFINE(name, channel, state)
74};
75
76/**
77 * A socket divided in channels over a TLS session
78 */
79class MultiplexedSocket : public std::enable_shared_from_this<MultiplexedSocket>
80{
81public:
82 MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, std::unique_ptr<TlsSocketEndpoint> endpoint);
83 ~MultiplexedSocket();
84 std::shared_ptr<ChannelSocket> addChannel(const std::string& name);
85
86 std::shared_ptr<MultiplexedSocket> shared()
87 {
88 return std::static_pointer_cast<MultiplexedSocket>(shared_from_this());
89 }
90 std::shared_ptr<MultiplexedSocket const> shared() const
91 {
92 return std::static_pointer_cast<MultiplexedSocket const>(shared_from_this());
93 }
94 std::weak_ptr<MultiplexedSocket> weak()
95 {
96 return std::static_pointer_cast<MultiplexedSocket>(shared_from_this());
97 }
98 std::weak_ptr<MultiplexedSocket const> weak() const
99 {
100 return std::static_pointer_cast<MultiplexedSocket const>(shared_from_this());
101 }
102
103 DeviceId deviceId() const;
104 bool isReliable() const;
105 bool isInitiator() const;
106 int maxPayload() const;
107
108 /**
109 * Will be triggered when a new channel is ready
110 */
111 void setOnReady(OnConnectionReadyCb&& cb);
112 /**
113 * Will be triggered when the peer asks for a new channel
114 */
115 void setOnRequest(OnConnectionRequestCb&& cb);
116
117 std::size_t write(const uint16_t& channel,
118 const uint8_t* buf,
119 std::size_t len,
120 std::error_code& ec);
121
122 /**
123 * This will close all channels and send a TLS EOF on the main socket.
124 */
125 void shutdown();
126
127 /**
128 * This will wait that eventLoop is stopped and stop it if necessary
129 */
130 void join();
131
132 /**
133 * Will trigger that callback when shutdown() is called
134 */
135 void onShutdown(OnShutdownCb&& cb);
136
137 /**
138 * Get informations from socket (channels opened)
139 */
140 void monitor() const;
141
142 const std::shared_ptr<Logger>& logger();
143
144 /**
Amna31791e52023-08-03 12:40:57 -0400145 * Get the list of channels
146 */
147 std::vector<std::map<std::string, std::string>> getChannelList() const;
148
149 /**
Adrien Béraud612b55b2023-05-29 10:42:04 -0400150 * Send a beacon on the socket and close if no response come
151 * @param timeout
152 */
153 void sendBeacon(const std::chrono::milliseconds& timeout = SEND_BEACON_TIMEOUT);
154
155 /**
156 * Get peer's certificate
157 */
158 std::shared_ptr<dht::crypto::Certificate> peerCertificate() const;
159
160 IpAddr getLocalAddress() const;
161 IpAddr getRemoteAddress() const;
162
163 void eraseChannel(uint16_t channel);
164
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400165#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400166 /**
167 * Check if we can send beacon on the socket
168 */
169 bool canSendBeacon() const;
170
171 /**
172 * Decide if yes or not we answer to beacon
173 * @param value New value
174 */
175 void answerToBeacon(bool value);
176
177 /**
178 * Change version sent to the peer
179 */
180 void setVersion(int version);
181
182 /**
183 * Set a callback to detect beacon messages
184 */
185 void setOnBeaconCb(const std::function<void(bool)>& cb);
186
187 /**
188 * Set a callback to detect version messages
189 */
190 void setOnVersionCb(const std::function<void(int)>& cb);
191
192 /**
193 * Send the version
194 */
195 void sendVersion();
196#endif
197
198private:
199 class Impl;
200 std::unique_ptr<Impl> pimpl_;
201};
202
203class ChannelSocketInterface : public GenericSocket<uint8_t>
204{
205public:
206 using SocketType = GenericSocket<uint8_t>;
207
208 virtual DeviceId deviceId() const = 0;
209 virtual std::string name() const = 0;
210 virtual uint16_t channel() const = 0;
211 /**
212 * Triggered when a specific channel is ready
213 * Used by ConnectionManager::connectDevice()
214 */
215 virtual void onReady(ChannelReadyCb&& cb) = 0;
216 /**
217 * Will trigger that callback when shutdown() is called
218 */
219 virtual void onShutdown(OnShutdownCb&& cb) = 0;
220
221 virtual void onRecv(std::vector<uint8_t>&& pkt) = 0;
222};
223
224class ChannelSocketTest : public ChannelSocketInterface
225{
226public:
227 ChannelSocketTest(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId, const std::string& name, const uint16_t& channel);
228 ~ChannelSocketTest();
229
230 static void link(const std::shared_ptr<ChannelSocketTest>& socket1,
231 const std::shared_ptr<ChannelSocketTest>& socket2);
232
233 DeviceId deviceId() const override;
234 std::string name() const override;
235 uint16_t channel() const override;
236
237 bool isReliable() const override { return true; };
238 bool isInitiator() const override { return true; };
239 int maxPayload() const override { return 0; };
240
241 void shutdown() override;
242
243 std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override;
244 std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override;
245 int waitForData(std::chrono::milliseconds timeout, std::error_code&) const override;
246 void setOnRecv(RecvCb&&) override;
247 void onRecv(std::vector<uint8_t>&& pkt) override;
248
249 /**
250 * Triggered when a specific channel is ready
251 * Used by ConnectionManager::connectDevice()
252 */
253 void onReady(ChannelReadyCb&& cb) override;
254 /**
255 * Will trigger that callback when shutdown() is called
256 */
257 void onShutdown(OnShutdownCb&& cb) override;
258
259 std::vector<uint8_t> rx_buf {};
260 mutable std::mutex mutex {};
261 mutable std::condition_variable cv {};
262 GenericSocket<uint8_t>::RecvCb cb {};
263
264private:
265 const DeviceId pimpl_deviceId;
266 const std::string pimpl_name;
267 const uint16_t pimpl_channel;
268 asio::io_context& ioCtx_;
269 std::weak_ptr<ChannelSocketTest> remote;
270 OnShutdownCb shutdownCb_ {[&] {
271 }};
272 std::atomic_bool isShutdown_ {false};
273};
274
275/**
276 * Represents a channel of the multiplexed socket (channel, name)
277 */
278class ChannelSocket : public ChannelSocketInterface
279{
280public:
281 ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
282 const std::string& name,
283 const uint16_t& channel,
284 bool isInitiator = false,
285 std::function<void()> rmFromMxSockCb = {});
286 ~ChannelSocket();
287
288 DeviceId deviceId() const override;
289 std::string name() const override;
290 uint16_t channel() const override;
291 bool isReliable() const override;
292 bool isInitiator() const override;
293 int maxPayload() const override;
294 /**
295 * Like shutdown, but don't send any packet on the socket.
296 * Used by Multiplexed Socket when the TLS endpoint is already shutting down
297 */
298 void stop();
299
300 /**
301 * This will send an empty buffer as a packet (equivalent to EOF)
302 * Will trigger onShutdown's callback
303 */
304 void shutdown() override;
305
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400306 void ready(bool accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400307 /**
308 * Triggered when a specific channel is ready
309 * Used by ConnectionManager::connectDevice()
310 */
311 void onReady(ChannelReadyCb&& cb) override;
312 /**
313 * Will trigger that callback when shutdown() is called
314 */
315 void onShutdown(OnShutdownCb&& cb) override;
316
317 std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override;
318 /**
319 * @note len should be < UINT8_MAX, else you will get ec = EMSGSIZE
320 */
321 std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override;
322 int waitForData(std::chrono::milliseconds timeout, std::error_code&) const override;
323
324 /**
325 * set a callback when receiving data
326 * @note: this callback should take a little time and not block
327 * but you can move it in a thread
328 */
329 void setOnRecv(RecvCb&&) override;
330
331 void onRecv(std::vector<uint8_t>&& pkt) override;
332
333 /**
334 * Send a beacon on the socket and close if no response come
335 * @param timeout
336 */
337 void sendBeacon(const std::chrono::milliseconds& timeout = SEND_BEACON_TIMEOUT);
338
339 /**
340 * Get peer's certificate
341 */
342 std::shared_ptr<dht::crypto::Certificate> peerCertificate() const;
343
Adrien Béraud6b6a5d32023-08-15 15:53:33 -0400344#ifdef DHTNET_TESTABLE
Adrien Béraud612b55b2023-05-29 10:42:04 -0400345 std::shared_ptr<MultiplexedSocket> underlyingSocket() const;
346#endif
347
348 // Note: When a channel is accepted, it can receives data ASAP and when finished will be removed
349 // however, onAccept is it's own thread due to the callbacks. In this case, the channel must be
350 // deleted in the onAccept.
351 void answered();
352 bool isAnswered() const;
353 void removable();
354 bool isRemovable() const;
355
356 IpAddr getLocalAddress() const;
357 IpAddr getRemoteAddress() const;
358
359private:
360 class Impl;
361 std::unique_ptr<Impl> pimpl_;
362};
363
Sébastien Blin464bdff2023-07-19 08:02:53 -0400364} // namespace dhtnet
Adrien Béraud612b55b2023-05-29 10:42:04 -0400365
Adrien Béraud1ae60aa2023-07-07 09:55:09 -0400366MSGPACK_ADD_ENUM(dhtnet::ChannelRequestState);