blob: c1bf14e96b6d4073db44f7e5ea73a6dea99ef839 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
2 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
3 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
Adrien Béraudcb753622023-07-17 22:32:49 -040015 * along with this program. If not, see <https://www.gnu.org/licenses/>.
Adrien Béraud612b55b2023-05-29 10:42:04 -040016 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "peer_connection.h"
18#include "tls_session.h"
19
20#include <opendht/thread_pool.h>
21#include <opendht/logger.h>
22
23#include <algorithm>
24#include <chrono>
25#include <future>
26#include <vector>
27#include <atomic>
28#include <stdexcept>
29#include <istream>
30#include <ostream>
31#include <unistd.h>
32#include <cstdio>
33
34#ifdef _WIN32
35#include <winsock2.h>
36#include <ws2tcpip.h>
37#else
38#include <sys/select.h>
39#endif
40
41#ifndef _MSC_VER
42#include <sys/time.h>
43#endif
44
45static constexpr int ICE_COMP_ID_SIP_TRANSPORT {1};
46
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040047namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040048
49int
50init_crt(gnutls_session_t session, dht::crypto::Certificate& crt)
51{
52 // Support only x509 format
53 if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) {
54 return GNUTLS_E_CERTIFICATE_ERROR;
55 }
56
57 // Store verification status
58 unsigned int status = 0;
59 auto ret = gnutls_certificate_verify_peers2(session, &status);
60 if (ret < 0 or (status & GNUTLS_CERT_SIGNATURE_FAILURE) != 0) {
61 return GNUTLS_E_CERTIFICATE_ERROR;
62 }
63
64 unsigned int cert_list_size = 0;
65 auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size);
66 if (cert_list == nullptr) {
67 return GNUTLS_E_CERTIFICATE_ERROR;
68 }
69
70 // Check if received peer certificate is awaited
71 std::vector<std::pair<uint8_t*, uint8_t*>> crt_data;
72 crt_data.reserve(cert_list_size);
73 for (unsigned i = 0; i < cert_list_size; i++)
74 crt_data.emplace_back(cert_list[i].data, cert_list[i].data + cert_list[i].size);
75 crt = dht::crypto::Certificate {crt_data};
76
77 return GNUTLS_E_SUCCESS;
78}
79
80using lock = std::lock_guard<std::mutex>;
81
82//==============================================================================
83
84IceSocketEndpoint::IceSocketEndpoint(std::shared_ptr<IceTransport> ice, bool isSender)
85 : ice_(std::move(ice))
86 , iceIsSender(isSender)
87{}
88
89IceSocketEndpoint::~IceSocketEndpoint()
90{
91 shutdown();
92 if (ice_)
93 dht::ThreadPool::io().run([ice = std::move(ice_)] {});
94}
95
96void
97IceSocketEndpoint::shutdown()
98{
99 // Sometimes the other peer never send any packet
100 // So, we cancel pending read to avoid to have
101 // any blocking operation.
102 if (ice_)
103 ice_->cancelOperations();
104}
105
106int
107IceSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
108{
109 if (ice_) {
110 if (!ice_->isRunning())
111 return -1;
112 return ice_->waitForData(compId_, timeout, ec);
113 }
114 return -1;
115}
116
117std::size_t
118IceSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
119{
120 if (ice_) {
121 if (!ice_->isRunning())
122 return 0;
123 try {
124 auto res = ice_->recvfrom(compId_, reinterpret_cast<char*>(buf), len, ec);
125 if (res < 0)
126 shutdown();
127 return res;
128 } catch (const std::exception& e) {
129 if (auto logger = ice_->logger())
130 logger->error("IceSocketEndpoint::read exception: %s", e.what());
131 }
132 return 0;
133 }
134 return -1;
135}
136
137std::size_t
138IceSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
139{
140 if (ice_) {
141 if (!ice_->isRunning())
142 return 0;
143 auto res = 0;
144 res = ice_->send(compId_, reinterpret_cast<const unsigned char*>(buf), len);
145 if (res < 0) {
146 ec.assign(errno, std::generic_category());
147 shutdown();
148 } else {
149 ec.clear();
150 }
151 return res;
152 }
153 return -1;
154}
155
156//==============================================================================
157
158class TlsSocketEndpoint::Impl
159{
160public:
161 static constexpr auto TLS_TIMEOUT = std::chrono::seconds(40);
162
163 Impl(std::unique_ptr<IceSocketEndpoint>&& ep,
164 tls::CertificateStore& certStore,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400165 const std::shared_ptr<asio::io_context>& ioContext,
Adrien Béraud612b55b2023-05-29 10:42:04 -0400166 const dht::crypto::Certificate& peer_cert,
167 const Identity& local_identity,
168 const std::shared_future<tls::DhParams>& dh_params)
169 : peerCertificate {peer_cert}
170 , ep_ {ep.get()}
171 {
172 tls::TlsSession::TlsSessionCallbacks tls_cbs
173 = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); },
174 /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); },
175 /*.onCertificatesUpdate = */
176 [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) {
177 onTlsCertificatesUpdate(l, r, n);
178 },
179 /*.verifyCertificate = */
180 [this](gnutls_session_t session) {
181 return verifyCertificate(session);
182 }};
183 tls::TlsParams tls_param = {
184 /*.ca_list = */ "",
185 /*.peer_ca = */ nullptr,
186 /*.cert = */ local_identity.second,
187 /*.cert_key = */ local_identity.first,
188 /*.dh_params = */ dh_params,
189 /*.certStore = */ certStore,
190 /*.timeout = */ TLS_TIMEOUT,
191 /*.cert_check = */ nullptr,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400192 /*.io_context = */ ioContext,
193 /* .logger = */ ep->underlyingICE()->logger()
Adrien Béraud612b55b2023-05-29 10:42:04 -0400194 };
195 tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
196 }
197
198 Impl(std::unique_ptr<IceSocketEndpoint>&& ep,
199 tls::CertificateStore& certStore,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400200 std::shared_ptr<asio::io_context> ioContext,
Adrien Béraud612b55b2023-05-29 10:42:04 -0400201 std::function<bool(const dht::crypto::Certificate&)>&& cert_check,
202 const Identity& local_identity,
203 const std::shared_future<tls::DhParams>& dh_params)
204 : peerCertificateCheckFunc {std::move(cert_check)}
205 , peerCertificate {null_cert}
206 , ep_ {ep.get()}
207 {
208 tls::TlsSession::TlsSessionCallbacks tls_cbs
209 = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); },
210 /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); },
211 /*.onCertificatesUpdate = */
212 [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) {
213 onTlsCertificatesUpdate(l, r, n);
214 },
215 /*.verifyCertificate = */
216 [this](gnutls_session_t session) {
217 return verifyCertificate(session);
218 }};
219 tls::TlsParams tls_param = {
220 /*.ca_list = */ "",
221 /*.peer_ca = */ nullptr,
222 /*.cert = */ local_identity.second,
223 /*.cert_key = */ local_identity.first,
224 /*.dh_params = */ dh_params,
225 /*.certStore = */ certStore,
226 /*.timeout = */ std::chrono::duration_cast<decltype(tls::TlsParams::timeout)>(TLS_TIMEOUT),
227 /*.cert_check = */ nullptr,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400228 /*.io_context = */ ioContext,
229 /* .logger = */ ep->underlyingICE()->logger()
Adrien Béraud612b55b2023-05-29 10:42:04 -0400230 };
231 tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
232 }
233
234 ~Impl()
235 {
236 {
237 std::lock_guard<std::mutex> lk(cbMtx_);
238 onStateChangeCb_ = {};
239 onReadyCb_ = {};
240 }
241 tls.reset();
242 }
243
244 std::shared_ptr<IceTransport> underlyingICE() const
245 {
246 if (ep_)
247 if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_))
248 return iceSocket->underlyingICE();
249 return {};
250 }
251
252 // TLS callbacks
253 int verifyCertificate(gnutls_session_t);
254 void onTlsStateChange(tls::TlsSessionState);
255 void onTlsRxData(std::vector<uint8_t>&&);
256 void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int);
257
258 std::mutex cbMtx_ {};
259 OnStateChangeCb onStateChangeCb_;
260 dht::crypto::Certificate null_cert;
261 std::function<bool(const dht::crypto::Certificate&)> peerCertificateCheckFunc;
262 const dht::crypto::Certificate& peerCertificate;
263 std::atomic_bool isReady_ {false};
264 OnReadyCb onReadyCb_;
265 std::unique_ptr<tls::TlsSession> tls;
266 const IceSocketEndpoint* ep_;
267};
268
269int
270TlsSocketEndpoint::Impl::verifyCertificate(gnutls_session_t session)
271{
272 dht::crypto::Certificate crt;
273 auto verified = init_crt(session, crt);
274 if (verified != GNUTLS_E_SUCCESS)
275 return verified;
276 if (peerCertificateCheckFunc) {
277 if (!peerCertificateCheckFunc(crt)) {
278 if (const auto& logger = tls->logger())
279 logger->error("[TLS-SOCKET] Refusing peer certificate");
280 return GNUTLS_E_CERTIFICATE_ERROR;
281 }
282
283 null_cert = std::move(crt);
284 } else {
285 if (crt.getPacked() != peerCertificate.getPacked()) {
286 if (const auto& logger = tls->logger())
287 logger->error("[TLS-SOCKET] Unexpected peer certificate");
288 return GNUTLS_E_CERTIFICATE_ERROR;
289 }
290 }
291
292 return GNUTLS_E_SUCCESS;
293}
294
295void
296TlsSocketEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state)
297{
298 std::lock_guard<std::mutex> lk(cbMtx_);
299 if ((state == tls::TlsSessionState::SHUTDOWN || state == tls::TlsSessionState::ESTABLISHED)
300 && !isReady_) {
301 isReady_ = true;
302 if (onReadyCb_)
303 onReadyCb_(state == tls::TlsSessionState::ESTABLISHED);
304 }
305 if (onStateChangeCb_ && !onStateChangeCb_(state))
306 onStateChangeCb_ = {};
307}
308
309void
310TlsSocketEndpoint::Impl::onTlsRxData([[maybe_unused]] std::vector<uint8_t>&& buf)
311{}
312
313void
314TlsSocketEndpoint::Impl::onTlsCertificatesUpdate([[maybe_unused]] const gnutls_datum_t* local_raw,
315 [[maybe_unused]] const gnutls_datum_t* remote_raw,
316 [[maybe_unused]] unsigned int remote_count)
317{}
318
319TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr,
320 tls::CertificateStore& certStore,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400321 const std::shared_ptr<asio::io_context>& ioContext,
Adrien Béraud612b55b2023-05-29 10:42:04 -0400322 const Identity& local_identity,
323 const std::shared_future<tls::DhParams>& dh_params,
324 const dht::crypto::Certificate& peer_cert)
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400325 : pimpl_ {std::make_unique<Impl>(std::move(tr), certStore, ioContext, peer_cert, local_identity, dh_params)}
Adrien Béraud612b55b2023-05-29 10:42:04 -0400326{}
327
328TlsSocketEndpoint::TlsSocketEndpoint(
329 std::unique_ptr<IceSocketEndpoint>&& tr,
330 tls::CertificateStore& certStore,
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400331 const std::shared_ptr<asio::io_context>& ioContext,
Adrien Béraud612b55b2023-05-29 10:42:04 -0400332 const Identity& local_identity,
333 const std::shared_future<tls::DhParams>& dh_params,
334 std::function<bool(const dht::crypto::Certificate&)>&& cert_check)
335 : pimpl_ {
Adrien Béraud3f93ddf2023-07-21 14:46:22 -0400336 std::make_unique<Impl>(std::move(tr), certStore, ioContext, std::move(cert_check), local_identity, dh_params)}
Adrien Béraud612b55b2023-05-29 10:42:04 -0400337{}
338
339TlsSocketEndpoint::~TlsSocketEndpoint() {}
340
341bool
342TlsSocketEndpoint::isInitiator() const
343{
344 if (!pimpl_->tls) {
345 return false;
346 }
347 return pimpl_->tls->isInitiator();
348}
349
350int
351TlsSocketEndpoint::maxPayload() const
352{
353 if (!pimpl_->tls) {
354 return -1;
355 }
356 return pimpl_->tls->maxPayload();
357}
358
359std::size_t
360TlsSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
361{
362 if (!pimpl_->tls) {
363 ec = std::make_error_code(std::errc::broken_pipe);
364 return -1;
365 }
366 return pimpl_->tls->read(buf, len, ec);
367}
368
369std::size_t
370TlsSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
371{
372 if (!pimpl_->tls) {
373 ec = std::make_error_code(std::errc::broken_pipe);
374 return -1;
375 }
376 return pimpl_->tls->write(buf, len, ec);
377}
378
379std::shared_ptr<dht::crypto::Certificate>
380TlsSocketEndpoint::peerCertificate() const
381{
382 if (!pimpl_->tls)
383 return {};
384 return pimpl_->tls->peerCertificate();
385}
386
387void
388TlsSocketEndpoint::waitForReady(const std::chrono::milliseconds& timeout)
389{
390 if (!pimpl_->tls) {
391 return;
392 }
393 pimpl_->tls->waitForReady(timeout);
394}
395
396int
397TlsSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
398{
399 if (!pimpl_->tls) {
400 ec = std::make_error_code(std::errc::broken_pipe);
401 return -1;
402 }
403 return pimpl_->tls->waitForData(timeout, ec);
404}
405
406void
407TlsSocketEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb)
408{
409 std::lock_guard<std::mutex> lk(pimpl_->cbMtx_);
410 pimpl_->onStateChangeCb_ = std::move(cb);
411}
412
413void
414TlsSocketEndpoint::setOnReady(std::function<void(bool ok)>&& cb)
415{
416 std::lock_guard<std::mutex> lk(pimpl_->cbMtx_);
417 pimpl_->onReadyCb_ = std::move(cb);
418}
419
420void
421TlsSocketEndpoint::shutdown()
422{
423 pimpl_->tls->shutdown();
424 if (pimpl_->ep_) {
425 const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_);
426 if (iceSocket && iceSocket->underlyingICE())
427 iceSocket->underlyingICE()->cancelOperations();
428 }
429}
430
431void
432TlsSocketEndpoint::monitor() const
433{
434 if (auto ice = pimpl_->underlyingICE())
435 if (auto logger = ice->logger())
436 logger->debug("\t- Ice connection: {}", ice->link());
437}
438
439IpAddr
440TlsSocketEndpoint::getLocalAddress() const
441{
442 if (auto ice = pimpl_->underlyingICE())
443 return ice->getLocalAddress(ICE_COMP_ID_SIP_TRANSPORT);
444 return {};
445}
446
447IpAddr
448TlsSocketEndpoint::getRemoteAddress() const
449{
450 if (auto ice = pimpl_->underlyingICE())
451 return ice->getRemoteAddress(ICE_COMP_ID_SIP_TRANSPORT);
452 return {};
453}
454
Sébastien Blin464bdff2023-07-19 08:02:53 -0400455} // namespace dhtnet