diff --git a/src/net.cpp b/src/net.cpp index 338831bb48b..1545e36e684 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -824,8 +824,13 @@ CNetMessage V1Transport::GetReceivedMessage(const std::chrono::microseconds time return msg; } -void V1Transport::prepareForTransport(CSerializedNetMsg& msg, std::vector& header) const +bool V1Transport::SetMessageToSend(CSerializedNetMsg& msg) noexcept { + AssertLockNotHeld(m_send_mutex); + // Determine whether a new message can be set. + LOCK(m_send_mutex); + if (m_sending_header || m_bytes_sent < m_message_to_send.data.size()) return false; + // create dbl-sha256 checksum uint256 hash = Hash(msg.data); @@ -834,8 +839,50 @@ void V1Transport::prepareForTransport(CSerializedNetMsg& msg, std::vector CConnman::SocketSendData(CNode& node) const @@ -2910,27 +2957,40 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) msg.data.data() ); - // make sure we use the appropriate network transport format - std::vector serializedHeader; - pnode->m_transport->prepareForTransport(msg, serializedHeader); - size_t nTotalSize = nMessageSize + serializedHeader.size(); - size_t nBytesSent = 0; { LOCK(pnode->cs_vSend); - bool optimisticSend(pnode->vSendMsg.empty()); + const bool queue_was_empty{pnode->vSendMsg.empty()}; - //log total amount of bytes per message type - pnode->AccountForSentBytes(msg.m_type, nTotalSize); - pnode->nSendSize += nTotalSize; + // Give the message to the transport, and add all bytes it wants us to send out as byte + // vectors to vSendMsg. This is temporary code that exists to support the new transport + // sending interface using the old way of queueing data. In a future commit vSendMsg will + // be replaced with a queue of CSerializedNetMsg objects to be sent instead, and this code + // will disappear. + bool queued = pnode->m_transport->SetMessageToSend(msg); + assert(queued); + // In the current transport (V1Transport), GetBytesToSend first returns a header to send, + // and then the payload data (if any), necessitating a loop. + while (true) { + const auto& [bytes, _more, msg_type] = pnode->m_transport->GetBytesToSend(); + if (bytes.empty()) break; + // Update statistics per message type. + pnode->AccountForSentBytes(msg_type, bytes.size()); + // Update number of bytes in the send buffer. + pnode->nSendSize += bytes.size(); + if (pnode->nSendSize > nSendBufferMaxSize) pnode->fPauseSend = true; + pnode->vSendMsg.push_back({bytes.begin(), bytes.end()}); + // Notify transport that bytes have been processed (they're not actually sent yet, + // but pushed onto the vSendMsg queue of bytes to send). + pnode->m_transport->MarkBytesSent(bytes.size()); + } - if (pnode->nSendSize > nSendBufferMaxSize) pnode->fPauseSend = true; - pnode->vSendMsg.push_back(std::move(serializedHeader)); - if (nMessageSize) pnode->vSendMsg.push_back(std::move(msg.data)); - - // If write queue empty, attempt "optimistic write" - bool data_left; - if (optimisticSend) std::tie(nBytesSent, data_left) = SocketSendData(*pnode); + // If the write queue was empty before and isn't now, attempt "optimistic write": + // because the poll/select loop may pause for SELECT_TIMEOUT_MILLISECONDS before actually + // doing a send, try sending from the calling thread if the queue was empty before. + if (queue_was_empty && !pnode->vSendMsg.empty()) { + std::tie(nBytesSent, std::ignore) = SocketSendData(*pnode); + } } if (nBytesSent) RecordBytesSent(nBytesSent); } diff --git a/src/net.h b/src/net.h index a17ca36652a..83deb4afed2 100644 --- a/src/net.h +++ b/src/net.h @@ -270,10 +270,49 @@ public: /** Retrieve a completed message from transport (only when ReceivedMessageComplete). */ virtual CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) = 0; - // 2. Sending side functions: + // 2. Sending side functions, for converting messages into bytes to be sent over the wire. - // prepare message for transport (header construction, error-correction computation, payload encryption, etc.) - virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector& header) const = 0; + /** Set the next message to send. + * + * If no message can currently be set (perhaps because the previous one is not yet done being + * sent), returns false, and msg will be unmodified. Otherwise msg is enqueued (and + * possibly moved-from) and true is returned. + */ + virtual bool SetMessageToSend(CSerializedNetMsg& msg) noexcept = 0; + + /** Return type for GetBytesToSend, consisting of: + * - Span to_send: span of bytes to be sent over the wire (possibly empty). + * - bool more: whether there will be more bytes to be sent after the ones in to_send are + * all sent (as signaled by MarkBytesSent()). + * - const std::string& m_type: message type on behalf of which this is being sent. + */ + using BytesToSend = std::tuple< + Span /*to_send*/, + bool /*more*/, + const std::string& /*m_type*/ + >; + + /** Get bytes to send on the wire. + * + * As a const function, it does not modify the transport's observable state, and is thus safe + * to be called multiple times. + * + * The bytes returned by this function act as a stream which can only be appended to. This + * means that with the exception of MarkBytesSent, operations on the transport can only append + * to what is being returned. + * + * Note that m_type and to_send refer to data that is internal to the transport, and calling + * any non-const function on this object may invalidate them. + */ + virtual BytesToSend GetBytesToSend() const noexcept = 0; + + /** Report how many bytes returned by the last GetBytesToSend() have been sent. + * + * bytes_sent cannot exceed to_send.size() of the last GetBytesToSend() result. + * + * If bytes_sent=0, this call has no effect. + */ + virtual void MarkBytesSent(size_t bytes_sent) noexcept = 0; }; class V1Transport final : public Transport @@ -314,6 +353,17 @@ private: return hdr.nMessageSize == nDataPos; } + /** Lock for sending state. */ + mutable Mutex m_send_mutex; + /** The header of the message currently being sent. */ + std::vector m_header_to_send GUARDED_BY(m_send_mutex); + /** The data of the message currently being sent. */ + CSerializedNetMsg m_message_to_send GUARDED_BY(m_send_mutex); + /** Whether we're currently sending header bytes or message bytes. */ + bool m_sending_header GUARDED_BY(m_send_mutex) {false}; + /** How many bytes have been sent so far (from m_header_to_send, or from m_message_to_send.data). */ + size_t m_bytes_sent GUARDED_BY(m_send_mutex) {0}; + public: V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn) : m_chain_params(chain_params), @@ -354,7 +404,9 @@ public: CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); - void prepareForTransport(CSerializedNetMsg& msg, std::vector& header) const override; + bool SetMessageToSend(CSerializedNetMsg& msg) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + BytesToSend GetBytesToSend() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + void MarkBytesSent(size_t bytes_sent) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); }; struct CNodeOptions @@ -369,7 +421,8 @@ struct CNodeOptions class CNode { public: - /** Transport serializer/deserializer. The receive side functions are only called under cs_vRecv. */ + /** Transport serializer/deserializer. The receive side functions are only called under cs_vRecv, while + * the sending side functions are only called under cs_vSend. */ const std::unique_ptr m_transport; const NetPermissionFlags m_permission_flags; diff --git a/src/test/fuzz/p2p_transport_serialization.cpp b/src/test/fuzz/p2p_transport_serialization.cpp index dcf7529918c..d96215e8e04 100644 --- a/src/test/fuzz/p2p_transport_serialization.cpp +++ b/src/test/fuzz/p2p_transport_serialization.cpp @@ -79,7 +79,16 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial std::vector header; auto msg2 = CNetMsgMaker{msg.m_recv.GetVersion()}.Make(msg.m_type, Span{msg.m_recv}); - send_transport.prepareForTransport(msg2, header); + bool queued = send_transport.SetMessageToSend(msg2); + assert(queued); + std::optional known_more; + while (true) { + const auto& [to_send, more, _msg_type] = send_transport.GetBytesToSend(); + if (known_more) assert(!to_send.empty() == *known_more); + if (to_send.empty()) break; + send_transport.MarkBytesSent(to_send.size()); + known_more = more; + } } } } diff --git a/src/test/fuzz/process_messages.cpp b/src/test/fuzz/process_messages.cpp index 2617be3fa88..98962fceb5b 100644 --- a/src/test/fuzz/process_messages.cpp +++ b/src/test/fuzz/process_messages.cpp @@ -67,7 +67,7 @@ FUZZ_TARGET(process_messages, .init = initialize_process_messages) CNode& random_node = *PickValue(fuzzed_data_provider, peers); - (void)connman.ReceiveMsgFrom(random_node, net_msg); + (void)connman.ReceiveMsgFrom(random_node, std::move(net_msg)); random_node.fPauseSend = false; try { diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp index 00317700286..c071355bc0f 100644 --- a/src/test/util/net.cpp +++ b/src/test/util/net.cpp @@ -41,7 +41,7 @@ void ConnmanTestMsg::Handshake(CNode& node, relay_txs), }; - (void)connman.ReceiveMsgFrom(node, msg_version); + (void)connman.ReceiveMsgFrom(node, std::move(msg_version)); node.fPauseSend = false; connman.ProcessMessagesOnce(node); peerman.SendMessages(&node); @@ -54,7 +54,7 @@ void ConnmanTestMsg::Handshake(CNode& node, assert(statestats.their_services == remote_services); if (successfully_connected) { CSerializedNetMsg msg_verack{mm.Make(NetMsgType::VERACK)}; - (void)connman.ReceiveMsgFrom(node, msg_verack); + (void)connman.ReceiveMsgFrom(node, std::move(msg_verack)); node.fPauseSend = false; connman.ProcessMessagesOnce(node); peerman.SendMessages(&node); @@ -70,14 +70,17 @@ void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span msg_by } } -bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const +bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg&& ser_msg) const { - std::vector ser_msg_header; - node.m_transport->prepareForTransport(ser_msg, ser_msg_header); - - bool complete; - NodeReceiveMsgBytes(node, ser_msg_header, complete); - NodeReceiveMsgBytes(node, ser_msg.data, complete); + bool queued = node.m_transport->SetMessageToSend(ser_msg); + assert(queued); + bool complete{false}; + while (true) { + const auto& [to_send, _more, _msg_type] = node.m_transport->GetBytesToSend(); + if (to_send.empty()) break; + NodeReceiveMsgBytes(node, to_send, complete); + node.m_transport->MarkBytesSent(to_send.size()); + } return complete; } diff --git a/src/test/util/net.h b/src/test/util/net.h index b2f6ebb1637..687ce1e813a 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -54,7 +54,7 @@ struct ConnmanTestMsg : public CConnman { void NodeReceiveMsgBytes(CNode& node, Span msg_bytes, bool& complete) const; - bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const; + bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg&& ser_msg) const; }; constexpr ServiceFlags ALL_SERVICE_FLAGS[]{