| /* |
| * Copyright (C) 2004-2023 Savoir-faire Linux Inc. |
| * |
| * This program is free software: you can redistribute it and/or modify |
| * it under the terms of the GNU General Public License as published by |
| * the Free Software Foundation, either version 3 of the License, or |
| * (at your option) any later version. |
| * |
| * This program is distributed in the hope that it will be useful, |
| * but WITHOUT ANY WARRANTY; without even the implied warranty of |
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| * GNU General Public License for more details. |
| * |
| * You should have received a copy of the GNU General Public License |
| * along with this program. If not, see <https://www.gnu.org/licenses/>. |
| */ |
| #include "peer_connection.h" |
| #include "tls_session.h" |
| |
| #include <opendht/thread_pool.h> |
| #include <opendht/logger.h> |
| |
| #include <algorithm> |
| #include <chrono> |
| #include <future> |
| #include <vector> |
| #include <atomic> |
| #include <stdexcept> |
| #include <istream> |
| #include <ostream> |
| #include <unistd.h> |
| #include <cstdio> |
| |
| #ifdef _WIN32 |
| #include <winsock2.h> |
| #include <ws2tcpip.h> |
| #else |
| #include <sys/select.h> |
| #endif |
| |
| #ifndef _MSC_VER |
| #include <sys/time.h> |
| #endif |
| |
| namespace dhtnet { |
| |
| int |
| init_crt(gnutls_session_t session, dht::crypto::Certificate& crt) |
| { |
| // Support only x509 format |
| if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) { |
| return GNUTLS_E_CERTIFICATE_ERROR; |
| } |
| |
| // Store verification status |
| unsigned int status = 0; |
| auto ret = gnutls_certificate_verify_peers2(session, &status); |
| if (ret < 0 or (status & GNUTLS_CERT_SIGNATURE_FAILURE) != 0) { |
| return GNUTLS_E_CERTIFICATE_ERROR; |
| } |
| |
| unsigned int cert_list_size = 0; |
| auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size); |
| if (cert_list == nullptr) { |
| return GNUTLS_E_CERTIFICATE_ERROR; |
| } |
| |
| // Check if received peer certificate is awaited |
| std::vector<std::pair<uint8_t*, uint8_t*>> crt_data; |
| crt_data.reserve(cert_list_size); |
| for (unsigned i = 0; i < cert_list_size; i++) |
| crt_data.emplace_back(cert_list[i].data, cert_list[i].data + cert_list[i].size); |
| crt = dht::crypto::Certificate {crt_data}; |
| |
| return GNUTLS_E_SUCCESS; |
| } |
| |
| //============================================================================== |
| |
| IceSocketEndpoint::IceSocketEndpoint(std::shared_ptr<IceTransport> ice, bool isSender) |
| : ice_(std::move(ice)) |
| , iceIsSender(isSender) |
| {} |
| |
| IceSocketEndpoint::~IceSocketEndpoint() |
| { |
| shutdown(); |
| if (ice_) |
| dht::ThreadPool::io().run([ice = std::move(ice_)] {}); |
| } |
| |
| void |
| IceSocketEndpoint::shutdown() |
| { |
| // Sometimes the other peer never send any packet |
| // So, we cancel pending read to avoid to have |
| // any blocking operation. |
| if (ice_) |
| ice_->cancelOperations(); |
| } |
| |
| int |
| IceSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const |
| { |
| if (ice_) { |
| if (!ice_->isRunning()) |
| return -1; |
| return ice_->waitForData(compId_, timeout, ec); |
| } |
| return -1; |
| } |
| |
| std::size_t |
| IceSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec) |
| { |
| if (ice_) { |
| if (!ice_->isRunning()) |
| return 0; |
| try { |
| auto res = ice_->recvfrom(compId_, reinterpret_cast<char*>(buf), len, ec); |
| if (res < 0) |
| shutdown(); |
| return res; |
| } catch (const std::exception& e) { |
| if (auto logger = ice_->logger()) |
| logger->error("IceSocketEndpoint::read exception: %s", e.what()); |
| } |
| return 0; |
| } |
| return -1; |
| } |
| |
| std::size_t |
| IceSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec) |
| { |
| if (ice_) { |
| if (!ice_->isRunning()) |
| return 0; |
| auto res = 0; |
| res = ice_->send(compId_, reinterpret_cast<const unsigned char*>(buf), len); |
| if (res < 0) { |
| ec.assign(errno, std::generic_category()); |
| shutdown(); |
| } else { |
| ec.clear(); |
| } |
| return res; |
| } |
| return -1; |
| } |
| |
| //============================================================================== |
| |
| class TlsSocketEndpoint::Impl |
| { |
| public: |
| static constexpr auto TLS_TIMEOUT = std::chrono::seconds(40); |
| |
| Impl(std::unique_ptr<IceSocketEndpoint>&& ep, |
| tls::CertificateStore& certStore, |
| const std::shared_ptr<asio::io_context>& ioContext, |
| const dht::crypto::Certificate& peer_cert, |
| const Identity& local_identity, |
| const std::shared_future<tls::DhParams>& dh_params) |
| : peerCertificate {peer_cert} |
| , ep_ {ep.get()} |
| { |
| tls::TlsSession::TlsSessionCallbacks tls_cbs |
| = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); }, |
| /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); }, |
| /*.onCertificatesUpdate = */ |
| [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) { |
| onTlsCertificatesUpdate(l, r, n); |
| }, |
| /*.verifyCertificate = */ |
| [this](gnutls_session_t session) { |
| return verifyCertificate(session); |
| }}; |
| tls::TlsParams tls_param = { |
| /*.ca_list = */ "", |
| /*.peer_ca = */ nullptr, |
| /*.cert = */ local_identity.second, |
| /*.cert_key = */ local_identity.first, |
| /*.dh_params = */ dh_params, |
| /*.certStore = */ certStore, |
| /*.timeout = */ TLS_TIMEOUT, |
| /*.cert_check = */ nullptr, |
| /*.io_context = */ ioContext, |
| /* .logger = */ ep->underlyingICE()->logger() |
| }; |
| tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs); |
| } |
| |
| Impl(std::unique_ptr<IceSocketEndpoint>&& ep, |
| tls::CertificateStore& certStore, |
| std::shared_ptr<asio::io_context> ioContext, |
| std::function<bool(const dht::crypto::Certificate&)>&& cert_check, |
| const Identity& local_identity, |
| const std::shared_future<tls::DhParams>& dh_params) |
| : peerCertificateCheckFunc {std::move(cert_check)} |
| , peerCertificate {null_cert} |
| , ep_ {ep.get()} |
| { |
| tls::TlsSession::TlsSessionCallbacks tls_cbs |
| = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); }, |
| /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); }, |
| /*.onCertificatesUpdate = */ |
| [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) { |
| onTlsCertificatesUpdate(l, r, n); |
| }, |
| /*.verifyCertificate = */ |
| [this](gnutls_session_t session) { |
| return verifyCertificate(session); |
| }}; |
| tls::TlsParams tls_param = { |
| /*.ca_list = */ "", |
| /*.peer_ca = */ nullptr, |
| /*.cert = */ local_identity.second, |
| /*.cert_key = */ local_identity.first, |
| /*.dh_params = */ dh_params, |
| /*.certStore = */ certStore, |
| /*.timeout = */ std::chrono::duration_cast<decltype(tls::TlsParams::timeout)>(TLS_TIMEOUT), |
| /*.cert_check = */ nullptr, |
| /*.io_context = */ ioContext, |
| /* .logger = */ ep->underlyingICE()->logger() |
| }; |
| tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs); |
| } |
| |
| ~Impl() |
| { |
| { |
| std::lock_guard lk(cbMtx_); |
| onStateChangeCb_ = {}; |
| onReadyCb_ = {}; |
| } |
| tls.reset(); |
| } |
| |
| std::shared_ptr<IceTransport> underlyingICE() const |
| { |
| if (ep_) |
| if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_)) |
| return iceSocket->underlyingICE(); |
| return {}; |
| } |
| |
| // TLS callbacks |
| int verifyCertificate(gnutls_session_t); |
| void onTlsStateChange(tls::TlsSessionState); |
| void onTlsRxData(std::vector<uint8_t>&&); |
| void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int); |
| |
| std::mutex cbMtx_ {}; |
| OnStateChangeCb onStateChangeCb_; |
| dht::crypto::Certificate null_cert; |
| std::function<bool(const dht::crypto::Certificate&)> peerCertificateCheckFunc; |
| const dht::crypto::Certificate& peerCertificate; |
| std::atomic_bool isReady_ {false}; |
| OnReadyCb onReadyCb_; |
| std::unique_ptr<tls::TlsSession> tls; |
| const IceSocketEndpoint* ep_; |
| }; |
| |
| int |
| TlsSocketEndpoint::Impl::verifyCertificate(gnutls_session_t session) |
| { |
| dht::crypto::Certificate crt; |
| auto verified = init_crt(session, crt); |
| if (verified != GNUTLS_E_SUCCESS) |
| return verified; |
| if (peerCertificateCheckFunc) { |
| if (!peerCertificateCheckFunc(crt)) { |
| if (const auto& logger = tls->logger()) |
| logger->error("[TLS-SOCKET] Refusing peer certificate"); |
| return GNUTLS_E_CERTIFICATE_ERROR; |
| } |
| |
| null_cert = std::move(crt); |
| } else { |
| if (crt.getPacked() != peerCertificate.getPacked()) { |
| if (const auto& logger = tls->logger()) |
| logger->error("[TLS-SOCKET] Unexpected peer certificate"); |
| return GNUTLS_E_CERTIFICATE_ERROR; |
| } |
| } |
| |
| return GNUTLS_E_SUCCESS; |
| } |
| |
| void |
| TlsSocketEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state) |
| { |
| std::lock_guard lk(cbMtx_); |
| if ((state == tls::TlsSessionState::SHUTDOWN || state == tls::TlsSessionState::ESTABLISHED) |
| && !isReady_) { |
| isReady_ = true; |
| if (onReadyCb_) |
| onReadyCb_(state == tls::TlsSessionState::ESTABLISHED); |
| } |
| if (onStateChangeCb_ && !onStateChangeCb_(state)) |
| onStateChangeCb_ = {}; |
| } |
| |
| void |
| TlsSocketEndpoint::Impl::onTlsRxData([[maybe_unused]] std::vector<uint8_t>&& buf) |
| {} |
| |
| void |
| TlsSocketEndpoint::Impl::onTlsCertificatesUpdate([[maybe_unused]] const gnutls_datum_t* local_raw, |
| [[maybe_unused]] const gnutls_datum_t* remote_raw, |
| [[maybe_unused]] unsigned int remote_count) |
| {} |
| |
| TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr, |
| tls::CertificateStore& certStore, |
| const std::shared_ptr<asio::io_context>& ioContext, |
| const Identity& local_identity, |
| const std::shared_future<tls::DhParams>& dh_params, |
| const dht::crypto::Certificate& peer_cert) |
| : pimpl_ {std::make_unique<Impl>(std::move(tr), certStore, ioContext, peer_cert, local_identity, dh_params)} |
| {} |
| |
| TlsSocketEndpoint::TlsSocketEndpoint( |
| std::unique_ptr<IceSocketEndpoint>&& tr, |
| tls::CertificateStore& certStore, |
| const std::shared_ptr<asio::io_context>& ioContext, |
| const Identity& local_identity, |
| const std::shared_future<tls::DhParams>& dh_params, |
| std::function<bool(const dht::crypto::Certificate&)>&& cert_check) |
| : pimpl_ { |
| std::make_unique<Impl>(std::move(tr), certStore, ioContext, std::move(cert_check), local_identity, dh_params)} |
| {} |
| |
| TlsSocketEndpoint::~TlsSocketEndpoint() {} |
| |
| bool |
| TlsSocketEndpoint::isInitiator() const |
| { |
| if (!pimpl_->tls) { |
| return false; |
| } |
| return pimpl_->tls->isInitiator(); |
| } |
| |
| int |
| TlsSocketEndpoint::maxPayload() const |
| { |
| if (!pimpl_->tls) { |
| return -1; |
| } |
| return pimpl_->tls->maxPayload(); |
| } |
| |
| std::size_t |
| TlsSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec) |
| { |
| if (!pimpl_->tls) { |
| ec = std::make_error_code(std::errc::broken_pipe); |
| return -1; |
| } |
| return pimpl_->tls->read(buf, len, ec); |
| } |
| |
| std::size_t |
| TlsSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec) |
| { |
| if (!pimpl_->tls) { |
| ec = std::make_error_code(std::errc::broken_pipe); |
| return -1; |
| } |
| return pimpl_->tls->write(buf, len, ec); |
| } |
| |
| std::shared_ptr<dht::crypto::Certificate> |
| TlsSocketEndpoint::peerCertificate() const |
| { |
| if (!pimpl_->tls) |
| return {}; |
| return pimpl_->tls->peerCertificate(); |
| } |
| |
| int |
| TlsSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const |
| { |
| if (!pimpl_->tls) { |
| ec = std::make_error_code(std::errc::broken_pipe); |
| return -1; |
| } |
| return pimpl_->tls->waitForData(timeout, ec); |
| } |
| |
| void |
| TlsSocketEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb) |
| { |
| std::lock_guard lk(pimpl_->cbMtx_); |
| pimpl_->onStateChangeCb_ = std::move(cb); |
| } |
| |
| void |
| TlsSocketEndpoint::setOnReady(std::function<void(bool ok)>&& cb) |
| { |
| std::lock_guard lk(pimpl_->cbMtx_); |
| pimpl_->onReadyCb_ = std::move(cb); |
| } |
| |
| void |
| TlsSocketEndpoint::shutdown() |
| { |
| pimpl_->tls->shutdown(); |
| if (pimpl_->ep_) { |
| const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_); |
| if (iceSocket && iceSocket->underlyingICE()) |
| iceSocket->underlyingICE()->cancelOperations(); |
| } |
| } |
| |
| void |
| TlsSocketEndpoint::monitor() const |
| { |
| if (auto ice = pimpl_->underlyingICE()) |
| if (auto logger = ice->logger()) |
| logger->debug("\t- Ice connection: {}", ice->link()); |
| } |
| |
| IpAddr |
| TlsSocketEndpoint::getLocalAddress() const |
| { |
| if (auto ice = pimpl_->underlyingICE()) |
| return ice->getLocalAddress(ICE_COMP_ID_SIP_TRANSPORT); |
| return {}; |
| } |
| |
| IpAddr |
| TlsSocketEndpoint::getRemoteAddress() const |
| { |
| if (auto ice = pimpl_->underlyingICE()) |
| return ice->getRemoteAddress(ICE_COMP_ID_SIP_TRANSPORT); |
| return {}; |
| } |
| |
| } // namespace dhtnet |