From 0de48fe858a1ffcced340eef2c849165216141c8 Mon Sep 17 00:00:00 2001 From: Pieter Wuille Date: Fri, 21 Jul 2023 16:31:59 -0400 Subject: [PATCH] net: abstract sending side of transport serialization further This makes the sending side of P2P transports mirror the receiver side: caller provides message (consisting of type and payload) to be sent, and then asks what bytes must be sent. Once the message has been fully sent, a new message can be provided. This removes the assumption that P2P serialization of messages follows a strict structure of header (a function of type and payload), followed by (unmodified) payload, and instead lets transports decide the structure themselves. It also removes the assumption that a message must always be sent at once, or that no bytes are even sent on the wire when there is no message. This opens the door for supporting traffic shaping mechanisms in the future. --- src/net.cpp | 98 +++++++++++++++---- src/net.h | 63 +++++++++++- src/test/fuzz/p2p_transport_serialization.cpp | 11 ++- src/test/fuzz/process_messages.cpp | 2 +- src/test/util/net.cpp | 21 ++-- src/test/util/net.h | 2 +- 6 files changed, 161 insertions(+), 36 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 338831bb48..1545e36e68 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -824,8 +824,13 @@ CNetMessage V1Transport::GetReceivedMessage(const std::chrono::microseconds time return msg; } -void V1Transport::prepareForTransport(CSerializedNetMsg& msg, std::vector& header) const +bool V1Transport::SetMessageToSend(CSerializedNetMsg& msg) noexcept { + AssertLockNotHeld(m_send_mutex); + // Determine whether a new message can be set. + LOCK(m_send_mutex); + if (m_sending_header || m_bytes_sent < m_message_to_send.data.size()) return false; + // create dbl-sha256 checksum uint256 hash = Hash(msg.data); @@ -834,8 +839,50 @@ void V1Transport::prepareForTransport(CSerializedNetMsg& msg, std::vector CConnman::SocketSendData(CNode& node) const @@ -2910,27 +2957,40 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) msg.data.data() ); - // make sure we use the appropriate network transport format - std::vector serializedHeader; - pnode->m_transport->prepareForTransport(msg, serializedHeader); - size_t nTotalSize = nMessageSize + serializedHeader.size(); - size_t nBytesSent = 0; { LOCK(pnode->cs_vSend); - bool optimisticSend(pnode->vSendMsg.empty()); + const bool queue_was_empty{pnode->vSendMsg.empty()}; - //log total amount of bytes per message type - pnode->AccountForSentBytes(msg.m_type, nTotalSize); - pnode->nSendSize += nTotalSize; + // Give the message to the transport, and add all bytes it wants us to send out as byte + // vectors to vSendMsg. This is temporary code that exists to support the new transport + // sending interface using the old way of queueing data. In a future commit vSendMsg will + // be replaced with a queue of CSerializedNetMsg objects to be sent instead, and this code + // will disappear. + bool queued = pnode->m_transport->SetMessageToSend(msg); + assert(queued); + // In the current transport (V1Transport), GetBytesToSend first returns a header to send, + // and then the payload data (if any), necessitating a loop. + while (true) { + const auto& [bytes, _more, msg_type] = pnode->m_transport->GetBytesToSend(); + if (bytes.empty()) break; + // Update statistics per message type. + pnode->AccountForSentBytes(msg_type, bytes.size()); + // Update number of bytes in the send buffer. + pnode->nSendSize += bytes.size(); + if (pnode->nSendSize > nSendBufferMaxSize) pnode->fPauseSend = true; + pnode->vSendMsg.push_back({bytes.begin(), bytes.end()}); + // Notify transport that bytes have been processed (they're not actually sent yet, + // but pushed onto the vSendMsg queue of bytes to send). + pnode->m_transport->MarkBytesSent(bytes.size()); + } - if (pnode->nSendSize > nSendBufferMaxSize) pnode->fPauseSend = true; - pnode->vSendMsg.push_back(std::move(serializedHeader)); - if (nMessageSize) pnode->vSendMsg.push_back(std::move(msg.data)); - - // If write queue empty, attempt "optimistic write" - bool data_left; - if (optimisticSend) std::tie(nBytesSent, data_left) = SocketSendData(*pnode); + // If the write queue was empty before and isn't now, attempt "optimistic write": + // because the poll/select loop may pause for SELECT_TIMEOUT_MILLISECONDS before actually + // doing a send, try sending from the calling thread if the queue was empty before. + if (queue_was_empty && !pnode->vSendMsg.empty()) { + std::tie(nBytesSent, std::ignore) = SocketSendData(*pnode); + } } if (nBytesSent) RecordBytesSent(nBytesSent); } diff --git a/src/net.h b/src/net.h index a17ca36652..83deb4afed 100644 --- a/src/net.h +++ b/src/net.h @@ -270,10 +270,49 @@ public: /** Retrieve a completed message from transport (only when ReceivedMessageComplete). */ virtual CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) = 0; - // 2. Sending side functions: + // 2. Sending side functions, for converting messages into bytes to be sent over the wire. - // prepare message for transport (header construction, error-correction computation, payload encryption, etc.) - virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector& header) const = 0; + /** Set the next message to send. + * + * If no message can currently be set (perhaps because the previous one is not yet done being + * sent), returns false, and msg will be unmodified. Otherwise msg is enqueued (and + * possibly moved-from) and true is returned. + */ + virtual bool SetMessageToSend(CSerializedNetMsg& msg) noexcept = 0; + + /** Return type for GetBytesToSend, consisting of: + * - Span to_send: span of bytes to be sent over the wire (possibly empty). + * - bool more: whether there will be more bytes to be sent after the ones in to_send are + * all sent (as signaled by MarkBytesSent()). + * - const std::string& m_type: message type on behalf of which this is being sent. + */ + using BytesToSend = std::tuple< + Span /*to_send*/, + bool /*more*/, + const std::string& /*m_type*/ + >; + + /** Get bytes to send on the wire. + * + * As a const function, it does not modify the transport's observable state, and is thus safe + * to be called multiple times. + * + * The bytes returned by this function act as a stream which can only be appended to. This + * means that with the exception of MarkBytesSent, operations on the transport can only append + * to what is being returned. + * + * Note that m_type and to_send refer to data that is internal to the transport, and calling + * any non-const function on this object may invalidate them. + */ + virtual BytesToSend GetBytesToSend() const noexcept = 0; + + /** Report how many bytes returned by the last GetBytesToSend() have been sent. + * + * bytes_sent cannot exceed to_send.size() of the last GetBytesToSend() result. + * + * If bytes_sent=0, this call has no effect. + */ + virtual void MarkBytesSent(size_t bytes_sent) noexcept = 0; }; class V1Transport final : public Transport @@ -314,6 +353,17 @@ private: return hdr.nMessageSize == nDataPos; } + /** Lock for sending state. */ + mutable Mutex m_send_mutex; + /** The header of the message currently being sent. */ + std::vector m_header_to_send GUARDED_BY(m_send_mutex); + /** The data of the message currently being sent. */ + CSerializedNetMsg m_message_to_send GUARDED_BY(m_send_mutex); + /** Whether we're currently sending header bytes or message bytes. */ + bool m_sending_header GUARDED_BY(m_send_mutex) {false}; + /** How many bytes have been sent so far (from m_header_to_send, or from m_message_to_send.data). */ + size_t m_bytes_sent GUARDED_BY(m_send_mutex) {0}; + public: V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn) : m_chain_params(chain_params), @@ -354,7 +404,9 @@ public: CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); - void prepareForTransport(CSerializedNetMsg& msg, std::vector& header) const override; + bool SetMessageToSend(CSerializedNetMsg& msg) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + BytesToSend GetBytesToSend() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + void MarkBytesSent(size_t bytes_sent) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); }; struct CNodeOptions @@ -369,7 +421,8 @@ struct CNodeOptions class CNode { public: - /** Transport serializer/deserializer. The receive side functions are only called under cs_vRecv. */ + /** Transport serializer/deserializer. The receive side functions are only called under cs_vRecv, while + * the sending side functions are only called under cs_vSend. */ const std::unique_ptr m_transport; const NetPermissionFlags m_permission_flags; diff --git a/src/test/fuzz/p2p_transport_serialization.cpp b/src/test/fuzz/p2p_transport_serialization.cpp index dcf7529918..d96215e8e0 100644 --- a/src/test/fuzz/p2p_transport_serialization.cpp +++ b/src/test/fuzz/p2p_transport_serialization.cpp @@ -79,7 +79,16 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial std::vector header; auto msg2 = CNetMsgMaker{msg.m_recv.GetVersion()}.Make(msg.m_type, Span{msg.m_recv}); - send_transport.prepareForTransport(msg2, header); + bool queued = send_transport.SetMessageToSend(msg2); + assert(queued); + std::optional known_more; + while (true) { + const auto& [to_send, more, _msg_type] = send_transport.GetBytesToSend(); + if (known_more) assert(!to_send.empty() == *known_more); + if (to_send.empty()) break; + send_transport.MarkBytesSent(to_send.size()); + known_more = more; + } } } } diff --git a/src/test/fuzz/process_messages.cpp b/src/test/fuzz/process_messages.cpp index 2617be3fa8..98962fceb5 100644 --- a/src/test/fuzz/process_messages.cpp +++ b/src/test/fuzz/process_messages.cpp @@ -67,7 +67,7 @@ FUZZ_TARGET(process_messages, .init = initialize_process_messages) CNode& random_node = *PickValue(fuzzed_data_provider, peers); - (void)connman.ReceiveMsgFrom(random_node, net_msg); + (void)connman.ReceiveMsgFrom(random_node, std::move(net_msg)); random_node.fPauseSend = false; try { diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp index 0031770028..c071355bc0 100644 --- a/src/test/util/net.cpp +++ b/src/test/util/net.cpp @@ -41,7 +41,7 @@ void ConnmanTestMsg::Handshake(CNode& node, relay_txs), }; - (void)connman.ReceiveMsgFrom(node, msg_version); + (void)connman.ReceiveMsgFrom(node, std::move(msg_version)); node.fPauseSend = false; connman.ProcessMessagesOnce(node); peerman.SendMessages(&node); @@ -54,7 +54,7 @@ void ConnmanTestMsg::Handshake(CNode& node, assert(statestats.their_services == remote_services); if (successfully_connected) { CSerializedNetMsg msg_verack{mm.Make(NetMsgType::VERACK)}; - (void)connman.ReceiveMsgFrom(node, msg_verack); + (void)connman.ReceiveMsgFrom(node, std::move(msg_verack)); node.fPauseSend = false; connman.ProcessMessagesOnce(node); peerman.SendMessages(&node); @@ -70,14 +70,17 @@ void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span msg_by } } -bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const +bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg&& ser_msg) const { - std::vector ser_msg_header; - node.m_transport->prepareForTransport(ser_msg, ser_msg_header); - - bool complete; - NodeReceiveMsgBytes(node, ser_msg_header, complete); - NodeReceiveMsgBytes(node, ser_msg.data, complete); + bool queued = node.m_transport->SetMessageToSend(ser_msg); + assert(queued); + bool complete{false}; + while (true) { + const auto& [to_send, _more, _msg_type] = node.m_transport->GetBytesToSend(); + if (to_send.empty()) break; + NodeReceiveMsgBytes(node, to_send, complete); + node.m_transport->MarkBytesSent(to_send.size()); + } return complete; } diff --git a/src/test/util/net.h b/src/test/util/net.h index b2f6ebb163..687ce1e813 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -54,7 +54,7 @@ struct ConnmanTestMsg : public CConnman { void NodeReceiveMsgBytes(CNode& node, Span msg_bytes, bool& complete) const; - bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const; + bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg&& ser_msg) const; }; constexpr ServiceFlags ALL_SERVICE_FLAGS[]{