1use futures::channel::{mpsc, oneshot};
6use futures::stream::{Fuse, FusedStream};
7use futures::{Sink, SinkExt, StreamExt};
8use pin_project::pin_project;
9use rand_core::CryptoRngCore;
10use thiserror::Error;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::task::JoinHandle;
13use tokio_tungstenite::{
14 WebSocketStream,
15 tungstenite::{Error as TungsteniteError, protocol::Message},
16};
17
18use reach_core::communication::{
19 Communicable, CommunicableOwned, CommunicableTypes, Communication, MatchResponse,
20};
21
22#[cfg(feature = "local-client")]
23mod local;
24#[cfg(feature = "local-client")]
25pub use local::*;
26
27#[cfg(feature = "remote-client")]
28mod remote;
29#[cfg(feature = "remote-client")]
30pub use remote::*;
31
32mod sealed {
33 use super::*;
34
35 pub struct WebSocketClientContext<C, S, R>
36 where
37 C: CommunicableTypes,
38 C::Variant: Responders<C> + WebSocketClientExtension,
39 {
40 pub request_stream: Fuse<mpsc::Receiver<WebSocketSinkItem<C>>>,
41 pub web_socket: Fuse<WebSocketStream<S>>,
42 pub extension: <C::Variant as WebSocketClientExtension>::Extension,
43 pub responders: <C::Variant as Responders<C>>::Responders,
44 pub csprng: R,
45 }
46
47 pub trait WebSocketClientHandlers<C>
48 where
49 C: CommunicableTypes + Send + 'static,
50 C::Variant: Responders<C> + WebSocketClientExtension,
51 {
52 fn handle_request<S, R>(
53 item: WebSocketSinkItem<C>,
54 context: &mut WebSocketClientContext<C, S, R>,
55 ) -> impl Future<Output = Result<(), HandleRequestError<C::Req>>> + Send
56 where
57 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
58 R: CryptoRngCore + Send + 'static;
59
60 fn handle_response<S, R>(
61 item: Result<Message, TungsteniteError>,
62 context: &mut WebSocketClientContext<C, S, R>,
63 ) -> impl Future<Output = Result<(), HandleResponseError<C::Resp>>> + Send
64 where
65 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
66 R: CryptoRngCore + Send + 'static;
67 }
68}
69
70use sealed::*;
71
72#[derive(Debug, Error)]
73pub enum WebSocketClientError<Req, Resp> {
74 #[error("request channel: {0}")]
75 RequestChannel(#[from] mpsc::SendError),
76 #[error("request handler: {0}")]
77 RequestHandler(#[from] HandleRequestError<Req>),
78 #[error("response handler: {0}")]
79 ResponseHandler(#[from] HandleResponseError<Resp>),
80 #[error("response channel disconnected")]
81 ResponseChannelDisconnected(#[from] oneshot::Canceled),
82 #[error("response decoder: {0}")]
83 ResponseDecoder(#[from] reach_core::error::DecodeError),
84 #[error("response error: {0}")]
85 ResponseError(Resp),
86}
87
88type WebSocketClientResult<T, Req, Resp> = Result<T, WebSocketClientError<Req, Resp>>;
89
90#[derive(Debug, Error)]
91pub enum HandleRequestError<Req> {
92 #[error("permissible communication lengths not set")]
93 MissingPermissibleCommunicationLenghts,
94 #[error("responder already exists")]
95 ResponderAlreadyExists,
96 #[error("tag already in use: {0:?}")]
97 TagAlreadyInUse(Communication<Req>),
98 #[error("response channel disconnected")]
99 ResponseChannelDisconnected,
100 #[error("web socket: {0}")]
101 WebSocket(#[from] TungsteniteError),
102}
103
104impl<Req> HandleRequestError<Req> {
105 pub fn is_fatal(&self) -> bool {
106 matches!(self, Self::WebSocket(_))
107 }
108}
109
110#[derive(Debug, Error)]
111pub enum HandleResponseError<Resp> {
112 #[error("response decoder: {0}")]
113 ResponseDecoder(#[from] reach_core::error::DecodeError),
114 #[error("web socket: {0}")]
115 WebSocket(#[from] TungsteniteError),
116 #[error("message not supported: {0:?}")]
117 MessageNotSupported(Message),
118 #[error("response not tagged: {0:?}")]
119 ResponseNotTagged(Communication<Resp>),
120 #[error("responder not found: {0:?}")]
121 ResponderNotFound(Communication<Resp>),
122 #[error("response channel disconnected")]
123 ResponseChannelDisconnected,
124}
125
126impl<Resp> HandleResponseError<Resp> {
127 pub fn is_fatal(&self) -> bool {
128 matches!(self, Self::WebSocket(_))
129 }
130}
131
132type WebSocketClientTask<Req, Resp> = JoinHandle<WebSocketClientResult<(), Req, Resp>>;
133
134#[pin_project]
135pub struct WebSocketClient<C>
136where
137 C: CommunicableTypes,
138{
139 #[pin]
140 inner: mpsc::Sender<WebSocketSinkItem<C>>,
141 task: WebSocketClientTask<C::Req, C::Resp>,
142}
143
144pub struct WebSocketClientOptions<E, S, R, const B: usize> {
145 pub web_socket: WebSocketStream<S>,
146 pub csprng: R,
147 pub extension: E,
148}
149
150impl<C> WebSocketClient<C>
151where
152 C: CommunicableTypes<Req: Send + 'static, Resp: Send + 'static> + Send + 'static,
153 C::Variant: Responders<C>,
154 C::Variant: WebSocketClientHandlers<C> + WebSocketClientExtension,
155{
156 async fn send_communication_and_wait<Req>(
157 &mut self,
158 communication: Communication<C::Req>,
159 ) -> WebSocketClientResult<Req::Resp, C::Req, C::Resp>
160 where
161 Req: MatchResponse<C::Req, Resp: Communicable<C::Resp>>,
162 {
163 let (responder, receiver) = oneshot::channel();
164 let item = WebSocketSinkItem {
165 request: communication,
166 responder: Responder { inner: responder },
167 };
168 self.inner.send(item).await?;
169
170 let response = receiver.await??;
171
172 Req::Resp::try_from_communication(&response)
173 .map_err(|_| WebSocketClientError::ResponseError(response.r#type))
174 }
175
176 pub async fn send_and_wait<Req>(
177 &mut self,
178 request: &Req,
179 ) -> WebSocketClientResult<Req::Resp, C::Req, C::Resp>
180 where
181 Req: MatchResponse<C::Req, Resp: Communicable<C::Resp>> + Communicable<C::Req>,
182 {
183 self.send_communication_and_wait::<Req>(request.to_communication())
184 .await
185 }
186
187 pub async fn send_owned_and_wait<Req>(
188 &mut self,
189 request: Req,
190 ) -> WebSocketClientResult<Req::Resp, C::Req, C::Resp>
191 where
192 Req: MatchResponse<C::Req, Resp: Communicable<C::Resp>> + CommunicableOwned<C::Req>,
193 {
194 self.send_communication_and_wait::<Req>(request.to_communication())
195 .await
196 }
197
198 async fn start<S, R>(
199 mut context: WebSocketClientContext<C, S, R>,
200 ) -> WebSocketClientResult<(), C::Req, C::Resp>
201 where
202 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
203 R: CryptoRngCore + Send + 'static,
204 {
205 loop {
206 tokio::select! {
207 item = &mut context.request_stream.next() => match item {
208 None => {
209 if context.web_socket.is_terminated() || !context.responders.has_responder()
210 {
211 break;
212 } else {
213 continue;
214 }
215 },
216 Some(request) => {
217 if let Err(e) = C::Variant::handle_request(request, &mut context).await {
218 if e.is_fatal() {
219 return Err(e.into());
220 } else {
221 continue;
224 }
225 }
226 }
227 },
228 item = &mut context.web_socket.next() => match item {
229 None => break,
230 Some(response) => {
231 if let Err(e) = C::Variant::handle_response(response, &mut context).await {
232 if e.is_fatal() {
233 return Err(e.into());
234 } else {
235 continue;
238 }
239 }
240 }
241 }
242 }
243 }
244
245 Ok(())
246 }
247}
248
249type ResponseSender<Req, Resp> =
250 oneshot::Sender<WebSocketClientResult<Communication<Resp>, Req, Resp>>;
251
252pub struct Responder<C>
253where
254 C: CommunicableTypes,
255{
256 inner: ResponseSender<C::Req, C::Resp>,
257}
258
259pub struct WebSocketSinkItem<C: CommunicableTypes> {
260 request: Communication<C::Req>,
261 responder: Responder<C>,
262}
263
264impl<C> Sink<WebSocketSinkItem<C>> for WebSocketClient<C>
265where
266 C: CommunicableTypes,
267{
268 type Error = WebSocketClientError<C::Req, C::Resp>;
269
270 fn poll_ready(
271 self: std::pin::Pin<&mut Self>,
272 cx: &mut std::task::Context<'_>,
273 ) -> std::task::Poll<Result<(), Self::Error>> {
274 self.project().inner.poll_ready(cx).map_err(Into::into)
275 }
276
277 fn start_send(
278 self: std::pin::Pin<&mut Self>,
279 item: WebSocketSinkItem<C>,
280 ) -> WebSocketClientResult<(), C::Req, C::Resp> {
281 self.project().inner.start_send(item).map_err(Into::into)
282 }
283
284 fn poll_flush(
285 self: std::pin::Pin<&mut Self>,
286 cx: &mut std::task::Context<'_>,
287 ) -> std::task::Poll<Result<(), Self::Error>> {
288 self.project().inner.poll_flush(cx).map_err(Into::into)
289 }
290
291 fn poll_close(
292 self: std::pin::Pin<&mut Self>,
293 cx: &mut std::task::Context<'_>,
294 ) -> std::task::Poll<Result<(), Self::Error>> {
295 self.project().inner.poll_close(cx).map_err(Into::into)
296 }
297}
298
299pub trait Responders<C> {
300 type Responders: ContainsResponders;
301}
302
303pub trait ContainsResponders {
304 fn has_responder(&self) -> bool;
305}
306
307pub trait WebSocketClientExtension {
308 type Extension;
309}