sheave_core/readers/
message_header.rs

1use std::{
2    future::Future,
3    io::Result as IOResult,
4    pin::Pin,
5    task::{
6        Context as FutureContext,
7        Poll
8    },
9    time::Duration
10};
11use futures::ready;
12use tokio::io::{
13    AsyncRead,
14    ReadBuf
15};
16use crate::messages::headers::{
17    MessageHeader,
18    MessageFormat,
19    MessageType
20};
21
22#[doc(hidden)]
23#[derive(Debug)]
24pub struct MessageHeaderReader<'a, R: AsyncRead> {
25    reader: Pin<&'a mut R>,
26    message_format: MessageFormat
27}
28
29#[doc(hidden)]
30impl<R: AsyncRead> MessageHeaderReader<'_, R> {
31    fn read_timestamp(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<Duration>> {
32        let mut timestamp_bytes: [u8; 4] = [0; 4];
33        let mut buf = ReadBuf::new(&mut timestamp_bytes[1..]);
34        ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
35        Poll::Ready(Ok(Duration::from_millis(u32::from_be_bytes(timestamp_bytes) as u64)))
36    }
37
38    fn read_message_length(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<u32>> {
39        let mut message_length_bytes: [u8; 4] = [0; 4];
40        let mut buf = ReadBuf::new(&mut message_length_bytes[1..]);
41        ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
42        Poll::Ready(Ok(u32::from_be_bytes(message_length_bytes)))
43    }
44
45    fn read_message_type(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<MessageType>> {
46        let mut message_type_byte: [u8; 1] = [0; 1];
47        let mut buf = ReadBuf::new(&mut message_type_byte);
48        ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
49        Poll::Ready(Ok(u8::from_be_bytes(message_type_byte).into()))
50    }
51
52    fn read_message_id(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<u32>> {
53        let mut message_id_bytes: [u8; 4] = [0; 4];
54        let mut buf = ReadBuf::new(&mut message_id_bytes);
55        ready!(self.reader.as_mut().poll_read(cx, &mut buf))?;
56        Poll::Ready(Ok(u32::from_le_bytes(message_id_bytes)))
57    }
58
59    fn read_new(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<MessageHeader>> {
60        let timestamp = ready!(self.read_timestamp(cx))?;
61        let message_length = ready!(self.read_message_length(cx))?;
62        let message_type = ready!(self.read_message_type(cx))?;
63        let message_id = ready!(self.read_message_id(cx))?;
64        Poll::Ready(Ok((timestamp, message_length, message_type, message_id).into()))
65    }
66
67    fn read_same_source(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<MessageHeader>> {
68        let timestamp = ready!(self.read_timestamp(cx))?;
69        let message_length = ready!(self.read_message_length(cx))?;
70        let message_type = ready!(self.read_message_type(cx))?;
71        Poll::Ready(Ok((timestamp, message_length, message_type).into()))
72    }
73
74    fn read_timer_change(&mut self, cx: &mut FutureContext<'_>) -> Poll<IOResult<MessageHeader>> {
75        let timestamp = ready!(self.read_timestamp(cx))?;
76        Poll::Ready(Ok(timestamp.into()))
77    }
78
79    fn read_continue(&mut self, _cx: &mut FutureContext<'_>) -> Poll<IOResult<MessageHeader>> {
80        Poll::Ready(Ok(().into()))
81    }
82}
83
84#[doc(hidden)]
85impl<R: AsyncRead> Future for MessageHeaderReader<'_, R> {
86    type Output = IOResult<MessageHeader>;
87
88    fn poll(mut self: Pin<&mut Self>, cx: &mut FutureContext<'_>) -> Poll<Self::Output> {
89        match self.message_format {
90            MessageFormat::New => self.read_new(cx),
91            MessageFormat::SameSource => self.read_same_source(cx),
92            MessageFormat::TimerChange => self.read_timer_change(cx),
93            MessageFormat::Continue => self.read_continue(cx)
94        }
95    }
96}
97
98/// Reads a message header from streams.
99///
100/// # Examples
101///
102/// ```rust
103/// use std::{
104///     io::Result as IOResult,
105///     pin::pin,
106///     time::Duration
107/// };
108/// use rand::random;
109/// use sheave_core::{
110///     messages::headers::{
111///         MessageHeader,
112///         MessageFormat::*,
113///         MessageType
114///     },
115///     readers::read_message_header
116/// };
117///
118/// #[tokio::main]
119/// async fn main() -> IOResult<()> {
120///     // In case of 11 bytes.
121///     let mut reader: [u8; 11] = [0; 11];
122///     let timestamp = random::<u32>() << 8 >> 8;
123///     let message_length = random::<u32>() << 8 >> 8;
124///     let message_type = random::<u8>();
125///     let message_id = random::<u32>();
126///     reader[..3].copy_from_slice(&timestamp.to_be_bytes()[1..]);
127///     reader[3..6].copy_from_slice(&message_length.to_be_bytes()[1..]);
128///     reader[6] = message_type;
129///     reader[7..].copy_from_slice(&message_id.to_le_bytes());
130///     let result = read_message_header(pin!(reader.as_slice()), New).await?;
131///     assert_eq!(Duration::from_millis(timestamp as u64), result.get_timestamp().unwrap());
132///     assert_eq!(message_length, result.get_message_length().unwrap());
133///     assert_eq!(MessageType::from(message_type), result.get_message_type().unwrap());
134///     assert_eq!(message_id, result.get_message_id().unwrap());
135///
136///     // In case of 7 bytes.
137///     let mut reader: [u8; 7] = [0; 7];
138///     let timestamp = random::<u32>() << 8 >> 8;
139///     let message_length = random::<u32>() << 8 >> 8;
140///     let message_type = random::<u8>();
141///     reader[..3].copy_from_slice(&timestamp.to_be_bytes()[1..]);
142///     reader[3..6].copy_from_slice(&message_length.to_be_bytes()[1..]);
143///     reader[6] = message_type;
144///     let result = read_message_header(pin!(reader.as_slice()), SameSource).await?;
145///     assert_eq!(Duration::from_millis(timestamp as u64), result.get_timestamp().unwrap());
146///     assert_eq!(message_length, result.get_message_length().unwrap());
147///     assert_eq!(MessageType::from(message_type), result.get_message_type().unwrap());
148///
149///     // In case of 3 bytes.
150///     let mut reader: [u8; 3] = [0; 3];
151///     let timestamp = random::<u32>() << 8 >> 8;
152///     reader.copy_from_slice(&timestamp.to_be_bytes()[1..]);
153///     let result = read_message_header(pin!(reader.as_slice()), TimerChange).await?;
154///     assert_eq!(Duration::from_millis(timestamp as u64), result.get_timestamp().unwrap());
155///
156///     // In case of 0 bytes. (Continue)
157///     let mut reader: [u8; 0] = [0; 0];
158///     let result = read_message_header(pin!(reader.as_slice()), Continue).await?;
159///     assert!(result.get_timestamp().is_none());
160///     Ok(())
161/// }
162/// ```
163pub fn read_message_header<R: AsyncRead>(reader: Pin<&mut R>, message_format: MessageFormat) -> MessageHeaderReader<'_, R> {
164    MessageHeaderReader { reader, message_format }
165}
166
167#[cfg(test)]
168mod tests {
169    use std::pin::pin;
170    use rand::random;
171    use crate::messages::headers::MessageFormat::*;
172    use super::*;
173
174    #[tokio::test]
175    async fn read_new() {
176        let mut reader: [u8; 11] = [0; 11];
177        let timestamp = random::<u32>() << 8 >> 8;
178        let message_length = random::<u32>() << 8 >> 8;
179        let message_type = random::<u8>();
180        let message_id = random::<u32>();
181        reader[..3].copy_from_slice(&timestamp.to_be_bytes()[1..]);
182        reader[3..6].copy_from_slice(&message_length.to_be_bytes()[1..]);
183        reader[6] = message_type;
184        reader[7..].copy_from_slice(&message_id.to_le_bytes());
185        let result = read_message_header(pin!(reader.as_slice()), New).await;
186        assert!(result.is_ok());
187        let message_header = result.unwrap();
188        assert!(message_header.get_timestamp().is_some());
189        assert!(message_header.get_message_length().is_some());
190        assert!(message_header.get_message_type().is_some());
191        assert!(message_header.get_message_id().is_some());
192        assert_eq!(Duration::from_millis(timestamp as u64), message_header.get_timestamp().unwrap());
193        assert_eq!(message_length, message_header.get_message_length().unwrap());
194        assert_eq!(MessageType::from(message_type), message_header.get_message_type().unwrap());
195        assert_eq!(message_id, message_header.get_message_id().unwrap())
196    }
197
198    #[tokio::test]
199    async fn read_same_source() {
200        let mut reader: [u8; 7] = [0; 7];
201        let timestamp = random::<u32>() << 8 >> 8;
202        let message_length = random::<u32>() << 8 >> 8;
203        let message_type = random::<u8>();
204        reader[..3].copy_from_slice(&timestamp.to_be_bytes()[1..]);
205        reader[3..6].copy_from_slice(&message_length.to_be_bytes()[1..]);
206        reader[6] = message_type;
207        let result = read_message_header(pin!(reader.as_slice()), SameSource).await;
208        assert!(result.is_ok());
209        let message_header = result.unwrap();
210        assert!(message_header.get_timestamp().is_some());
211        assert!(message_header.get_message_length().is_some());
212        assert!(message_header.get_message_type().is_some());
213        assert_eq!(Duration::from_millis(timestamp as u64), message_header.get_timestamp().unwrap());
214        assert_eq!(message_length, message_header.get_message_length().unwrap());
215        assert_eq!(MessageType::from(message_type), message_header.get_message_type().unwrap())
216    }
217
218    #[tokio::test]
219    async fn read_timer_change() {
220        let mut reader: [u8; 3] = [0; 3];
221        let timestamp = random::<u32>() << 8 >> 8;
222        reader.copy_from_slice(&timestamp.to_be_bytes()[1..]);
223        let result = read_message_header(pin!(reader.as_slice()), TimerChange).await;
224        assert!(result.is_ok());
225        let message_header = result.unwrap();
226        assert!(message_header.get_timestamp().is_some());
227        assert_eq!(Duration::from_millis(timestamp as u64), message_header.get_timestamp().unwrap())
228    }
229
230    #[tokio::test]
231    async fn read_continue() {
232        let reader: [u8; 0] = [0; 0];
233        let result = read_message_header(pin!(reader.as_slice()), Continue).await.unwrap();
234        assert!(result.get_timestamp().is_none())
235    }
236}