1use std::{
2 future::Future,
3 io::Result as IOResult,
4 pin::Pin,
5 task::{
6 Context as FutureContext,
7 Poll
8 },
9 time::Duration
10};
11use futures::ready;
12use tokio::io::AsyncWrite;
13use crate::{
14 U24_MAX,
15 messages::headers::{
16 MessageHeader,
17 MessageType
18 }
19};
20
21#[doc(hidden)]
22#[derive(Debug)]
23pub struct MessageHeaderWriter<'a, W: AsyncWrite> {
24 writer: Pin<&'a mut W>,
25 message_header: &'a MessageHeader
26}
27
28#[doc(hidden)]
29impl<W: AsyncWrite> MessageHeaderWriter<'_, W> {
30 fn write_timestamp(&mut self, cx: &mut FutureContext<'_>, timestamp: Duration) -> Poll<IOResult<()>> {
31 assert!(timestamp.as_millis() <= U24_MAX as u128);
32 self.writer.as_mut().poll_write(cx, &(timestamp.as_millis() as u32).to_be_bytes()[1..]).map_ok(|_| ())
33 }
34
35 fn write_message_length(&mut self, cx: &mut FutureContext<'_>, message_length: u32) -> Poll<IOResult<()>> {
36 assert!(message_length <= U24_MAX);
37 self.writer.as_mut().poll_write(cx, &message_length.to_be_bytes()[1..]).map_ok(|_| ())
38 }
39
40 fn write_message_type(&mut self, cx: &mut FutureContext<'_>, message_type: MessageType) -> Poll<IOResult<()>> {
41 self.writer.as_mut().poll_write(cx, &u8::from(message_type).to_be_bytes()).map_ok(|_| ())
42 }
43
44 fn write_message_id(&mut self, cx: &mut FutureContext<'_>, message_id: u32) -> Poll<IOResult<()>> {
45 self.writer.as_mut().poll_write(cx, &message_id.to_le_bytes()).map_ok(|_| ())
46 }
47
48 fn write_new(&mut self, cx: &mut FutureContext<'_>, (timestamp, message_length, message_type, message_id): (Duration, u32, MessageType, u32)) -> Poll<IOResult<()>> {
49 ready!(self.write_timestamp(cx, timestamp))?;
50 ready!(self.write_message_length(cx, message_length))?;
51 ready!(self.write_message_type(cx, message_type))?;
52 ready!(self.write_message_id(cx, message_id))?;
53 Poll::Ready(Ok(()))
54 }
55
56 fn write_same_source(&mut self, cx: &mut FutureContext<'_>, (timestamp, message_length, message_type): (Duration, u32, MessageType)) -> Poll<IOResult<()>> {
57 ready!(self.write_timestamp(cx, timestamp))?;
58 ready!(self.write_message_length(cx, message_length))?;
59 ready!(self.write_message_type(cx, message_type))?;
60 Poll::Ready(Ok(()))
61 }
62
63 fn write_timer_change(&mut self, cx: &mut FutureContext<'_>, timestamp: Duration) -> Poll<IOResult<()>> {
64 ready!(self.write_timestamp(cx, timestamp))?;
65 Poll::Ready(Ok(()))
66 }
67
68 fn write_continue(&mut self, _cx: &mut FutureContext<'_>) -> Poll<IOResult<()>> {
69 Poll::Ready(Ok(()))
70 }
71}
72
73#[doc(hidden)]
74impl<W: AsyncWrite> Future for MessageHeaderWriter<'_, W> {
75 type Output = IOResult<()>;
76
77 fn poll(mut self: Pin<&mut Self>, cx: &mut FutureContext<'_>) -> Poll<Self::Output> {
78 let fields: (Option<Duration>, Option<u32>, Option<MessageType>, Option<u32>) = (*self.message_header).into();
79
80 if fields.3.is_some() {
81 self.write_new(cx, (fields.0.unwrap(), fields.1.unwrap(), fields.2.unwrap(), fields.3.unwrap()))
82 } else if fields.2.is_some() && fields.1.is_some() {
83 self.write_same_source(cx, (fields.0.unwrap(), fields.1.unwrap(), fields.2.unwrap()))
84 } else if fields.0.is_some() {
85 self.write_timer_change(cx, fields.0.unwrap())
86 } else {
87 self.write_continue(cx)
88 }
89 }
90}
91
92pub fn write_message_header<'a, W: AsyncWrite>(writer: Pin<&'a mut W>, message_header: &'a MessageHeader) -> MessageHeaderWriter<'a, W> {
182 MessageHeaderWriter { writer, message_header }
183}
184
185#[cfg(test)]
186mod tests {
187 use std::{
188 cmp::min,
189 pin::pin
190 };
191 use rand::random;
192 use super::*;
193
194 #[tokio::test]
195 async fn write_new() {
196 let mut writer: Pin<&mut Vec<u8>> = pin!(Vec::new());
197 let timestamp = Duration::from_millis(min(U24_MAX, random::<u32>()) as u64);
198 let message_length = min(U24_MAX, random::<u32>());
199 let message_type = random::<u8>();
200 let message_id = random::<u32>();
201 let message_header: MessageHeader = (timestamp, message_length, message_type.into(), message_id).into();
202 let result = write_message_header(writer.as_mut(), &message_header).await;
203 assert!(result.is_ok());
204 let mut written: [u8; 4] = [0; 4];
205 written[1..].copy_from_slice(&writer[..3]);
206 let timestamp = Duration::from_millis(u32::from_be_bytes(written) as u64);
207 assert_eq!(timestamp, message_header.get_timestamp().unwrap());
208 let mut written: [u8; 4] = [0; 4];
209 written[1..].copy_from_slice(&writer[3..6]);
210 let message_length = u32::from_be_bytes(written);
211 assert_eq!(message_length, message_header.get_message_length().unwrap());
212 let message_type = writer[6];
213 assert_eq!(MessageType::from(message_type), message_header.get_message_type().unwrap());
214 let mut written: [u8; 4] = [0; 4];
215 written.copy_from_slice(&writer[7..]);
216 let message_id = u32::from_le_bytes(written);
217 assert_eq!(message_id, message_header.get_message_id().unwrap())
218 }
219
220 #[tokio::test]
221 async fn write_same_source() {
222 let mut writer: Pin<&mut Vec<u8>> = pin!(Vec::new());
223 let timestamp = Duration::from_millis(min(U24_MAX, random::<u32>()) as u64);
224 let message_length = min(U24_MAX, random::<u32>());
225 let message_type = random::<u8>();
226 let message_header: MessageHeader = (timestamp, message_length, message_type.into()).into();
227 let result = write_message_header(writer.as_mut(), &message_header).await;
228 assert!(result.is_ok());
229 let mut written: [u8; 4] = [0; 4];
230 written[1..].copy_from_slice(&writer[..3]);
231 let timestamp = Duration::from_millis(u32::from_be_bytes(written) as u64);
232 assert_eq!(timestamp, message_header.get_timestamp().unwrap());
233 let mut written: [u8; 4] = [0; 4];
234 written[1..].copy_from_slice(&writer[3..6]);
235 let message_length = u32::from_be_bytes(written);
236 assert_eq!(message_length, message_header.get_message_length().unwrap());
237 let message_type = writer[6];
238 assert_eq!(MessageType::from(message_type), message_header.get_message_type().unwrap())
239 }
240
241 #[tokio::test]
242 async fn write_timer_change() {
243 let mut writer: Pin<&mut Vec<u8>> = pin!(Vec::new());
244 let timestamp = Duration::from_millis(min(U24_MAX, random::<u32>()) as u64);
245 let message_header: MessageHeader = timestamp.into();
246 let result = write_message_header(writer.as_mut(), &message_header).await;
247 assert!(result.is_ok());
248 let mut written: [u8; 4] = [0; 4];
249 written[1..].copy_from_slice(&writer[..3]);
250 let timestamp = Duration::from_millis(u32::from_be_bytes(written) as u64);
251 assert_eq!(timestamp, message_header.get_timestamp().unwrap())
252 }
253
254 #[tokio::test]
255 async fn write_continue() {
256 let mut writer: Pin<&mut Vec<u8>> = pin!(Vec::new());
257 let message_header: MessageHeader = ().into();
258 write_message_header(writer.as_mut(), &message_header).await.unwrap();
259 assert!(writer.is_empty())
260 }
261}