1use std::str;
6
7use axum::{
8 body::Body,
9 extract::FromRequestParts,
10 extract::ws::rejection::{
11 ConnectionNotUpgradable, InvalidConnectionHeader, InvalidProtocolPseudoheader,
12 InvalidUpgradeHeader, InvalidWebSocketVersionHeader, MethodNotConnect, MethodNotGet,
13 WebSocketKeyHeaderMissing, WebSocketUpgradeRejection,
14 },
15 http::{HeaderValue, Method, Version, request::Parts},
16 response::Response,
17};
18use hyper::{
19 StatusCode, header,
20 upgrade::{OnUpgrade, Upgraded},
21};
22use hyper_util::rt::TokioIo;
23use rand::seq::SliceRandom;
24
25use reach_aliases::*;
26
27use super::*;
28
29impl<Incoming, Outgoing> IntoWebSocketContext<RemoteServerContext<Incoming, Outgoing>>
30 for RemoteServerOptions<Incoming, Outgoing>
31where
32 Incoming: CommunicableType + Send,
33 Outgoing: CommunicableType + From<GenericWebSocketError> + Send,
34{
35 type Channel = RawItemSender<Incoming, Outgoing>;
36
37 fn into_web_socket_context(
38 self,
39 channel: Self::Channel,
40 ) -> RemoteServerContext<Incoming, Outgoing> {
41 ServerContext {
42 web_socket: self.web_socket,
43 incoming_sink: channel,
44 outgoing_stream: Default::default(),
45 extensions: self.extensions,
46 hold: None,
47 }
48 }
49}
50
51pub struct WebSocketUpgrade<E, D, C>
52where
53 D: RequestDelegator<C>,
54 C: CommunicableTypes,
55 C::Resp: From<GenericWebSocketError>,
56{
57 server_context_extensions: E,
58 global_context: Arc<D::GlobalContext>,
59 on_upgrade: OnUpgrade,
60 sec_websocket_signed_key: Option<HeaderValue>,
61}
62
63macro_rules! if_then_err {
64 ($boolean:expr, $error:tt) => {
65 if $boolean {
66 Err($error::default())?;
67 }
68 };
69}
70
71macro_rules! header_ne {
72 ($headers:expr, $key:expr, $value:expr) => {
73 !$headers
74 .get($key)
75 .is_some_and(|hdr| hdr.as_bytes().eq_ignore_ascii_case($value.as_bytes()))
76 };
77}
78
79#[inline]
80fn validate_http_11(parts: &Parts) -> Result<(), WebSocketUpgradeRejection> {
81 if_then_err!(parts.method != Method::GET, MethodNotGet);
82
83 if_then_err!(
84 !parts.headers.get(header::CONNECTION).is_some_and(|h| {
85 str::from_utf8(h.as_bytes()).is_ok_and(|b| b.to_ascii_lowercase().contains("upgrade"))
86 }),
87 InvalidConnectionHeader
88 );
89
90 if_then_err!(
91 header_ne!(parts.headers, header::UPGRADE, "websocket"),
92 InvalidUpgradeHeader
93 );
94
95 Ok(())
96}
97
98#[inline]
99fn validate_http_2(parts: &Parts) -> Result<(), WebSocketUpgradeRejection> {
100 if_then_err!(parts.method != Method::CONNECT, MethodNotConnect);
101
102 if_then_err!(
103 parts
104 .extensions
105 .get::<hyper::ext::Protocol>()
106 .is_none_or(|p| p.as_str() != "websocket"),
107 InvalidProtocolPseudoheader
108 );
109
110 Ok(())
111}
112
113#[inline]
114fn sign(key: &HeaderValue) -> Result<HeaderValue, WebSocketUpgradeRejection> {
115 use sha1::{Digest, Sha1};
116
117 let mut hasher = Sha1::new();
118 hasher.update(key);
119 hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
120 use base64::prelude::*;
121
122 let x = BASE64_STANDARD.encode(hasher.finalize());
123
124 HeaderValue::from_str(&x).map_err(|_| ConnectionNotUpgradable::default().into())
125}
126
127impl<E, D, C> FromRequestParts<(E, Arc<D::GlobalContext>)> for WebSocketUpgrade<E, D, C>
128where
129 E: Clone + Send + Sync,
130 D: RequestDelegator<C>,
131 D::GlobalContext: Send + Sync,
132 C: CommunicableTypes,
133 C::Resp: From<GenericWebSocketError>,
134{
135 type Rejection = WebSocketUpgradeRejection;
136
137 async fn from_request_parts(
138 parts: &mut Parts,
139 (server_context_extensions, global_context): &(E, Arc<D::GlobalContext>),
140 ) -> Result<Self, Self::Rejection> {
141 let headers = &parts.headers;
142
143 let sec_websocket_signed_key = match parts.version {
144 Version::HTTP_2 => validate_http_2(parts).map(|_| None)?,
145 Version::HTTP_11 => {
146 validate_http_11(parts)?;
147
148 let key = headers
149 .get(header::SEC_WEBSOCKET_KEY)
150 .ok_or(WebSocketKeyHeaderMissing::default())?;
151
152 Some(sign(key)?)
153 }
154 _ => Err(ConnectionNotUpgradable::default())?,
155 };
156
157 if_then_err!(
158 header_ne!(headers, header::SEC_WEBSOCKET_VERSION, "13"),
159 InvalidWebSocketVersionHeader
160 );
161
162 Ok(Self {
163 global_context: global_context.clone(),
164 sec_websocket_signed_key,
165 server_context_extensions: server_context_extensions.clone(),
166 on_upgrade: parts
167 .extensions
168 .remove()
169 .ok_or(ConnectionNotUpgradable::default())?,
170 })
171 }
172}
173
174#[derive(Debug, Clone)]
175pub struct RemoteServerContextExtensions {
176 pub permissible_communication_lengths: Vec<usize>,
177 pub csprng: ReachRng,
178}
179
180impl RemoteServerContextExtensions {
181 pub fn new(permissible_communication_lengths: Vec<usize>, csprng: ReachRng) -> Self {
182 Self {
183 permissible_communication_lengths,
184 csprng,
185 }
186 }
187}
188
189pub type RemoteServerOptions<Incoming, Outgoing> =
190 ServerOptions<TokioIo<Upgraded>, Incoming, Outgoing, RemoteServerContextExtensions>;
191
192pub type RemoteServerContext<Incoming, Outgoing> =
193 ServerContext<TokioIo<Upgraded>, Incoming, Outgoing, (), RemoteServerContextExtensions>;
194
195impl<Incoming, Outgoing> WebSocketContext for RemoteServerContext<Incoming, Outgoing>
196where
197 Incoming: CommunicableType + Send,
198 Outgoing: CommunicableType + From<GenericWebSocketError> + Send,
199{
200 type Options = RemoteServerOptions<Incoming, Outgoing>;
201 type Incoming = Incoming;
202 type Outgoing = Outgoing;
203 type Responder = ();
204
205 async fn handle_incoming(
206 &mut self,
207 incoming: Communication<Self::Incoming>,
208 ) -> Result<(), WebSocketError> {
209 let (item, receiver) = WebSocketItem::channel(incoming);
210 self.outgoing_stream.push(receiver);
211 if let Err(e) = self.incoming_sink.send(item).await {
212 if e.is_disconnected() {
213 return Err(e.into());
214 } else {
215 }
217 }
218
219 Ok(())
220 }
221
222 async fn handle_outgoing(
223 &mut self,
224 outgoing: WebSocketItem<Communication<Self::Outgoing>, Self::Responder>,
225 ) -> Result<(), WebSocketError> {
226 let length = loop {
229 match self
230 .extensions
231 .permissible_communication_lengths
232 .choose(&mut self.extensions.csprng)
233 {
234 Some(length) => {
235 if length > &outgoing.item.inner.len() {
236 break *length;
237 }
238 }
239 None => break 0,
240 }
241 };
242 self.web_socket.send(outgoing.item.pad_to(length)).await?;
243
244 Ok(())
245 }
246}
247
248const UPGRADE_HEADER: HeaderValue = HeaderValue::from_static("upgrade");
249const WEBSOCKET_HEADER: HeaderValue = HeaderValue::from_static("websocket");
250
251pub fn on_upgrade<const B: usize, E, D, C, W>(
252 upgrade: WebSocketUpgrade<E, D, C>,
253 session_context_init: impl FnOnce() -> Arc<RwLock<D::SessionContext>> + Clone + Send + 'static,
254) -> Response
255where
256 D: RequestDelegator<C> + 'static,
257 D::GlobalContext: Send + Sync + 'static,
258 D::SessionContext: Send + Sync + 'static,
259 C: CommunicableTypes + 'static,
260 C::Req: Send + Sync,
261 C::Resp: From<GenericWebSocketError> + Send + Sync,
262 E: Clone + Send + 'static,
263 W: WebSocketContext<Incoming = C::Req, Outgoing = C::Resp> + Send + Unpin + 'static,
264 W::Responder: Send + Sync,
265 W::Options: From<(WebSocketChannel<TokioIo<Upgraded>, C::Req, C::Resp>, E)>,
266 W::Options: IntoWebSocketContext<W, Channel = RawItemSender<W::Incoming, W::Outgoing>>,
267{
268 tokio::spawn(async move {
269 if let Ok(upgraded) = upgrade.on_upgrade.await {
270 let upgraded = TokioIo::new(upgraded);
271
272 let web_socket_stream =
273 WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
274
275 let _ = handle_connection::<B, _, E, D, C, W>(
276 web_socket_stream,
277 upgrade.server_context_extensions.clone(),
278 upgrade.global_context,
279 session_context_init,
280 )
281 .await;
282 }
283 });
284
285 let mut response = Response::new(Body::empty());
286
287 if let Some(signed_key) = upgrade.sec_websocket_signed_key {
288 let (mut parts, body) = response.into_parts();
289 parts.status = StatusCode::SWITCHING_PROTOCOLS;
290 parts.headers.extend([
291 (header::CONNECTION, UPGRADE_HEADER),
292 (header::UPGRADE, WEBSOCKET_HEADER),
293 (header::SEC_WEBSOCKET_ACCEPT, signed_key),
294 ]);
295 response = Response::from_parts(parts, body);
296 }
297
298 response
299}