0
0
Fork 0
mirror of https://github.com/bitcoin/bitcoin.git synced 2025-02-10 10:52:31 -05:00

Merge #20056: net: Use Span in ReceiveMsgBytes

fa5ed3b4ca net: Use Span in ReceiveMsgBytes (MarcoFalke)

Pull request description:

  Pass a data pointer and a size as span in `ReceiveMsgBytes` to get the benefits of a span

ACKs for top commit:
  jonatack:
    ACK fa5ed3b4ca code review, rebased to current master 12a1c3ad1a, debug build, unit tests, ran bitcoind/-netinfo/getpeerinfo
  theStack:
    ACK fa5ed3b4ca

Tree-SHA512: 89bf111323148d6e6e50185ad20ab39f73ab3a58a27e46319e3a08bcf5dcf9d6aa84faff0fd6afb90cb892ac2f557a237c144560986063bc736a69ace353ab9d
This commit is contained in:
Wladimir J. van der Laan 2020-11-20 06:05:42 +01:00
commit fdd068507d
No known key found for this signature in database
GPG key ID: 1E4AED62986CD25D
6 changed files with 43 additions and 45 deletions

View file

@ -629,34 +629,21 @@ void CNode::copyStats(CNodeStats &stats, const std::vector<bool> &m_asmap)
} }
#undef X #undef X
/** bool CNode::ReceiveMsgBytes(Span<const char> msg_bytes, bool& complete)
* Receive bytes from the buffer and deserialize them into messages.
*
* @param[in] pch A pointer to the raw data
* @param[in] nBytes Size of the data
* @param[out] complete Set True if at least one message has been
* deserialized and is ready to be processed
* @return True if the peer should stay connected,
* False if the peer should be disconnected from.
*/
bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete)
{ {
complete = false; complete = false;
const auto time = GetTime<std::chrono::microseconds>(); const auto time = GetTime<std::chrono::microseconds>();
LOCK(cs_vRecv); LOCK(cs_vRecv);
nLastRecv = std::chrono::duration_cast<std::chrono::seconds>(time).count(); nLastRecv = std::chrono::duration_cast<std::chrono::seconds>(time).count();
nRecvBytes += nBytes; nRecvBytes += msg_bytes.size();
while (nBytes > 0) { while (msg_bytes.size() > 0) {
// absorb network data // absorb network data
int handled = m_deserializer->Read(pch, nBytes); int handled = m_deserializer->Read(msg_bytes);
if (handled < 0) { if (handled < 0) {
// Serious header problem, disconnect from the peer. // Serious header problem, disconnect from the peer.
return false; return false;
} }
pch += handled;
nBytes -= handled;
if (m_deserializer->Complete()) { if (m_deserializer->Complete()) {
// decompose a transport agnostic CNetMessage from the deserializer // decompose a transport agnostic CNetMessage from the deserializer
uint32_t out_err_raw_size{0}; uint32_t out_err_raw_size{0};
@ -686,13 +673,13 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
return true; return true;
} }
int V1TransportDeserializer::readHeader(const char *pch, unsigned int nBytes) int V1TransportDeserializer::readHeader(Span<const char> msg_bytes)
{ {
// copy data to temporary parsing buffer // copy data to temporary parsing buffer
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos; unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
unsigned int nCopy = std::min(nRemaining, nBytes); unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
memcpy(&hdrbuf[nHdrPos], pch, nCopy); memcpy(&hdrbuf[nHdrPos], msg_bytes.data(), nCopy);
nHdrPos += nCopy; nHdrPos += nCopy;
// if header incomplete, exit // if header incomplete, exit
@ -726,18 +713,18 @@ int V1TransportDeserializer::readHeader(const char *pch, unsigned int nBytes)
return nCopy; return nCopy;
} }
int V1TransportDeserializer::readData(const char *pch, unsigned int nBytes) int V1TransportDeserializer::readData(Span<const char> msg_bytes)
{ {
unsigned int nRemaining = hdr.nMessageSize - nDataPos; unsigned int nRemaining = hdr.nMessageSize - nDataPos;
unsigned int nCopy = std::min(nRemaining, nBytes); unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
if (vRecv.size() < nDataPos + nCopy) { if (vRecv.size() < nDataPos + nCopy) {
// Allocate up to 256 KiB ahead, but never more than the total message size. // Allocate up to 256 KiB ahead, but never more than the total message size.
vRecv.resize(std::min(hdr.nMessageSize, nDataPos + nCopy + 256 * 1024)); vRecv.resize(std::min(hdr.nMessageSize, nDataPos + nCopy + 256 * 1024));
} }
hasher.Write({(const unsigned char*)pch, nCopy}); hasher.Write(MakeUCharSpan(msg_bytes.first(nCopy)));
memcpy(&vRecv[nDataPos], pch, nCopy); memcpy(&vRecv[nDataPos], msg_bytes.data(), nCopy);
nDataPos += nCopy; nDataPos += nCopy;
return nCopy; return nCopy;
@ -1487,7 +1474,7 @@ void CConnman::SocketHandler()
if (nBytes > 0) if (nBytes > 0)
{ {
bool notify = false; bool notify = false;
if (!pnode->ReceiveMsgBytes(pchBuf, nBytes, notify)) if (!pnode->ReceiveMsgBytes(Span<const char>(pchBuf, nBytes), notify))
pnode->CloseSocketDisconnect(); pnode->CloseSocketDisconnect();
RecordBytesRecv(nBytes); RecordBytesRecv(nBytes);
if (notify) { if (notify) {

View file

@ -757,8 +757,8 @@ public:
virtual bool Complete() const = 0; virtual bool Complete() const = 0;
// set the serialization context version // set the serialization context version
virtual void SetVersion(int version) = 0; virtual void SetVersion(int version) = 0;
// read and deserialize data /** read and deserialize data, advances msg_bytes data pointer */
virtual int Read(const char *data, unsigned int bytes) = 0; virtual int Read(Span<const char>& msg_bytes) = 0;
// decomposes a message from the context // decomposes a message from the context
virtual Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err) = 0; virtual Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err) = 0;
virtual ~TransportDeserializer() {} virtual ~TransportDeserializer() {}
@ -779,8 +779,8 @@ private:
unsigned int nDataPos; unsigned int nDataPos;
const uint256& GetMessageHash() const; const uint256& GetMessageHash() const;
int readHeader(const char *pch, unsigned int nBytes); int readHeader(Span<const char> msg_bytes);
int readData(const char *pch, unsigned int nBytes); int readData(Span<const char> msg_bytes);
void Reset() { void Reset() {
vRecv.clear(); vRecv.clear();
@ -814,9 +814,14 @@ public:
hdrbuf.SetVersion(nVersionIn); hdrbuf.SetVersion(nVersionIn);
vRecv.SetVersion(nVersionIn); vRecv.SetVersion(nVersionIn);
} }
int Read(const char *pch, unsigned int nBytes) override { int Read(Span<const char>& msg_bytes) override
int ret = in_data ? readData(pch, nBytes) : readHeader(pch, nBytes); {
if (ret < 0) Reset(); int ret = in_data ? readData(msg_bytes) : readHeader(msg_bytes);
if (ret < 0) {
Reset();
} else {
msg_bytes = msg_bytes.subspan(ret);
}
return ret; return ret;
} }
Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err_raw_size) override; Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err_raw_size) override;
@ -1118,7 +1123,16 @@ public:
return nRefCount; return nRefCount;
} }
bool ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete); /**
* Receive bytes from the buffer and deserialize them into messages.
*
* @param[in] msg_bytes The raw data
* @param[out] complete Set True if at least one message has been
* deserialized and is ready to be processed
* @return True if the peer should stay connected,
* False if the peer should be disconnected from.
*/
bool ReceiveMsgBytes(Span<const char> msg_bytes, bool& complete);
void SetCommonVersion(int greatest_common_version) void SetCommonVersion(int greatest_common_version)
{ {

View file

@ -128,7 +128,7 @@ void test_one_input(const std::vector<uint8_t>& buffer)
case 11: { case 11: {
const std::vector<uint8_t> b = ConsumeRandomLengthByteVector(fuzzed_data_provider); const std::vector<uint8_t> b = ConsumeRandomLengthByteVector(fuzzed_data_provider);
bool complete; bool complete;
node.ReceiveMsgBytes((const char*)b.data(), b.size(), complete); node.ReceiveMsgBytes({(const char*)b.data(), b.size()}, complete);
break; break;
} }
} }

View file

@ -21,15 +21,12 @@ void test_one_input(const std::vector<uint8_t>& buffer)
{ {
// Construct deserializer, with a dummy NodeId // Construct deserializer, with a dummy NodeId
V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION};
const char* pch = (const char*)buffer.data(); Span<const char> msg_bytes{(const char*)buffer.data(), buffer.size()};
size_t n_bytes = buffer.size(); while (msg_bytes.size() > 0) {
while (n_bytes > 0) { const int handled = deserializer.Read(msg_bytes);
const int handled = deserializer.Read(pch, n_bytes);
if (handled < 0) { if (handled < 0) {
break; break;
} }
pch += handled;
n_bytes -= handled;
if (deserializer.Complete()) { if (deserializer.Complete()) {
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()}; const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
uint32_t out_err_raw_size{0}; uint32_t out_err_raw_size{0};

View file

@ -7,9 +7,9 @@
#include <chainparams.h> #include <chainparams.h>
#include <net.h> #include <net.h>
void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, const char* pch, unsigned int nBytes, bool& complete) const void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span<const char> msg_bytes, bool& complete) const
{ {
assert(node.ReceiveMsgBytes(pch, nBytes, complete)); assert(node.ReceiveMsgBytes(msg_bytes, complete));
if (complete) { if (complete) {
size_t nSizeAdded = 0; size_t nSizeAdded = 0;
auto it(node.vRecvMsg.begin()); auto it(node.vRecvMsg.begin());
@ -33,7 +33,7 @@ bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) con
node.m_serializer->prepareForTransport(ser_msg, ser_msg_header); node.m_serializer->prepareForTransport(ser_msg, ser_msg_header);
bool complete; bool complete;
NodeReceiveMsgBytes(node, (const char*)ser_msg_header.data(), ser_msg_header.size(), complete); NodeReceiveMsgBytes(node, {(const char*)ser_msg_header.data(), ser_msg_header.size()}, complete);
NodeReceiveMsgBytes(node, (const char*)ser_msg.data.data(), ser_msg.data.size(), complete); NodeReceiveMsgBytes(node, {(const char*)ser_msg.data.data(), ser_msg.data.size()}, complete);
return complete; return complete;
} }

View file

@ -25,7 +25,7 @@ struct ConnmanTestMsg : public CConnman {
void ProcessMessagesOnce(CNode& node) { m_msgproc->ProcessMessages(&node, flagInterruptMsgProc); } void ProcessMessagesOnce(CNode& node) { m_msgproc->ProcessMessages(&node, flagInterruptMsgProc); }
void NodeReceiveMsgBytes(CNode& node, const char* pch, unsigned int nBytes, bool& complete) const; void NodeReceiveMsgBytes(CNode& node, Span<const char> msg_bytes, bool& complete) const;
bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const; bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const;
}; };