sheave_core/readers/
message_header.rs

1use std::{
2    future::Future,
3    io::Result as IOResult,
4    pin::Pin,
5    task::{
6        Context as FutureContext,
7        Poll
8    },
9    time::Duration
10};
11use futures::ready;
12use tokio::io::{
13    AsyncRead,
14    ReadBuf
15};
16use crate::messages::headers::{
17    MessageHeader,
18    New,
19    SameSource,
20    TimerChange,
21    MessageFormat,
22    MessageType
23};
24
25#[doc(hidden)]
26#[derive(Debug)]
27pub struct MessageHeaderReader<'a, R: AsyncRead> {
28    reader: Pin<&'a mut R>,
29    message_format: MessageFormat
30}
31
32#[doc(hidden)]
33impl<R: AsyncRead> MessageHeaderReader<'_, R> {
34    fn read_timestamp(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<Duration>> {
35        let mut timestamp_bytes: [u8; 4] = [0; 4];
36        let mut buf = ReadBuf::new(&mut timestamp_bytes[1..]);
37        ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
38        Poll::Ready(Ok(Duration::from_millis(u32::from_be_bytes(timestamp_bytes) as u64)))
39    }
40
41    fn read_message_length(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<u32>> {
42        let mut message_length_bytes: [u8; 4] = [0; 4];
43        let mut buf = ReadBuf::new(&mut message_length_bytes[1..]);
44        ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
45        Poll::Ready(Ok(u32::from_be_bytes(message_length_bytes)))
46    }
47
48    fn read_message_type(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<MessageType>> {
49        let mut message_type_byte: [u8; 1] = [0; 1];
50        let mut buf = ReadBuf::new(&mut message_type_byte);
51        ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
52        Poll::Ready(Ok(u8::from_be_bytes(message_type_byte).into()))
53    }
54
55    fn read_message_id(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<u32>> {
56        let mut message_id_bytes: [u8; 4] = [0; 4];
57        let mut buf = ReadBuf::new(&mut message_id_bytes);
58        ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
59        Poll::Ready(Ok(u32::from_le_bytes(message_id_bytes)))
60    }
61
62    fn read_new(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<New>> {
63        let timestamp = ready!(self.read_timestamp(cx))?;
64        let message_length = ready!(self.read_message_length(cx))?;
65        let message_type = ready!(self.read_message_type(cx))?;
66        let message_id = ready!(self.read_message_id(cx))?;
67        Poll::Ready(Ok((timestamp, message_length, message_type, message_id).into()))
68    }
69
70    fn read_same_source(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<SameSource>> {
71        let timestamp = ready!(self.read_timestamp(cx))?;
72        let message_length = ready!(self.read_message_length(cx))?;
73        let message_type = ready!(self.read_message_type(cx))?;
74        Poll::Ready(Ok((timestamp, message_length, message_type).into()))
75    }
76
77    fn read_timer_change(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<TimerChange>> {
78        let timestamp = ready!(self.read_timestamp(cx))?;
79        Poll::Ready(Ok(timestamp.into()))
80    }
81}
82
83#[doc(hidden)]
84impl<R: AsyncRead> Future for MessageHeaderReader<'_, R> {
85    type Output = IOResult<MessageHeader>;
86
87    fn poll(mut self: Pin<&mut Self>, cx: &mut FutureContext<'_>) -> Poll<Self::Output> {
88        match self.message_format {
89            MessageFormat::New => self.read_new(cx).map_ok(MessageHeader::New),
90            MessageFormat::SameSource => self.read_same_source(cx).map_ok(MessageHeader::SameSource),
91            MessageFormat::TimerChange => self.read_timer_change(cx).map_ok(MessageHeader::TimerChange),
92            MessageFormat::Continue => Poll::Ready(Ok(MessageHeader::Continue))
93        }
94    }
95}
96
97/// Reads a message header from streams.
98///
99/// # Examples
100///
101/// ```rust
102/// use std::{
103///     io::Result as IOResult,
104///     pin::pin,
105///     time::Duration
106/// };
107/// use rand::random;
108/// use sheave_core::{
109///     messages::headers::{
110///         MessageHeader,
111///         MessageFormat::*,
112///         MessageType
113///     },
114///     readers::read_message_header
115/// };
116///
117/// #[tokio::main]
118/// async fn main() -> IOResult<()> {
119///     // In case of 11 bytes.
120///     let mut reader: [u8; 11] = [0; 11];
121///     let timestamp = random::<u32>() << 8 >> 8;
122///     let message_length = random::<u32>() << 8 >> 8;
123///     let message_type = random::<u8>();
124///     let message_id = random::<u32>();
125///     reader[..3].copy_from_slice(&timestamp.to_be_bytes()[1..]);
126///     reader[3..6].copy_from_slice(&message_length.to_be_bytes()[1..]);
127///     reader[6] = message_type;
128///     reader[7..].copy_from_slice(&message_id.to_le_bytes());
129///     let result = read_message_header(pin!(reader.as_slice()), New).await?;
130///     assert_eq!(Duration::from_millis(timestamp as u64), result.get_timestamp().unwrap());
131///     assert_eq!(message_length, result.get_message_length().unwrap());
132///     assert_eq!(MessageType::from(message_type), result.get_message_type().unwrap());
133///     assert_eq!(message_id, result.get_message_id().unwrap());
134///
135///     // In case of 7 bytes.
136///     let mut reader: [u8; 7] = [0; 7];
137///     let timestamp = random::<u32>() << 8 >> 8;
138///     let message_length = random::<u32>() << 8 >> 8;
139///     let message_type = random::<u8>();
140///     reader[..3].copy_from_slice(&timestamp.to_be_bytes()[1..]);
141///     reader[3..6].copy_from_slice(&message_length.to_be_bytes()[1..]);
142///     reader[6] = message_type;
143///     let result = read_message_header(pin!(reader.as_slice()), SameSource).await?;
144///     assert_eq!(Duration::from_millis(timestamp as u64), result.get_timestamp().unwrap());
145///     assert_eq!(message_length, result.get_message_length().unwrap());
146///     assert_eq!(MessageType::from(message_type), result.get_message_type().unwrap());
147///
148///     // In case of 3 bytes.
149///     let mut reader: [u8; 3] = [0; 3];
150///     let timestamp = random::<u32>() << 8 >> 8;
151///     reader.copy_from_slice(&timestamp.to_be_bytes()[1..]);
152///     let result = read_message_header(pin!(reader.as_slice()), TimerChange).await?;
153///     assert_eq!(Duration::from_millis(timestamp as u64), result.get_timestamp().unwrap());
154///
155///     // In case of 0 bytes. (Continue)
156///     let mut reader: [u8; 0] = [0; 0];
157///     let result = read_message_header(pin!(reader.as_slice()), Continue).await?;
158///     assert!(result.get_timestamp().is_none());
159///     Ok(())
160/// }
161/// ```
162///
163/// If message format is 3 (Continue), you are unnecessary to read message header because it has no data.
164pub fn read_message_header<R: AsyncRead>(reader: Pin<&mut R>, message_format: MessageFormat) -> MessageHeaderReader<'_, R> {
165    MessageHeaderReader { reader, message_format }
166}
167
168#[cfg(test)]
169mod tests {
170    use std::pin::pin;
171    use rand::random;
172    use crate::messages::headers::MessageFormat::*;
173    use super::*;
174
175    #[tokio::test]
176    async fn read_new() {
177        let mut reader: [u8; 11] = [0; 11];
178        let timestamp = random::<u32>() << 8 >> 8;
179        let message_length = random::<u32>() << 8 >> 8;
180        let message_type = random::<u8>();
181        let message_id = random::<u32>();
182        reader[..3].copy_from_slice(&timestamp.to_be_bytes()[1..]);
183        reader[3..6].copy_from_slice(&message_length.to_be_bytes()[1..]);
184        reader[6] = message_type;
185        reader[7..].copy_from_slice(&message_id.to_le_bytes());
186        let result = read_message_header(pin!(reader.as_slice()), New).await;
187        assert!(result.is_ok());
188        let message_header = result.unwrap();
189        assert!(message_header.get_timestamp().is_some());
190        assert!(message_header.get_message_length().is_some());
191        assert!(message_header.get_message_type().is_some());
192        assert!(message_header.get_message_id().is_some());
193        assert_eq!(Duration::from_millis(timestamp as u64), message_header.get_timestamp().unwrap());
194        assert_eq!(message_length, message_header.get_message_length().unwrap());
195        assert_eq!(MessageType::from(message_type), message_header.get_message_type().unwrap());
196        assert_eq!(message_id, message_header.get_message_id().unwrap())
197    }
198
199    #[tokio::test]
200    async fn read_same_source() {
201        let mut reader: [u8; 7] = [0; 7];
202        let timestamp = random::<u32>() << 8 >> 8;
203        let message_length = random::<u32>() << 8 >> 8;
204        let message_type = random::<u8>();
205        reader[..3].copy_from_slice(&timestamp.to_be_bytes()[1..]);
206        reader[3..6].copy_from_slice(&message_length.to_be_bytes()[1..]);
207        reader[6] = message_type;
208        let result = read_message_header(pin!(reader.as_slice()), SameSource).await;
209        assert!(result.is_ok());
210        let message_header = result.unwrap();
211        assert!(message_header.get_timestamp().is_some());
212        assert!(message_header.get_message_length().is_some());
213        assert!(message_header.get_message_type().is_some());
214        assert_eq!(Duration::from_millis(timestamp as u64), message_header.get_timestamp().unwrap());
215        assert_eq!(message_length, message_header.get_message_length().unwrap());
216        assert_eq!(MessageType::from(message_type), message_header.get_message_type().unwrap())
217    }
218
219    #[tokio::test]
220    async fn read_timer_change() {
221        let mut reader: [u8; 3] = [0; 3];
222        let timestamp = random::<u32>() << 8 >> 8;
223        reader.copy_from_slice(&timestamp.to_be_bytes()[1..]);
224        let result = read_message_header(pin!(reader.as_slice()), TimerChange).await;
225        assert!(result.is_ok());
226        let message_header = result.unwrap();
227        assert!(message_header.get_timestamp().is_some());
228        assert_eq!(Duration::from_millis(timestamp as u64), message_header.get_timestamp().unwrap())
229    }
230
231    #[tokio::test]
232    async fn read_continue() {
233        let reader: [u8; 0] = [0; 0];
234        let result = read_message_header(pin!(reader.as_slice()), Continue).await.unwrap();
235        assert!(result.get_timestamp().is_none())
236    }
237}