reach_websocket/server/
local.rs

1// SPDX-FileCopyrightText: 2025 Michael Goldenberg <m@mgoldenberg.net>
2// SPDX-FileCopyrightText: 2025 eaon <eaon@posteo.net>
3// SPDX-License-Identifier: EUPL-1.2
4
5use std::path::Path;
6
7use tokio::net::{UnixListener, UnixStream};
8
9use super::*;
10
11pub type LocalServerOptions<Incoming, Outgoing> = ServerOptions<UnixStream, Incoming, Outgoing, ()>;
12
13impl<Incoming, Outgoing> IntoWebSocketContext<LocalServerContext<Incoming, Outgoing>>
14    for LocalServerOptions<Incoming, Outgoing>
15where
16    Incoming: CommunicableType + Send,
17    Outgoing: CommunicableType + From<GenericWebSocketError> + Send,
18{
19    type Channel = RawItemSender<Incoming, Outgoing>;
20
21    fn into_web_socket_context(
22        self,
23        channel: Self::Channel,
24    ) -> LocalServerContext<Incoming, Outgoing> {
25        ServerContext {
26            web_socket: self.web_socket,
27            incoming_sink: channel,
28            outgoing_stream: Default::default(),
29            extensions: self.extensions,
30            hold: None,
31        }
32    }
33}
34
35pub type LocalServerContext<Incoming, Outgoing> =
36    ServerContext<UnixStream, Incoming, Outgoing, Vec<u8>, ()>;
37
38impl<Incoming, Outgoing> WebSocketContext for LocalServerContext<Incoming, Outgoing>
39where
40    Incoming: CommunicableType + Send,
41    Outgoing: CommunicableType + From<GenericWebSocketError> + Send,
42{
43    type Options = LocalServerOptions<Incoming, Outgoing>;
44    type Incoming = Incoming;
45    type Outgoing = Outgoing;
46    type Responder = ();
47
48    async fn handle_incoming(
49        &mut self,
50        incoming: Communication<Self::Incoming>,
51    ) -> Result<(), WebSocketError> {
52        match incoming.augmentation {
53            Some(ref tag) => {
54                let tag = tag.clone();
55                let (item, receiver) = WebSocketItem::tagged_channel(incoming, tag);
56                self.outgoing_stream.push(receiver);
57                if let Err(e) = self.incoming_sink.send(item).await {
58                    if e.is_disconnected() {
59                        return Err(e.into());
60                    } else {
61                        // NOTE: If item is not forwarded, then `responder` is dropped,
62                        // which means that the receiver will resolve to `Canceled` the
63                        // next time it is polled and will be properly forwarded over
64                        // the wire.
65                        //
66                        // TODO: log error here
67                    }
68                }
69            }
70            None => {
71                // TODO: Should this be a different variant? Perhaps `Untagged`?
72                self.web_socket
73                    .send(GenericWebSocketError::Unsupported.into_communication())
74                    .await?
75            }
76        };
77        Ok(())
78    }
79
80    async fn handle_outgoing(
81        &mut self,
82        outgoing: WebSocketItem<Communication<Self::Outgoing>, Self::Responder>,
83    ) -> Result<(), WebSocketError> {
84        self.web_socket.send(outgoing.item).await?;
85        Ok(())
86    }
87}
88
89pub async fn start<const B: usize, E, D, C, W>(
90    socket_path: impl AsRef<Path>,
91    server_context_extensions: E,
92    global_context: D::GlobalContext,
93    session_context_init: impl Into<Arc<RwLock<D::SessionContext>>> + Clone + Send + 'static,
94) -> std::io::Result<()>
95where
96    E: Clone + Send + 'static,
97    D: RequestDelegator<C> + 'static,
98    D::GlobalContext: Sync + Send + 'static,
99    D::SessionContext: Sync + Send + 'static,
100    C: CommunicableTypes<Req = W::Incoming, Resp = W::Outgoing> + 'static,
101    C::Resp: From<GenericWebSocketError>,
102    W: WebSocketContext + Unpin + Send + 'static,
103    W::Incoming: Send,
104    W::Outgoing: Send,
105    W::Responder: Send,
106    W::Options: IntoWebSocketContext<W, Channel = RawItemSender<W::Incoming, W::Outgoing>>,
107    W::Options: From<(WebSocketChannel<UnixStream, W::Incoming, W::Outgoing>, E)>,
108    D::SessionContext: Sync + Send + 'static,
109{
110    let listener = UnixListener::bind(socket_path)?;
111    log::info!("Listening at {:?}", listener.local_addr()?);
112
113    let global_context = Arc::new(global_context);
114
115    let mut tasks = JoinSet::new();
116    loop {
117        match listener.accept().await {
118            Ok((unix_stream, _)) => {
119                log::info!("Client connected");
120                let web_socket_stream =
121                    WebSocketStream::from_raw_socket(unix_stream, Role::Server, None).await;
122                let session_context_init = session_context_init.clone();
123                let _ = tasks.spawn(handle_connection::<B, UnixStream, E, D, C, W>(
124                    web_socket_stream,
125                    server_context_extensions.clone(),
126                    global_context.clone(),
127                    move || session_context_init.into(),
128                ));
129            }
130            Err(e) => {
131                log::error!("Error: {e}, shutting down ...");
132                tasks.shutdown().await;
133                return Err(e);
134            }
135        }
136    }
137}