reach_websocket/server/
mod.rs

1// SPDX-FileCopyrightText: 2025 Michael Goldenberg <m@mgoldenberg.net>
2// SPDX-FileCopyrightText: 2025 eaon <eaon@posteo.net>
3// SPDX-License-Identifier: EUPL-1.2
4
5use std::{
6    marker::PhantomData,
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10};
11
12use futures::{SinkExt, Stream, StreamExt, channel::mpsc, stream::FuturesUnordered};
13use pin_project::pin_project;
14use tokio::{
15    io::{AsyncRead, AsyncWrite},
16    sync::RwLock,
17    task::{JoinHandle, JoinSet},
18};
19use tokio_tungstenite::{
20    WebSocketStream,
21    tungstenite::{Error as TungsteniteError, protocol::Role},
22};
23
24use reach_core::{
25    communication::*,
26    error::{DecodeError, GenericWebSocketError},
27};
28
29use crate::*;
30
31#[cfg(feature = "local-server")]
32mod local;
33#[cfg(feature = "local-server")]
34pub use local::*;
35
36pub mod macros;
37
38#[cfg(feature = "remote-server")]
39mod remote;
40#[cfg(feature = "remote-server")]
41pub use remote::*;
42
43type RawItemSender<I, O> = mpsc::Sender<RawWebSocketItem<I, O>>;
44
45#[derive(Debug)]
46pub struct ServerOptions<Transport, Incoming, Outgoing, Extensions> {
47    pub web_socket: WebSocketChannel<Transport, Incoming, Outgoing>,
48    pub extensions: Extensions,
49}
50
51impl<Transport, Incoming, Outgoing, Extensions>
52    From<(WebSocketChannel<Transport, Incoming, Outgoing>, Extensions)>
53    for ServerOptions<Transport, Incoming, Outgoing, Extensions>
54{
55    fn from(
56        (web_socket, extensions): (WebSocketChannel<Transport, Incoming, Outgoing>, Extensions),
57    ) -> Self {
58        Self {
59            web_socket,
60            extensions,
61        }
62    }
63}
64
65#[derive(Debug)]
66#[pin_project]
67pub struct ServerContext<Transport, Incoming, Outgoing, Tag, Extensions> {
68    #[pin]
69    web_socket: WebSocketChannel<Transport, Incoming, Outgoing>,
70    incoming_sink: RawItemSender<Incoming, Outgoing>,
71    #[pin]
72    outgoing_stream: FuturesUnordered<WebSocketItemReceiver<Outgoing, Tag>>,
73    extensions: Extensions,
74    hold: Option<WebSocketStreamResult<Next<Incoming, Outgoing, ()>>>,
75}
76
77impl<Transport, Incoming, Outgoing, Tag, Extensions> Stream
78    for ServerContext<Transport, Incoming, Outgoing, Tag, Extensions>
79where
80    Transport: AsyncRead + AsyncWrite + Unpin,
81    Incoming: CommunicableType,
82    Outgoing: CommunicableType + From<GenericWebSocketError>,
83    Tag: IntoAugmentation + Default,
84    WebSocketItemReceiver<Outgoing, Tag>: Unpin,
85{
86    type Item = WebSocketStreamResult<Next<Incoming, Outgoing, ()>>;
87
88    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89        let mut this = self.project();
90        if this.hold.is_some() {
91            Poll::Ready(this.hold.take())
92        } else {
93            let incoming = this.web_socket.poll_next_unpin(cx);
94            let outgoing = this.outgoing_stream.poll_next_unpin(cx);
95            match (incoming, outgoing) {
96                (Poll::Ready(Some(incoming)), Poll::Ready(Some(outgoing))) => {
97                    let hold = Ok(Next::Outgoing(WebSocketItem::without_responder(outgoing)));
98                    this.hold.replace(hold);
99                    Poll::Ready(Some(incoming.map(Next::Incoming)))
100                }
101                (Poll::Ready(Some(incoming)), _) => Poll::Ready(Some(incoming.map(Next::Incoming))),
102                (_, Poll::Ready(Some(outgoing))) => Poll::Ready(Some(Ok(Next::Outgoing(
103                    WebSocketItem::without_responder(outgoing),
104                )))),
105                (Poll::Ready(None), Poll::Ready(None)) => Poll::Ready(None),
106                (_, _) => Poll::Pending,
107            }
108        }
109    }
110}
111
112impl<Transport, Incoming, Outgoing, Tag, Extensions> WebSocketErrorResponder
113    for ServerContext<Transport, Incoming, Outgoing, Tag, Extensions>
114where
115    Transport: AsyncRead + AsyncWrite + Unpin + Send,
116    Incoming: CommunicableType + Send,
117    Outgoing: CommunicableType + From<GenericWebSocketError> + Send,
118    Tag: Send,
119    Extensions: Send,
120{
121    async fn respond_with_error(
122        &mut self,
123        error: GenericWebSocketError,
124    ) -> WebSocketStreamResult<()> {
125        self.web_socket
126            .send(error.into_communication())
127            .await
128            .map_err(Into::into)
129    }
130}
131
132#[pin_project]
133pub struct WebSocketServer<W: WebSocketContext> {
134    #[pin]
135    inner: mpsc::Receiver<RawWebSocketItem<W::Incoming, W::Outgoing>>,
136    task: JoinHandle<Result<(), WebSocketError>>,
137    context: PhantomData<W>,
138}
139
140impl<W> WebSocketServer<W>
141where
142    W: WebSocketContext + Unpin + Send + 'static,
143    W::Options: IntoWebSocketContext<W, Channel = RawItemSender<W::Incoming, W::Outgoing>>,
144    W::Incoming: CommunicableType + Send + 'static,
145    W::Outgoing: CommunicableType + From<GenericWebSocketError> + Send + 'static,
146    W::Responder: Send,
147{
148    pub fn with_options<const B: usize>(options: W::Options) -> Self {
149        let (sink, stream) = mpsc::channel(B);
150        Self {
151            inner: stream,
152            task: tokio::spawn(Self::start(W::with_options(options, sink))),
153            context: PhantomData,
154        }
155    }
156
157    async fn start(mut context: W) -> Result<(), WebSocketError> {
158        while let Some(result) = context.next().await {
159            match result {
160                Ok(next) => match next {
161                    Next::Incoming(incoming) => context.handle_incoming(incoming).await?,
162                    Next::Outgoing(outgoing) => context.handle_outgoing(outgoing).await?,
163                },
164                Err(e) => {
165                    context
166                        .respond_with_error(GenericWebSocketError::from(&e))
167                        .await?
168                }
169            }
170        }
171
172        Ok(())
173    }
174}
175
176impl<W: WebSocketContext> Stream for WebSocketServer<W> {
177    type Item = RawWebSocketItem<W::Incoming, W::Outgoing>;
178
179    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
180        self.project().inner.poll_next(cx)
181    }
182}
183
184async fn handle_connection<const B: usize, S, E, D, C, W>(
185    web_socket_stream: WebSocketStream<S>,
186    server_context_extensions: E,
187    global_context: Arc<D::GlobalContext>,
188    session_context_init: impl FnOnce() -> Arc<RwLock<D::SessionContext>> + Clone,
189) -> Result<(), TungsteniteError>
190where
191    S: AsyncRead + AsyncWrite + Unpin,
192    D: RequestDelegator<C> + 'static,
193    D::GlobalContext: Sync + Send,
194    D::SessionContext: Sync + Send,
195    C: CommunicableTypes<Req = W::Incoming, Resp = W::Outgoing> + 'static,
196    C::Resp: From<GenericWebSocketError>,
197    W: WebSocketContext + Unpin + Send + Sized + 'static,
198    W::Incoming: Send,
199    W::Outgoing: Send,
200    W::Responder: Send,
201    W::Options: IntoWebSocketContext<W, Channel = RawItemSender<W::Incoming, W::Outgoing>>,
202    W::Options: From<(WebSocketChannel<S, W::Incoming, W::Outgoing>, E)>,
203{
204    let session_context = session_context_init();
205
206    let web_socket_channel = WebSocketChannel::with_stream(web_socket_stream);
207    let mut connection = WebSocketServer::<W>::with_options::<B>(
208        (web_socket_channel, server_context_extensions).into(),
209    );
210    let mut tasks = JoinSet::new();
211    while let Some(item) = connection.next().await {
212        let _ = tasks.spawn(D::delegate(
213            item,
214            global_context.clone(),
215            session_context.clone(),
216        ));
217    }
218    let _ = tasks.join_all().await;
219
220    Ok(())
221}
222
223#[derive(Debug, thiserror::Error)]
224pub enum RespondToItemError {
225    #[error("decode: {0}")]
226    Decode(#[from] DecodeError),
227    #[error("reponse failure")]
228    ResponseFailure,
229    #[error("unsupported")]
230    Unsupported,
231}
232
233pub trait RequestHandler<C>
234where
235    C: CommunicableTypes,
236    C::Resp: ErrorSubset,
237    Self: MatchResponse<C::Req>,
238{
239    type GlobalContext;
240    type SessionContext;
241
242    fn handle(
243        request: Self,
244        global_context: Arc<Self::GlobalContext>,
245        session_context: Arc<RwLock<Self::SessionContext>>,
246    ) -> impl Future<
247        Output = Result<
248            <Self as MatchResponse<C::Req>>::Resp,
249            impl Into<<C::Resp as ErrorSubset>::Error>,
250        >,
251    > + Send;
252}
253
254pub trait RequestDelegator<C>
255where
256    C: CommunicableTypes,
257    C::Resp: From<GenericWebSocketError>,
258{
259    type GlobalContext;
260    type SessionContext;
261
262    fn delegate(
263        item: RawWebSocketItem<C::Req, C::Resp>,
264        global_context: Arc<Self::GlobalContext>,
265        session_context: Arc<RwLock<Self::SessionContext>>,
266    ) -> impl Future<Output = Result<(), RespondToItemError>> + Send;
267}
268
269pub async fn respond_to_item<C, Request>(
270    item: RawWebSocketItem<C::Req, C::Resp>,
271    global_context: Arc<Request::GlobalContext>,
272    session_context: Arc<RwLock<Request::SessionContext>>,
273) -> Result<(), RespondToItemError>
274where
275    C: CommunicableTypes,
276    C::Resp: From<GenericWebSocketError> + ErrorSubset,
277    Request: Communicable<C::Req> + MatchResponse<C::Req> + RequestHandler<C>,
278    Request::Resp: Communicable<C::Resp>,
279{
280    let (request, responder) = item.decode::<Request>()?.split();
281
282    match Request::handle(request, global_context, session_context).await {
283        Ok(response) => responder
284            .send(response)
285            .map_err(|_| RespondToItemError::ResponseFailure),
286        Err(error) => responder
287            .send_error(error.into())
288            .map_err(|_| RespondToItemError::ResponseFailure),
289    }
290}