sheave_core/handlers/
stream_wrapper.rs

1use std::{
2    io::Result as IOResult,
3    pin::Pin,
4    sync::Arc,
5    task::{
6        Context,
7        Poll
8    }
9};
10use futures::ready;
11use tokio::io::{
12    AsyncRead,
13    AsyncWrite,
14    ReadBuf
15};
16use super::MeasureAcknowledgement;
17
18/// The wrapper for stream types.
19#[derive(Debug)]
20pub struct StreamWrapper<RW: Unpin> {
21    stream: RW,
22    is_measured: bool,
23    current_amount: u32
24}
25
26impl<RW: Unpin> StreamWrapper<RW> {
27    /// Constructs a wrapped stream.
28    pub fn new(stream: RW) -> Self {
29        Self {
30            stream,
31            is_measured: bool::default(),
32            current_amount: u32::default()
33        }
34    }
35
36    /// Makes this stream into *pinned* weak pointer.
37    ///
38    /// Currently upper APIs use this wrapper via `Arc`.
39    /// Because avoids problems which every RTMP's connection steps need same stream but can't borrow mutablly across scopes.
40    /// Therefore upper APIs wrap streams into `Arc` at first, then make them able to copy as weak pointers.
41    ///
42    /// # Examples
43    ///
44    /// ```rust
45    /// use std::sync::Arc;
46    /// use sheave_core::handlers::{
47    ///     StreamWrapper,
48    ///     VecStream
49    /// };
50    ///
51    /// Arc::new(StreamWrapper::new(VecStream::default())).make_weak_pin();
52    /// ```
53    pub fn make_weak_pin<'a>(self: &'a Arc<Self>) -> Pin<&'a mut Self> {
54        unsafe { Pin::new(&mut *(Arc::downgrade(self).as_ptr() as *mut Self)) }
55    }
56}
57
58impl<RW: Unpin> MeasureAcknowledgement for StreamWrapper<RW> {
59    fn begin_measuring(&mut self) {
60        self.current_amount = u32::default();
61        self.is_measured = true;
62    }
63
64    fn finish_measuring(&mut self) {
65        self.current_amount = u32::default();
66        self.is_measured = false;
67    }
68
69    fn add_amount(&mut self, amount: u32) {
70        self.current_amount += amount;
71    }
72
73    fn get_current_amount(&mut self) -> u32 {
74        self.current_amount
75    }
76}
77
78impl<R: AsyncRead + Unpin> AsyncRead for StreamWrapper<R> {
79    /// Wraps a stream to make it able to measure the amount of bytes.
80    ///
81    /// When bytes read exceeded some bandwidth limit, RTMP peers are required to send the `Acknowldgement` message to the other peer.
82    /// But prepared stream like Vec, slice, or TCP streams has no implementation above.
83    /// Therefore, StreamWrapper measures amounts of bytes read and writes `Acknowledgement` messages instead.
84    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<IOResult<()>> {
85        ready!(Pin::new(&mut self.stream).poll_read(cx, buf))?;
86
87        if self.is_measured {
88            self.add_amount(buf.filled().len() as u32);
89        }
90
91        Poll::Ready(Ok(()))
92    }
93}
94
95impl<W: AsyncWrite + Unpin> AsyncWrite for StreamWrapper<W> {
96    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IOResult<usize>> {
97        Pin::new(&mut self.stream).poll_write(cx, buf)
98    }
99
100    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IOResult<()>> {
101        Pin::new(&mut self.stream).poll_flush(cx)
102    }
103
104    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IOResult<()>> {
105        Pin::new(&mut self.stream).poll_shutdown(cx)
106    }
107}