reach_websocket/client/
mod.rs

1// SPDX-FileCopyrightText: 2024-2025 Michael Goldenberg <m@mgoldenberg.net>
2// SPDX-FileCopyrightText: 2024-2025 eaon <eaon@posteo.net>
3// SPDX-License-Identifier: EUPL-1.2
4
5use futures::channel::{mpsc, oneshot};
6use futures::stream::{Fuse, FusedStream};
7use futures::{Sink, SinkExt, StreamExt};
8use pin_project::pin_project;
9use rand_core::CryptoRngCore;
10use thiserror::Error;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::task::JoinHandle;
13use tokio_tungstenite::{
14    WebSocketStream,
15    tungstenite::{Error as TungsteniteError, protocol::Message},
16};
17
18use reach_core::communication::{
19    Communicable, CommunicableOwned, CommunicableTypes, Communication, MatchResponse,
20};
21
22#[cfg(feature = "local-client")]
23mod local;
24#[cfg(feature = "local-client")]
25pub use local::*;
26
27#[cfg(feature = "remote-client")]
28mod remote;
29#[cfg(feature = "remote-client")]
30pub use remote::*;
31
32mod sealed {
33    use super::*;
34
35    pub struct WebSocketClientContext<C, S, R>
36    where
37        C: CommunicableTypes,
38        C::Variant: Responders<C> + WebSocketClientExtension,
39    {
40        pub request_stream: Fuse<mpsc::Receiver<WebSocketSinkItem<C>>>,
41        pub web_socket: Fuse<WebSocketStream<S>>,
42        pub extension: <C::Variant as WebSocketClientExtension>::Extension,
43        pub responders: <C::Variant as Responders<C>>::Responders,
44        pub csprng: R,
45    }
46
47    pub trait WebSocketClientHandlers<C>
48    where
49        C: CommunicableTypes + Send + 'static,
50        C::Variant: Responders<C> + WebSocketClientExtension,
51    {
52        fn handle_request<S, R>(
53            item: WebSocketSinkItem<C>,
54            context: &mut WebSocketClientContext<C, S, R>,
55        ) -> impl Future<Output = Result<(), HandleRequestError<C::Req>>> + Send
56        where
57            S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
58            R: CryptoRngCore + Send + 'static;
59
60        fn handle_response<S, R>(
61            item: Result<Message, TungsteniteError>,
62            context: &mut WebSocketClientContext<C, S, R>,
63        ) -> impl Future<Output = Result<(), HandleResponseError<C::Resp>>> + Send
64        where
65            S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
66            R: CryptoRngCore + Send + 'static;
67    }
68}
69
70use sealed::*;
71
72#[derive(Debug, Error)]
73pub enum WebSocketClientError<Req, Resp> {
74    #[error("request channel: {0}")]
75    RequestChannel(#[from] mpsc::SendError),
76    #[error("request handler: {0}")]
77    RequestHandler(#[from] HandleRequestError<Req>),
78    #[error("response handler: {0}")]
79    ResponseHandler(#[from] HandleResponseError<Resp>),
80    #[error("response channel disconnected")]
81    ResponseChannelDisconnected(#[from] oneshot::Canceled),
82    #[error("response decoder: {0}")]
83    ResponseDecoder(#[from] reach_core::error::DecodeError),
84    #[error("response error: {0}")]
85    ResponseError(Resp),
86}
87
88type WebSocketClientResult<T, Req, Resp> = Result<T, WebSocketClientError<Req, Resp>>;
89
90#[derive(Debug, Error)]
91pub enum HandleRequestError<Req> {
92    #[error("permissible communication lengths not set")]
93    MissingPermissibleCommunicationLenghts,
94    #[error("responder already exists")]
95    ResponderAlreadyExists,
96    #[error("tag already in use: {0:?}")]
97    TagAlreadyInUse(Communication<Req>),
98    #[error("response channel disconnected")]
99    ResponseChannelDisconnected,
100    #[error("web socket: {0}")]
101    WebSocket(#[from] TungsteniteError),
102}
103
104impl<Req> HandleRequestError<Req> {
105    pub fn is_fatal(&self) -> bool {
106        matches!(self, Self::WebSocket(_))
107    }
108}
109
110#[derive(Debug, Error)]
111pub enum HandleResponseError<Resp> {
112    #[error("response decoder: {0}")]
113    ResponseDecoder(#[from] reach_core::error::DecodeError),
114    #[error("web socket: {0}")]
115    WebSocket(#[from] TungsteniteError),
116    #[error("message not supported: {0:?}")]
117    MessageNotSupported(Message),
118    #[error("response not tagged: {0:?}")]
119    ResponseNotTagged(Communication<Resp>),
120    #[error("responder not found: {0:?}")]
121    ResponderNotFound(Communication<Resp>),
122    #[error("response channel disconnected")]
123    ResponseChannelDisconnected,
124}
125
126impl<Resp> HandleResponseError<Resp> {
127    pub fn is_fatal(&self) -> bool {
128        matches!(self, Self::WebSocket(_))
129    }
130}
131
132type WebSocketClientTask<Req, Resp> = JoinHandle<WebSocketClientResult<(), Req, Resp>>;
133
134#[pin_project]
135pub struct WebSocketClient<C>
136where
137    C: CommunicableTypes,
138{
139    #[pin]
140    inner: mpsc::Sender<WebSocketSinkItem<C>>,
141    task: WebSocketClientTask<C::Req, C::Resp>,
142}
143
144pub struct WebSocketClientOptions<E, S, R, const B: usize> {
145    pub web_socket: WebSocketStream<S>,
146    pub csprng: R,
147    pub extension: E,
148}
149
150impl<C> WebSocketClient<C>
151where
152    C: CommunicableTypes<Req: Send + 'static, Resp: Send + 'static> + Send + 'static,
153    C::Variant: Responders<C>,
154    C::Variant: WebSocketClientHandlers<C> + WebSocketClientExtension,
155{
156    async fn send_communication_and_wait<Req>(
157        &mut self,
158        communication: Communication<C::Req>,
159    ) -> WebSocketClientResult<Req::Resp, C::Req, C::Resp>
160    where
161        Req: MatchResponse<C::Req, Resp: Communicable<C::Resp>>,
162    {
163        let (responder, receiver) = oneshot::channel();
164        let item = WebSocketSinkItem {
165            request: communication,
166            responder: Responder { inner: responder },
167        };
168        self.inner.send(item).await?;
169
170        let response = receiver.await??;
171
172        Req::Resp::try_from_communication(&response)
173            .map_err(|_| WebSocketClientError::ResponseError(response.r#type))
174    }
175
176    pub async fn send_and_wait<Req>(
177        &mut self,
178        request: &Req,
179    ) -> WebSocketClientResult<Req::Resp, C::Req, C::Resp>
180    where
181        Req: MatchResponse<C::Req, Resp: Communicable<C::Resp>> + Communicable<C::Req>,
182    {
183        self.send_communication_and_wait::<Req>(request.to_communication())
184            .await
185    }
186
187    pub async fn send_owned_and_wait<Req>(
188        &mut self,
189        request: Req,
190    ) -> WebSocketClientResult<Req::Resp, C::Req, C::Resp>
191    where
192        Req: MatchResponse<C::Req, Resp: Communicable<C::Resp>> + CommunicableOwned<C::Req>,
193    {
194        self.send_communication_and_wait::<Req>(request.to_communication())
195            .await
196    }
197
198    async fn start<S, R>(
199        mut context: WebSocketClientContext<C, S, R>,
200    ) -> WebSocketClientResult<(), C::Req, C::Resp>
201    where
202        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
203        R: CryptoRngCore + Send + 'static,
204    {
205        loop {
206            tokio::select! {
207                item = &mut context.request_stream.next() => match item {
208                    None => {
209                        if context.web_socket.is_terminated() || !context.responders.has_responder()
210                        {
211                            break;
212                        } else {
213                            continue;
214                        }
215                    },
216                    Some(request) => {
217                        if let Err(e) = C::Variant::handle_request(request, &mut context).await {
218                            if e.is_fatal() {
219                                return Err(e.into());
220                            } else {
221                                // TODO: log non-fatal errors here
222                                // log::error!("{e}");
223                                continue;
224                            }
225                        }
226                    }
227                },
228                item = &mut context.web_socket.next() => match item {
229                    None => break,
230                    Some(response) => {
231                        if let Err(e) = C::Variant::handle_response(response, &mut context).await {
232                            if e.is_fatal() {
233                                return Err(e.into());
234                            } else {
235                                // TODO: log non-fatal errors here
236                                // log::error!("{e}");
237                                continue;
238                            }
239                        }
240                    }
241                }
242            }
243        }
244
245        Ok(())
246    }
247}
248
249type ResponseSender<Req, Resp> =
250    oneshot::Sender<WebSocketClientResult<Communication<Resp>, Req, Resp>>;
251
252pub struct Responder<C>
253where
254    C: CommunicableTypes,
255{
256    inner: ResponseSender<C::Req, C::Resp>,
257}
258
259pub struct WebSocketSinkItem<C: CommunicableTypes> {
260    request: Communication<C::Req>,
261    responder: Responder<C>,
262}
263
264impl<C> Sink<WebSocketSinkItem<C>> for WebSocketClient<C>
265where
266    C: CommunicableTypes,
267{
268    type Error = WebSocketClientError<C::Req, C::Resp>;
269
270    fn poll_ready(
271        self: std::pin::Pin<&mut Self>,
272        cx: &mut std::task::Context<'_>,
273    ) -> std::task::Poll<Result<(), Self::Error>> {
274        self.project().inner.poll_ready(cx).map_err(Into::into)
275    }
276
277    fn start_send(
278        self: std::pin::Pin<&mut Self>,
279        item: WebSocketSinkItem<C>,
280    ) -> WebSocketClientResult<(), C::Req, C::Resp> {
281        self.project().inner.start_send(item).map_err(Into::into)
282    }
283
284    fn poll_flush(
285        self: std::pin::Pin<&mut Self>,
286        cx: &mut std::task::Context<'_>,
287    ) -> std::task::Poll<Result<(), Self::Error>> {
288        self.project().inner.poll_flush(cx).map_err(Into::into)
289    }
290
291    fn poll_close(
292        self: std::pin::Pin<&mut Self>,
293        cx: &mut std::task::Context<'_>,
294    ) -> std::task::Poll<Result<(), Self::Error>> {
295        self.project().inner.poll_close(cx).map_err(Into::into)
296    }
297}
298
299pub trait Responders<C> {
300    type Responders: ContainsResponders;
301}
302
303pub trait ContainsResponders {
304    fn has_responder(&self) -> bool;
305}
306
307pub trait WebSocketClientExtension {
308    type Extension;
309}