From 2716647ebf60cea05fc9edce6a18dcce4e7727ad Mon Sep 17 00:00:00 2001 From: Troy Giorshev Date: Mon, 29 Jun 2020 14:09:42 -0400 Subject: [PATCH 1/6] Give V1TransportDeserializer an m_node_id member This is intended to only be used for logging. This will allow log messages in the following commits to keep recording the peer's ID, even when logging is moved into V1TransportDeserializer. --- src/net.cpp | 7 ++++--- src/net.h | 9 +++++++-- src/test/fuzz/p2p_transport_deserializer.cpp | 3 ++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index e7d3a146ffa..73029655ce5 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -699,10 +699,11 @@ CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageSta msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) == 0); if (!msg.m_valid_checksum) { - LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s\n", + LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n", SanitizeString(msg.m_command), msg.m_message_size, HexStr(Span(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)), - HexStr(hdr.pchChecksum)); + HexStr(hdr.pchChecksum), + m_node_id); } // store receive time @@ -2828,7 +2829,7 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn LogPrint(BCLog::NET, "Added connection peer=%d\n", id); } - m_deserializer = MakeUnique(V1TransportDeserializer(Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION)); + m_deserializer = MakeUnique(V1TransportDeserializer(Params().MessageStart(), GetId(), SER_NETWORK, INIT_PROTO_VERSION)); m_serializer = MakeUnique(V1TransportSerializer()); } diff --git a/src/net.h b/src/net.h index 0366fa0f5b7..bda6007e733 100644 --- a/src/net.h +++ b/src/net.h @@ -739,6 +739,7 @@ public: class V1TransportDeserializer final : public TransportDeserializer { private: + const NodeId m_node_id; // Only for logging mutable CHash256 hasher; mutable uint256 data_hash; bool in_data; // parsing header (false) or data (true) @@ -764,8 +765,12 @@ private: } public: - - V1TransportDeserializer(const CMessageHeader::MessageStartChars& pchMessageStartIn, int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) { + V1TransportDeserializer(const CMessageHeader::MessageStartChars& pchMessageStartIn, const NodeId node_id, int nTypeIn, int nVersionIn) + : m_node_id(node_id), + hdrbuf(nTypeIn, nVersionIn), + hdr(pchMessageStartIn), + vRecv(nTypeIn, nVersionIn) + { Reset(); } diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp index 6fba2bfabaa..732136330b5 100644 --- a/src/test/fuzz/p2p_transport_deserializer.cpp +++ b/src/test/fuzz/p2p_transport_deserializer.cpp @@ -19,7 +19,8 @@ void initialize() void test_one_input(const std::vector& buffer) { - V1TransportDeserializer deserializer{Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION}; + // Construct deserializer, with a dummy NodeId + V1TransportDeserializer deserializer{Params().MessageStart(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; const char* pch = (const char*)buffer.data(); size_t n_bytes = buffer.size(); while (n_bytes > 0) { From 890b1d7c2b8312d41d048d2db124586c5dbc8a49 Mon Sep 17 00:00:00 2001 From: Troy Giorshev Date: Mon, 29 Jun 2020 14:15:06 -0400 Subject: [PATCH 2/6] Move checksum check from net_processing to net This removes the m_valid_checksum member from CNetMessage. Instead, GetMessage() returns an Optional. Additionally, GetMessage() has been given an out parameter to be used to hold error information. For now it is specifically a uint32_t used to hold the raw size of the corrupt message. The checksum check is now done in GetMessage. --- src/net.cpp | 47 +++++++++++--------- src/net.h | 8 ++-- src/net_processing.cpp | 11 +---- src/test/fuzz/p2p_transport_deserializer.cpp | 23 +++++----- 4 files changed, 45 insertions(+), 44 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 73029655ce5..3e015a68108 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -595,25 +595,33 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete while (nBytes > 0) { // absorb network data int handled = m_deserializer->Read(pch, nBytes); - if (handled < 0) return false; + if (handled < 0) { + return false; + } pch += handled; nBytes -= handled; if (m_deserializer->Complete()) { // decompose a transport agnostic CNetMessage from the deserializer - CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), time); + uint32_t out_err_raw_size{0}; + Optional result{m_deserializer->GetMessage(Params().MessageStart(), time, out_err_raw_size)}; + if (!result) { + // store the size of the corrupt message + mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size; + continue; + } //store received bytes per message command //to prevent a memory DOS, only allow valid commands - mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(msg.m_command); + mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(result->m_command); if (i == mapRecvBytesPerMsgCmd.end()) i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER); assert(i != mapRecvBytesPerMsgCmd.end()); - i->second += msg.m_raw_message_size; + i->second += result->m_raw_message_size; // push the message to the process queue, - vRecvMsg.push_back(std::move(msg)); + vRecvMsg.push_back(std::move(*result)); complete = true; } @@ -679,37 +687,36 @@ const uint256& V1TransportDeserializer::GetMessageHash() const return data_hash; } -CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time) +Optional V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time, uint32_t& out_err_raw_size) { // decompose a single CNetMessage from the TransportDeserializer - CNetMessage msg(std::move(vRecv)); + Optional msg(std::move(vRecv)); // store state about valid header, netmagic and checksum - msg.m_valid_header = hdr.IsValid(message_start); - msg.m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0); + msg->m_valid_header = hdr.IsValid(message_start); + msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0); uint256 hash = GetMessageHash(); - // store command string, payload size - msg.m_command = hdr.GetCommand(); - msg.m_message_size = hdr.nMessageSize; - msg.m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE; + // store command string, time, and sizes + msg->m_command = hdr.GetCommand(); + msg->m_time = time; + msg->m_message_size = hdr.nMessageSize; + msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE; // We just received a message off the wire, harvest entropy from the time (and the message checksum) RandAddEvent(ReadLE32(hash.begin())); - msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) == 0); - if (!msg.m_valid_checksum) { + if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0) { LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n", - SanitizeString(msg.m_command), msg.m_message_size, + SanitizeString(msg->m_command), msg->m_message_size, HexStr(Span(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)), HexStr(hdr.pchChecksum), m_node_id); + out_err_raw_size = msg->m_raw_message_size; + msg = nullopt; } - // store receive time - msg.m_time = time; - - // reset the network deserializer (prepare for the next message) + // Always reset the network deserializer (prepare for the next message) Reset(); return msg; } diff --git a/src/net.h b/src/net.h index bda6007e733..f581ce8ff96 100644 --- a/src/net.h +++ b/src/net.h @@ -14,8 +14,9 @@ #include #include #include -#include #include +#include +#include #include #include #include @@ -706,7 +707,6 @@ public: std::chrono::microseconds m_time{0}; //!< time of message receipt bool m_valid_netmagic = false; bool m_valid_header = false; - bool m_valid_checksum = false; uint32_t m_message_size{0}; //!< size of the payload uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) std::string m_command; @@ -732,7 +732,7 @@ public: // read and deserialize data virtual int Read(const char *data, unsigned int bytes) = 0; // decomposes a message from the context - virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) = 0; + virtual Optional GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err) = 0; virtual ~TransportDeserializer() {} }; @@ -790,7 +790,7 @@ public: if (ret < 0) Reset(); return ret; } - CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) override; + Optional GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err_raw_size) override; }; /** The TransportSerializer prepares messages for the network transport diff --git a/src/net_processing.cpp b/src/net_processing.cpp index 690b59476b3..d9d32cded66 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -3886,17 +3886,8 @@ bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic& interruptMsgP // Message size unsigned int nMessageSize = msg.m_message_size; - // Checksum - CDataStream& vRecv = msg.m_recv; - if (!msg.m_valid_checksum) - { - LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR peer=%d\n", __func__, - SanitizeString(msg_type), nMessageSize, pfrom->GetId()); - return fMoreWork; - } - try { - ProcessMessage(*pfrom, msg_type, vRecv, msg.m_time, interruptMsgProc); + ProcessMessage(*pfrom, msg_type, msg.m_recv, msg.m_time, interruptMsgProc); if (interruptMsgProc) return false; if (!pfrom->vRecvGetData.empty()) diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp index 732136330b5..3e9cd3af38a 100644 --- a/src/test/fuzz/p2p_transport_deserializer.cpp +++ b/src/test/fuzz/p2p_transport_deserializer.cpp @@ -32,16 +32,19 @@ void test_one_input(const std::vector& buffer) n_bytes -= handled; if (deserializer.Complete()) { const std::chrono::microseconds m_time{std::numeric_limits::max()}; - const CNetMessage msg = deserializer.GetMessage(Params().MessageStart(), m_time); - assert(msg.m_command.size() <= CMessageHeader::COMMAND_SIZE); - assert(msg.m_raw_message_size <= buffer.size()); - assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size); - assert(msg.m_time == m_time); - if (msg.m_valid_header) { - assert(msg.m_valid_netmagic); - } - if (!msg.m_valid_netmagic) { - assert(!msg.m_valid_header); + uint32_t out_err_raw_size{0}; + Optional result{deserializer.GetMessage(Params().MessageStart(), m_time, out_err_raw_size)}; + if (result) { + assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE); + assert(result->m_raw_message_size <= buffer.size()); + assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size); + assert(result->m_time == m_time); + if (result->m_valid_header) { + assert(result->m_valid_netmagic); + } + if (!result->m_valid_netmagic) { + assert(!result->m_valid_header); + } } } } From 1ca20c1af8f08f07c407c3183c37b467ddf0f413 Mon Sep 17 00:00:00 2001 From: Troy Giorshev Date: Tue, 26 May 2020 16:01:03 -0400 Subject: [PATCH 3/6] Add doxygen comment for ReceiveMsgBytes --- src/net.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/net.cpp b/src/net.cpp index 3e015a68108..fdb76d3b83d 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -585,6 +585,16 @@ void CNode::copyStats(CNodeStats &stats, const std::vector &m_asmap) } #undef X +/** + * Receive bytes from the buffer and deserialize them into messages. + * + * @param[in] pch A pointer to the raw data + * @param[in] nBytes Size of the data + * @param[out] complete Set True if at least one message has been + * deserialized and is ready to be processed + * @return True if the peer should stay connected, + * False if the peer should be disconnected from. + */ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete) { complete = false; From 5bceef6b12fa16d20287693be377dace3dfec3e5 Mon Sep 17 00:00:00 2001 From: Troy Giorshev Date: Mon, 8 Jun 2020 22:37:55 -0400 Subject: [PATCH 4/6] Change CMessageHeader Constructor This commit removes the single-parameter contructor of CMessageHeader and replaces it with a default constructor. The single parameter contructor isn't used anywhere except for tests. There is no reason to initialize a CMessageHeader with a particular messagestart. This messagestart should always be replaced when deserializing an actual message header so that we can run checks on it. The default constructor initializes it to zero, just like the command and checksum. This also removes a parameter of a V1TransportDeserializer constructor, as it was only used for this purpose. --- src/net.cpp | 2 +- src/net.h | 3 +-- src/protocol.cpp | 4 ++-- src/protocol.h | 2 +- src/test/fuzz/deserialize.cpp | 2 +- src/test/fuzz/p2p_transport_deserializer.cpp | 2 +- 6 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index fdb76d3b83d..1ae4b8fe085 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -2846,7 +2846,7 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn LogPrint(BCLog::NET, "Added connection peer=%d\n", id); } - m_deserializer = MakeUnique(V1TransportDeserializer(Params().MessageStart(), GetId(), SER_NETWORK, INIT_PROTO_VERSION)); + m_deserializer = MakeUnique(V1TransportDeserializer(GetId(), SER_NETWORK, INIT_PROTO_VERSION)); m_serializer = MakeUnique(V1TransportSerializer()); } diff --git a/src/net.h b/src/net.h index f581ce8ff96..cec201c5d24 100644 --- a/src/net.h +++ b/src/net.h @@ -765,10 +765,9 @@ private: } public: - V1TransportDeserializer(const CMessageHeader::MessageStartChars& pchMessageStartIn, const NodeId node_id, int nTypeIn, int nVersionIn) + V1TransportDeserializer(const NodeId node_id, int nTypeIn, int nVersionIn) : m_node_id(node_id), hdrbuf(nTypeIn, nVersionIn), - hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) { Reset(); diff --git a/src/protocol.cpp b/src/protocol.cpp index 48ca0c6df6a..6b4de68ce98 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -84,9 +84,9 @@ const static std::string allNetMessageTypes[] = { }; const static std::vector allNetMessageTypesVec(allNetMessageTypes, allNetMessageTypes+ARRAYLEN(allNetMessageTypes)); -CMessageHeader::CMessageHeader(const MessageStartChars& pchMessageStartIn) +CMessageHeader::CMessageHeader() { - memcpy(pchMessageStart, pchMessageStartIn, MESSAGE_START_SIZE); + memset(pchMessageStart, 0, MESSAGE_START_SIZE); memset(pchCommand, 0, sizeof(pchCommand)); nMessageSize = -1; memset(pchChecksum, 0, CHECKSUM_SIZE); diff --git a/src/protocol.h b/src/protocol.h index 7fb84cddf11..3bf0797ca40 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -37,7 +37,7 @@ public: static constexpr size_t HEADER_SIZE = MESSAGE_START_SIZE + COMMAND_SIZE + MESSAGE_SIZE_SIZE + CHECKSUM_SIZE; typedef unsigned char MessageStartChars[MESSAGE_START_SIZE]; - explicit CMessageHeader(const MessageStartChars& pchMessageStartIn); + explicit CMessageHeader(); /** Construct a P2P message header from message-start characters, a command and the size of the message. * @note Passing in a `pszCommand` longer than COMMAND_SIZE will result in a run-time assertion error. diff --git a/src/test/fuzz/deserialize.cpp b/src/test/fuzz/deserialize.cpp index 54793c890fe..f87b7576a42 100644 --- a/src/test/fuzz/deserialize.cpp +++ b/src/test/fuzz/deserialize.cpp @@ -190,7 +190,7 @@ void test_one_input(const std::vector& buffer) AssertEqualAfterSerializeDeserialize(s); #elif MESSAGEHEADER_DESERIALIZE const CMessageHeader::MessageStartChars pchMessageStart = {0x00, 0x00, 0x00, 0x00}; - CMessageHeader mh(pchMessageStart); + CMessageHeader mh; DeserializeFromFuzzingInput(buffer, mh); (void)mh.IsValid(pchMessageStart); #elif ADDRESS_DESERIALIZE diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp index 3e9cd3af38a..5349fd3f688 100644 --- a/src/test/fuzz/p2p_transport_deserializer.cpp +++ b/src/test/fuzz/p2p_transport_deserializer.cpp @@ -20,7 +20,7 @@ void initialize() void test_one_input(const std::vector& buffer) { // Construct deserializer, with a dummy NodeId - V1TransportDeserializer deserializer{Params().MessageStart(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; + V1TransportDeserializer deserializer{(NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; const char* pch = (const char*)buffer.data(); size_t n_bytes = buffer.size(); while (n_bytes > 0) { From 52d4ae46ab822d0f54e246a6f2364415cda149bd Mon Sep 17 00:00:00 2001 From: Troy Giorshev Date: Mon, 8 Jun 2020 22:26:22 -0400 Subject: [PATCH 5/6] Give V1TransportDeserializer CChainParams& member This adds a CChainParams& member to V1TransportDeserializer member, and use it in place of many Params() calls. In addition to reducing the number of calls to a global, this removes a parameter from GetMessage (and will later allow us to remove one from CMessageHeader::IsValid()) --- src/net.cpp | 14 +++++++------- src/net.h | 11 +++++++---- src/test/fuzz/p2p_transport_deserializer.cpp | 4 ++-- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 1ae4b8fe085..941ea3c4acc 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -615,7 +614,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete if (m_deserializer->Complete()) { // decompose a transport agnostic CNetMessage from the deserializer uint32_t out_err_raw_size{0}; - Optional result{m_deserializer->GetMessage(Params().MessageStart(), time, out_err_raw_size)}; + Optional result{m_deserializer->GetMessage(time, out_err_raw_size)}; if (!result) { // store the size of the corrupt message mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size; @@ -697,15 +696,14 @@ const uint256& V1TransportDeserializer::GetMessageHash() const return data_hash; } -Optional V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time, uint32_t& out_err_raw_size) +Optional V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, uint32_t& out_err_raw_size) { // decompose a single CNetMessage from the TransportDeserializer Optional msg(std::move(vRecv)); // store state about valid header, netmagic and checksum - msg->m_valid_header = hdr.IsValid(message_start); - msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0); - uint256 hash = GetMessageHash(); + msg->m_valid_header = hdr.IsValid(m_chain_params.MessageStart()); + msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) == 0); // store command string, time, and sizes msg->m_command = hdr.GetCommand(); @@ -713,6 +711,8 @@ Optional V1TransportDeserializer::GetMessage(const CMessageHeader:: msg->m_message_size = hdr.nMessageSize; msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE; + uint256 hash = GetMessageHash(); + // We just received a message off the wire, harvest entropy from the time (and the message checksum) RandAddEvent(ReadLE32(hash.begin())); @@ -2846,7 +2846,7 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn LogPrint(BCLog::NET, "Added connection peer=%d\n", id); } - m_deserializer = MakeUnique(V1TransportDeserializer(GetId(), SER_NETWORK, INIT_PROTO_VERSION)); + m_deserializer = MakeUnique(V1TransportDeserializer(Params(), GetId(), SER_NETWORK, INIT_PROTO_VERSION)); m_serializer = MakeUnique(V1TransportSerializer()); } diff --git a/src/net.h b/src/net.h index cec201c5d24..29941b9622b 100644 --- a/src/net.h +++ b/src/net.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -732,13 +733,14 @@ public: // read and deserialize data virtual int Read(const char *data, unsigned int bytes) = 0; // decomposes a message from the context - virtual Optional GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err) = 0; + virtual Optional GetMessage(std::chrono::microseconds time, uint32_t& out_err) = 0; virtual ~TransportDeserializer() {} }; class V1TransportDeserializer final : public TransportDeserializer { private: + const CChainParams& m_chain_params; const NodeId m_node_id; // Only for logging mutable CHash256 hasher; mutable uint256 data_hash; @@ -765,8 +767,9 @@ private: } public: - V1TransportDeserializer(const NodeId node_id, int nTypeIn, int nVersionIn) - : m_node_id(node_id), + V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn) + : m_chain_params(chain_params), + m_node_id(node_id), hdrbuf(nTypeIn, nVersionIn), vRecv(nTypeIn, nVersionIn) { @@ -789,7 +792,7 @@ public: if (ret < 0) Reset(); return ret; } - Optional GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err_raw_size) override; + Optional GetMessage(std::chrono::microseconds time, uint32_t& out_err_raw_size) override; }; /** The TransportSerializer prepares messages for the network transport diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp index 5349fd3f688..6252b8e91ba 100644 --- a/src/test/fuzz/p2p_transport_deserializer.cpp +++ b/src/test/fuzz/p2p_transport_deserializer.cpp @@ -20,7 +20,7 @@ void initialize() void test_one_input(const std::vector& buffer) { // Construct deserializer, with a dummy NodeId - V1TransportDeserializer deserializer{(NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; + V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; const char* pch = (const char*)buffer.data(); size_t n_bytes = buffer.size(); while (n_bytes > 0) { @@ -33,7 +33,7 @@ void test_one_input(const std::vector& buffer) if (deserializer.Complete()) { const std::chrono::microseconds m_time{std::numeric_limits::max()}; uint32_t out_err_raw_size{0}; - Optional result{deserializer.GetMessage(Params().MessageStart(), m_time, out_err_raw_size)}; + Optional result{deserializer.GetMessage(m_time, out_err_raw_size)}; if (result) { assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE); assert(result->m_raw_message_size <= buffer.size()); From deb52711a17236d0fca302701b5af585341ab42a Mon Sep 17 00:00:00 2001 From: Troy Giorshev Date: Tue, 26 May 2020 17:01:57 -0400 Subject: [PATCH 6/6] Remove header checks out of net_processing This moves header size and netmagic checking out of net_processing and into net. This check now runs in ReadHeader, so that net can exit early out of receiving bytes from the peer. IsValid is now slimmed down, so it no longer needs a MessageStartChars& parameter. Additionally this removes the rest of the m_valid_* members from CNetMessage. --- src/net.cpp | 20 +++++++++++--- src/net.h | 6 ++-- src/net_processing.cpp | 21 -------------- src/protocol.cpp | 29 ++++++-------------- src/protocol.h | 2 +- src/test/fuzz/deserialize.cpp | 3 +- src/test/fuzz/p2p_transport_deserializer.cpp | 6 ---- test/functional/p2p_invalid_messages.py | 7 ++--- 8 files changed, 32 insertions(+), 62 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 941ea3c4acc..633f9a2f7f1 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -605,6 +605,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete // absorb network data int handled = m_deserializer->Read(pch, nBytes); if (handled < 0) { + // Serious header problem, disconnect from the peer. return false; } @@ -616,6 +617,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete uint32_t out_err_raw_size{0}; Optional result{m_deserializer->GetMessage(time, out_err_raw_size)}; if (!result) { + // Message deserialization failed. Drop the message but don't disconnect the peer. // store the size of the corrupt message mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size; continue; @@ -657,11 +659,19 @@ int V1TransportDeserializer::readHeader(const char *pch, unsigned int nBytes) hdrbuf >> hdr; } catch (const std::exception&) { + LogPrint(BCLog::NET, "HEADER ERROR - UNABLE TO DESERIALIZE, peer=%d\n", m_node_id); + return -1; + } + + // Check start string, network magic + if (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) != 0) { + LogPrint(BCLog::NET, "HEADER ERROR - MESSAGESTART (%s, %u bytes), received %s, peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, HexStr(hdr.pchMessageStart), m_node_id); return -1; } // reject messages larger than MAX_SIZE or MAX_PROTOCOL_MESSAGE_LENGTH if (hdr.nMessageSize > MAX_SIZE || hdr.nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) { + LogPrint(BCLog::NET, "HEADER ERROR - SIZE (%s, %u bytes), peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, m_node_id); return -1; } @@ -701,10 +711,6 @@ Optional V1TransportDeserializer::GetMessage(const std::chrono::mic // decompose a single CNetMessage from the TransportDeserializer Optional msg(std::move(vRecv)); - // store state about valid header, netmagic and checksum - msg->m_valid_header = hdr.IsValid(m_chain_params.MessageStart()); - msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) == 0); - // store command string, time, and sizes msg->m_command = hdr.GetCommand(); msg->m_time = time; @@ -716,6 +722,7 @@ Optional V1TransportDeserializer::GetMessage(const std::chrono::mic // We just received a message off the wire, harvest entropy from the time (and the message checksum) RandAddEvent(ReadLE32(hash.begin())); + // Check checksum and header command string if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0) { LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n", SanitizeString(msg->m_command), msg->m_message_size, @@ -724,6 +731,11 @@ Optional V1TransportDeserializer::GetMessage(const std::chrono::mic m_node_id); out_err_raw_size = msg->m_raw_message_size; msg = nullopt; + } else if (!hdr.IsCommandValid()) { + LogPrint(BCLog::NET, "HEADER ERROR - COMMAND (%s, %u bytes), peer=%d\n", + hdr.GetCommand(), msg->m_message_size, m_node_id); + out_err_raw_size = msg->m_raw_message_size; + msg = nullopt; } // Always reset the network deserializer (prepare for the next message) diff --git a/src/net.h b/src/net.h index 29941b9622b..9a92f805119 100644 --- a/src/net.h +++ b/src/net.h @@ -706,10 +706,8 @@ class CNetMessage { public: CDataStream m_recv; //!< received message data std::chrono::microseconds m_time{0}; //!< time of message receipt - bool m_valid_netmagic = false; - bool m_valid_header = false; - uint32_t m_message_size{0}; //!< size of the payload - uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) + uint32_t m_message_size{0}; //!< size of the payload + uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) std::string m_command; CNetMessage(CDataStream&& recv_in) : m_recv(std::move(recv_in)) {} diff --git a/src/net_processing.cpp b/src/net_processing.cpp index d9d32cded66..920e7a1abf3 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -3820,14 +3820,6 @@ bool PeerManager::MaybeDiscourageAndDisconnect(CNode& pnode) bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic& interruptMsgProc) { - // - // Message format - // (4) message start - // (12) command - // (4) size - // (4) checksum - // (x) data - // bool fMoreWork = false; if (!pfrom->vRecvGetData.empty()) @@ -3868,19 +3860,6 @@ bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic& interruptMsgP CNetMessage& msg(msgs.front()); msg.SetVersion(pfrom->GetCommonVersion()); - // Check network magic - if (!msg.m_valid_netmagic) { - LogPrint(BCLog::NET, "PROCESSMESSAGE: INVALID MESSAGESTART %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId()); - pfrom->fDisconnect = true; - return false; - } - - // Check header - if (!msg.m_valid_header) - { - LogPrint(BCLog::NET, "PROCESSMESSAGE: ERRORS IN HEADER %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId()); - return fMoreWork; - } const std::string& msg_type = msg.m_command; // Message size diff --git a/src/protocol.cpp b/src/protocol.cpp index 6b4de68ce98..84b6e96aee9 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -111,31 +111,20 @@ std::string CMessageHeader::GetCommand() const return std::string(pchCommand, pchCommand + strnlen(pchCommand, COMMAND_SIZE)); } -bool CMessageHeader::IsValid(const MessageStartChars& pchMessageStartIn) const +bool CMessageHeader::IsCommandValid() const { - // Check start string - if (memcmp(pchMessageStart, pchMessageStartIn, MESSAGE_START_SIZE) != 0) - return false; - // Check the command string for errors - for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; p1++) - { - if (*p1 == 0) - { + for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; ++p1) { + if (*p1 == 0) { // Must be all zeros after the first zero - for (; p1 < pchCommand + COMMAND_SIZE; p1++) - if (*p1 != 0) + for (; p1 < pchCommand + COMMAND_SIZE; ++p1) { + if (*p1 != 0) { return false; - } - else if (*p1 < ' ' || *p1 > 0x7E) + } + } + } else if (*p1 < ' ' || *p1 > 0x7E) { return false; - } - - // Message size - if (nMessageSize > MAX_SIZE) - { - LogPrintf("CMessageHeader::IsValid(): (%s, %u bytes) nMessageSize > MAX_SIZE\n", GetCommand(), nMessageSize); - return false; + } } return true; diff --git a/src/protocol.h b/src/protocol.h index 3bf0797ca40..9a44a1626c5 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -45,7 +45,7 @@ public: CMessageHeader(const MessageStartChars& pchMessageStartIn, const char* pszCommand, unsigned int nMessageSizeIn); std::string GetCommand() const; - bool IsValid(const MessageStartChars& messageStart) const; + bool IsCommandValid() const; SERIALIZE_METHODS(CMessageHeader, obj) { READWRITE(obj.pchMessageStart, obj.pchCommand, obj.nMessageSize, obj.pchChecksum); } diff --git a/src/test/fuzz/deserialize.cpp b/src/test/fuzz/deserialize.cpp index f87b7576a42..b799d3b43b8 100644 --- a/src/test/fuzz/deserialize.cpp +++ b/src/test/fuzz/deserialize.cpp @@ -189,10 +189,9 @@ void test_one_input(const std::vector& buffer) DeserializeFromFuzzingInput(buffer, s); AssertEqualAfterSerializeDeserialize(s); #elif MESSAGEHEADER_DESERIALIZE - const CMessageHeader::MessageStartChars pchMessageStart = {0x00, 0x00, 0x00, 0x00}; CMessageHeader mh; DeserializeFromFuzzingInput(buffer, mh); - (void)mh.IsValid(pchMessageStart); + (void)mh.IsCommandValid(); #elif ADDRESS_DESERIALIZE CAddress a; DeserializeFromFuzzingInput(buffer, a); diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp index 6252b8e91ba..7e216e16feb 100644 --- a/src/test/fuzz/p2p_transport_deserializer.cpp +++ b/src/test/fuzz/p2p_transport_deserializer.cpp @@ -39,12 +39,6 @@ void test_one_input(const std::vector& buffer) assert(result->m_raw_message_size <= buffer.size()); assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size); assert(result->m_time == m_time); - if (result->m_valid_header) { - assert(result->m_valid_netmagic); - } - if (!result->m_valid_netmagic) { - assert(!result->m_valid_header); - } } } } diff --git a/test/functional/p2p_invalid_messages.py b/test/functional/p2p_invalid_messages.py index fe57057a83a..78a9d2e8523 100755 --- a/test/functional/p2p_invalid_messages.py +++ b/test/functional/p2p_invalid_messages.py @@ -81,7 +81,7 @@ class InvalidMessagesTest(BitcoinTestFramework): def test_magic_bytes(self): self.log.info("Test message with invalid magic bytes disconnects peer") conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['PROCESSMESSAGE: INVALID MESSAGESTART badmsg']): + with self.nodes[0].assert_debug_log(['HEADER ERROR - MESSAGESTART (badmsg, 2 bytes), received ffffffff']): msg = conn.build_message(msg_unrecognized(str_data="d")) # modify magic bytes msg = b'\xff' * 4 + msg[4:] @@ -105,7 +105,7 @@ class InvalidMessagesTest(BitcoinTestFramework): def test_size(self): self.log.info("Test message with oversized payload disconnects peer") conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['']): + with self.nodes[0].assert_debug_log(['HEADER ERROR - SIZE (badmsg, 4000001 bytes)']): msg = msg_unrecognized(str_data="d" * (VALID_DATA_LIMIT + 1)) msg = conn.build_message(msg) conn.send_raw_message(msg) @@ -115,9 +115,8 @@ class InvalidMessagesTest(BitcoinTestFramework): def test_msgtype(self): self.log.info("Test message with invalid message type logs an error") conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['PROCESSMESSAGE: ERRORS IN HEADER']): + with self.nodes[0].assert_debug_log(['HEADER ERROR - COMMAND']): msg = msg_unrecognized(str_data="d") - msg.msgtype = b'\xff' * 12 msg = conn.build_message(msg) # Modify msgtype msg = msg[:7] + b'\x00' + msg[7 + 1:]