blob: 42bfbc011e921bf204df3bcb912a128951602e11 [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
2 * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
3 *
Adrien Béraudcb753622023-07-17 22:32:49 -04004 * This program is free software: you can redistribute it and/or modify
Adrien Béraud612b55b2023-05-29 10:42:04 -04005 * it under the terms of the GNU General Public License as published by
Adrien Béraudcb753622023-07-17 22:32:49 -04006 * the Free Software Foundation, either version 3 of the License, or
Adrien Béraud612b55b2023-05-29 10:42:04 -04007 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Adrien Béraudcb753622023-07-17 22:32:49 -040011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Adrien Béraud612b55b2023-05-29 10:42:04 -040012 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
Adrien Béraudcb753622023-07-17 22:32:49 -040015 * along with this program. If not, see <https://www.gnu.org/licenses/>.
Adrien Béraud612b55b2023-05-29 10:42:04 -040016 */
Adrien Béraud612b55b2023-05-29 10:42:04 -040017#include "peer_connection.h"
18#include "tls_session.h"
19
20#include <opendht/thread_pool.h>
21#include <opendht/logger.h>
22
23#include <algorithm>
24#include <chrono>
25#include <future>
26#include <vector>
27#include <atomic>
28#include <stdexcept>
29#include <istream>
30#include <ostream>
31#include <unistd.h>
32#include <cstdio>
33
34#ifdef _WIN32
35#include <winsock2.h>
36#include <ws2tcpip.h>
37#else
38#include <sys/select.h>
39#endif
40
41#ifndef _MSC_VER
42#include <sys/time.h>
43#endif
44
45static constexpr int ICE_COMP_ID_SIP_TRANSPORT {1};
46
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040047namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040048
49int
50init_crt(gnutls_session_t session, dht::crypto::Certificate& crt)
51{
52 // Support only x509 format
53 if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) {
54 return GNUTLS_E_CERTIFICATE_ERROR;
55 }
56
57 // Store verification status
58 unsigned int status = 0;
59 auto ret = gnutls_certificate_verify_peers2(session, &status);
60 if (ret < 0 or (status & GNUTLS_CERT_SIGNATURE_FAILURE) != 0) {
61 return GNUTLS_E_CERTIFICATE_ERROR;
62 }
63
64 unsigned int cert_list_size = 0;
65 auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size);
66 if (cert_list == nullptr) {
67 return GNUTLS_E_CERTIFICATE_ERROR;
68 }
69
70 // Check if received peer certificate is awaited
71 std::vector<std::pair<uint8_t*, uint8_t*>> crt_data;
72 crt_data.reserve(cert_list_size);
73 for (unsigned i = 0; i < cert_list_size; i++)
74 crt_data.emplace_back(cert_list[i].data, cert_list[i].data + cert_list[i].size);
75 crt = dht::crypto::Certificate {crt_data};
76
77 return GNUTLS_E_SUCCESS;
78}
79
80using lock = std::lock_guard<std::mutex>;
81
82//==============================================================================
83
84IceSocketEndpoint::IceSocketEndpoint(std::shared_ptr<IceTransport> ice, bool isSender)
85 : ice_(std::move(ice))
86 , iceIsSender(isSender)
87{}
88
89IceSocketEndpoint::~IceSocketEndpoint()
90{
91 shutdown();
92 if (ice_)
93 dht::ThreadPool::io().run([ice = std::move(ice_)] {});
94}
95
96void
97IceSocketEndpoint::shutdown()
98{
99 // Sometimes the other peer never send any packet
100 // So, we cancel pending read to avoid to have
101 // any blocking operation.
102 if (ice_)
103 ice_->cancelOperations();
104}
105
106int
107IceSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
108{
109 if (ice_) {
110 if (!ice_->isRunning())
111 return -1;
112 return ice_->waitForData(compId_, timeout, ec);
113 }
114 return -1;
115}
116
117std::size_t
118IceSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
119{
120 if (ice_) {
121 if (!ice_->isRunning())
122 return 0;
123 try {
124 auto res = ice_->recvfrom(compId_, reinterpret_cast<char*>(buf), len, ec);
125 if (res < 0)
126 shutdown();
127 return res;
128 } catch (const std::exception& e) {
129 if (auto logger = ice_->logger())
130 logger->error("IceSocketEndpoint::read exception: %s", e.what());
131 }
132 return 0;
133 }
134 return -1;
135}
136
137std::size_t
138IceSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
139{
140 if (ice_) {
141 if (!ice_->isRunning())
142 return 0;
143 auto res = 0;
144 res = ice_->send(compId_, reinterpret_cast<const unsigned char*>(buf), len);
145 if (res < 0) {
146 ec.assign(errno, std::generic_category());
147 shutdown();
148 } else {
149 ec.clear();
150 }
151 return res;
152 }
153 return -1;
154}
155
156//==============================================================================
157
158class TlsSocketEndpoint::Impl
159{
160public:
161 static constexpr auto TLS_TIMEOUT = std::chrono::seconds(40);
162
163 Impl(std::unique_ptr<IceSocketEndpoint>&& ep,
164 tls::CertificateStore& certStore,
165 const dht::crypto::Certificate& peer_cert,
166 const Identity& local_identity,
167 const std::shared_future<tls::DhParams>& dh_params)
168 : peerCertificate {peer_cert}
169 , ep_ {ep.get()}
170 {
171 tls::TlsSession::TlsSessionCallbacks tls_cbs
172 = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); },
173 /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); },
174 /*.onCertificatesUpdate = */
175 [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) {
176 onTlsCertificatesUpdate(l, r, n);
177 },
178 /*.verifyCertificate = */
179 [this](gnutls_session_t session) {
180 return verifyCertificate(session);
181 }};
182 tls::TlsParams tls_param = {
183 /*.ca_list = */ "",
184 /*.peer_ca = */ nullptr,
185 /*.cert = */ local_identity.second,
186 /*.cert_key = */ local_identity.first,
187 /*.dh_params = */ dh_params,
188 /*.certStore = */ certStore,
189 /*.timeout = */ TLS_TIMEOUT,
190 /*.cert_check = */ nullptr,
191 };
192 tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
193 }
194
195 Impl(std::unique_ptr<IceSocketEndpoint>&& ep,
196 tls::CertificateStore& certStore,
197 std::function<bool(const dht::crypto::Certificate&)>&& cert_check,
198 const Identity& local_identity,
199 const std::shared_future<tls::DhParams>& dh_params)
200 : peerCertificateCheckFunc {std::move(cert_check)}
201 , peerCertificate {null_cert}
202 , ep_ {ep.get()}
203 {
204 tls::TlsSession::TlsSessionCallbacks tls_cbs
205 = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); },
206 /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); },
207 /*.onCertificatesUpdate = */
208 [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) {
209 onTlsCertificatesUpdate(l, r, n);
210 },
211 /*.verifyCertificate = */
212 [this](gnutls_session_t session) {
213 return verifyCertificate(session);
214 }};
215 tls::TlsParams tls_param = {
216 /*.ca_list = */ "",
217 /*.peer_ca = */ nullptr,
218 /*.cert = */ local_identity.second,
219 /*.cert_key = */ local_identity.first,
220 /*.dh_params = */ dh_params,
221 /*.certStore = */ certStore,
222 /*.timeout = */ std::chrono::duration_cast<decltype(tls::TlsParams::timeout)>(TLS_TIMEOUT),
223 /*.cert_check = */ nullptr,
224 };
225 tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
226 }
227
228 ~Impl()
229 {
230 {
231 std::lock_guard<std::mutex> lk(cbMtx_);
232 onStateChangeCb_ = {};
233 onReadyCb_ = {};
234 }
235 tls.reset();
236 }
237
238 std::shared_ptr<IceTransport> underlyingICE() const
239 {
240 if (ep_)
241 if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_))
242 return iceSocket->underlyingICE();
243 return {};
244 }
245
246 // TLS callbacks
247 int verifyCertificate(gnutls_session_t);
248 void onTlsStateChange(tls::TlsSessionState);
249 void onTlsRxData(std::vector<uint8_t>&&);
250 void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int);
251
252 std::mutex cbMtx_ {};
253 OnStateChangeCb onStateChangeCb_;
254 dht::crypto::Certificate null_cert;
255 std::function<bool(const dht::crypto::Certificate&)> peerCertificateCheckFunc;
256 const dht::crypto::Certificate& peerCertificate;
257 std::atomic_bool isReady_ {false};
258 OnReadyCb onReadyCb_;
259 std::unique_ptr<tls::TlsSession> tls;
260 const IceSocketEndpoint* ep_;
261};
262
263int
264TlsSocketEndpoint::Impl::verifyCertificate(gnutls_session_t session)
265{
266 dht::crypto::Certificate crt;
267 auto verified = init_crt(session, crt);
268 if (verified != GNUTLS_E_SUCCESS)
269 return verified;
270 if (peerCertificateCheckFunc) {
271 if (!peerCertificateCheckFunc(crt)) {
272 if (const auto& logger = tls->logger())
273 logger->error("[TLS-SOCKET] Refusing peer certificate");
274 return GNUTLS_E_CERTIFICATE_ERROR;
275 }
276
277 null_cert = std::move(crt);
278 } else {
279 if (crt.getPacked() != peerCertificate.getPacked()) {
280 if (const auto& logger = tls->logger())
281 logger->error("[TLS-SOCKET] Unexpected peer certificate");
282 return GNUTLS_E_CERTIFICATE_ERROR;
283 }
284 }
285
286 return GNUTLS_E_SUCCESS;
287}
288
289void
290TlsSocketEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state)
291{
292 std::lock_guard<std::mutex> lk(cbMtx_);
293 if ((state == tls::TlsSessionState::SHUTDOWN || state == tls::TlsSessionState::ESTABLISHED)
294 && !isReady_) {
295 isReady_ = true;
296 if (onReadyCb_)
297 onReadyCb_(state == tls::TlsSessionState::ESTABLISHED);
298 }
299 if (onStateChangeCb_ && !onStateChangeCb_(state))
300 onStateChangeCb_ = {};
301}
302
303void
304TlsSocketEndpoint::Impl::onTlsRxData([[maybe_unused]] std::vector<uint8_t>&& buf)
305{}
306
307void
308TlsSocketEndpoint::Impl::onTlsCertificatesUpdate([[maybe_unused]] const gnutls_datum_t* local_raw,
309 [[maybe_unused]] const gnutls_datum_t* remote_raw,
310 [[maybe_unused]] unsigned int remote_count)
311{}
312
313TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr,
314 tls::CertificateStore& certStore,
315 const Identity& local_identity,
316 const std::shared_future<tls::DhParams>& dh_params,
317 const dht::crypto::Certificate& peer_cert)
318 : pimpl_ {std::make_unique<Impl>(std::move(tr), certStore, peer_cert, local_identity, dh_params)}
319{}
320
321TlsSocketEndpoint::TlsSocketEndpoint(
322 std::unique_ptr<IceSocketEndpoint>&& tr,
323 tls::CertificateStore& certStore,
324 const Identity& local_identity,
325 const std::shared_future<tls::DhParams>& dh_params,
326 std::function<bool(const dht::crypto::Certificate&)>&& cert_check)
327 : pimpl_ {
328 std::make_unique<Impl>(std::move(tr), certStore, std::move(cert_check), local_identity, dh_params)}
329{}
330
331TlsSocketEndpoint::~TlsSocketEndpoint() {}
332
333bool
334TlsSocketEndpoint::isInitiator() const
335{
336 if (!pimpl_->tls) {
337 return false;
338 }
339 return pimpl_->tls->isInitiator();
340}
341
342int
343TlsSocketEndpoint::maxPayload() const
344{
345 if (!pimpl_->tls) {
346 return -1;
347 }
348 return pimpl_->tls->maxPayload();
349}
350
351std::size_t
352TlsSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
353{
354 if (!pimpl_->tls) {
355 ec = std::make_error_code(std::errc::broken_pipe);
356 return -1;
357 }
358 return pimpl_->tls->read(buf, len, ec);
359}
360
361std::size_t
362TlsSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
363{
364 if (!pimpl_->tls) {
365 ec = std::make_error_code(std::errc::broken_pipe);
366 return -1;
367 }
368 return pimpl_->tls->write(buf, len, ec);
369}
370
371std::shared_ptr<dht::crypto::Certificate>
372TlsSocketEndpoint::peerCertificate() const
373{
374 if (!pimpl_->tls)
375 return {};
376 return pimpl_->tls->peerCertificate();
377}
378
379void
380TlsSocketEndpoint::waitForReady(const std::chrono::milliseconds& timeout)
381{
382 if (!pimpl_->tls) {
383 return;
384 }
385 pimpl_->tls->waitForReady(timeout);
386}
387
388int
389TlsSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
390{
391 if (!pimpl_->tls) {
392 ec = std::make_error_code(std::errc::broken_pipe);
393 return -1;
394 }
395 return pimpl_->tls->waitForData(timeout, ec);
396}
397
398void
399TlsSocketEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb)
400{
401 std::lock_guard<std::mutex> lk(pimpl_->cbMtx_);
402 pimpl_->onStateChangeCb_ = std::move(cb);
403}
404
405void
406TlsSocketEndpoint::setOnReady(std::function<void(bool ok)>&& cb)
407{
408 std::lock_guard<std::mutex> lk(pimpl_->cbMtx_);
409 pimpl_->onReadyCb_ = std::move(cb);
410}
411
412void
413TlsSocketEndpoint::shutdown()
414{
415 pimpl_->tls->shutdown();
416 if (pimpl_->ep_) {
417 const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_);
418 if (iceSocket && iceSocket->underlyingICE())
419 iceSocket->underlyingICE()->cancelOperations();
420 }
421}
422
423void
424TlsSocketEndpoint::monitor() const
425{
426 if (auto ice = pimpl_->underlyingICE())
427 if (auto logger = ice->logger())
428 logger->debug("\t- Ice connection: {}", ice->link());
429}
430
431IpAddr
432TlsSocketEndpoint::getLocalAddress() const
433{
434 if (auto ice = pimpl_->underlyingICE())
435 return ice->getLocalAddress(ICE_COMP_ID_SIP_TRANSPORT);
436 return {};
437}
438
439IpAddr
440TlsSocketEndpoint::getRemoteAddress() const
441{
442 if (auto ice = pimpl_->underlyingICE())
443 return ice->getRemoteAddress(ICE_COMP_ID_SIP_TRANSPORT);
444 return {};
445}
446
Sébastien Blin464bdff2023-07-19 08:02:53 -0400447} // namespace dhtnet