diff --git a/src/network/message.rs b/src/network/message.rs index 2486a25a60..f7f1909a57 100644 --- a/src/network/message.rs +++ b/src/network/message.rs @@ -64,6 +64,20 @@ impl Decodable for CommandString { } } +#[derive(Debug)] +/// Struct used to configure stream reader function +pub struct StreamReaderConfig { + /// Number of attempts to read data from the stream if the reader returns 0 bytes + pub iterations: usize, + /// Size of allocated buffer for a single read opetaion + pub buffer_size: usize +} + +/// Defining default values +impl Default for StreamReaderConfig { + fn default() -> Self { Self { iterations: 16, buffer_size: 64 * 1024 } } +} + #[derive(Debug)] /// A Network message pub struct RawNetworkMessage { @@ -150,28 +164,34 @@ impl RawNetworkMessage { /// Reads stream from a TCP socket and parses first message from it, returing /// the rest of the unparsed buffer for later usage. - pub fn from_stream(stream: &mut Read, remaining_part: &mut Vec) -> Result { - let mut max_iterations = 16; - while max_iterations > 0 { - max_iterations -= 1; + pub fn from_stream(stream: &mut Read, remaining_part: &mut Vec, + StreamReaderConfig { iterations, buffer_size }: StreamReaderConfig) -> Result { + println!("Called with {} iterations and {} ubffer size", iterations, buffer_size); + let mut iterations = iterations; + while iterations > 0 { + iterations -= 1; if remaining_part.len() > 0 { match encode::deserialize_partial::(&remaining_part) { // In this case we just have an incomplete data, so we need to read more Err(encode::Error::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof => (), + // All other types of errors should be passed up to the caller Err(err) => return Err(err), + // We have successfully read from the buffer Ok((message, index)) => { + println!("Deserialized {} bytes", index); remaining_part.drain(..index); return Ok(message) }, } } - let mut new_data = vec![0u8; 1024]; + let mut new_data = vec![0u8; buffer_size]; let count = stream.read(&mut new_data)?; if count > 0 { remaining_part.extend(new_data[0..count].iter()); } + println!("Read {} bytes, remaining part now is {} bytes length", count, remaining_part.len()); } Err(encode::Error::ParseFailed("Zero-length input")) } @@ -394,7 +414,7 @@ mod test { tmpfile.seek(SeekFrom::Start(0)).unwrap(); let mut buffer = vec![]; - let msg = RawNetworkMessage::from_stream(&mut tmpfile, &mut buffer); + let msg = RawNetworkMessage::from_stream(&mut tmpfile, &mut buffer, Default::default()); assert!(buffer.len() > 0); assert!(msg.is_err()); } @@ -430,7 +450,7 @@ mod test { tmpfile.seek(SeekFrom::Start(0)).unwrap(); let mut buffer = vec![]; - let msg = RawNetworkMessage::from_stream(&mut tmpfile, &mut buffer).unwrap(); + let msg = RawNetworkMessage::from_stream(&mut tmpfile, &mut buffer, Default::default()).unwrap(); assert!(buffer.len() > 0); assert_eq!(msg.magic, 0xd9b4bef9); if let NetworkMessage::Version(version_msg) = msg.payload { @@ -446,7 +466,7 @@ mod test { } println!("{:?}", &buffer); - let msg = RawNetworkMessage::from_stream(&mut tmpfile, &mut buffer).unwrap(); + let msg = RawNetworkMessage::from_stream(&mut tmpfile, &mut buffer,Default::default()).unwrap(); assert_eq!(buffer.len(), 0); assert_eq!(msg.magic, 0xd9b4bef9); if let NetworkMessage::Ping(nonce) = msg.payload {