reach_websocket/client/
remote.rs

1// SPDX-FileCopyrightText: 2024-2025 Michael Goldenberg <m@mgoldenberg.net>
2// SPDX-FileCopyrightText: 2024-2025 eaon <eaon@posteo.net>
3// SPDX-License-Identifier: EUPL-1.2
4
5use rand::seq::SliceRandom;
6use tokio::net::TcpStream;
7use tokio_tungstenite::{MaybeTlsStream, connect_async, tungstenite::client::IntoClientRequest};
8
9use reach_core::communication::RemoteCommunication;
10
11use crate::WebSocketError;
12
13use super::*;
14
15pub async fn websocketstream_from_request<R>(
16    request: R,
17) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, WebSocketError>
18where
19    R: IntoClientRequest + Unpin,
20{
21    Ok(connect_async(request).await.map(|(stream, _)| stream)?)
22}
23
24impl<C> WebSocketClientHandlers<C> for RemoteCommunication
25where
26    C: CommunicableTypes<Req: Send + 'static, Resp: Send + 'static> + Send + 'static,
27    C::Variant: Responders<C, Responders = Option<Responder<C>>>,
28    C::Variant: WebSocketClientExtension<Extension = Vec<usize>>,
29{
30    async fn handle_request<S, R>(
31        item: WebSocketSinkItem<C>,
32        context: &mut WebSocketClientContext<C, S, R>,
33    ) -> Result<(), HandleRequestError<C::Req>>
34    where
35        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
36        R: CryptoRngCore + Send + 'static,
37    {
38        if context.responders.is_some() {
39            let response = HandleRequestError::ResponderAlreadyExists;
40            item.responder
41                .inner
42                .send(Err(response.into()))
43                .map_err(|_| HandleRequestError::ResponseChannelDisconnected)?;
44            return Ok(());
45        }
46
47        let mut length = 0;
48        // TODO: take maximum deviation from the base into account.
49        // We don't want to blow up tiny messages to humongous sizes
50        while length < item.request.inner.len() {
51            length = match context.extension.choose(&mut context.csprng) {
52                Some(length) => *length,
53                None => {
54                    let response = HandleRequestError::MissingPermissibleCommunicationLenghts;
55                    item.responder
56                        .inner
57                        .send(Err(response.into()))
58                        .map_err(|_| HandleRequestError::ResponseChannelDisconnected)?;
59                    return Ok(());
60                }
61            };
62        }
63        let communication = item.request.pad_to(length);
64
65        context
66            .web_socket
67            .send(Message::binary(communication.encode()))
68            .await?;
69        context.responders = Some(item.responder);
70
71        Ok(())
72    }
73
74    async fn handle_response<S, R>(
75        item: Result<Message, TungsteniteError>,
76        context: &mut WebSocketClientContext<C, S, R>,
77    ) -> Result<(), HandleResponseError<C::Resp>>
78    where
79        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
80        R: CryptoRngCore + Send + 'static,
81    {
82        match item {
83            Ok(Message::Binary(bytes)) => {
84                let communication = Communication::decode(bytes.as_ref())?;
85                match context.responders.take() {
86                    None => Err(HandleResponseError::ResponderNotFound(communication)),
87                    Some(responder) => responder
88                        .inner
89                        .send(Ok(communication))
90                        .map_err(|_| HandleResponseError::ResponseChannelDisconnected),
91                }
92            }
93            Ok(message) => Err(HandleResponseError::MessageNotSupported(message)),
94            Err(e) => Err(HandleResponseError::WebSocket(e)),
95        }
96    }
97}
98
99impl<C> Responders<C> for RemoteCommunication
100where
101    C: CommunicableTypes,
102{
103    type Responders = Option<Responder<C>>;
104}
105
106impl<C> ContainsResponders for Option<Responder<C>>
107where
108    C: CommunicableTypes,
109{
110    fn has_responder(&self) -> bool {
111        self.is_some()
112    }
113}
114
115impl WebSocketClientExtension for RemoteCommunication {
116    type Extension = Vec<usize>;
117}
118
119impl<S, R, C, const B: usize> From<WebSocketClientOptions<Vec<usize>, S, R, B>>
120    for WebSocketClient<C>
121where
122    C: CommunicableTypes<Req: Send + 'static, Resp: Send + 'static> + Send + 'static,
123    C::Variant: WebSocketClientExtension<Extension = Vec<usize>>,
124    C::Variant: WebSocketClientHandlers<C> + Responders<C, Responders = Option<Responder<C>>>,
125    C::Variant: Responders<C, Responders = Option<Responder<C>>>,
126    <C::Variant as Responders<C>>::Responders: ContainsResponders + Send + 'static,
127    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
128    R: CryptoRngCore + Send + 'static,
129{
130    fn from(options: WebSocketClientOptions<Vec<usize>, S, R, B>) -> Self {
131        let (sink, stream) = mpsc::channel(B);
132        Self {
133            inner: sink,
134            task: tokio::spawn(Self::start(WebSocketClientContext {
135                request_stream: stream.fuse(),
136                web_socket: options.web_socket.fuse(),
137                extension: options.extension,
138                responders: None,
139                csprng: options.csprng,
140            })),
141        }
142    }
143}
144
145impl<S, R, const B: usize> WebSocketClientOptions<Vec<usize>, S, R, B> {
146    pub fn new(
147        web_socket: WebSocketStream<S>,
148        permissible_communication_lengths: Vec<usize>,
149        csprng: R,
150    ) -> Self {
151        Self {
152            web_socket,
153            csprng,
154            extension: permissible_communication_lengths,
155        }
156    }
157}
158
159pub type WebSocketRemoteClientOptions<S, R> = WebSocketClientOptions<Vec<usize>, S, R, 1>;