sheave_core/readers/
basic_header.rs

1use std::{
2    future::Future,
3    io::Result as IOResult,
4    pin::Pin,
5    task::{
6        Context as FutureContext,
7        Poll
8    }
9};
10use futures::ready;
11use tokio::io::{
12    AsyncRead,
13    ReadBuf
14};
15use crate::messages::headers::{
16    BasicHeader,
17    MessageFormat
18};
19
20#[doc(hidden)]
21#[derive(Debug)]
22pub struct BasicHeaderReader<'a, R: AsyncRead> {
23    reader: Pin<&'a mut R>
24}
25
26#[doc(hidden)]
27impl<R: AsyncRead> Future for BasicHeaderReader<'_, R> {
28    type Output = IOResult<BasicHeader>;
29
30    fn poll(mut self: Pin<&mut Self>, cx: &mut FutureContext<'_>) -> Poll<Self::Output> {
31        let mut first_byte: [u8; 1] = [0];
32        let mut buf = ReadBuf::new(&mut first_byte);
33        ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
34        let message_format: MessageFormat = (first_byte[0] >> 6).into();
35        let chunk_id = match first_byte[0] << 2 >> 2 {
36            1 => {
37                let mut chunk_id_bytes: [u8; 2] = [0; 2];
38                let mut buf = ReadBuf::new(&mut chunk_id_bytes);
39                ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
40                u16::from_le_bytes(chunk_id_bytes) + 64
41            },
42            0 => {
43                let mut chunk_id_bytes: [u8; 2] = [0; 2];
44                let mut buf = ReadBuf::new(&mut chunk_id_bytes[1..]);
45                ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
46                u16::from_be_bytes(chunk_id_bytes) + 64
47            },
48            chunk_id => chunk_id as u16
49        };
50        Poll::Ready(Ok(BasicHeader::new(message_format, chunk_id)))
51    }
52}
53
54/// Reads a basic header from streams.
55///
56/// # Examples
57///
58/// ```rust
59/// use std::{
60///     io::Result as IOResult,
61///     pin::pin
62/// };
63/// use sheave_core::{
64///     messages::headers::{
65///         BasicHeader,
66///         MessageFormat::*
67///     },
68///     readers::read_basic_header
69/// };
70///
71/// #[tokio::main]
72/// async fn main() -> IOResult<()> {
73///     // In case of 1 byte.
74///     let reader = ((New as u8) << 6 | 2).to_be_bytes();
75///     let result = read_basic_header(pin!(reader.as_slice())).await?;
76///     assert_eq!(New, result.get_message_format());
77///     assert_eq!(2, result.get_chunk_id());
78///
79///     // In case of 2 bytes.
80///     let mut reader: [u8; 2] = [0; 2];
81///     reader[0] = (New as u8) << 6;
82///     reader[1] = 2;
83///     let result = read_basic_header(pin!(reader.as_slice())).await?;
84///     assert_eq!(New, result.get_message_format());
85///     assert_eq!(66, result.get_chunk_id());
86///
87///     // In case of 3 bytes.
88///     let mut reader: [u8; 3] = [0; 3];
89///     reader[0] = (New as u8) << 6 | 1;
90///     reader[1..].copy_from_slice((2 as u16).to_le_bytes().as_slice());
91///     let result = read_basic_header(pin!(reader.as_slice())).await?;
92///     assert_eq!(New, result.get_message_format());
93///     assert_eq!(66, result.get_chunk_id());
94///     Ok(())
95/// }
96/// ```
97pub fn read_basic_header<R: AsyncRead>(reader: Pin<&mut R>) -> BasicHeaderReader<'_, R> {
98    BasicHeaderReader { reader }
99}
100
101#[cfg(test)]
102mod tests {
103    use std::{
104        cmp::max,
105        pin::pin
106    };
107    use rand::random;
108    use crate::messages::headers::MessageFormat;
109    use super::*;
110
111    #[tokio::test]
112    async fn read_one_byte() {
113        let message_format_bits = random::<u8>() & 0xc0;
114        let chunk_id_bits = max(2, random::<u8>() << 2 >> 2);
115        let reader: [u8; 1] = [message_format_bits | chunk_id_bits];
116        let result = read_basic_header(pin!(reader.as_slice())).await;
117        assert!(result.is_ok());
118        let basic_header = result.unwrap();
119        assert_eq!(MessageFormat::from(message_format_bits >> 6), basic_header.get_message_format());
120        assert_eq!(chunk_id_bits as u16, basic_header.get_chunk_id())
121    }
122
123    #[tokio::test]
124    async fn read_two_bytes() {
125        let message_format_bits = random::<u8>() & 0xc0;
126        let is_two_bytes: u8 = 0;
127        let chunk_id_byte = random::<u8>();
128        let mut reader: [u8; 2] = [0; 2];
129        reader[0] = message_format_bits | is_two_bytes;
130        reader[1] = chunk_id_byte;
131        let result = read_basic_header(pin!(reader.as_slice())).await;
132        assert!(result.is_ok());
133        let basic_header = result.unwrap();
134        assert_eq!(MessageFormat::from(message_format_bits >> 6), basic_header.get_message_format());
135        assert_eq!((chunk_id_byte as u16) + 64, basic_header.get_chunk_id())
136    }
137
138    #[tokio::test]
139    async fn read_three_bytes() {
140        let message_format_bits = random::<u8>() & 0xc0;
141        let is_two_bytes: u8 = 1;
142        let chunk_id_bytes = random::<u16>();
143        let mut reader: [u8; 3] = [0; 3];
144        reader[0] = message_format_bits | is_two_bytes;
145        reader[1..].copy_from_slice(chunk_id_bytes.to_le_bytes().as_slice());
146        let result = read_basic_header(pin!(reader.as_slice())).await;
147        assert!(result.is_ok());
148        let basic_header = result.unwrap();
149        assert_eq!(MessageFormat::from(message_format_bits >> 6), basic_header.get_message_format());
150        assert_eq!(chunk_id_bytes + 64, basic_header.get_chunk_id())
151    }
152}