blob: cea117af8965a689f8ccaf2616ef5323701a85f7 [file] [log] [blame]
/*
* Copyright (C) 2004-2023 Savoir-faire Linux Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include "certstore.h"
#include "security_const.h"
#include "fileutils.h"
#include <opendht/thread_pool.h>
#include <opendht/logger.h>
#include <gnutls/ocsp.h>
#if __has_include(<fmt/std.h>)
#include <fmt/std.h>
#else
#include <fmt/ostream.h>
#endif
#include <thread>
#include <sstream>
#include <fmt/format.h>
namespace dhtnet {
namespace tls {
CertificateStore::CertificateStore(const std::filesystem::path& path, std::shared_ptr<Logger> logger)
: logger_(std::move(logger))
, certPath_(path / "certificates")
, crlPath_(path /"crls")
, ocspPath_(path /"oscp")
{
fileutils::check_dir(certPath_);
fileutils::check_dir(crlPath_);
fileutils::check_dir(ocspPath_);
loadLocalCertificates();
}
unsigned
CertificateStore::loadLocalCertificates()
{
std::lock_guard<std::mutex> l(lock_);
if (logger_)
logger_->debug("CertificateStore: loading certificates from {}", certPath_);
unsigned n = 0;
std::error_code ec;
for (const auto& crtPath : std::filesystem::directory_iterator(certPath_, ec)) {
const auto& path = crtPath.path();
auto fileName = path.filename().string();
try {
auto crt = std::make_shared<crypto::Certificate>(
fileutils::loadFile(crtPath));
auto id = crt->getId().toString();
auto longId = crt->getLongId().toString();
if (id != fileName && longId != fileName)
throw std::logic_error("Certificate id mismatch");
while (crt) {
id = crt->getId().toString();
longId = crt->getLongId().toString();
certs_.emplace(std::move(id), crt);
certs_.emplace(std::move(longId), crt);
loadRevocations(*crt);
crt = crt->issuer;
++n;
}
} catch (const std::exception& e) {
if (logger_)
logger_->warn("loadLocalCertificates: error loading {}: {}", path, e.what());
remove(path);
}
}
if (logger_)
logger_->debug("CertificateStore: loaded {} local certificates.", n);
return n;
}
void
CertificateStore::loadRevocations(crypto::Certificate& crt) const
{
std::error_code ec;
auto dir = crlPath_ / crt.getId().toString();
for (const auto& crl : std::filesystem::directory_iterator(dir, ec)) {
try {
crt.addRevocationList(std::make_shared<crypto::RevocationList>(
fileutils::loadFile(crl)));
} catch (const std::exception& e) {
if (logger_)
logger_->warn("Can't load revocation list: %s", e.what());
}
}
auto ocsp_dir = ocspPath_ / crt.getId().toString();
for (const auto& ocsp_filepath : std::filesystem::directory_iterator(ocsp_dir, ec)) {
try {
auto ocsp = ocsp_filepath.path().filename().string();
if (logger_) logger_->debug("Found {}", ocsp_filepath.path());
auto serial = crt.getSerialNumber();
if (dht::toHex(serial.data(), serial.size()) != ocsp)
continue;
// Save the response
auto ocspBlob = fileutils::loadFile(ocsp_filepath);
crt.ocspResponse = std::make_shared<dht::crypto::OcspResponse>(ocspBlob.data(),
ocspBlob.size());
unsigned int status = crt.ocspResponse->getCertificateStatus();
if (status == GNUTLS_OCSP_CERT_GOOD) {
if (logger_) logger_->debug("Certificate {:s} has good OCSP status", crt.getId());
} else if (status == GNUTLS_OCSP_CERT_REVOKED) {
if (logger_) logger_->error("Certificate {:s} has revoked OCSP status", crt.getId());
} else if (status == GNUTLS_OCSP_CERT_UNKNOWN) {
if (logger_) logger_->error("Certificate {:s} has unknown OCSP status", crt.getId());
} else {
if (logger_) logger_->error("Certificate {:s} has invalid OCSP status", crt.getId());
}
} catch (const std::exception& e) {
if (logger_)
logger_->warn("Can't load OCSP revocation status: {:s}", e.what());
}
}
}
std::vector<std::string>
CertificateStore::getPinnedCertificates() const
{
std::lock_guard<std::mutex> l(lock_);
std::vector<std::string> certIds;
certIds.reserve(certs_.size());
for (const auto& crt : certs_)
certIds.emplace_back(crt.first);
return certIds;
}
std::shared_ptr<crypto::Certificate>
CertificateStore::getCertificate(const std::string& k)
{
auto getCertificate_ = [this](const std::string& k) -> std::shared_ptr<crypto::Certificate> {
auto cit = certs_.find(k);
return cit != certs_.cend() ? cit->second : std::shared_ptr<crypto::Certificate>{};
};
std::unique_lock<std::mutex> l(lock_);
auto crt = getCertificate_(k);
// Check if certificate is complete
// If the certificate has been splitted, reconstruct it
auto top_issuer = crt;
while (top_issuer && top_issuer->getUID() != top_issuer->getIssuerUID()) {
if (top_issuer->issuer) {
top_issuer = top_issuer->issuer;
} else if (auto cert = getCertificate_(top_issuer->getIssuerUID())) {
top_issuer->issuer = cert;
top_issuer = cert;
} else {
// In this case, a certificate was not found
if (logger_)
logger_->warn("Incomplete certificate detected {:s}", k);
break;
}
}
return crt;
}
std::shared_ptr<crypto::Certificate>
CertificateStore::getCertificateLegacy(const std::string& dataDir, const std::string& k)
{
try {
auto oldPath = fmt::format("{}/certificates/{}", dataDir, k);
if (fileutils::isFile(oldPath)) {
auto crt = std::make_shared<crypto::Certificate>(oldPath);
pinCertificate(crt, true);
return crt;
}
} catch (const std::exception& e) {
if (logger_)
logger_->warn("Can't load certificate: {:s}", e.what());
}
return {};
}
std::shared_ptr<crypto::Certificate>
CertificateStore::findCertificateByName(const std::string& name, crypto::NameType type) const
{
std::unique_lock<std::mutex> l(lock_);
for (auto& i : certs_) {
if (i.second->getName() == name)
return i.second;
if (type != crypto::NameType::UNKNOWN) {
for (const auto& alt : i.second->getAltNames())
if (alt.first == type and alt.second == name)
return i.second;
}
}
return {};
}
std::shared_ptr<crypto::Certificate>
CertificateStore::findCertificateByUID(const std::string& uid) const
{
std::unique_lock<std::mutex> l(lock_);
for (auto& i : certs_) {
if (i.second->getUID() == uid)
return i.second;
}
return {};
}
std::shared_ptr<crypto::Certificate>
CertificateStore::findIssuer(const std::shared_ptr<crypto::Certificate>& crt) const
{
std::shared_ptr<crypto::Certificate> ret {};
auto n = crt->getIssuerUID();
if (not n.empty()) {
if (crt->issuer and crt->issuer->getUID() == n)
ret = crt->issuer;
else
ret = findCertificateByUID(n);
}
if (not ret) {
n = crt->getIssuerName();
if (not n.empty())
ret = findCertificateByName(n);
}
if (not ret)
return ret;
unsigned verify_out = 0;
int err = gnutls_x509_crt_verify(crt->cert, &ret->cert, 1, 0, &verify_out);
if (err != GNUTLS_E_SUCCESS) {
if (logger_)
logger_->warn("gnutls_x509_crt_verify failed: {:s}", gnutls_strerror(err));
return {};
}
if (verify_out & GNUTLS_CERT_INVALID)
return {};
return ret;
}
static std::vector<crypto::Certificate>
readCertificates(const std::filesystem::path& path, const std::string& crl_path)
{
std::vector<crypto::Certificate> ret;
if (std::filesystem::is_directory(path)) {
std::error_code ec;
for (const auto& file : std::filesystem::directory_iterator(path, ec)) {
auto certs = readCertificates(file, crl_path);
ret.insert(std::end(ret),
std::make_move_iterator(std::begin(certs)),
std::make_move_iterator(std::end(certs)));
}
} else {
try {
auto data = fileutils::loadFile(path);
const gnutls_datum_t dt {data.data(), (unsigned) data.size()};
gnutls_x509_crt_t* certs {nullptr};
unsigned cert_num {0};
gnutls_x509_crt_list_import2(&certs, &cert_num, &dt, GNUTLS_X509_FMT_PEM, 0);
for (unsigned i = 0; i < cert_num; i++)
ret.emplace_back(certs[i]);
gnutls_free(certs);
} catch (const std::exception& e) {
};
}
return ret;
}
void
CertificateStore::pinCertificatePath(const std::string& path,
std::function<void(const std::vector<std::string>&)> cb)
{
dht::ThreadPool::computation().run([&, path, cb]() {
auto certs = readCertificates(path, crlPath_.string());
std::vector<std::string> ids;
std::vector<std::weak_ptr<crypto::Certificate>> scerts;
ids.reserve(certs.size());
scerts.reserve(certs.size());
{
std::lock_guard<std::mutex> l(lock_);
for (auto& cert : certs) {
try {
auto shared = std::make_shared<crypto::Certificate>(std::move(cert));
scerts.emplace_back(shared);
auto e = certs_.emplace(shared->getId().toString(), shared);
ids.emplace_back(e.first->first);
e = certs_.emplace(shared->getLongId().toString(), shared);
ids.emplace_back(e.first->first);
} catch (const std::exception& e) {
if (logger_)
logger_->warn("Can't load certificate: {:s}", e.what());
}
}
paths_.emplace(path, std::move(scerts));
}
if (logger_) logger_->d("CertificateStore: loaded %zu certificates from %s.", certs.size(), path.c_str());
if (cb)
cb(ids);
//emitSignal<libdhtnet::ConfigurationSignal::CertificatePathPinned>(path, ids);
});
}
unsigned
CertificateStore::unpinCertificatePath(const std::string& path)
{
std::lock_guard<std::mutex> l(lock_);
auto certs = paths_.find(path);
if (certs == std::end(paths_))
return 0;
unsigned n = 0;
for (const auto& wcert : certs->second) {
if (auto cert = wcert.lock()) {
certs_.erase(cert->getId().toString());
++n;
}
}
paths_.erase(certs);
return n;
}
std::vector<std::string>
CertificateStore::pinCertificate(const std::vector<uint8_t>& cert, bool local) noexcept
{
try {
return pinCertificate(crypto::Certificate(cert), local);
} catch (const std::exception& e) {
}
return {};
}
std::vector<std::string>
CertificateStore::pinCertificate(crypto::Certificate&& cert, bool local)
{
return pinCertificate(std::make_shared<crypto::Certificate>(std::move(cert)), local);
}
std::vector<std::string>
CertificateStore::pinCertificate(const std::shared_ptr<crypto::Certificate>& cert, bool local)
{
bool sig {false};
std::vector<std::string> ids {};
{
auto c = cert;
std::lock_guard<std::mutex> l(lock_);
while (c) {
bool inserted;
auto id = c->getId().toString();
auto longId = c->getLongId().toString();
decltype(certs_)::iterator it;
std::tie(it, inserted) = certs_.emplace(id, c);
if (not inserted)
it->second = c;
std::tie(it, inserted) = certs_.emplace(longId, c);
if (not inserted)
it->second = c;
if (local) {
for (const auto& crl : c->getRevocationLists())
pinRevocationList(id, *crl);
}
ids.emplace_back(longId);
ids.emplace_back(id);
c = c->issuer;
sig |= inserted;
}
if (local) {
if (sig)
fileutils::saveFile(certPath_ / ids.front(), cert->getPacked());
}
}
//for (const auto& id : ids)
// emitSignal<libdhtnet::ConfigurationSignal::CertificatePinned>(id);
return ids;
}
bool
CertificateStore::unpinCertificate(const std::string& id)
{
std::lock_guard<std::mutex> l(lock_);
certs_.erase(id);
return remove(certPath_ / id);
}
bool
CertificateStore::setTrustedCertificate(const std::string& id, TrustStatus status)
{
if (status == TrustStatus::TRUSTED) {
if (auto crt = getCertificate(id)) {
trustedCerts_.emplace_back(crt);
return true;
}
} else {
auto tc = std::find_if(trustedCerts_.begin(),
trustedCerts_.end(),
[&](const std::shared_ptr<crypto::Certificate>& crt) {
return crt->getId().toString() == id;
});
if (tc != trustedCerts_.end()) {
trustedCerts_.erase(tc);
return true;
}
}
return false;
}
std::vector<gnutls_x509_crt_t>
CertificateStore::getTrustedCertificates() const
{
std::vector<gnutls_x509_crt_t> crts;
crts.reserve(trustedCerts_.size());
for (auto& crt : trustedCerts_)
crts.emplace_back(crt->getCopy());
return crts;
}
void
CertificateStore::pinRevocationList(const std::string& id,
const std::shared_ptr<dht::crypto::RevocationList>& crl)
{
try {
if (auto c = getCertificate(id))
c->addRevocationList(crl);
pinRevocationList(id, *crl);
} catch (...) {
if (logger_)
logger_->warn("Can't add revocation list");
}
}
void
CertificateStore::pinRevocationList(const std::string& id, const dht::crypto::RevocationList& crl)
{
fileutils::check_dir(crlPath_ / id);
fileutils::saveFile(crlPath_ / id / dht::toHex(crl.getNumber()),
crl.getPacked());
}
void
CertificateStore::pinOcspResponse(const dht::crypto::Certificate& cert)
{
if (not cert.ocspResponse)
return;
try {
cert.ocspResponse->getCertificateStatus();
} catch (dht::crypto::CryptoException& e) {
if (logger_) logger_->error("Failed to read certificate status of OCSP response: {:s}", e.what());
return;
}
auto id = cert.getId().toString();
auto serial = cert.getSerialNumber();
auto serialhex = dht::toHex(serial);
auto dir = ocspPath_ / id;
if (auto localCert = getCertificate(id)) {
// Update certificate in the local store if relevant
if (localCert.get() != &cert && serial == localCert->getSerialNumber()) {
if (logger_) logger_->d("Updating OCSP for certificate %s in the local store", id.c_str());
localCert->ocspResponse = cert.ocspResponse;
}
}
dht::ThreadPool::io().run([l=logger_,
path = dir / serialhex,
dir = std::move(dir),
id = std::move(id),
serialhex = std::move(serialhex),
ocspResponse = cert.ocspResponse] {
if (l) l->d("Saving OCSP Response of device %s with serial %s", id.c_str(), serialhex.c_str());
std::lock_guard<std::mutex> lock(fileutils::getFileLock(path));
fileutils::check_dir(dir.c_str());
fileutils::saveFile(path, ocspResponse->pack());
});
}
TrustStore::PermissionStatus
TrustStore::statusFromStr(const char* str)
{
if (!std::strcmp(str, libdhtnet::Certificate::Status::ALLOWED))
return PermissionStatus::ALLOWED;
if (!std::strcmp(str, libdhtnet::Certificate::Status::BANNED))
return PermissionStatus::BANNED;
return PermissionStatus::UNDEFINED;
}
const char*
TrustStore::statusToStr(TrustStore::PermissionStatus s)
{
switch (s) {
case PermissionStatus::ALLOWED:
return libdhtnet::Certificate::Status::ALLOWED;
case PermissionStatus::BANNED:
return libdhtnet::Certificate::Status::BANNED;
case PermissionStatus::UNDEFINED:
default:
return libdhtnet::Certificate::Status::UNDEFINED;
}
}
TrustStatus
trustStatusFromStr(const char* str)
{
if (!std::strcmp(str, libdhtnet::Certificate::TrustStatus::TRUSTED))
return TrustStatus::TRUSTED;
return TrustStatus::UNTRUSTED;
}
const char*
statusToStr(TrustStatus s)
{
switch (s) {
case TrustStatus::TRUSTED:
return libdhtnet::Certificate::TrustStatus::TRUSTED;
case TrustStatus::UNTRUSTED:
default:
return libdhtnet::Certificate::TrustStatus::UNTRUSTED;
}
}
bool
TrustStore::addRevocationList(dht::crypto::RevocationList&& crl)
{
allowed_.add(crl);
return true;
}
bool
TrustStore::setCertificateStatus(const std::string& cert_id,
const TrustStore::PermissionStatus status)
{
return setCertificateStatus(nullptr, cert_id, status, false);
}
bool
TrustStore::setCertificateStatus(const std::shared_ptr<crypto::Certificate>& cert,
const TrustStore::PermissionStatus status,
bool local)
{
return setCertificateStatus(cert, cert->getId().toString(), status, local);
}
bool
TrustStore::setCertificateStatus(std::shared_ptr<crypto::Certificate> cert,
const std::string& cert_id,
const TrustStore::PermissionStatus status,
bool local)
{
if (cert)
certStore_.pinCertificate(cert, local);
std::lock_guard<std::recursive_mutex> lk(mutex_);
updateKnownCerts();
bool dirty {false};
if (status == PermissionStatus::UNDEFINED) {
unknownCertStatus_.erase(cert_id);
dirty = certStatus_.erase(cert_id);
} else {
bool allowed = (status == PermissionStatus::ALLOWED);
auto s = certStatus_.find(cert_id);
if (s == std::end(certStatus_)) {
// Certificate state is currently undefined
if (not cert)
cert = certStore_.getCertificate(cert_id);
if (cert) {
unknownCertStatus_.erase(cert_id);
auto& crt_status = certStatus_[cert_id];
if (not crt_status.first)
crt_status.first = cert;
crt_status.second.allowed = allowed;
setStoreCertStatus(*cert, allowed);
} else {
// Can't find certificate
unknownCertStatus_[cert_id].allowed = allowed;
}
} else {
// Certificate is already allowed or banned
if (s->second.second.allowed != allowed) {
s->second.second.allowed = allowed;
if (allowed) // Certificate is re-added after ban, rebuld needed
dirty = true;
else
allowed_.remove(*s->second.first, false);
}
}
}
if (dirty)
rebuildTrust();
return true;
}
TrustStore::PermissionStatus
TrustStore::getCertificateStatus(const std::string& cert_id) const
{
std::lock_guard<std::recursive_mutex> lk(mutex_);
auto cert = certStore_.getCertificate(cert_id);
if (!cert)
return PermissionStatus::UNDEFINED;
auto allowed = false;
auto found = false;
while (cert) {
auto s = certStatus_.find(cert->getId().toString());
if (s != std::end(certStatus_)) {
if (!found) {
found = true;
allowed = true; // we need to find at least a certificate
}
allowed &= s->second.second.allowed;
if (!allowed)
return PermissionStatus::BANNED;
} else {
auto us = unknownCertStatus_.find(cert->getId().toString());
if (us != std::end(unknownCertStatus_)) {
if (!found) {
found = true;
allowed = true; // we need to find at least a certificate
}
allowed &= us->second.allowed;
if (!allowed)
return PermissionStatus::BANNED;
}
}
if (cert->getUID() == cert->getIssuerUID())
break;
cert = cert->issuer? cert->issuer : certStore_.getCertificate(cert->getIssuerUID());
}
return allowed ? PermissionStatus::ALLOWED : PermissionStatus::UNDEFINED;
}
std::vector<std::string>
TrustStore::getCertificatesByStatus(TrustStore::PermissionStatus status) const
{
std::lock_guard<std::recursive_mutex> lk(mutex_);
std::vector<std::string> ret;
for (const auto& i : certStatus_)
if (i.second.second.allowed == (status == TrustStore::PermissionStatus::ALLOWED))
ret.emplace_back(i.first);
for (const auto& i : unknownCertStatus_)
if (i.second.allowed == (status == TrustStore::PermissionStatus::ALLOWED))
ret.emplace_back(i.first);
return ret;
}
bool
TrustStore::isAllowed(const crypto::Certificate& crt, bool allowPublic)
{
// Match by certificate pinning
std::lock_guard<std::recursive_mutex> lk(mutex_);
bool allowed {allowPublic};
for (auto c = &crt; c; c = c->issuer.get()) {
auto status = getCertificateStatus(c->getId().toString()); // lock mutex_
if (status == PermissionStatus::ALLOWED)
allowed = true;
else if (status == PermissionStatus::BANNED)
return false;
}
// Match by certificate chain
updateKnownCerts();
auto ret = allowed_.verify(crt);
// Unknown issuer (only that) are accepted if allowPublic is true
if (not ret
and !(allowPublic and ret.result == (GNUTLS_CERT_INVALID | GNUTLS_CERT_SIGNER_NOT_FOUND))) {
if (certStore_.logger())
certStore_.logger()->warn("%s", ret.toString().c_str());
return false;
}
return allowed;
}
void
TrustStore::updateKnownCerts()
{
auto i = std::begin(unknownCertStatus_);
while (i != std::end(unknownCertStatus_)) {
if (auto crt = certStore_.getCertificate(i->first)) {
certStatus_.emplace(i->first, std::make_pair(crt, i->second));
setStoreCertStatus(*crt, i->second.allowed);
i = unknownCertStatus_.erase(i);
} else
++i;
}
}
void
TrustStore::setStoreCertStatus(const crypto::Certificate& crt, bool status)
{
if (status)
allowed_.add(crt);
else
allowed_.remove(crt, false);
}
void
TrustStore::rebuildTrust()
{
allowed_ = {};
for (const auto& c : certStatus_)
setStoreCertStatus(*c.second.first, c.second.second.allowed);
}
} // namespace tls
} // namespace dhtnet