From f4015481312c709f1f83ad1ccc1bab86e513160a Mon Sep 17 00:00:00 2001 From: j-berman Date: Thu, 8 Jan 2026 16:10:36 -0800 Subject: [PATCH] p2p: connection patches - Make sure the server sends a complete response when the client includes the "Connection: close" header. - Make sure the server terminates in `m_strand` to avoid concurrent socket closure and ops processing. --- .../epee/include/net/abstract_tcp_server2.h | 3 + .../epee/include/net/abstract_tcp_server2.inl | 164 +++++++---- tests/unit_tests/epee_http_server.cpp | 272 ++++++++++++++++++ 3 files changed, 377 insertions(+), 62 deletions(-) create mode 100644 tests/unit_tests/epee_http_server.cpp diff --git a/contrib/epee/include/net/abstract_tcp_server2.h b/contrib/epee/include/net/abstract_tcp_server2.h index a9d7fce..33e50cd 100644 --- a/contrib/epee/include/net/abstract_tcp_server2.h +++ b/contrib/epee/include/net/abstract_tcp_server2.h @@ -128,6 +128,7 @@ namespace net_utils void start_handshake(); void start_read(); + void finish_read(size_t bytes_transferred); void start_write(); void start_shutdown(); void cancel_socket(); @@ -139,6 +140,7 @@ namespace net_utils void terminate(); void on_terminating(); + void terminate_async(); bool send(epee::byte_slice message); bool start_internal( @@ -192,6 +194,7 @@ namespace net_utils bool wait_read; bool handle_read; bool cancel_read; + bool shutdown_read; bool wait_write; bool handle_write; diff --git a/contrib/epee/include/net/abstract_tcp_server2.inl b/contrib/epee/include/net/abstract_tcp_server2.inl index c830de2..a95e67e 100644 --- a/contrib/epee/include/net/abstract_tcp_server2.inl +++ b/contrib/epee/include/net/abstract_tcp_server2.inl @@ -171,7 +171,7 @@ namespace net_utils return; m_state.timers.general.wait_expire = true; auto self = connection::shared_from_this(); - m_timers.general.async_wait([this, self](const ec_t & ec){ + auto on_wait = [this, self] { std::lock_guard guard(m_state.lock); m_state.timers.general.wait_expire = false; if (m_state.timers.general.cancel_expire) { @@ -189,6 +189,9 @@ namespace net_utils interrupt(); else if (m_state.status == status_t::INTERRUPTED) terminate(); + }; + m_timers.general.async_wait([this, self, on_wait](const ec_t & ec){ + boost::asio::post(m_strand, on_wait); }); } @@ -242,27 +245,7 @@ namespace net_utils ) ) { m_state.ssl.enabled = false; - m_state.socket.handle_read = true; - boost::asio::post( - connection_basic::strand_, - [this, self, bytes_transferred]{ - bool success = m_handler.handle_recv( - reinterpret_cast(m_state.data.read.buffer.data()), - bytes_transferred - ); - std::lock_guard guard(m_state.lock); - m_state.socket.handle_read = false; - if (m_state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (m_state.status == status_t::TERMINATING) - on_terminating(); - else if (!success) - interrupt(); - else { - start_read(); - } - } - ); + finish_read(bytes_transferred); } else { m_state.ssl.detected = true; @@ -322,7 +305,7 @@ namespace net_utils void connection::start_read() { if (m_state.timers.throttle.in.wait_expire || m_state.socket.wait_read || - m_state.socket.handle_read + m_state.socket.handle_read || m_state.socket.shutdown_read ) { return; } @@ -346,7 +329,7 @@ namespace net_utils if (duration > duration_t{}) { m_timers.throttle.in.expires_after(duration); m_state.timers.throttle.in.wait_expire = true; - m_timers.throttle.in.async_wait([this, self](const ec_t &ec){ + auto on_wait = [this, self](const ec_t &ec){ std::lock_guard guard(m_state.lock); m_state.timers.throttle.in.wait_expire = false; if (m_state.timers.throttle.in.cancel_expire) { @@ -355,8 +338,16 @@ namespace net_utils } else if (ec.value()) interrupt(); - else + }; + m_timers.throttle.in.async_wait([this, self, on_wait](const ec_t &ec){ + std::lock_guard guard(m_state.lock); + const bool error_status = m_state.timers.throttle.in.cancel_expire || ec.value(); + if (error_status) + boost::asio::post(m_strand, std::bind(on_wait, ec)); + else { + m_state.timers.throttle.in.wait_expire = false; start_read(); + } }); return; } @@ -392,33 +383,7 @@ namespace net_utils m_conn_context.m_recv_cnt += bytes_transferred; start_timer(get_timeout_from_bytes_read(bytes_transferred), true); } - - // Post handle_recv to a separate `strand_`, distinct from `m_strand` - // which is listening for reads/writes. This avoids a circular dep. - // handle_recv can queue many writes, and `m_strand` will process those - // writes until the connection terminates without deadlocking waiting - // for handle_recv. - m_state.socket.handle_read = true; - boost::asio::post( - connection_basic::strand_, - [this, self, bytes_transferred]{ - bool success = m_handler.handle_recv( - reinterpret_cast(m_state.data.read.buffer.data()), - bytes_transferred - ); - std::lock_guard guard(m_state.lock); - m_state.socket.handle_read = false; - if (m_state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (m_state.status == status_t::TERMINATING) - on_terminating(); - else if (!success) - interrupt(); - else { - start_read(); - } - } - ); + finish_read(bytes_transferred); } }; if (!m_state.ssl.enabled) @@ -444,6 +409,62 @@ namespace net_utils ); } + template + void connection::finish_read(size_t bytes_transferred) + { + // Post handle_recv to a separate `strand_`, distinct from `m_strand` + // which is listening for reads/writes. This avoids a circular dep. + // handle_recv can queue many writes, and `m_strand` will process those + // writes until the connection terminates without deadlocking waiting + // for handle_recv. + m_state.socket.handle_read = true; + auto self = connection::shared_from_this(); + boost::asio::post( + connection_basic::strand_, + [this, self, bytes_transferred]{ + bool success = m_handler.handle_recv( + reinterpret_cast(m_state.data.read.buffer.data()), + bytes_transferred + ); + std::lock_guard guard(m_state.lock); + const bool error_status = m_state.status == status_t::INTERRUPTED + || m_state.status == status_t::TERMINATING + || !success; + if (!error_status) { + m_state.socket.handle_read = false; + start_read(); + return; + } + boost::asio::post( + m_strand, + [this, self, success]{ + // expect error_status == true + std::lock_guard guard(m_state.lock); + m_state.socket.handle_read = false; + if (m_state.status == status_t::INTERRUPTED) + on_interrupted(); + else if (m_state.status == status_t::TERMINATING) + on_terminating(); + else if (!success) { + ec_t ec; + if (m_state.socket.wait_write) { + // Allow the already queued writes time to finish, but no more new reads + connection_basic::socket_.next_layer().shutdown( + socket_t::shutdown_receive, + ec + ); + m_state.socket.shutdown_read = true; + } + if (!m_state.socket.wait_write || ec.value()) { + interrupt(); + } + } + } + ); + } + ); + } + template void connection::start_write() { @@ -475,7 +496,7 @@ namespace net_utils if (duration > duration_t{}) { m_timers.throttle.out.expires_after(duration); m_state.timers.throttle.out.wait_expire = true; - m_timers.throttle.out.async_wait([this, self](const ec_t &ec){ + auto on_wait = [this, self](const ec_t &ec){ std::lock_guard guard(m_state.lock); m_state.timers.throttle.out.wait_expire = false; if (m_state.timers.throttle.out.cancel_expire) { @@ -484,8 +505,16 @@ namespace net_utils } else if (ec.value()) interrupt(); - else + }; + m_timers.throttle.out.async_wait([this, self, on_wait](const ec_t &ec){ + std::lock_guard guard(m_state.lock); + const bool error_status = m_state.timers.throttle.out.cancel_expire || ec.value(); + if (error_status) + boost::asio::post(m_strand, std::bind(on_wait, ec)); + else { + m_state.timers.throttle.out.wait_expire = false; start_write(); + } }); } } @@ -533,7 +562,12 @@ namespace net_utils m_state.data.write.total_bytes -= std::min(m_state.data.write.total_bytes, byte_count); m_state.condition.notify_all(); - start_write(); + if (m_state.data.write.queue.empty() && m_state.socket.shutdown_read) { + // All writes have been sent and reads shutdown already, connection can be closed + interrupt(); + } else { + start_write(); + } } }; if (!m_state.ssl.enabled) @@ -762,6 +796,17 @@ namespace net_utils m_state.status = status_t::WASTED; } + template + void connection::terminate_async() + { + // synchronize with intermediate writes on `m_strand` + auto self = connection::shared_from_this(); + boost::asio::post(m_strand, [this, self] { + std::lock_guard guard(m_state.lock); + terminate(); + }); + } + template bool connection::send(epee::byte_slice message) { @@ -814,12 +859,7 @@ namespace net_utils ); m_state.data.write.wait_consume = false; if (!success) { - // synchronize with intermediate writes on `m_strand` - auto self = connection::shared_from_this(); - boost::asio::post(m_strand, [this, self] { - std::lock_guard guard(m_state.lock); - terminate(); - }); + terminate_async(); return false; } else @@ -1093,7 +1133,7 @@ namespace net_utils std::lock_guard guard(m_state.lock); if (m_state.status != status_t::RUNNING) return false; - terminate(); + terminate_async(); return true; } diff --git a/tests/unit_tests/epee_http_server.cpp b/tests/unit_tests/epee_http_server.cpp new file mode 100644 index 0000000..b67f3f6 --- /dev/null +++ b/tests/unit_tests/epee_http_server.cpp @@ -0,0 +1,272 @@ +// Copyright (c) 2014-2024, The Monero Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// + +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "net/http_server_handlers_map2.h" +#include "net/http_server_impl_base.h" +#include "storages/portable_storage_template_helper.h" + +namespace +{ + constexpr const std::size_t payload_size = 26 * 1024 * 1024; + constexpr const std::size_t max_private_ips = 25; + struct dummy + { + struct request + { + BEGIN_KV_SERIALIZE_MAP() + END_KV_SERIALIZE_MAP() + }; + + struct response + { + BEGIN_KV_SERIALIZE_MAP() + KV_SERIALIZE(payload) + END_KV_SERIALIZE_MAP() + + std::string payload; + }; + }; + + std::string make_payload() + { + dummy::request body{}; + const auto body_serialized = epee::serialization::store_t_to_binary(body); + return std::string{ + reinterpret_cast(body_serialized.data()), + body_serialized.size() + }; + } + + struct http_server : epee::http_server_impl_base + { + using connection_context = epee::net_utils::connection_context_base; + + http_server() + : epee::http_server_impl_base(), + dummy_size(payload_size) + {} + + CHAIN_HTTP_TO_MAP2(connection_context); //forward http requests to uri map + + BEGIN_URI_MAP2() + MAP_URI_AUTO_BIN2("/dummy", on_dummy, dummy) + END_URI_MAP2() + + bool on_dummy(const dummy::request&, dummy::response& res, const connection_context *ctx = NULL) + { + res.payload.resize(dummy_size.load(), 'f'); + return true; + } + + std::atomic dummy_size; + }; +} // anonymous + +TEST(http_server, response_soft_limit) +{ + namespace http = boost::beast::http; + + http_server server{}; + server.init(nullptr, "8080"); + server.run(1, false); + + boost::system::error_code error{}; + boost::asio::io_context context{}; + boost::asio::ip::tcp::socket stream{context}; + stream.connect( + boost::asio::ip::tcp::endpoint{ + boost::asio::ip::make_address("127.0.0.1"), 8080 + }, + error + ); + EXPECT_FALSE(bool(error)); + + http::request req{http::verb::get, "/dummy", 11}; + req.set(http::field::host, "127.0.0.1"); + req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING); + req.body() = make_payload(); + req.prepare_payload(); + http::write(stream, req, error); + EXPECT_FALSE(bool(error)); + + { + dummy::response payload{}; + boost::beast::flat_buffer buffer; + http::response_parser> parser; + parser.body_limit(payload_size + 1024); + http::read(stream, buffer, parser, error); + EXPECT_FALSE(bool(error)); + ASSERT_TRUE(parser.is_done()); + const auto res = parser.release(); + EXPECT_EQ(200u, res.result_int()); + EXPECT_TRUE(epee::serialization::load_t_from_binary(payload, res.body())); + EXPECT_EQ(payload_size, std::count(payload.payload.begin(), payload.payload.end(), 'f')); + } + + while (!error) + http::write(stream, req, error); + server.send_stop_signal(); +} + +TEST(http_server, private_ip_limit) +{ + namespace http = boost::beast::http; + + http_server server{}; + server.dummy_size = 1; + server.init(nullptr, "8080"); + server.run(1, false); + + boost::system::error_code error{}; + boost::asio::io_context context{}; + + http::request req{http::verb::get, "/dummy", 11}; + req.set(http::field::host, "127.0.0.1"); + req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING); + req.body() = make_payload(); + req.prepare_payload(); + + std::vector streams{}; + for (std::size_t i = 0; i < max_private_ips; ++i) + { + streams.emplace_back(context); + streams.back().connect( + boost::asio::ip::tcp::endpoint{ + boost::asio::ip::make_address("127.0.0.1"), 8080 + }, + error + ); + http::write(streams.back(), req, error); + EXPECT_FALSE(bool(error)); + + boost::beast::flat_buffer buffer; + http::response_parser> parser; + parser.body_limit(payload_size + 1024); + + http::read(streams.back(), buffer, parser, error); + EXPECT_FALSE(bool(error)); + EXPECT_TRUE(parser.is_done()); + } + + boost::asio::ip::tcp::socket stream{context}; + stream.connect( + boost::asio::ip::tcp::endpoint{ + boost::asio::ip::make_address("127.0.0.1"), 8080 + }, + error + ); + bool failed = bool(error); + http::write(stream, req, error); + failed |= bool(error); + { + boost::beast::flat_buffer buffer; + http::response_parser> parser; + parser.body_limit(payload_size + 1024); + + // make sure server ran async_accept code + http::read(stream, buffer, parser, error); + } + failed |= bool(error); + EXPECT_TRUE(failed); +} + +TEST(http_server, read_then_close) +{ + namespace http = boost::beast::http; + + http_server server{}; + server.dummy_size = 200000; + server.init(nullptr, "8080"); + server.run(2, false); // need at least 2 threads to trigger issues + + bool failed_read = false; + bool closed_all_connections = true; + for (std::size_t j = 0; j < 1000; ++j) + { + boost::system::error_code error{}; + boost::asio::io_context context{}; + boost::asio::ip::tcp::socket stream{context}; + stream.connect( + boost::asio::ip::tcp::endpoint{ + boost::asio::ip::make_address("127.0.0.1"), 8080 + }, + error + ); + EXPECT_FALSE(bool(error)); + + http::request req{http::verb::get, "/dummy", 11}; + req.set(http::field::host, "127.0.0.1"); + req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING); + req.set(http::field::connection, "close"); // tell server to close connection after sending all data to the client + req.body() = make_payload(); + req.prepare_payload(); + + dummy::response payload{}; + boost::beast::flat_buffer buffer; + http::response_parser> parser; + parser.body_limit(server.dummy_size + 1024); + + http::write(stream, req, error); + EXPECT_FALSE(bool(error)); + + http::read(stream, buffer, parser, error); + + // If the read fails, continue the loop still just to make sure the server can handle it + failed_read |= bool(error); + if (failed_read) + continue; + failed_read |= !(parser.is_done()); + if (failed_read) + continue; + const auto res = parser.release(); + failed_read |= res.result_int() != 200u + || !(epee::serialization::load_t_from_binary(payload, res.body())) + || (server.dummy_size != std::count(payload.payload.begin(), payload.payload.end(), 'f')); + + // See if the server closes the connection after handling the resp + char buf[1]; + stream.read_some(boost::asio::buffer(buf), error); + closed_all_connections &= error == boost::asio::error::eof; + } + + // The client should have been able to read all data sent by the server across all requests + EXPECT_FALSE(failed_read); + + // The server should have closed all connections + EXPECT_TRUE(closed_all_connections); + + server.send_stop_signal(); +}