reach_websocket/
channel.rs1use futures::{Sink, Stream};
5use pin_project::pin_project;
6use std::marker::PhantomData;
7use tokio::{
8 io::{AsyncRead, AsyncWrite},
9 net::{TcpStream, UnixStream},
10};
11use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message};
12
13use reach_core::communication::{CommunicableType, Communication};
14
15use crate::{WebSocketStreamError, WebSocketStreamResult};
16
17#[derive(Debug)]
18#[pin_project]
19pub struct WebSocketChannel<Inner, Incoming, Outgoing> {
20 #[pin]
21 inner: tokio_tungstenite::WebSocketStream<Inner>,
22 phantom: PhantomData<(Incoming, Outgoing)>,
23}
24
25pub type LocalWebSocketChannel<Incoming, Outgoing> =
26 WebSocketChannel<UnixStream, Incoming, Outgoing>;
27
28pub type RemoteWebSocketChannel<Incoming, Outgoing> =
29 WebSocketChannel<TcpStream, Incoming, Outgoing>;
30
31impl<Inner, Incoming, Outgoing> WebSocketChannel<Inner, Incoming, Outgoing> {
32 pub fn with_stream(stream: tokio_tungstenite::WebSocketStream<Inner>) -> Self {
33 Self {
34 inner: stream,
35 phantom: PhantomData,
36 }
37 }
38}
39
40impl<Inner, Incoming, Outgoing> Stream for WebSocketChannel<Inner, Incoming, Outgoing>
41where
42 Inner: AsyncRead + AsyncWrite + Unpin,
43 Incoming: CommunicableType,
44 Outgoing: CommunicableType,
45{
46 type Item = WebSocketStreamResult<Communication<Incoming>>;
47
48 fn poll_next(
49 self: std::pin::Pin<&mut Self>,
50 cx: &mut std::task::Context<'_>,
51 ) -> std::task::Poll<Option<Self::Item>> {
52 self.project().inner.poll_next(cx).map(|option| {
53 option.map(|result| match result {
54 Ok(Message::Binary(bytes)) => {
55 Communication::decode(bytes.as_ref()).map_err(Into::into)
56 }
57 Ok(_) => Err(WebSocketStreamError::UnsupportedWebSocketMessage),
58 Err(e) => Err(e.into()),
59 })
60 })
61 }
62}
63
64impl<Inner, Incoming, Outgoing> Sink<Communication<Outgoing>>
65 for WebSocketChannel<Inner, Incoming, Outgoing>
66where
67 Inner: AsyncRead + AsyncWrite + Unpin,
68 Incoming: CommunicableType,
69 Outgoing: CommunicableType,
70{
71 type Error = TungsteniteError;
72
73 fn poll_ready(
74 self: std::pin::Pin<&mut Self>,
75 cx: &mut std::task::Context<'_>,
76 ) -> std::task::Poll<Result<(), Self::Error>> {
77 self.project().inner.poll_ready(cx)
78 }
79
80 fn start_send(
81 self: std::pin::Pin<&mut Self>,
82 item: Communication<Outgoing>,
83 ) -> Result<(), Self::Error> {
84 self.project()
85 .inner
86 .start_send(Message::Binary(item.encode().into()))
87 }
88
89 fn poll_flush(
90 self: std::pin::Pin<&mut Self>,
91 cx: &mut std::task::Context<'_>,
92 ) -> std::task::Poll<Result<(), Self::Error>> {
93 self.project().inner.poll_flush(cx)
94 }
95
96 fn poll_close(
97 self: std::pin::Pin<&mut Self>,
98 cx: &mut std::task::Context<'_>,
99 ) -> std::task::Poll<Result<(), Self::Error>> {
100 self.project().inner.poll_close(cx)
101 }
102}