reach_websocket/client/
local.rs

1// SPDX-FileCopyrightText: 2024-2025 Michael Goldenberg <m@mgoldenberg.net>
2// SPDX-FileCopyrightText: 2024-2025 eaon <eaon@posteo.net>
3// SPDX-License-Identifier: EUPL-1.2
4
5use std::num::NonZeroUsize;
6
7use lru::LruCache;
8use tokio_tungstenite::tungstenite::protocol::Role;
9
10use reach_core::communication::LocalCommunication;
11
12use super::*;
13
14pub async fn websocketstream_from_raw_socket<S>(stream: S) -> WebSocketStream<S>
15where
16    S: AsyncRead + AsyncWrite + Unpin,
17{
18    WebSocketStream::from_raw_socket(stream, Role::Client, None).await
19}
20
21impl<C> WebSocketClientHandlers<C> for LocalCommunication
22where
23    C: CommunicableTypes<Req: Send + 'static, Resp: Send + 'static> + Send + 'static,
24    C::Variant: Responders<C, Responders = LruCache<Vec<u8>, Responder<C>>>,
25    C::Variant: WebSocketClientExtension<Extension: Send + 'static>,
26{
27    async fn handle_request<S, R>(
28        item: WebSocketSinkItem<C>,
29        context: &mut WebSocketClientContext<C, S, R>,
30    ) -> Result<(), HandleRequestError<C::Req>>
31    where
32        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
33        R: CryptoRngCore + Send + 'static,
34    {
35        let (communication, tag) = match item.request.augmentation.clone() {
36            Some(tag) => (item.request, tag),
37            None => item.request.tag(&mut context.csprng),
38        };
39        if context.responders.contains(&tag) {
40            let response = HandleRequestError::TagAlreadyInUse(communication);
41            item.responder
42                .inner
43                .send(Err(response.into()))
44                .map_err(|_| HandleRequestError::ResponseChannelDisconnected)?;
45            return Ok(());
46        }
47        context
48            .web_socket
49            .send(Message::binary(communication.encode()))
50            .await?;
51        context.responders.put(tag, item.responder);
52
53        Ok(())
54    }
55
56    async fn handle_response<S, R>(
57        item: Result<Message, TungsteniteError>,
58        context: &mut WebSocketClientContext<C, S, R>,
59    ) -> Result<(), HandleResponseError<C::Resp>>
60    where
61        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
62        R: CryptoRngCore + Send + 'static,
63    {
64        match item {
65            Ok(Message::Binary(bytes)) => {
66                let communication = Communication::decode(bytes.as_ref())?;
67                match communication.augmentation {
68                    None => Err(HandleResponseError::ResponseNotTagged(communication)),
69                    Some(ref tag) => match context.responders.pop(tag) {
70                        None => Err(HandleResponseError::ResponderNotFound(communication)),
71                        Some(responder) => responder
72                            .inner
73                            .send(Ok(communication))
74                            .map_err(|_| HandleResponseError::ResponseChannelDisconnected),
75                    },
76                }
77            }
78            Ok(message) => Err(HandleResponseError::MessageNotSupported(message)),
79            Err(e) => Err(HandleResponseError::WebSocket(e)),
80        }
81    }
82}
83
84impl<C> Responders<C> for LocalCommunication
85where
86    C: CommunicableTypes,
87{
88    type Responders = LruCache<Vec<u8>, Responder<C>>;
89}
90
91impl<T> ContainsResponders for LruCache<Vec<u8>, T> {
92    fn has_responder(&self) -> bool {
93        !self.is_empty()
94    }
95}
96
97impl WebSocketClientExtension for LocalCommunication {
98    type Extension = ();
99}
100
101impl<S, R, C, const B: usize> From<WebSocketClientOptions<NonZeroUsize, S, R, B>>
102    for WebSocketClient<C>
103where
104    C: CommunicableTypes<Req: Send + 'static, Resp: Send + 'static> + Send + 'static,
105    C::Variant: WebSocketClientExtension<Extension = ()>,
106    C::Variant: WebSocketClientHandlers<C>,
107    C::Variant: Responders<C, Responders = LruCache<Vec<u8>, Responder<C>>>,
108    <C::Variant as Responders<C>>::Responders: ContainsResponders + Send + 'static,
109    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
110    R: CryptoRngCore + Send + 'static,
111{
112    fn from(options: WebSocketClientOptions<NonZeroUsize, S, R, B>) -> Self {
113        let (sink, stream) = mpsc::channel(B);
114        Self {
115            inner: sink,
116            task: tokio::spawn(Self::start(WebSocketClientContext {
117                request_stream: stream.fuse(),
118                web_socket: options.web_socket.fuse(),
119                extension: (),
120                responders: LruCache::new(options.extension),
121                csprng: options.csprng,
122            })),
123        }
124    }
125}
126
127impl<S, R, const B: usize> WebSocketClientOptions<NonZeroUsize, S, R, B> {
128    pub fn new(
129        web_socket: WebSocketStream<S>,
130        processing_cache_size: NonZeroUsize,
131        csprng: R,
132    ) -> Self {
133        Self {
134            web_socket,
135            csprng,
136            extension: processing_cache_size,
137        }
138    }
139}
140
141pub type WebSocketLocalClientOptions<S, R, const B: usize> =
142    WebSocketClientOptions<NonZeroUsize, S, R, B>;