reach_websocket/
channel.rs

1// SPDX-FileCopyrightText: 2025 Michael Goldenberg <m@mgoldenberg.net>
2// SPDX-License-Identifier: EUPL-1.2
3
4use futures::{Sink, Stream};
5use pin_project::pin_project;
6use std::marker::PhantomData;
7use tokio::{
8    io::{AsyncRead, AsyncWrite},
9    net::{TcpStream, UnixStream},
10};
11use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message};
12
13use reach_core::communication::{CommunicableType, Communication};
14
15use crate::{WebSocketStreamError, WebSocketStreamResult};
16
17#[derive(Debug)]
18#[pin_project]
19pub struct WebSocketChannel<Inner, Incoming, Outgoing> {
20    #[pin]
21    inner: tokio_tungstenite::WebSocketStream<Inner>,
22    phantom: PhantomData<(Incoming, Outgoing)>,
23}
24
25pub type LocalWebSocketChannel<Incoming, Outgoing> =
26    WebSocketChannel<UnixStream, Incoming, Outgoing>;
27
28pub type RemoteWebSocketChannel<Incoming, Outgoing> =
29    WebSocketChannel<TcpStream, Incoming, Outgoing>;
30
31impl<Inner, Incoming, Outgoing> WebSocketChannel<Inner, Incoming, Outgoing> {
32    pub fn with_stream(stream: tokio_tungstenite::WebSocketStream<Inner>) -> Self {
33        Self {
34            inner: stream,
35            phantom: PhantomData,
36        }
37    }
38}
39
40impl<Inner, Incoming, Outgoing> Stream for WebSocketChannel<Inner, Incoming, Outgoing>
41where
42    Inner: AsyncRead + AsyncWrite + Unpin,
43    Incoming: CommunicableType,
44    Outgoing: CommunicableType,
45{
46    type Item = WebSocketStreamResult<Communication<Incoming>>;
47
48    fn poll_next(
49        self: std::pin::Pin<&mut Self>,
50        cx: &mut std::task::Context<'_>,
51    ) -> std::task::Poll<Option<Self::Item>> {
52        self.project().inner.poll_next(cx).map(|option| {
53            option.map(|result| match result {
54                Ok(Message::Binary(bytes)) => {
55                    Communication::decode(bytes.as_ref()).map_err(Into::into)
56                }
57                Ok(_) => Err(WebSocketStreamError::UnsupportedWebSocketMessage),
58                Err(e) => Err(e.into()),
59            })
60        })
61    }
62}
63
64impl<Inner, Incoming, Outgoing> Sink<Communication<Outgoing>>
65    for WebSocketChannel<Inner, Incoming, Outgoing>
66where
67    Inner: AsyncRead + AsyncWrite + Unpin,
68    Incoming: CommunicableType,
69    Outgoing: CommunicableType,
70{
71    type Error = TungsteniteError;
72
73    fn poll_ready(
74        self: std::pin::Pin<&mut Self>,
75        cx: &mut std::task::Context<'_>,
76    ) -> std::task::Poll<Result<(), Self::Error>> {
77        self.project().inner.poll_ready(cx)
78    }
79
80    fn start_send(
81        self: std::pin::Pin<&mut Self>,
82        item: Communication<Outgoing>,
83    ) -> Result<(), Self::Error> {
84        self.project()
85            .inner
86            .start_send(Message::Binary(item.encode().into()))
87    }
88
89    fn poll_flush(
90        self: std::pin::Pin<&mut Self>,
91        cx: &mut std::task::Context<'_>,
92    ) -> std::task::Poll<Result<(), Self::Error>> {
93        self.project().inner.poll_flush(cx)
94    }
95
96    fn poll_close(
97        self: std::pin::Pin<&mut Self>,
98        cx: &mut std::task::Context<'_>,
99    ) -> std::task::Poll<Result<(), Self::Error>> {
100        self.project().inner.poll_close(cx)
101    }
102}