blob: 3fbd8205e94fb7ee1e966f3eb4243904142fda4a [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
2 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
3 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
Adrien Béraudcb753622023-07-17 22:32:49 -040015 * along with this program. If not, see <https://www.gnu.org/licenses/>.
Adrien Béraud612b55b2023-05-29 10:42:04 -040016 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "peer_connection.h"
18#include "tls_session.h"
19
20#include <opendht/thread_pool.h>
21#include <opendht/logger.h>
22
23#include <algorithm>
24#include <chrono>
25#include <future>
26#include <vector>
27#include <atomic>
28#include <stdexcept>
29#include <istream>
30#include <ostream>
31#include <unistd.h>
32#include <cstdio>
33
34#ifdef _WIN32
35#include <winsock2.h>
36#include <ws2tcpip.h>
37#else
38#include <sys/select.h>
39#endif
40
41#ifndef _MSC_VER
42#include <sys/time.h>
43#endif
44
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040045namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040046
47int
48init_crt(gnutls_session_t session, dht::crypto::Certificate& crt)
49{
50 // Support only x509 format
51 if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) {
52 return GNUTLS_E_CERTIFICATE_ERROR;
53 }
54
55 // Store verification status
56 unsigned int status = 0;
57 auto ret = gnutls_certificate_verify_peers2(session, &status);
58 if (ret < 0 or (status & GNUTLS_CERT_SIGNATURE_FAILURE) != 0) {
59 return GNUTLS_E_CERTIFICATE_ERROR;
60 }
61
62 unsigned int cert_list_size = 0;
63 auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size);
64 if (cert_list == nullptr) {
65 return GNUTLS_E_CERTIFICATE_ERROR;
66 }
67
68 // Check if received peer certificate is awaited
69 std::vector<std::pair<uint8_t*, uint8_t*>> crt_data;
70 crt_data.reserve(cert_list_size);
71 for (unsigned i = 0; i < cert_list_size; i++)
72 crt_data.emplace_back(cert_list[i].data, cert_list[i].data + cert_list[i].size);
73 crt = dht::crypto::Certificate {crt_data};
74
75 return GNUTLS_E_SUCCESS;
76}
77
78using lock = std::lock_guard<std::mutex>;
79
80//==============================================================================
81
82IceSocketEndpoint::IceSocketEndpoint(std::shared_ptr<IceTransport> ice, bool isSender)
83 : ice_(std::move(ice))
84 , iceIsSender(isSender)
85{}
86
87IceSocketEndpoint::~IceSocketEndpoint()
88{
89 shutdown();
90 if (ice_)
91 dht::ThreadPool::io().run([ice = std::move(ice_)] {});
92}
93
94void
95IceSocketEndpoint::shutdown()
96{
97 // Sometimes the other peer never send any packet
98 // So, we cancel pending read to avoid to have
99 // any blocking operation.
100 if (ice_)
101 ice_->cancelOperations();
102}
103
104int
105IceSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
106{
107 if (ice_) {
108 if (!ice_->isRunning())
109 return -1;
110 return ice_->waitForData(compId_, timeout, ec);
111 }
112 return -1;
113}
114
115std::size_t
116IceSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
117{
118 if (ice_) {
119 if (!ice_->isRunning())
120 return 0;
121 try {
122 auto res = ice_->recvfrom(compId_, reinterpret_cast<char*>(buf), len, ec);
123 if (res < 0)
124 shutdown();
125 return res;
126 } catch (const std::exception& e) {
127 if (auto logger = ice_->logger())
128 logger->error("IceSocketEndpoint::read exception: %s", e.what());
129 }
130 return 0;
131 }
132 return -1;
133}
134
135std::size_t
136IceSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
137{
138 if (ice_) {
139 if (!ice_->isRunning())
140 return 0;
141 auto res = 0;
142 res = ice_->send(compId_, reinterpret_cast<const unsigned char*>(buf), len);
143 if (res < 0) {
144 ec.assign(errno, std::generic_category());
145 shutdown();
146 } else {
147 ec.clear();
148 }
149 return res;
150 }
151 return -1;
152}
153
154//==============================================================================
155
156class TlsSocketEndpoint::Impl
157{
158public:
159 static constexpr auto TLS_TIMEOUT = std::chrono::seconds(40);
160
161 Impl(std::unique_ptr<IceSocketEndpoint>&& ep,
162 tls::CertificateStore& certStore,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400163 const std::shared_ptr<asio::io_context>& ioContext,
Adrien Béraud612b55b2023-05-29 10:42:04 -0400164 const dht::crypto::Certificate& peer_cert,
165 const Identity& local_identity,
166 const std::shared_future<tls::DhParams>& dh_params)
167 : peerCertificate {peer_cert}
168 , ep_ {ep.get()}
169 {
170 tls::TlsSession::TlsSessionCallbacks tls_cbs
171 = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); },
172 /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); },
173 /*.onCertificatesUpdate = */
174 [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) {
175 onTlsCertificatesUpdate(l, r, n);
176 },
177 /*.verifyCertificate = */
178 [this](gnutls_session_t session) {
179 return verifyCertificate(session);
180 }};
181 tls::TlsParams tls_param = {
182 /*.ca_list = */ "",
183 /*.peer_ca = */ nullptr,
184 /*.cert = */ local_identity.second,
185 /*.cert_key = */ local_identity.first,
186 /*.dh_params = */ dh_params,
187 /*.certStore = */ certStore,
188 /*.timeout = */ TLS_TIMEOUT,
189 /*.cert_check = */ nullptr,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400190 /*.io_context = */ ioContext,
191 /* .logger = */ ep->underlyingICE()->logger()
Adrien Béraud612b55b2023-05-29 10:42:04 -0400192 };
193 tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
194 }
195
196 Impl(std::unique_ptr<IceSocketEndpoint>&& ep,
197 tls::CertificateStore& certStore,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400198 std::shared_ptr<asio::io_context> ioContext,
Adrien Béraud612b55b2023-05-29 10:42:04 -0400199 std::function<bool(const dht::crypto::Certificate&)>&& cert_check,
200 const Identity& local_identity,
201 const std::shared_future<tls::DhParams>& dh_params)
202 : peerCertificateCheckFunc {std::move(cert_check)}
203 , peerCertificate {null_cert}
204 , ep_ {ep.get()}
205 {
206 tls::TlsSession::TlsSessionCallbacks tls_cbs
207 = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); },
208 /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); },
209 /*.onCertificatesUpdate = */
210 [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) {
211 onTlsCertificatesUpdate(l, r, n);
212 },
213 /*.verifyCertificate = */
214 [this](gnutls_session_t session) {
215 return verifyCertificate(session);
216 }};
217 tls::TlsParams tls_param = {
218 /*.ca_list = */ "",
219 /*.peer_ca = */ nullptr,
220 /*.cert = */ local_identity.second,
221 /*.cert_key = */ local_identity.first,
222 /*.dh_params = */ dh_params,
223 /*.certStore = */ certStore,
224 /*.timeout = */ std::chrono::duration_cast<decltype(tls::TlsParams::timeout)>(TLS_TIMEOUT),
225 /*.cert_check = */ nullptr,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400226 /*.io_context = */ ioContext,
227 /* .logger = */ ep->underlyingICE()->logger()
Adrien Béraud612b55b2023-05-29 10:42:04 -0400228 };
229 tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
230 }
231
232 ~Impl()
233 {
234 {
235 std::lock_guard<std::mutex> lk(cbMtx_);
236 onStateChangeCb_ = {};
237 onReadyCb_ = {};
238 }
239 tls.reset();
240 }
241
242 std::shared_ptr<IceTransport> underlyingICE() const
243 {
244 if (ep_)
245 if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_))
246 return iceSocket->underlyingICE();
247 return {};
248 }
249
250 // TLS callbacks
251 int verifyCertificate(gnutls_session_t);
252 void onTlsStateChange(tls::TlsSessionState);
253 void onTlsRxData(std::vector<uint8_t>&&);
254 void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int);
255
256 std::mutex cbMtx_ {};
257 OnStateChangeCb onStateChangeCb_;
258 dht::crypto::Certificate null_cert;
259 std::function<bool(const dht::crypto::Certificate&)> peerCertificateCheckFunc;
260 const dht::crypto::Certificate& peerCertificate;
261 std::atomic_bool isReady_ {false};
262 OnReadyCb onReadyCb_;
263 std::unique_ptr<tls::TlsSession> tls;
264 const IceSocketEndpoint* ep_;
265};
266
267int
268TlsSocketEndpoint::Impl::verifyCertificate(gnutls_session_t session)
269{
270 dht::crypto::Certificate crt;
271 auto verified = init_crt(session, crt);
272 if (verified != GNUTLS_E_SUCCESS)
273 return verified;
274 if (peerCertificateCheckFunc) {
275 if (!peerCertificateCheckFunc(crt)) {
276 if (const auto& logger = tls->logger())
277 logger->error("[TLS-SOCKET] Refusing peer certificate");
278 return GNUTLS_E_CERTIFICATE_ERROR;
279 }
280
281 null_cert = std::move(crt);
282 } else {
283 if (crt.getPacked() != peerCertificate.getPacked()) {
284 if (const auto& logger = tls->logger())
285 logger->error("[TLS-SOCKET] Unexpected peer certificate");
286 return GNUTLS_E_CERTIFICATE_ERROR;
287 }
288 }
289
290 return GNUTLS_E_SUCCESS;
291}
292
293void
294TlsSocketEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state)
295{
296 std::lock_guard<std::mutex> lk(cbMtx_);
297 if ((state == tls::TlsSessionState::SHUTDOWN || state == tls::TlsSessionState::ESTABLISHED)
298 && !isReady_) {
299 isReady_ = true;
300 if (onReadyCb_)
301 onReadyCb_(state == tls::TlsSessionState::ESTABLISHED);
302 }
303 if (onStateChangeCb_ && !onStateChangeCb_(state))
304 onStateChangeCb_ = {};
305}
306
307void
308TlsSocketEndpoint::Impl::onTlsRxData([[maybe_unused]] std::vector<uint8_t>&& buf)
309{}
310
311void
312TlsSocketEndpoint::Impl::onTlsCertificatesUpdate([[maybe_unused]] const gnutls_datum_t* local_raw,
313 [[maybe_unused]] const gnutls_datum_t* remote_raw,
314 [[maybe_unused]] unsigned int remote_count)
315{}
316
317TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr,
318 tls::CertificateStore& certStore,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400319 const std::shared_ptr<asio::io_context>& ioContext,
Adrien Béraud612b55b2023-05-29 10:42:04 -0400320 const Identity& local_identity,
321 const std::shared_future<tls::DhParams>& dh_params,
322 const dht::crypto::Certificate& peer_cert)
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400323 : pimpl_ {std::make_unique<Impl>(std::move(tr), certStore, ioContext, peer_cert, local_identity, dh_params)}
Adrien Béraud612b55b2023-05-29 10:42:04 -0400324{}
325
326TlsSocketEndpoint::TlsSocketEndpoint(
327 std::unique_ptr<IceSocketEndpoint>&& tr,
328 tls::CertificateStore& certStore,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400329 const std::shared_ptr<asio::io_context>& ioContext,
Adrien Béraud612b55b2023-05-29 10:42:04 -0400330 const Identity& local_identity,
331 const std::shared_future<tls::DhParams>& dh_params,
332 std::function<bool(const dht::crypto::Certificate&)>&& cert_check)
333 : pimpl_ {
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400334 std::make_unique<Impl>(std::move(tr), certStore, ioContext, std::move(cert_check), local_identity, dh_params)}
Adrien Béraud612b55b2023-05-29 10:42:04 -0400335{}
336
337TlsSocketEndpoint::~TlsSocketEndpoint() {}
338
339bool
340TlsSocketEndpoint::isInitiator() const
341{
342 if (!pimpl_->tls) {
343 return false;
344 }
345 return pimpl_->tls->isInitiator();
346}
347
348int
349TlsSocketEndpoint::maxPayload() const
350{
351 if (!pimpl_->tls) {
352 return -1;
353 }
354 return pimpl_->tls->maxPayload();
355}
356
357std::size_t
358TlsSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
359{
360 if (!pimpl_->tls) {
361 ec = std::make_error_code(std::errc::broken_pipe);
362 return -1;
363 }
364 return pimpl_->tls->read(buf, len, ec);
365}
366
367std::size_t
368TlsSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
369{
370 if (!pimpl_->tls) {
371 ec = std::make_error_code(std::errc::broken_pipe);
372 return -1;
373 }
374 return pimpl_->tls->write(buf, len, ec);
375}
376
377std::shared_ptr<dht::crypto::Certificate>
378TlsSocketEndpoint::peerCertificate() const
379{
380 if (!pimpl_->tls)
381 return {};
382 return pimpl_->tls->peerCertificate();
383}
384
Adrien Béraud612b55b2023-05-29 10:42:04 -0400385int
386TlsSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
387{
388 if (!pimpl_->tls) {
389 ec = std::make_error_code(std::errc::broken_pipe);
390 return -1;
391 }
392 return pimpl_->tls->waitForData(timeout, ec);
393}
394
395void
396TlsSocketEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb)
397{
398 std::lock_guard<std::mutex> lk(pimpl_->cbMtx_);
399 pimpl_->onStateChangeCb_ = std::move(cb);
400}
401
402void
403TlsSocketEndpoint::setOnReady(std::function<void(bool ok)>&& cb)
404{
405 std::lock_guard<std::mutex> lk(pimpl_->cbMtx_);
406 pimpl_->onReadyCb_ = std::move(cb);
407}
408
409void
410TlsSocketEndpoint::shutdown()
411{
412 pimpl_->tls->shutdown();
413 if (pimpl_->ep_) {
414 const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_);
415 if (iceSocket && iceSocket->underlyingICE())
416 iceSocket->underlyingICE()->cancelOperations();
417 }
418}
419
420void
421TlsSocketEndpoint::monitor() const
422{
423 if (auto ice = pimpl_->underlyingICE())
424 if (auto logger = ice->logger())
425 logger->debug("\t- Ice connection: {}", ice->link());
426}
427
428IpAddr
429TlsSocketEndpoint::getLocalAddress() const
430{
431 if (auto ice = pimpl_->underlyingICE())
432 return ice->getLocalAddress(ICE_COMP_ID_SIP_TRANSPORT);
433 return {};
434}
435
436IpAddr
437TlsSocketEndpoint::getRemoteAddress() const
438{
439 if (auto ice = pimpl_->underlyingICE())
440 return ice->getRemoteAddress(ICE_COMP_ID_SIP_TRANSPORT);
441 return {};
442}
443
Sébastien Blin464bdff2023-07-19 08:02:53 -0400444} // namespace dhtnet