diff --git a/src/merge_mining_client_tari.cpp b/src/merge_mining_client_tari.cpp index b6d139d..d24c319 100644 --- a/src/merge_mining_client_tari.cpp +++ b/src/merge_mining_client_tari.cpp @@ -870,6 +870,10 @@ bool MergeMiningClientTari::TariServer::connect_upstream(TariClient* downstream) upstream->m_pairedClient = downstream; upstream->m_pairedClientSavedResetCounter = downstream->m_resetCounter; + upstream->m_connectionPending = true; + + downstream->m_pairedClient = upstream; + downstream->m_pairedClientSavedResetCounter = upstream->m_resetCounter; return true; } @@ -885,18 +889,29 @@ const char* MergeMiningClientTari::TariServer::get_log_category() const MergeMiningClientTari::TariClient::TariClient() : Client(m_buf, sizeof(m_buf)) + , m_connectionPending(false) , m_pairedClient(nullptr) , m_pairedClientSavedResetCounter(std::numeric_limits::max()) { m_buf[0] = '\0'; } +void MergeMiningClientTari::TariClient::reset() +{ + m_pendingData.clear(); + m_connectionPending = false; + + break_pairing(); +} + void MergeMiningClientTari::TariClient::break_pairing() { if (is_paired()) { m_pairedClient->m_pairedClient = nullptr; m_pairedClient->m_pairedClientSavedResetCounter = std::numeric_limits::max(); - m_pairedClient->close(); + if (!m_pairedClient->m_connectionPending) { + m_pairedClient->close(); + } } m_pairedClient = nullptr; m_pairedClientSavedResetCounter = std::numeric_limits::max(); @@ -914,10 +929,10 @@ bool MergeMiningClientTari::TariClient::on_connect() } else { // The outgoing connection is ready now + m_connectionPending = false; + // Check if the incoming connection (downstream) has already sent something that needs to be relayed TariClient* downstream = m_pairedClient; - downstream->m_pairedClient = this; - downstream->m_pairedClientSavedResetCounter = m_resetCounter; const std::vector& v = downstream->m_pendingData; @@ -948,7 +963,7 @@ bool MergeMiningClientTari::TariClient::on_read(const char* data, uint32_t size) return false; } - if (!is_paired()) { + if (!is_paired() || m_pairedClient->m_connectionPending) { LOGWARN(5, "Read " << size << " bytes from " << static_cast(m_addrString) << " but it's not paired yet. Buffering it."); m_pendingData.insert(m_pendingData.end(), data, data + size); return true; diff --git a/src/merge_mining_client_tari.h b/src/merge_mining_client_tari.h index aef1ae3..82a26f5 100644 --- a/src/merge_mining_client_tari.h +++ b/src/merge_mining_client_tari.h @@ -113,7 +113,7 @@ private: static Client* allocate() { return new TariClient(); } virtual size_t size() const override { return sizeof(TariClient); } - void reset() override { break_pairing(); } + void reset() override; [[nodiscard]] bool on_connect() override; [[nodiscard]] bool on_read(const char* data, uint32_t size) override; void on_connect_failed(int /*err*/) override { break_pairing(); } @@ -121,6 +121,7 @@ private: alignas(8) char m_buf[BUF_SIZE]; std::vector m_pendingData; + bool m_connectionPending; bool is_paired() const { return m_pairedClient && (m_pairedClient->m_resetCounter == m_pairedClientSavedResetCounter); }