/*
 *  Copyright (C) 2004-2023 Savoir-faire Linux Inc.
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program. If not, see <https://www.gnu.org/licenses/>.
 */
#include "tls_session.h"
#include "threadloop.h"
#include "certstore.h"

#include <gnutls/gnutls.h>
#include <gnutls/dtls.h>
#include <gnutls/abstract.h>

#include <gnutls/crypto.h>
#include <gnutls/ocsp.h>
#include <opendht/http.h>
#include <opendht/logger.h>

#include <list>
#include <mutex>
#include <condition_variable>
#include <utility>
#include <map>
#include <atomic>
#include <iterator>
#include <stdexcept>
#include <algorithm>
#include <cstring> // std::memset

#include <cstdlib>
#include <unistd.h>

namespace dhtnet {
namespace tls {

static constexpr const char* DTLS_CERT_PRIORITY_STRING {
    "SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"};
static constexpr const char* DTLS_FULL_PRIORITY_STRING {
    "SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_"
    "PRECEDENCE:%SAFE_RENEGOTIATION"};
// Note: -GROUP-FFDHE4096:-GROUP-FFDHE6144:-GROUP-FFDHE8192:+GROUP-X25519:
// is added after gnutls 3.6.7, because some safety checks were introduced for FFDHE resulting in a
// performance drop for our usage (2/3s of delay) This performance drop is visible on mobiles devices.

// Benchmark result (on a computer)
// $gnutls-cli --benchmark-tls-kx
// (TLS1.3)-(DHE-FFDHE3072)-(RSA-PSS-RSAE-SHA256)-(AES-128-GCM)  20.48 transactions/sec
//            (avg. handshake time: 48.45 ms, sample variance: 0.68)
// (TLS1.3)-(ECDHE-SECP256R1)-(RSA-PSS-RSAE-SHA256)-(AES-128-GCM)  208.14 transactions/sec
//            (avg. handshake time: 4.01 ms, sample variance: 0.01)
// (TLS1.3)-(ECDHE-X25519)-(RSA-PSS-RSAE-SHA256)-(AES-128-GCM)  240.93 transactions/sec
//            (avg. handshake time: 4.00 ms, sample variance: 0.00)
static constexpr const char* TLS_CERT_PRIORITY_STRING {
    "SECURE192:-RSA:-GROUP-FFDHE4096:-GROUP-FFDHE6144:-GROUP-FFDHE8192:+GROUP-X25519:%SERVER_"
    "PRECEDENCE:%SAFE_RENEGOTIATION"};
static constexpr const char* TLS_FULL_PRIORITY_STRING {
    "SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-RSA:-GROUP-FFDHE4096:-GROUP-FFDHE6144:-"
    "GROUP-FFDHE8192:+GROUP-X25519:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"};
static constexpr uint32_t RX_MAX_SIZE {64 * 1024}; // 64k = max size of a UDP packet
static constexpr std::size_t INPUT_MAX_SIZE {
    1000}; // Maximum number of packets to store before dropping (pkt size = DTLS_MTU)
static constexpr ssize_t FLOOD_THRESHOLD {4 * 1024};
static constexpr auto FLOOD_PAUSE = std::chrono::milliseconds(
    100); // Time to wait after an invalid cookie packet (anti flood attack)
static constexpr size_t HANDSHAKE_MAX_RETRY {64};
static constexpr auto DTLS_RETRANSMIT_TIMEOUT = std::chrono::milliseconds(
    1000); // Delay between two handshake request on DTLS
static constexpr auto COOKIE_TIMEOUT = std::chrono::seconds(
    10); // Time to wait for a cookie packet from client
static constexpr int MIN_MTU {
    512 - 20 - 8}; // minimal payload size of a DTLS packet carried by an IPv4 packet
static constexpr uint8_t HEARTBEAT_TRIES = 1; // Number of tries at each heartbeat ping send
static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds(
    700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds)
static constexpr auto HEARTBEAT_TOTAL_TIMEOUT
    = HEARTBEAT_RETRANS_TIMEOUT
      * HEARTBEAT_TRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds)
static constexpr int MISS_ORDERING_LIMIT
    = 32; // maximal accepted distance of out-of-order packet (note: must be a signed type)
static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(1500);
static constexpr int ASYMETRIC_TRANSPORT_MTU_OFFSET
    = 20; // when client, if your local IP is IPV4 and server is IPV6; you must reduce your MTU to
          // avoid packet too big error on server side. the offset is the difference in size of IP headers
static constexpr auto OCSP_REQUEST_TIMEOUT = std::chrono::seconds(
    2); // Time to wait for an ocsp-request

// Helper to cast any duration into an integer number of milliseconds
template<class Rep, class Period>
static std::chrono::milliseconds::rep
duration2ms(std::chrono::duration<Rep, Period> d)
{
    return std::chrono::duration_cast<std::chrono::milliseconds>(d).count();
}

static inline uint64_t
array2uint(const std::array<uint8_t, 8>& a)
{
    uint64_t res = 0;
    for (int i = 0; i < 8; ++i)
        res = (res << 8) + a[i];
    return res;
}

//==============================================================================

namespace {

class TlsCertificateCredendials
{
    using T = gnutls_certificate_credentials_t;

public:
    TlsCertificateCredendials()
    {
        int ret = gnutls_certificate_allocate_credentials(&creds_);
        if (ret < 0) {
            //if (params_.logger)
            //    params_.logger->e("gnutls_certificate_allocate_credentials() failed with ret=%d", ret);
            throw std::bad_alloc();
        }
    }

    ~TlsCertificateCredendials() { gnutls_certificate_free_credentials(creds_); }

    operator T() { return creds_; }

private:
    TlsCertificateCredendials(const TlsCertificateCredendials&) = delete;
    TlsCertificateCredendials& operator=(const TlsCertificateCredendials&) = delete;
    T creds_;
};

class TlsAnonymousClientCredendials
{
    using T = gnutls_anon_client_credentials_t;

public:
    TlsAnonymousClientCredendials()
    {
        int ret = gnutls_anon_allocate_client_credentials(&creds_);
        if (ret < 0) {
            //if (params_.logger)
            //    params_.logger->e("gnutls_anon_allocate_client_credentials() failed with ret=%d", ret);
            throw std::bad_alloc();
        }
    }

    ~TlsAnonymousClientCredendials() { gnutls_anon_free_client_credentials(creds_); }

    operator T() { return creds_; }

private:
    TlsAnonymousClientCredendials(const TlsAnonymousClientCredendials&) = delete;
    TlsAnonymousClientCredendials& operator=(const TlsAnonymousClientCredendials&) = delete;
    T creds_;
};

class TlsAnonymousServerCredendials
{
    using T = gnutls_anon_server_credentials_t;

public:
    TlsAnonymousServerCredendials()
    {
        int ret = gnutls_anon_allocate_server_credentials(&creds_);
        if (ret < 0) {
            //if (params_.logger)
            //    params_.logger->e("gnutls_anon_allocate_server_credentials() failed with ret=%d", ret);
            throw std::bad_alloc();
        }
    }

    ~TlsAnonymousServerCredendials() { gnutls_anon_free_server_credentials(creds_); }

    operator T() { return creds_; }

private:
    TlsAnonymousServerCredendials(const TlsAnonymousServerCredendials&) = delete;
    TlsAnonymousServerCredendials& operator=(const TlsAnonymousServerCredendials&) = delete;
    T creds_;
};

} // namespace

//==============================================================================

class TlsSession::TlsSessionImpl
{
public:
    using clock = std::chrono::steady_clock;
    using StateHandler = std::function<TlsSessionState(TlsSessionState state)>;
    using OcspVerification = std::function<void(const int status)>;
    using HttpResponse = std::function<void(const dht::http::Response& response)>;

    // Constants (ctor init.)
    const bool isServer_;
    const TlsParams params_;
    const TlsSessionCallbacks callbacks_;
    const bool anonymous_;

    TlsSessionImpl(std::unique_ptr<SocketType>&& transport,
                   const TlsParams& params,
                   const TlsSessionCallbacks& cbs,
                   bool anonymous);

    ~TlsSessionImpl();

    const char* typeName() const;

    std::unique_ptr<SocketType> transport_;

    // State protectors
    std::mutex stateMutex_;
    std::condition_variable stateCondition_;

    // State machine
    TlsSessionState handleStateSetup(TlsSessionState state);
    TlsSessionState handleStateCookie(TlsSessionState state);
    TlsSessionState handleStateHandshake(TlsSessionState state);
    TlsSessionState handleStateMtuDiscovery(TlsSessionState state);
    TlsSessionState handleStateEstablished(TlsSessionState state);
    TlsSessionState handleStateShutdown(TlsSessionState state);
    std::map<TlsSessionState, StateHandler> fsmHandlers_ {};
    std::atomic<TlsSessionState> state_ {TlsSessionState::SETUP};
    std::atomic<TlsSessionState> newState_ {TlsSessionState::NONE};
    std::atomic<int> maxPayload_ {-1};

    // IO GnuTLS <-> ICE
    std::mutex rxMutex_ {};
    std::condition_variable rxCv_ {};
    std::list<std::vector<ValueType>> rxQueue_ {};

    bool flushProcessing_ {false};     ///< protect against recursive call to flushRxQueue
    std::vector<ValueType> rawPktBuf_; ///< gnutls incoming packet buffer
    uint64_t baseSeq_ {0};   ///< sequence number of first application data packet received
    uint64_t lastRxSeq_ {0}; ///< last received and valid packet sequence number
    uint64_t gapOffset_ {0}; ///< offset of first byte not received yet
    clock::time_point lastReadTime_;
    std::map<uint64_t, std::vector<ValueType>> reorderBuffer_ {};
    std::list<clock::time_point> nextFlush_ {};

    std::size_t send(const ValueType*, std::size_t, std::error_code&);
    ssize_t sendRaw(const void*, size_t);
    ssize_t sendRawVec(const giovec_t*, int);
    ssize_t recvRaw(void*, size_t);
    int waitForRawData(std::chrono::milliseconds);

    bool initFromRecordState(int offset = 0);
    void handleDataPacket(std::vector<ValueType>&&, uint64_t);
    void flushRxQueue(std::unique_lock<std::mutex>&);

    // Statistics
    std::atomic<std::size_t> stRxRawPacketCnt_ {0};
    std::atomic<std::size_t> stRxRawBytesCnt_ {0};
    std::atomic<std::size_t> stRxRawPacketDropCnt_ {0};
    std::atomic<std::size_t> stTxRawPacketCnt_ {0};
    std::atomic<std::size_t> stTxRawBytesCnt_ {0};
    void dump_io_stats() const;

    std::unique_ptr<TlsAnonymousClientCredendials> cacred_; // ctor init.
    std::unique_ptr<TlsAnonymousServerCredendials> sacred_; // ctor init.
    std::unique_ptr<TlsCertificateCredendials> xcred_;      // ctor init.
    std::mutex sessionReadMutex_;
    std::mutex sessionWriteMutex_;
    gnutls_session_t session_ {nullptr};
    gnutls_datum_t cookie_key_ {nullptr, 0};
    gnutls_dtls_prestate_st prestate_ {};
    ssize_t cookie_count_ {0};

    TlsSessionState setupClient();
    TlsSessionState setupServer();
    void initAnonymous();
    void initCredentials();
    bool commonSessionInit();

    std::shared_ptr<dht::crypto::Certificate> peerCertificate(gnutls_session_t session) const;

    /*
     * Implicit certificate validations.
     */
    int verifyCertificateWrapper(gnutls_session_t session);
    /*
     * Verify OCSP (Online Certificate Service Protocol):
     */
    void verifyOcsp(const std::string& url,
                    dht::crypto::Certificate& cert,
                    gnutls_x509_crt_t issuer,
                    OcspVerification cb);
    /*
     * Send OCSP Request to the specified URI.
     */
    void sendOcspRequest(const std::string& uri,
                         std::string body,
                         std::chrono::seconds timeout,
                         HttpResponse cb = {});

    // FSM thread (TLS states)
    ThreadLoop thread_; // ctor init.
    bool setup();
    void process();
    void cleanup();

    // Path mtu discovery
    std::array<int, 3> MTUS_;
    int mtuProbe_;
    int hbPingRecved_ {0};
    bool pmtudOver_ {false};
    void pathMtuHeartbeat();

    std::mutex requestsMtx_;
    std::set<std::shared_ptr<dht::http::Request>> requests_;
    std::shared_ptr<dht::crypto::Certificate> pCert_ {};
};

TlsSession::TlsSessionImpl::TlsSessionImpl(std::unique_ptr<SocketType>&& transport,
                                           const TlsParams& params,
                                           const TlsSessionCallbacks& cbs,
                                           bool anonymous)
    : isServer_(not transport->isInitiator())
    , params_(params)
    , callbacks_(cbs)
    , anonymous_(anonymous)
    , transport_ {std::move(transport)}
    , cacred_(nullptr)
    , sacred_(nullptr)
    , xcred_(nullptr)
    , thread_(params.logger, [this] { return setup(); }, [this] { process(); }, [this] { cleanup(); })
{
    if (not transport_->isReliable()) {
        transport_->setOnRecv([this](const ValueType* buf, size_t len) {
            std::lock_guard lk {rxMutex_};
            if (rxQueue_.size() == INPUT_MAX_SIZE) {
                rxQueue_.pop_front(); // drop oldest packet if input buffer is full
                ++stRxRawPacketDropCnt_;
            }
            rxQueue_.emplace_back(buf, buf + len);
            ++stRxRawPacketCnt_;
            stRxRawBytesCnt_ += len;
            rxCv_.notify_one();
            return len;
        });
    }

    // Run FSM into dedicated thread
    thread_.start();
}

TlsSession::TlsSessionImpl::~TlsSessionImpl()
{
    state_ = TlsSessionState::SHUTDOWN;
    stateCondition_.notify_all();
    rxCv_.notify_all();
    {
        std::lock_guard lock(requestsMtx_);
        // requests_ store a shared_ptr, so we need to cancel requests
        // to not be stuck in verifyCertificateWrapper
        for (auto& request : requests_)
            request->cancel();
        requests_.clear();
    }
    thread_.join();
    if (not transport_->isReliable())
        transport_->setOnRecv(nullptr);
}

const char*
TlsSession::TlsSessionImpl::typeName() const
{
    return isServer_ ? "server" : "client";
}

void
TlsSession::TlsSessionImpl::dump_io_stats() const
{
    if (params_.logger)
        params_.logger->debug("[TLS] RxRawPkt={:d} ({:d} bytes) - TxRawPkt={:d} ({:d} bytes)",
             stRxRawPacketCnt_.load(),
             stRxRawBytesCnt_.load(),
             stTxRawPacketCnt_.load(),
             stTxRawBytesCnt_.load());
}

TlsSessionState
TlsSession::TlsSessionImpl::setupClient()
{
    int ret;

    if (not transport_->isReliable()) {
        ret = gnutls_init(&session_, GNUTLS_CLIENT | GNUTLS_DATAGRAM);
        // uncoment to reactivate PMTUD
        // if (params_.logger)
            params_.logger->d("[TLS] set heartbeat reception for retrocompatibility check on server");
        // gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND);
    } else {
        ret = gnutls_init(&session_, GNUTLS_CLIENT);
    }

    if (ret != GNUTLS_E_SUCCESS) {
        if (params_.logger)
            params_.logger->e("[TLS] session init failed: %s", gnutls_strerror(ret));
        return TlsSessionState::SHUTDOWN;
    }

    if (not commonSessionInit()) {
        return TlsSessionState::SHUTDOWN;
    }

    return TlsSessionState::HANDSHAKE;
}

TlsSessionState
TlsSession::TlsSessionImpl::setupServer()
{
    int ret;

    if (not transport_->isReliable()) {
        ret = gnutls_init(&session_, GNUTLS_SERVER | GNUTLS_DATAGRAM);

        // uncoment to reactivate PMTUD
        // if (params_.logger)
            params_.logger->d("[TLS] set heartbeat reception");
        // gnutls_heartbeat_enable(session_, GNUTLS_HB_PEER_ALLOWED_TO_SEND);

        gnutls_dtls_prestate_set(session_, &prestate_);
    } else {
        ret = gnutls_init(&session_, GNUTLS_SERVER);
    }

    if (ret != GNUTLS_E_SUCCESS) {
        if (params_.logger)
            params_.logger->e("[TLS] session init failed: %s", gnutls_strerror(ret));
        return TlsSessionState::SHUTDOWN;
    }

    gnutls_certificate_server_set_request(session_, GNUTLS_CERT_REQUIRE);

    if (not commonSessionInit())
        return TlsSessionState::SHUTDOWN;

    return TlsSessionState::HANDSHAKE;
}

void
TlsSession::TlsSessionImpl::initAnonymous()
{
    // credentials for handshaking and transmission
    if (isServer_)
        sacred_.reset(new TlsAnonymousServerCredendials());
    else
        cacred_.reset(new TlsAnonymousClientCredendials());

    // Setup DH-params for anonymous authentification
    if (isServer_) {
        if (const auto& dh_params = params_.dh_params.get().get())
            gnutls_anon_set_server_dh_params(*sacred_, dh_params);
        else
            if (params_.logger)
                params_.logger->w("[TLS] DH params unavailable");
    }
}

void
TlsSession::TlsSessionImpl::initCredentials()
{
    int ret;

    // credentials for handshaking and transmission
    xcred_.reset(new TlsCertificateCredendials());

    gnutls_certificate_set_verify_function(*xcred_, [](gnutls_session_t session) -> int {
        auto this_ = reinterpret_cast<TlsSessionImpl*>(gnutls_session_get_ptr(session));
        return this_->verifyCertificateWrapper(session);
    });

    // Load user-given CA list
    if (not params_.ca_list.empty()) {
        // Try PEM format first
        ret = gnutls_certificate_set_x509_trust_file(*xcred_,
                                                     params_.ca_list.c_str(),
                                                     GNUTLS_X509_FMT_PEM);

        // Then DER format
        if (ret < 0)
            ret = gnutls_certificate_set_x509_trust_file(*xcred_,
                                                         params_.ca_list.c_str(),
                                                         GNUTLS_X509_FMT_DER);
        if (ret < 0)
            throw std::runtime_error("Unable to load CA " + params_.ca_list + ": "
                                     + std::string(gnutls_strerror(ret)));

        if (params_.logger)
            params_.logger->d("[TLS] CA list %s loadev", params_.ca_list.c_str());
    }
    if (params_.peer_ca) {
        auto chain = params_.peer_ca->getChainWithRevocations();
        auto ret = gnutls_certificate_set_x509_trust(*xcred_,
                                                     chain.first.data(),
                                                     chain.first.size());
        if (not chain.second.empty())
            gnutls_certificate_set_x509_crl(*xcred_, chain.second.data(), chain.second.size());
        if (params_.logger)
            params_.logger->debug("[TLS] Peer CA list {:d} ({:d} CRLs): {:d}",
                 chain.first.size(),
                 chain.second.size(),
                 ret);
    }

    // Load user-given identity (key and passwd)
    if (params_.cert) {
        std::vector<gnutls_x509_crt_t> certs;
        certs.reserve(3);
        auto crt = params_.cert;
        while (crt) {
            certs.emplace_back(crt->cert);
            crt = crt->issuer;
        }

        ret = gnutls_certificate_set_x509_key(*xcred_,
                                              certs.data(),
                                              certs.size(),
                                              params_.cert_key->x509_key);
        if (ret < 0)
            throw std::runtime_error("Unable to load certificate: " + std::string(gnutls_strerror(ret)));

        if (params_.logger)
            params_.logger->d("[TLS] User identity loaded");
    }

    // Setup DH-params (server only, may block on dh_params.get())
    if (isServer_) {
        if (const auto& dh_params = params_.dh_params.get().get())
            gnutls_certificate_set_dh_params(*xcred_, dh_params);
        else
            if (params_.logger)
                params_.logger->w("[TLS] DH params unavailable"); // YOMGUI: need to stop?
    }
}

bool
TlsSession::TlsSessionImpl::commonSessionInit()
{
    int ret;

    if (anonymous_) {
        // Force anonymous connection, see handleStateHandshake how we handle failures
        ret = gnutls_priority_set_direct(session_,
                                         transport_->isReliable() ? TLS_FULL_PRIORITY_STRING
                                                                  : DTLS_FULL_PRIORITY_STRING,
                                         nullptr);
        if (ret != GNUTLS_E_SUCCESS) {
            if (params_.logger)
                params_.logger->e("[TLS] TLS priority set failed: %s", gnutls_strerror(ret));
            return false;
        }

        // Add anonymous credentials
        if (isServer_)
            ret = gnutls_credentials_set(session_, GNUTLS_CRD_ANON, *sacred_);
        else
            ret = gnutls_credentials_set(session_, GNUTLS_CRD_ANON, *cacred_);

        if (ret != GNUTLS_E_SUCCESS) {
            if (params_.logger)
                params_.logger->e("[TLS] anonymous credential set failed: %s", gnutls_strerror(ret));
            return false;
        }
    } else {
        // Use a classic non-encrypted CERTIFICATE exchange method (less anonymous)
        ret = gnutls_priority_set_direct(session_,
                                         transport_->isReliable() ? TLS_CERT_PRIORITY_STRING
                                                                  : DTLS_CERT_PRIORITY_STRING,
                                         nullptr);
        if (ret != GNUTLS_E_SUCCESS) {
            if (params_.logger)
                params_.logger->e("[TLS] TLS priority set failed: %s", gnutls_strerror(ret));
            return false;
        }
    }

    // Add certificate credentials
    ret = gnutls_credentials_set(session_, GNUTLS_CRD_CERTIFICATE, *xcred_);
    if (ret != GNUTLS_E_SUCCESS) {
        if (params_.logger)
            params_.logger->e("[TLS] certificate credential set failed: %s", gnutls_strerror(ret));
        return false;
    }
    gnutls_certificate_send_x509_rdn_sequence(session_, 0);

    if (not transport_->isReliable()) {
        // DTLS hanshake timeouts
        auto re_tx_timeout = duration2ms(DTLS_RETRANSMIT_TIMEOUT);
        gnutls_dtls_set_timeouts(session_,
                                 re_tx_timeout,
                                 std::max(duration2ms(params_.timeout), re_tx_timeout));

        // gnutls DTLS mtu = maximum payload size given by transport
        gnutls_dtls_set_mtu(session_, transport_->maxPayload());
    }

    // Stuff for transport callbacks
    gnutls_session_set_ptr(session_, this);
    gnutls_transport_set_ptr(session_, this);
    gnutls_transport_set_vec_push_function(session_,
                                           [](gnutls_transport_ptr_t t,
                                              const giovec_t* iov,
                                              int iovcnt) -> ssize_t {
                                               auto this_ = reinterpret_cast<TlsSessionImpl*>(t);
                                               return this_->sendRawVec(iov, iovcnt);
                                           });
    gnutls_transport_set_pull_function(session_,
                                       [](gnutls_transport_ptr_t t, void* d, size_t s) -> ssize_t {
                                           auto this_ = reinterpret_cast<TlsSessionImpl*>(t);
                                           return this_->recvRaw(d, s);
                                       });
    gnutls_transport_set_pull_timeout_function(session_,
                                               [](gnutls_transport_ptr_t t, unsigned ms) -> int {
                                                   auto this_ = reinterpret_cast<TlsSessionImpl*>(t);
                                                   return this_->waitForRawData(
                                                       std::chrono::milliseconds(ms));
                                               });
    // TODO -1 = default else set value
    if (transport_->isReliable())
        gnutls_handshake_set_timeout(session_, duration2ms(params_.timeout));
    return true;
}

std::string
getOcspUrl(gnutls_x509_crt_t cert)
{
    int ret;
    gnutls_datum_t aia;
    unsigned int seq = 0;
    do {
        // Extracts the Authority Information Access (AIA) extension, see RFC 5280 section 4.2.2.1
        ret = gnutls_x509_crt_get_authority_info_access(cert, seq++, GNUTLS_IA_OCSP_URI, &aia, NULL);
    } while (ret < 0 && ret != GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE);
    // could also try the issuer if we include ocsp uri into there
    if (ret < 0) {
        return {};
    }
    std::string url((const char*) aia.data, (size_t) aia.size);
    gnutls_free(aia.data);
    return url;
}

int
TlsSession::TlsSessionImpl::verifyCertificateWrapper(gnutls_session_t session)
{
    // Perform user-set verification first to avoid flooding with ocsp-requests if peer is denied
    int verified;
    if (callbacks_.verifyCertificate) {
        auto this_ = reinterpret_cast<TlsSessionImpl*>(gnutls_session_get_ptr(session));
        verified = this_->callbacks_.verifyCertificate(session);
        if (verified != GNUTLS_E_SUCCESS)
            return verified;
    } else {
        verified = GNUTLS_E_SUCCESS;
    }
    /*
     * Support only x509 format
     */
    if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509)
        return GNUTLS_E_CERTIFICATE_ERROR;

    pCert_ = peerCertificate(session);
    if (!pCert_)
        return GNUTLS_E_CERTIFICATE_ERROR;

    std::string ocspUrl = getOcspUrl(pCert_->cert);
    if (ocspUrl.empty()) {
        // Skipping OCSP verification: AIA not found
        return verified;
    }

    // OCSP (Online Certificate Service Protocol) {
    std::promise<int> v;
    std::future<int> f = v.get_future();

    gnutls_x509_crt_t issuer_crt = pCert_->issuer ? pCert_->issuer->cert : nullptr;
    verifyOcsp(ocspUrl, *pCert_, issuer_crt, [&](const int status) {
        if (status == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) {
            // OCSP URI is absent, don't fail the verification by overwritting the user-set one.
            if (params_.logger)
                params_.logger->w("Skipping OCSP verification %s: request failed", pCert_->getUID().c_str());
            v.set_value(verified);
        } else {
            if (status != GNUTLS_E_SUCCESS) {
                if (params_.logger)
                    params_.logger->e("OCSP verification failed for %s: %s (%i)",
                         pCert_->getUID().c_str(),
                         gnutls_strerror(status),
                         status);
            }
            v.set_value(status);
        }
    });
    f.wait();

    return f.get();
}

void
TlsSession::TlsSessionImpl::verifyOcsp(const std::string& aia_uri,
                                       dht::crypto::Certificate& cert,
                                       gnutls_x509_crt_t issuer,
                                       OcspVerification cb)
{
    if (params_.logger)
        params_.logger->d("Certificate's AIA URI: %s", aia_uri.c_str());

    // Generate OCSP request
    std::pair<std::string, dht::Blob> ocsp_req;
    try {
        ocsp_req = cert.generateOcspRequest(issuer);
    } catch (dht::crypto::CryptoException& e) {
        if (params_.logger)
            params_.logger->e("Failed to generate OCSP request: %s", e.what());
        if (cb)
            cb(GNUTLS_E_INVALID_REQUEST);
        return;
    }

    sendOcspRequest(aia_uri,
                    std::move(ocsp_req.first),
                    OCSP_REQUEST_TIMEOUT,
                    [cb = std::move(cb), &cert, nonce = std::move(ocsp_req.second), this](
                        const dht::http::Response& r) {
                        // Prepare response data
                        // Verify response validity
                        if (r.status_code != 200) {
                            if (params_.logger)
                                params_.logger->w("HTTP OCSP Request Failed with code %i", r.status_code);
                            if (cb)
                                cb(GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE);
                            return;
                        }
                        if (params_.logger)
                            params_.logger->d("HTTP OCSP Request done!");
                        gnutls_ocsp_cert_status_t verify = GNUTLS_OCSP_CERT_UNKNOWN;
                        try {
                            cert.ocspResponse = std::make_shared<dht::crypto::OcspResponse>(
                                (const uint8_t*) r.body.data(), r.body.size());
                            if (params_.logger)
                                params_.logger->d("%s", cert.ocspResponse->toString().c_str());
                            verify = cert.ocspResponse->verifyDirect(cert, nonce);
                        } catch (dht::crypto::CryptoException& e) {
                            if (params_.logger)
                                params_.logger->e("Failed to verify OCSP response: %s", e.what());
                        }
                        if (verify == GNUTLS_OCSP_CERT_UNKNOWN) {
                            // Soft-fail
                            if (cb)
                                cb(GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE);
                            return;
                        }
                        int status = GNUTLS_E_SUCCESS;
                        if (verify == GNUTLS_OCSP_CERT_GOOD) {
                            if (params_.logger)
                                params_.logger->d("OCSP verification success!");
                        } else {
                            status = GNUTLS_E_CERTIFICATE_ERROR;
                            if (params_.logger)
                                params_.logger->e("OCSP verification: certificate is revoked!");
                        }
                        // Save response into the certificate store
                        try {
                            params_.certStore.pinOcspResponse(cert);
                        } catch (std::exception& e) {
                            if (params_.logger)
                                params_.logger->error("{}", e.what());
                        }
                        if (cb)
                            cb(status);
                    });
}

void
TlsSession::TlsSessionImpl::sendOcspRequest(const std::string& uri,
                                            std::string body,
                                            std::chrono::seconds timeout,
                                            HttpResponse cb)
{
    using namespace dht;
    auto request = std::make_shared<http::Request>(*params_.io_context,
                                                   uri); //, logger);
    request->set_method(restinio::http_method_post());
    request->set_header_field(restinio::http_field_t::user_agent, "Jami");
    request->set_header_field(restinio::http_field_t::accept, "*/*");
    request->set_header_field(restinio::http_field_t::content_type, "application/ocsp-request");
    request->set_body(std::move(body));
    request->set_connection_type(restinio::http_connection_header_t::close);
    request->timeout(timeout, [request,l=params_.logger](const asio::error_code& ec) {
        if (ec and ec != asio::error::operation_aborted)
            if (l) l->error("HTTP OCSP Request timeout with error: {:s}", ec.message());
        request->cancel();
    });
    request->add_on_state_change_callback([this, cb = std::move(cb)](const http::Request::State state,
                                                    const http::Response response) {
        if (params_.logger)
            params_.logger->d("HTTP OCSP Request state=%i status_code=%i",
                 (unsigned int) state,
                 response.status_code);
        if (state != http::Request::State::DONE)
            return;
        if (cb)
            cb(response);
        if (auto request = response.request.lock()) {
            std::lock_guard lock(requestsMtx_);
            requests_.erase(request);
        }
    });
    {
        std::lock_guard lock(requestsMtx_);
        requests_.emplace(request);
    }
    request->send();
}

std::shared_ptr<dht::crypto::Certificate>
TlsSession::TlsSessionImpl::peerCertificate(gnutls_session_t session) const
{
    if (!session)
        return {};
    /*
     * Get the peer's raw certificate (chain) as sent by the peer.
     * The first certificate in the list is the peer's certificate, following the issuer's cert. etc.
     */
    unsigned int cert_list_size = 0;
    auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size);

    if (cert_list == nullptr)
        return {};
    std::vector<std::pair<uint8_t*, uint8_t*>> crt_data;
    crt_data.reserve(cert_list_size);
    for (unsigned i = 0; i < cert_list_size; i++)
        crt_data.emplace_back(cert_list[i].data, cert_list[i].data + cert_list[i].size);
    return std::make_shared<dht::crypto::Certificate>(crt_data);
}

std::size_t
TlsSession::TlsSessionImpl::send(const ValueType* tx_data, std::size_t tx_size, std::error_code& ec)
{
    std::lock_guard lk(sessionWriteMutex_);
    if (state_ != TlsSessionState::ESTABLISHED) {
        ec = std::error_code(GNUTLS_E_INVALID_SESSION, std::system_category());
        return 0;
    }

    std::size_t total_written = 0;
    std::size_t max_tx_sz;

    if (transport_->isReliable())
        max_tx_sz = tx_size;
    else
        max_tx_sz = gnutls_dtls_get_data_mtu(session_);

    // Split incoming data into chunck suitable for the underlying transport
    while (total_written < tx_size) {
        auto chunck_sz = std::min(max_tx_sz, tx_size - total_written);
        auto data_seq = tx_data + total_written;
        ssize_t nwritten;
        do {
            nwritten = gnutls_record_send(session_, data_seq, chunck_sz);
        } while ((nwritten == GNUTLS_E_INTERRUPTED and state_ != TlsSessionState::SHUTDOWN)
                 or nwritten == GNUTLS_E_AGAIN);
        if (nwritten < 0) {
            /* Normally we would have to retry record_send but our internal
             * state has not changed, so we have to ask for more data first.
             * We will just try again later, although this should never happen.
             */
            if (params_.logger)
                params_.logger->error("[TLS] send failed (only {} bytes sent): {}", total_written, gnutls_strerror(nwritten));
            ec = std::error_code(nwritten, std::system_category());
            return 0;
        }

        total_written += nwritten;
    }

    ec.clear();
    return total_written;
}

// Called by GNUTLS to send encrypted packet to low-level transport.
// Should return a positive number indicating the bytes sent, and -1 on error.
ssize_t
TlsSession::TlsSessionImpl::sendRaw(const void* buf, size_t size)
{
    std::error_code ec;
    unsigned retry_count = 0;
    do {
        auto n = transport_->write(reinterpret_cast<const ValueType*>(buf), size, ec);
        if (!ec) {
            // log only on success
            ++stTxRawPacketCnt_;
            stTxRawBytesCnt_ += n;
            return n;
        }

        if (ec.value() == EAGAIN) {
            if (params_.logger)
                params_.logger->warn("[TLS] EAGAIN from transport, retry#{:d}", ++retry_count);
            std::this_thread::sleep_for(std::chrono::milliseconds(10));
            if (retry_count == 100) {
                if (params_.logger)
                    params_.logger->e("[TLS] excessive retry detected, aborting");
                ec.assign(EIO, std::system_category());
            }
        }
    } while (ec.value() == EAGAIN);

    // Must be called to pass errno value to GnuTLS on Windows (cf. GnuTLS doc)
    gnutls_transport_set_errno(session_, ec.value());
    if (params_.logger)
        params_.logger->error("[TLS] transport failure on tx: errno = {:d}: {:s}", ec.value(), strerror(ec.value()));
    return -1;
}

// Called by GNUTLS to send encrypted packet to low-level transport.
// Should return a positive number indicating the bytes sent, and -1 on error.
ssize_t
TlsSession::TlsSessionImpl::sendRawVec(const giovec_t* iov, int iovcnt)
{
    ssize_t sent = 0;
    for (int i = 0; i < iovcnt; ++i) {
        const giovec_t& dat = iov[i];
        ssize_t ret = sendRaw(dat.iov_base, dat.iov_len);
        if (ret < 0)
            return -1;
        sent += ret;
    }
    return sent;
}

// Called by GNUTLS to receive encrypted packet from low-level transport.
// Should return 0 on connection termination,
// a positive number indicating the number of bytes received,
// and -1 on error.
ssize_t
TlsSession::TlsSessionImpl::recvRaw(void* buf, size_t size)
{
    if (transport_->isReliable()) {
        std::error_code ec;
        auto count = transport_->read(reinterpret_cast<ValueType*>(buf), size, ec);
        if (!ec)
            return count;
        gnutls_transport_set_errno(session_, ec.value());
        return -1;
    }

    std::lock_guard lk {rxMutex_};
    if (rxQueue_.empty()) {
        gnutls_transport_set_errno(session_, EAGAIN);
        return -1;
    }

    const auto& pkt = rxQueue_.front();
    const std::size_t count = std::min(pkt.size(), size);
    std::copy_n(pkt.begin(), count, reinterpret_cast<ValueType*>(buf));
    rxQueue_.pop_front();
    return count;
}

// Called by GNUTLS to wait for encrypted packet from low-level transport.
// 'timeout' is in milliseconds.
// Should return 0 on timeout, a positive number if data are available for read, or -1 on error.
int
TlsSession::TlsSessionImpl::waitForRawData(std::chrono::milliseconds timeout)
{
    if (transport_->isReliable()) {
        std::error_code ec;
        auto err = transport_->waitForData(timeout, ec);
        if (err <= 0) {
            // shutdown?
            if (state_ == TlsSessionState::SHUTDOWN) {
                gnutls_transport_set_errno(session_, EINTR);
                return -1;
            }
            if (ec) {
                gnutls_transport_set_errno(session_, ec.value());
                return -1;
            }
            return 0;
        }
        return 1;
    }

    // non-reliable uses callback installed with setOnRecv()
    std::unique_lock lk {rxMutex_};
    rxCv_.wait_for(lk, timeout, [this] {
        return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN;
    });
    if (state_ == TlsSessionState::SHUTDOWN) {
        gnutls_transport_set_errno(session_, EINTR);
        return -1;
    }
    if (rxQueue_.empty()) {
        if (params_.logger)
            params_.logger->error("[TLS] waitForRawData: timeout after {}", timeout);
        return 0;
    }
    return 1;
}

bool
TlsSession::TlsSessionImpl::initFromRecordState(int offset)
{
    std::array<uint8_t, 8> seq;
    if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0])
        != GNUTLS_E_SUCCESS) {
        if (params_.logger)
            params_.logger->e("[TLS] Fatal-error Unable to read initial state");
        return false;
    }

    baseSeq_ = array2uint(seq) + offset;
    gapOffset_ = baseSeq_;
    lastRxSeq_ = baseSeq_ - 1;
    if (params_.logger)
        params_.logger->debug("[TLS] Initial sequence number: {:d}", baseSeq_);
    return true;
}

bool
TlsSession::TlsSessionImpl::setup()
{
    // Setup FSM
    fsmHandlers_[TlsSessionState::SETUP] = [this](TlsSessionState s) {
        return handleStateSetup(s);
    };
    fsmHandlers_[TlsSessionState::COOKIE] = [this](TlsSessionState s) {
        return handleStateCookie(s);
    };
    fsmHandlers_[TlsSessionState::HANDSHAKE] = [this](TlsSessionState s) {
        return handleStateHandshake(s);
    };
    fsmHandlers_[TlsSessionState::MTU_DISCOVERY] = [this](TlsSessionState s) {
        return handleStateMtuDiscovery(s);
    };
    fsmHandlers_[TlsSessionState::ESTABLISHED] = [this](TlsSessionState s) {
        return handleStateEstablished(s);
    };
    fsmHandlers_[TlsSessionState::SHUTDOWN] = [this](TlsSessionState s) {
        return handleStateShutdown(s);
    };

    return true;
}

void
TlsSession::TlsSessionImpl::cleanup()
{
    state_ = TlsSessionState::SHUTDOWN; // be sure to block any user operations
    stateCondition_.notify_all();

    {
        std::lock_guard lk1(sessionReadMutex_);
        std::lock_guard lk2(sessionWriteMutex_);
        if (session_) {
            if (transport_->isReliable())
                gnutls_bye(session_, GNUTLS_SHUT_RDWR);
            else
                gnutls_bye(session_, GNUTLS_SHUT_WR); // not wait for a peer answer
            gnutls_deinit(session_);
            session_ = nullptr;
        }
    }

    if (cookie_key_.data)
        gnutls_free(cookie_key_.data);

    transport_->shutdown();
}

TlsSessionState
TlsSession::TlsSessionImpl::handleStateSetup([[maybe_unused]] TlsSessionState state)
{
    if (params_.logger)
        params_.logger->d("[TLS] Start %s session", typeName());

    try {
        if (anonymous_)
            initAnonymous();
        initCredentials();
    } catch (const std::exception& e) {
        if (params_.logger)
            params_.logger->e("[TLS] authentifications init failed: %s", e.what());
        return TlsSessionState::SHUTDOWN;
    }

    if (not isServer_)
        return setupClient();

    // Extra step for DTLS-like transports
    if (transport_ and not transport_->isReliable()) {
        gnutls_key_generate(&cookie_key_, GNUTLS_COOKIE_KEY_SIZE);
        return TlsSessionState::COOKIE;
    }
    return setupServer();
}

TlsSessionState
TlsSession::TlsSessionImpl::handleStateCookie(TlsSessionState state)
{
    if (params_.logger)
        params_.logger->d("[TLS] SYN cookie");

    std::size_t count;
    {
        // block until rx packet or shutdown
        std::unique_lock lk {rxMutex_};
        if (!rxCv_.wait_for(lk, COOKIE_TIMEOUT, [this] {
                return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN;
            })) {
            if (params_.logger)
                params_.logger->e("[TLS] SYN cookie failed: timeout");
            return TlsSessionState::SHUTDOWN;
        }
        // Shutdown state?
        if (rxQueue_.empty())
            return TlsSessionState::SHUTDOWN;
        count = rxQueue_.front().size();
    }

    // Total bytes rx during cookie checking (see flood protection below)
    cookie_count_ += count;

    int ret;

    // Peek and verify front packet
    {
        std::lock_guard lk {rxMutex_};
        auto& pkt = rxQueue_.front();
        std::memset(&prestate_, 0, sizeof(prestate_));
        ret = gnutls_dtls_cookie_verify(&cookie_key_, nullptr, 0, pkt.data(), pkt.size(), &prestate_);
    }

    if (ret < 0) {
        gnutls_dtls_cookie_send(&cookie_key_,
                                nullptr,
                                0,
                                &prestate_,
                                this,
                                [](gnutls_transport_ptr_t t, const void* d, size_t s) -> ssize_t {
                                    auto this_ = reinterpret_cast<TlsSessionImpl*>(t);
                                    return this_->sendRaw(d, s);
                                });

        // Drop front packet
        {
            std::lock_guard lk {rxMutex_};
            rxQueue_.pop_front();
        }

        // Cookie may be sent on multiple network packets
        // So we retry until we get a valid cookie.
        // To protect against a flood attack we delay each retry after FLOOD_THRESHOLD rx bytes.
        if (cookie_count_ >= FLOOD_THRESHOLD) {
            if (params_.logger)
                params_.logger->warn("[TLS] flood threshold reach (retry in {})", FLOOD_PAUSE);
            dump_io_stats();
            std::this_thread::sleep_for(FLOOD_PAUSE); // flood attack protection
        }
        return state;
    }

    if (params_.logger)
        params_.logger->d("[TLS] cookie ok");

    return setupServer();
}

TlsSessionState
TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state)
{
    int ret;
    size_t retry_count = 0;
    if (params_.logger)
        params_.logger->debug("[TLS] handshake");
    do {
        ret = gnutls_handshake(session_);
    } while ((ret == GNUTLS_E_INTERRUPTED or ret == GNUTLS_E_AGAIN)
             and ++retry_count < HANDSHAKE_MAX_RETRY
             and state_.load() != TlsSessionState::SHUTDOWN);
    if (retry_count > 0) {
        if (params_.logger)
            params_.logger->error("[TLS] handshake retried count: {}", retry_count);
    }

    // Stop on fatal error
    if (gnutls_error_is_fatal(ret) || state_.load() == TlsSessionState::SHUTDOWN) {
        if (params_.logger)
            params_.logger->error("[TLS] handshake failed: {:s}", gnutls_strerror(ret));
        return TlsSessionState::SHUTDOWN;
    }

    // Continue handshaking on non-fatal error
    if (ret != GNUTLS_E_SUCCESS) {
        // TODO: handle GNUTLS_E_LARGE_PACKET (MTU must be lowered)
        if (ret != GNUTLS_E_AGAIN)
            if (params_.logger)
                params_.logger->debug("[TLS] non-fatal handshake error: {:s}", gnutls_strerror(ret));
        return state;
    }

    // Safe-Renegotiation status shall always be true to prevent MiM attack
    // Following https://www.gnutls.org/manual/html_node/Safe-renegotiation.html
    // "Unlike TLS 1.2, the server is not allowed to change identities"
    // So, we don't have to check the status if we are the client
    bool isTLS1_3 = gnutls_protocol_get_version(session_) == GNUTLS_TLS1_3;
    if (!isTLS1_3 || (isTLS1_3 && isServer_)) {
        if (!gnutls_safe_renegotiation_status(session_)) {
            if (params_.logger)
                params_.logger->error("[TLS] server identity changed! MiM attack?");
            return TlsSessionState::SHUTDOWN;
        }
    }

    auto desc = gnutls_session_get_desc(session_);
    if (params_.logger)
        params_.logger->debug("[TLS] session established: {:s}", desc);
    gnutls_free(desc);

    // Anonymous connection? rehandshake immediately with certificate authentification forced
    auto cred = gnutls_auth_get_type(session_);
    if (cred == GNUTLS_CRD_ANON) {
        if (params_.logger)
            params_.logger->debug("[TLS] renogotiate with certificate authentification");

        // Re-setup TLS algorithms priority list with only certificate based cipher suites
        ret = gnutls_priority_set_direct(session_,
                                         transport_ and transport_->isReliable()
                                             ? TLS_CERT_PRIORITY_STRING
                                             : DTLS_CERT_PRIORITY_STRING,
                                         nullptr);
        if (ret != GNUTLS_E_SUCCESS) {
            if (params_.logger)
                params_.logger->error("[TLS] session TLS cert-only priority set failed: {:s}", gnutls_strerror(ret));
            return TlsSessionState::SHUTDOWN;
        }

        // remove anon credentials and re-enable certificate ones
        gnutls_credentials_clear(session_);
        ret = gnutls_credentials_set(session_, GNUTLS_CRD_CERTIFICATE, *xcred_);
        if (ret != GNUTLS_E_SUCCESS) {
            if (params_.logger)
                params_.logger->error("[TLS] session credential set failed: {:s}", gnutls_strerror(ret));
            return TlsSessionState::SHUTDOWN;
        }

        return state; // handshake

    } else if (cred != GNUTLS_CRD_CERTIFICATE) {
        if (params_.logger)
            params_.logger->error("[TLS] spurious session credential ({})", (int)cred);
        return TlsSessionState::SHUTDOWN;
    }

    // Aware about certificates updates
    if (callbacks_.onCertificatesUpdate) {
        unsigned int remote_count;
        auto local = gnutls_certificate_get_ours(session_);
        auto remote = gnutls_certificate_get_peers(session_, &remote_count);
        callbacks_.onCertificatesUpdate(local, remote, remote_count);
    }

    return transport_ and transport_->isReliable() ? TlsSessionState::ESTABLISHED
                                                   : TlsSessionState::MTU_DISCOVERY;
}

TlsSessionState
TlsSession::TlsSessionImpl::handleStateMtuDiscovery([[maybe_unused]] TlsSessionState state)
{
    if (!transport_) {
        if (params_.logger)
            params_.logger->w("No transport available when discovering the MTU");
        return TlsSessionState::SHUTDOWN;
    }
    mtuProbe_ = transport_->maxPayload();
    assert(mtuProbe_ >= MIN_MTU);
    MTUS_ = {MIN_MTU, std::max((mtuProbe_ + MIN_MTU) / 2, MIN_MTU), mtuProbe_};

    // retrocompatibility check
    if (gnutls_heartbeat_allowed(session_, GNUTLS_HB_LOCAL_ALLOWED_TO_SEND) == 1) {
        if (!isServer_) {
            pathMtuHeartbeat();
            if (state_ == TlsSessionState::SHUTDOWN) {
                if (params_.logger)
                    params_.logger->e("[TLS] session destroyed while performing PMTUD, shuting down");
                return TlsSessionState::SHUTDOWN;
            }
            pmtudOver_ = true;
        }
    } else {
        if (params_.logger)
            params_.logger->e("[TLS] PEER HEARTBEAT DISABLED: using transport MTU value ", mtuProbe_);
        pmtudOver_ = true;
    }

    gnutls_dtls_set_mtu(session_, mtuProbe_);
    maxPayload_ = gnutls_dtls_get_data_mtu(session_);

    if (pmtudOver_) {
        if (params_.logger)
            params_.logger->d("[TLS] maxPayload: ", maxPayload_.load());
        if (!initFromRecordState())
            return TlsSessionState::SHUTDOWN;
    }

    return TlsSessionState::ESTABLISHED;
}

/*
 * Path MTU discovery heuristic
 * heuristic description:
 * The two members of the current tls connection will exchange dtls heartbeat messages
 * of increasing size until the heartbeat times out which will be considered as a packet
 * drop from the network due to the size of the packet. (one retry to test for a buffer issue)
 * when timeout happens or all the values have been tested, the mtu will be returned.
 * In case of unexpected error the first (and minimal) value of the mtu array
 */
void
TlsSession::TlsSessionImpl::pathMtuHeartbeat()
{
    if (params_.logger)
        params_.logger->debug("[TLS] PMTUD: starting probing with {} of retransmission timeout", HEARTBEAT_RETRANS_TIMEOUT);

    gnutls_heartbeat_set_timeouts(session_,
                                  HEARTBEAT_RETRANS_TIMEOUT.count(),
                                  HEARTBEAT_TOTAL_TIMEOUT.count());

    int errno_send = GNUTLS_E_SUCCESS;
    int mtuOffset = 0;

    // when the remote (server) has a IPV6 interface selected by ICE, and local (client) has a IPV4
    // selected, the path MTU discovery triggers errors for packets too big on server side because
    // of different IP headers overhead. Hence we have to signal to the TLS session to reduce the
    // MTU on client size accordingly.
    if (transport_ and transport_->localAddr().isIpv4() and transport_->remoteAddr().isIpv6()) {
        mtuOffset = ASYMETRIC_TRANSPORT_MTU_OFFSET;
        if (params_.logger)
            params_.logger->w("[TLS] local/remote IP protocol version not alike, use an MTU offset of {} bytes to compensate", ASYMETRIC_TRANSPORT_MTU_OFFSET);
    }

    mtuProbe_ = MTUS_[0];

    for (auto mtu : MTUS_) {
        gnutls_dtls_set_mtu(session_, mtu);
        auto data_mtu = gnutls_dtls_get_data_mtu(session_);
        if (params_.logger)
            params_.logger->debug("[TLS] PMTUD: mtu {}, payload {}", mtu,  data_mtu);
        auto bytesToSend = data_mtu - mtuOffset - 3; // want to know why -3? ask gnutls!

        do {
            errno_send = gnutls_heartbeat_ping(session_,
                                               bytesToSend,
                                               HEARTBEAT_TRIES,
                                               GNUTLS_HEARTBEAT_WAIT);
        } while (errno_send == GNUTLS_E_AGAIN
                 || (errno_send == GNUTLS_E_INTERRUPTED && state_ != TlsSessionState::SHUTDOWN));

        if (errno_send != GNUTLS_E_SUCCESS) {
            if (params_.logger)
                params_.logger->debug("[TLS] PMTUD: mtu {} [FAILED]", mtu);
            break;
        }

        mtuProbe_ = mtu;
        if (params_.logger)
            params_.logger->debug("[TLS] PMTUD: mtu {} [OK]", mtu);
    }

    if (errno_send == GNUTLS_E_TIMEDOUT) { // timeout is considered as a packet loss, then the good
                                           // mtu is the precedent
        if (mtuProbe_ == MTUS_[0]) {
            if (params_.logger)
                params_.logger->warn("[TLS] PMTUD: no response on first ping, using minimal MTU value {}", mtuProbe_);
        } else {
            if (params_.logger)
                params_.logger->warn("[TLS] PMTUD: timed out, using last working mtu {}", mtuProbe_);
        }
    } else if (errno_send != GNUTLS_E_SUCCESS) {
        if (params_.logger)
            params_.logger->error("[TLS] PMTUD: failed with gnutls error '{}'", gnutls_strerror(errno_send));
    } else {
        if (params_.logger)
            params_.logger->debug("[TLS] PMTUD: reached maximal value");
    }
}

void
TlsSession::TlsSessionImpl::handleDataPacket(std::vector<ValueType>&& buf, uint64_t pkt_seq)
{
    // Check for a valid seq. num. delta
    int64_t seq_delta = pkt_seq - lastRxSeq_;
    if (seq_delta > 0) {
        lastRxSeq_ = pkt_seq;
    } else {
        // too old?
        if (seq_delta <= -MISS_ORDERING_LIMIT) {
            if (params_.logger)
                params_.logger->warn("[TLS] drop old pkt: 0x{:x}", pkt_seq);
            return;
        }

        // No duplicate check as DTLS prevents that for us (replay protection)

        // accept Out-Of-Order pkt - will be reordered by queue flush operation
        if (params_.logger)
            params_.logger->warn("[TLS] OOO pkt: 0x{:x}", pkt_seq);
    }

    std::unique_lock lk {rxMutex_};
    auto now = clock::now();
    if (reorderBuffer_.empty())
        lastReadTime_ = now;
    reorderBuffer_.emplace(pkt_seq, std::move(buf));
    nextFlush_.emplace_back(now + RX_OOO_TIMEOUT);
    rxCv_.notify_one();
    // Try to flush right now as a new packet is available
    flushRxQueue(lk);
}

///
/// Reorder and push received packet to upper layer
///
/// \note This method must be called continuously, faster than RX_OOO_TIMEOUT
///
void
TlsSession::TlsSessionImpl::flushRxQueue(std::unique_lock<std::mutex>& lk)
{
    // RAII bool swap
    class GuardedBoolSwap
    {
    public:
        explicit GuardedBoolSwap(bool& var)
            : var_ {var}
        {
            var_ = !var_;
        }
        ~GuardedBoolSwap() { var_ = !var_; }

    private:
        bool& var_;
    };

    if (reorderBuffer_.empty())
        return;

    // Prevent re-entrant access as the callbacks_.onRxData() is called in unprotected region
    if (flushProcessing_)
        return;

    GuardedBoolSwap swap_flush_processing {flushProcessing_};

    auto now = clock::now();

    auto item = std::begin(reorderBuffer_);
    auto next_offset = item->first;

    // Wait for next continuous packet until timeout
    if ((now - lastReadTime_) >= RX_OOO_TIMEOUT) {
        // OOO packet timeout - consider waited packets as lost
        if (auto lost = next_offset - gapOffset_) {
            if (params_.logger)
                params_.logger->warn("[TLS] {:d} lost since 0x{:x}", lost, gapOffset_);
        } else if (params_.logger)
            params_.logger->warn("[TLS] slow flush");
    } else if (next_offset != gapOffset_)
        return;

    // Loop on offset-ordered received packet until a discontinuity in sequence number
    while (item != std::end(reorderBuffer_) and item->first <= next_offset) {
        auto pkt_offset = item->first;
        auto pkt = std::move(item->second);

        // Remove item before unlocking to not trash the item' relationship
        next_offset = pkt_offset + 1;
        item = reorderBuffer_.erase(item);

        if (callbacks_.onRxData) {
            lk.unlock();
            callbacks_.onRxData(std::move(pkt));
            lk.lock();
        }
    }

    gapOffset_ = std::max(gapOffset_, next_offset);
    lastReadTime_ = now;
}

TlsSessionState
TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state)
{
    // Nothing to do in reliable mode, so just wait for state change
    if (transport_ and transport_->isReliable()) {
        auto disconnected = [this]() -> bool {
            return state_.load() != TlsSessionState::ESTABLISHED
                   or newState_.load() != TlsSessionState::NONE;
        };
        std::unique_lock lk(stateMutex_);
        stateCondition_.wait(lk, disconnected);
        auto oldState = state_.load();
        if (oldState == TlsSessionState::ESTABLISHED) {
            auto newState = newState_.load();
            if (newState != TlsSessionState::NONE) {
                newState_ = TlsSessionState::NONE;
                return newState;
            }
        }
        return oldState;
    }

    // block until rx packet or state change
    {
        std::unique_lock lk {rxMutex_};
        if (nextFlush_.empty())
            rxCv_.wait(lk, [this] {
                return state_ != TlsSessionState::ESTABLISHED or not rxQueue_.empty()
                       or not nextFlush_.empty();
            });
        else
            rxCv_.wait_until(lk, nextFlush_.front(), [this] {
                return state_ != TlsSessionState::ESTABLISHED or !rxQueue_.empty();
            });
        state = state_.load();
        if (state != TlsSessionState::ESTABLISHED)
            return state;

        if (not nextFlush_.empty()) {
            auto now = clock::now();
            if (nextFlush_.front() <= now) {
                while (nextFlush_.front() <= now)
                    nextFlush_.pop_front();
                flushRxQueue(lk);
                return state;
            }
        }
    }

    std::array<uint8_t, 8> seq;
    rawPktBuf_.resize(RX_MAX_SIZE);
    auto ret = gnutls_record_recv_seq(session_, rawPktBuf_.data(), rawPktBuf_.size(), &seq[0]);

    if (ret > 0) {
        // Are we in PMTUD phase?
        if (!pmtudOver_) {
            mtuProbe_ = MTUS_[std::max(0, hbPingRecved_ - 1)];
            gnutls_dtls_set_mtu(session_, mtuProbe_);
            maxPayload_ = gnutls_dtls_get_data_mtu(session_);
            pmtudOver_ = true;
            if (params_.logger)
                params_.logger->debug("[TLS] maxPayload: {}", maxPayload_.load());

            if (!initFromRecordState(-1))
                return TlsSessionState::SHUTDOWN;
        }

        rawPktBuf_.resize(ret);
        handleDataPacket(std::move(rawPktBuf_), array2uint(seq));
        // no state change
    } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) {
        if (params_.logger)
            params_.logger->d("[TLS] PMTUD: ping received sending pong");
        auto errno_send = gnutls_heartbeat_pong(session_, 0);

        if (errno_send != GNUTLS_E_SUCCESS) {
            if (params_.logger)
                params_.logger->e("[TLS] PMTUD: failed on pong with error %d: %s",
                     errno_send,
                     gnutls_strerror(errno_send));
        } else {
            ++hbPingRecved_;
        }
        // no state change
    } else if (ret == 0) {
        if (params_.logger)
            params_.logger->d("[TLS] eof");
        state = TlsSessionState::SHUTDOWN;
    } else if (ret == GNUTLS_E_REHANDSHAKE) {
        if (params_.logger)
            params_.logger->d("[TLS] re-handshake");
        state = TlsSessionState::HANDSHAKE;
    } else if (gnutls_error_is_fatal(ret)) {
        if (params_.logger)
            params_.logger->e("[TLS] fatal error in recv: %s", gnutls_strerror(ret));
        state = TlsSessionState::SHUTDOWN;
    } // else non-fatal error... let's continue

    return state;
}

TlsSessionState
TlsSession::TlsSessionImpl::handleStateShutdown(TlsSessionState state)
{
    if (params_.logger)
        params_.logger->d("[TLS] shutdown");

    // Stop ourself
    thread_.stop();
    return state;
}

void
TlsSession::TlsSessionImpl::process()
{
    auto old_state = state_.load();
    auto new_state = fsmHandlers_[old_state](old_state);

    // update state_ with taking care for external state change
    if (not std::atomic_compare_exchange_strong(&state_, &old_state, new_state))
        new_state = old_state;

    if (old_state != new_state)
        stateCondition_.notify_all();

    if (old_state != new_state and callbacks_.onStateChange)
        callbacks_.onStateChange(new_state);
}

//==============================================================================

TlsSession::TlsSession(std::unique_ptr<SocketType>&& transport,
                       const TlsParams& params,
                       const TlsSessionCallbacks& cbs,
                       bool anonymous)

    : pimpl_ {std::make_unique<TlsSessionImpl>(std::move(transport), params, cbs, anonymous)}
{}

TlsSession::~TlsSession() {}

bool
TlsSession::isInitiator() const
{
    return !pimpl_->isServer_;
}

bool
TlsSession::isReliable() const
{
    if (!pimpl_->transport_)
        return false;
    return pimpl_->transport_->isReliable();
}

int
TlsSession::maxPayload() const
{
    if (pimpl_->state_ == TlsSessionState::SHUTDOWN)
        throw std::runtime_error("Getting maxPayload from non-valid TLS session");
    if (!pimpl_->transport_)
        return 0;
    return pimpl_->transport_->maxPayload();
}

// Called by anyone to stop the connection and the FSM thread
void
TlsSession::shutdown()
{
    pimpl_->newState_ = TlsSessionState::SHUTDOWN;
    pimpl_->stateCondition_.notify_all();
    pimpl_->rxCv_.notify_one(); // unblock waiting FSM
}

std::size_t
TlsSession::write(const ValueType* data, std::size_t size, std::error_code& ec)
{
    return pimpl_->send(data, size, ec);
}

std::size_t
TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec)
{
    std::errc error;

    if (pimpl_->state_ != TlsSessionState::ESTABLISHED) {
        ec = std::make_error_code(std::errc::broken_pipe);
        return 0;
    }

    while (true) {
        ssize_t ret;
        {
            std::lock_guard lk(pimpl_->sessionReadMutex_);
            if (!pimpl_->session_)
                return 0;
            ret = gnutls_record_recv(pimpl_->session_, data, size);
        }
        if (ret > 0) {
            ec.clear();
            return ret;
        }

        std::lock_guard lk(pimpl_->stateMutex_);
        if (ret == 0) {
            if (pimpl_) {
                if (pimpl_->params_.logger)
                    pimpl_->params_.logger->d("[TLS] eof");
                pimpl_->newState_ = TlsSessionState::SHUTDOWN;
                pimpl_->stateCondition_.notify_all();
                pimpl_->rxCv_.notify_one(); // unblock waiting FSM
            }
            error = std::errc::broken_pipe;
            break;
        } else if (ret == GNUTLS_E_REHANDSHAKE) {
            if (pimpl_->params_.logger)
                pimpl_->params_.logger->d("[TLS] re-handshake");
            pimpl_->newState_ = TlsSessionState::HANDSHAKE;
            pimpl_->rxCv_.notify_one(); // unblock waiting FSM
            pimpl_->stateCondition_.notify_all();
        } else if (gnutls_error_is_fatal(ret)) {
            if (pimpl_ && pimpl_->state_ != TlsSessionState::SHUTDOWN) {
                if (pimpl_->params_.logger)
                    pimpl_->params_.logger->e("[TLS] fatal error in recv: %s", gnutls_strerror(ret));
                pimpl_->newState_ = TlsSessionState::SHUTDOWN;
                pimpl_->stateCondition_.notify_all();
                pimpl_->rxCv_.notify_one(); // unblock waiting FSM
            }
            error = std::errc::io_error;
            break;
        }
    }

    ec = std::make_error_code(error);
    return 0;
}

int
TlsSession::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
{
    if (!pimpl_->transport_) {
        ec = std::make_error_code(std::errc::broken_pipe);
        return -1;
    }
    if (!pimpl_->transport_->waitForData(timeout, ec))
        return 0;
    return 1;
}

std::shared_ptr<dht::crypto::Certificate>
TlsSession::peerCertificate() const
{
    return pimpl_->pCert_;
}

const std::shared_ptr<dht::log::Logger>&
TlsSession::logger() const
{
    return pimpl_->params_.logger;
}

} // namespace tls
} // namespace dhtnet
