reach_websocket/server/
remote.rs

1// SPDX-FileCopyrightText: 2025 Michael Goldenberg <m@mgoldenberg.net>
2// SPDX-FileCopyrightText: 2025 eaon <eaon@posteo.net>
3// SPDX-License-Identifier: EUPL-1.2
4
5use std::str;
6
7use axum::{
8    body::Body,
9    extract::FromRequestParts,
10    extract::ws::rejection::{
11        ConnectionNotUpgradable, InvalidConnectionHeader, InvalidProtocolPseudoheader,
12        InvalidUpgradeHeader, InvalidWebSocketVersionHeader, MethodNotConnect, MethodNotGet,
13        WebSocketKeyHeaderMissing, WebSocketUpgradeRejection,
14    },
15    http::{HeaderValue, Method, Version, request::Parts},
16    response::Response,
17};
18use hyper::{
19    StatusCode, header,
20    upgrade::{OnUpgrade, Upgraded},
21};
22use hyper_util::rt::TokioIo;
23use rand::seq::SliceRandom;
24
25use reach_aliases::*;
26
27use super::*;
28
29impl<Incoming, Outgoing> IntoWebSocketContext<RemoteServerContext<Incoming, Outgoing>>
30    for RemoteServerOptions<Incoming, Outgoing>
31where
32    Incoming: CommunicableType + Send,
33    Outgoing: CommunicableType + From<GenericWebSocketError> + Send,
34{
35    type Channel = RawItemSender<Incoming, Outgoing>;
36
37    fn into_web_socket_context(
38        self,
39        channel: Self::Channel,
40    ) -> RemoteServerContext<Incoming, Outgoing> {
41        ServerContext {
42            web_socket: self.web_socket,
43            incoming_sink: channel,
44            outgoing_stream: Default::default(),
45            extensions: self.extensions,
46            hold: None,
47        }
48    }
49}
50
51pub struct WebSocketUpgrade<E, D, C>
52where
53    D: RequestDelegator<C>,
54    C: CommunicableTypes,
55    C::Resp: From<GenericWebSocketError>,
56{
57    server_context_extensions: E,
58    global_context: Arc<D::GlobalContext>,
59    on_upgrade: OnUpgrade,
60    sec_websocket_signed_key: Option<HeaderValue>,
61}
62
63macro_rules! if_then_err {
64    ($boolean:expr, $error:tt) => {
65        if $boolean {
66            Err($error::default())?;
67        }
68    };
69}
70
71macro_rules! header_ne {
72    ($headers:expr, $key:expr, $value:expr) => {
73        !$headers
74            .get($key)
75            .is_some_and(|hdr| hdr.as_bytes().eq_ignore_ascii_case($value.as_bytes()))
76    };
77}
78
79#[inline]
80fn validate_http_11(parts: &Parts) -> Result<(), WebSocketUpgradeRejection> {
81    if_then_err!(parts.method != Method::GET, MethodNotGet);
82
83    if_then_err!(
84        !parts.headers.get(header::CONNECTION).is_some_and(|h| {
85            str::from_utf8(h.as_bytes()).is_ok_and(|b| b.to_ascii_lowercase().contains("upgrade"))
86        }),
87        InvalidConnectionHeader
88    );
89
90    if_then_err!(
91        header_ne!(parts.headers, header::UPGRADE, "websocket"),
92        InvalidUpgradeHeader
93    );
94
95    Ok(())
96}
97
98#[inline]
99fn validate_http_2(parts: &Parts) -> Result<(), WebSocketUpgradeRejection> {
100    if_then_err!(parts.method != Method::CONNECT, MethodNotConnect);
101
102    if_then_err!(
103        parts
104            .extensions
105            .get::<hyper::ext::Protocol>()
106            .is_none_or(|p| p.as_str() != "websocket"),
107        InvalidProtocolPseudoheader
108    );
109
110    Ok(())
111}
112
113#[inline]
114fn sign(key: &HeaderValue) -> Result<HeaderValue, WebSocketUpgradeRejection> {
115    use sha1::{Digest, Sha1};
116
117    let mut hasher = Sha1::new();
118    hasher.update(key);
119    hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
120    use base64::prelude::*;
121
122    let x = BASE64_STANDARD.encode(hasher.finalize());
123
124    HeaderValue::from_str(&x).map_err(|_| ConnectionNotUpgradable::default().into())
125}
126
127impl<E, D, C> FromRequestParts<(E, Arc<D::GlobalContext>)> for WebSocketUpgrade<E, D, C>
128where
129    E: Clone + Send + Sync,
130    D: RequestDelegator<C>,
131    D::GlobalContext: Send + Sync,
132    C: CommunicableTypes,
133    C::Resp: From<GenericWebSocketError>,
134{
135    type Rejection = WebSocketUpgradeRejection;
136
137    async fn from_request_parts(
138        parts: &mut Parts,
139        (server_context_extensions, global_context): &(E, Arc<D::GlobalContext>),
140    ) -> Result<Self, Self::Rejection> {
141        let headers = &parts.headers;
142
143        let sec_websocket_signed_key = match parts.version {
144            Version::HTTP_2 => validate_http_2(parts).map(|_| None)?,
145            Version::HTTP_11 => {
146                validate_http_11(parts)?;
147
148                let key = headers
149                    .get(header::SEC_WEBSOCKET_KEY)
150                    .ok_or(WebSocketKeyHeaderMissing::default())?;
151
152                Some(sign(key)?)
153            }
154            _ => Err(ConnectionNotUpgradable::default())?,
155        };
156
157        if_then_err!(
158            header_ne!(headers, header::SEC_WEBSOCKET_VERSION, "13"),
159            InvalidWebSocketVersionHeader
160        );
161
162        Ok(Self {
163            global_context: global_context.clone(),
164            sec_websocket_signed_key,
165            server_context_extensions: server_context_extensions.clone(),
166            on_upgrade: parts
167                .extensions
168                .remove()
169                .ok_or(ConnectionNotUpgradable::default())?,
170        })
171    }
172}
173
174#[derive(Debug, Clone)]
175pub struct RemoteServerContextExtensions {
176    pub permissible_communication_lengths: Vec<usize>,
177    pub csprng: ReachRng,
178}
179
180impl RemoteServerContextExtensions {
181    pub fn new(permissible_communication_lengths: Vec<usize>, csprng: ReachRng) -> Self {
182        Self {
183            permissible_communication_lengths,
184            csprng,
185        }
186    }
187}
188
189pub type RemoteServerOptions<Incoming, Outgoing> =
190    ServerOptions<TokioIo<Upgraded>, Incoming, Outgoing, RemoteServerContextExtensions>;
191
192pub type RemoteServerContext<Incoming, Outgoing> =
193    ServerContext<TokioIo<Upgraded>, Incoming, Outgoing, (), RemoteServerContextExtensions>;
194
195impl<Incoming, Outgoing> WebSocketContext for RemoteServerContext<Incoming, Outgoing>
196where
197    Incoming: CommunicableType + Send,
198    Outgoing: CommunicableType + From<GenericWebSocketError> + Send,
199{
200    type Options = RemoteServerOptions<Incoming, Outgoing>;
201    type Incoming = Incoming;
202    type Outgoing = Outgoing;
203    type Responder = ();
204
205    async fn handle_incoming(
206        &mut self,
207        incoming: Communication<Self::Incoming>,
208    ) -> Result<(), WebSocketError> {
209        let (item, receiver) = WebSocketItem::channel(incoming);
210        self.outgoing_stream.push(receiver);
211        if let Err(e) = self.incoming_sink.send(item).await {
212            if e.is_disconnected() {
213                return Err(e.into());
214            } else {
215                // TODO: log error here
216            }
217        }
218
219        Ok(())
220    }
221
222    async fn handle_outgoing(
223        &mut self,
224        outgoing: WebSocketItem<Communication<Self::Outgoing>, Self::Responder>,
225    ) -> Result<(), WebSocketError> {
226        // TODO: take maximum deviation from the base into account.
227        // We don't want to blow up tiny messages to humongous sizes
228        let length = loop {
229            match self
230                .extensions
231                .permissible_communication_lengths
232                .choose(&mut self.extensions.csprng)
233            {
234                Some(length) => {
235                    if length > &outgoing.item.inner.len() {
236                        break *length;
237                    }
238                }
239                None => break 0,
240            }
241        };
242        self.web_socket.send(outgoing.item.pad_to(length)).await?;
243
244        Ok(())
245    }
246}
247
248const UPGRADE_HEADER: HeaderValue = HeaderValue::from_static("upgrade");
249const WEBSOCKET_HEADER: HeaderValue = HeaderValue::from_static("websocket");
250
251pub fn on_upgrade<const B: usize, E, D, C, W>(
252    upgrade: WebSocketUpgrade<E, D, C>,
253    session_context_init: impl FnOnce() -> Arc<RwLock<D::SessionContext>> + Clone + Send + 'static,
254) -> Response
255where
256    D: RequestDelegator<C> + 'static,
257    D::GlobalContext: Send + Sync + 'static,
258    D::SessionContext: Send + Sync + 'static,
259    C: CommunicableTypes + 'static,
260    C::Req: Send + Sync,
261    C::Resp: From<GenericWebSocketError> + Send + Sync,
262    E: Clone + Send + 'static,
263    W: WebSocketContext<Incoming = C::Req, Outgoing = C::Resp> + Send + Unpin + 'static,
264    W::Responder: Send + Sync,
265    W::Options: From<(WebSocketChannel<TokioIo<Upgraded>, C::Req, C::Resp>, E)>,
266    W::Options: IntoWebSocketContext<W, Channel = RawItemSender<W::Incoming, W::Outgoing>>,
267{
268    tokio::spawn(async move {
269        if let Ok(upgraded) = upgrade.on_upgrade.await {
270            let upgraded = TokioIo::new(upgraded);
271
272            let web_socket_stream =
273                WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
274
275            let _ = handle_connection::<B, _, E, D, C, W>(
276                web_socket_stream,
277                upgrade.server_context_extensions.clone(),
278                upgrade.global_context,
279                session_context_init,
280            )
281            .await;
282        }
283    });
284
285    let mut response = Response::new(Body::empty());
286
287    if let Some(signed_key) = upgrade.sec_websocket_signed_key {
288        let (mut parts, body) = response.into_parts();
289        parts.status = StatusCode::SWITCHING_PROTOCOLS;
290        parts.headers.extend([
291            (header::CONNECTION, UPGRADE_HEADER),
292            (header::UPGRADE, WEBSOCKET_HEADER),
293            (header::SEC_WEBSOCKET_ACCEPT, signed_key),
294        ]);
295        response = Response::from_parts(parts, body);
296    }
297
298    response
299}