1use std::{
6 marker::PhantomData,
7 pin::Pin,
8 sync::Arc,
9 task::{Context, Poll},
10};
11
12use futures::{SinkExt, Stream, StreamExt, channel::mpsc, stream::FuturesUnordered};
13use pin_project::pin_project;
14use tokio::{
15 io::{AsyncRead, AsyncWrite},
16 sync::RwLock,
17 task::{JoinHandle, JoinSet},
18};
19use tokio_tungstenite::{
20 WebSocketStream,
21 tungstenite::{Error as TungsteniteError, protocol::Role},
22};
23
24use reach_core::{
25 communication::*,
26 error::{DecodeError, GenericWebSocketError},
27};
28
29use crate::*;
30
31#[cfg(feature = "local-server")]
32mod local;
33#[cfg(feature = "local-server")]
34pub use local::*;
35
36pub mod macros;
37
38#[cfg(feature = "remote-server")]
39mod remote;
40#[cfg(feature = "remote-server")]
41pub use remote::*;
42
43type RawItemSender<I, O> = mpsc::Sender<RawWebSocketItem<I, O>>;
44
45#[derive(Debug)]
46pub struct ServerOptions<Transport, Incoming, Outgoing, Extensions> {
47 pub web_socket: WebSocketChannel<Transport, Incoming, Outgoing>,
48 pub extensions: Extensions,
49}
50
51impl<Transport, Incoming, Outgoing, Extensions>
52 From<(WebSocketChannel<Transport, Incoming, Outgoing>, Extensions)>
53 for ServerOptions<Transport, Incoming, Outgoing, Extensions>
54{
55 fn from(
56 (web_socket, extensions): (WebSocketChannel<Transport, Incoming, Outgoing>, Extensions),
57 ) -> Self {
58 Self {
59 web_socket,
60 extensions,
61 }
62 }
63}
64
65#[derive(Debug)]
66#[pin_project]
67pub struct ServerContext<Transport, Incoming, Outgoing, Tag, Extensions> {
68 #[pin]
69 web_socket: WebSocketChannel<Transport, Incoming, Outgoing>,
70 incoming_sink: RawItemSender<Incoming, Outgoing>,
71 #[pin]
72 outgoing_stream: FuturesUnordered<WebSocketItemReceiver<Outgoing, Tag>>,
73 extensions: Extensions,
74 hold: Option<WebSocketStreamResult<Next<Incoming, Outgoing, ()>>>,
75}
76
77impl<Transport, Incoming, Outgoing, Tag, Extensions> Stream
78 for ServerContext<Transport, Incoming, Outgoing, Tag, Extensions>
79where
80 Transport: AsyncRead + AsyncWrite + Unpin,
81 Incoming: CommunicableType,
82 Outgoing: CommunicableType + From<GenericWebSocketError>,
83 Tag: IntoAugmentation + Default,
84 WebSocketItemReceiver<Outgoing, Tag>: Unpin,
85{
86 type Item = WebSocketStreamResult<Next<Incoming, Outgoing, ()>>;
87
88 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89 let mut this = self.project();
90 if this.hold.is_some() {
91 Poll::Ready(this.hold.take())
92 } else {
93 let incoming = this.web_socket.poll_next_unpin(cx);
94 let outgoing = this.outgoing_stream.poll_next_unpin(cx);
95 match (incoming, outgoing) {
96 (Poll::Ready(Some(incoming)), Poll::Ready(Some(outgoing))) => {
97 let hold = Ok(Next::Outgoing(WebSocketItem::without_responder(outgoing)));
98 this.hold.replace(hold);
99 Poll::Ready(Some(incoming.map(Next::Incoming)))
100 }
101 (Poll::Ready(Some(incoming)), _) => Poll::Ready(Some(incoming.map(Next::Incoming))),
102 (_, Poll::Ready(Some(outgoing))) => Poll::Ready(Some(Ok(Next::Outgoing(
103 WebSocketItem::without_responder(outgoing),
104 )))),
105 (Poll::Ready(None), Poll::Ready(None)) => Poll::Ready(None),
106 (_, _) => Poll::Pending,
107 }
108 }
109 }
110}
111
112impl<Transport, Incoming, Outgoing, Tag, Extensions> WebSocketErrorResponder
113 for ServerContext<Transport, Incoming, Outgoing, Tag, Extensions>
114where
115 Transport: AsyncRead + AsyncWrite + Unpin + Send,
116 Incoming: CommunicableType + Send,
117 Outgoing: CommunicableType + From<GenericWebSocketError> + Send,
118 Tag: Send,
119 Extensions: Send,
120{
121 async fn respond_with_error(
122 &mut self,
123 error: GenericWebSocketError,
124 ) -> WebSocketStreamResult<()> {
125 self.web_socket
126 .send(error.into_communication())
127 .await
128 .map_err(Into::into)
129 }
130}
131
132#[pin_project]
133pub struct WebSocketServer<W: WebSocketContext> {
134 #[pin]
135 inner: mpsc::Receiver<RawWebSocketItem<W::Incoming, W::Outgoing>>,
136 task: JoinHandle<Result<(), WebSocketError>>,
137 context: PhantomData<W>,
138}
139
140impl<W> WebSocketServer<W>
141where
142 W: WebSocketContext + Unpin + Send + 'static,
143 W::Options: IntoWebSocketContext<W, Channel = RawItemSender<W::Incoming, W::Outgoing>>,
144 W::Incoming: CommunicableType + Send + 'static,
145 W::Outgoing: CommunicableType + From<GenericWebSocketError> + Send + 'static,
146 W::Responder: Send,
147{
148 pub fn with_options<const B: usize>(options: W::Options) -> Self {
149 let (sink, stream) = mpsc::channel(B);
150 Self {
151 inner: stream,
152 task: tokio::spawn(Self::start(W::with_options(options, sink))),
153 context: PhantomData,
154 }
155 }
156
157 async fn start(mut context: W) -> Result<(), WebSocketError> {
158 while let Some(result) = context.next().await {
159 match result {
160 Ok(next) => match next {
161 Next::Incoming(incoming) => context.handle_incoming(incoming).await?,
162 Next::Outgoing(outgoing) => context.handle_outgoing(outgoing).await?,
163 },
164 Err(e) => {
165 context
166 .respond_with_error(GenericWebSocketError::from(&e))
167 .await?
168 }
169 }
170 }
171
172 Ok(())
173 }
174}
175
176impl<W: WebSocketContext> Stream for WebSocketServer<W> {
177 type Item = RawWebSocketItem<W::Incoming, W::Outgoing>;
178
179 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
180 self.project().inner.poll_next(cx)
181 }
182}
183
184async fn handle_connection<const B: usize, S, E, D, C, W>(
185 web_socket_stream: WebSocketStream<S>,
186 server_context_extensions: E,
187 global_context: Arc<D::GlobalContext>,
188 session_context_init: impl FnOnce() -> Arc<RwLock<D::SessionContext>> + Clone,
189) -> Result<(), TungsteniteError>
190where
191 S: AsyncRead + AsyncWrite + Unpin,
192 D: RequestDelegator<C> + 'static,
193 D::GlobalContext: Sync + Send,
194 D::SessionContext: Sync + Send,
195 C: CommunicableTypes<Req = W::Incoming, Resp = W::Outgoing> + 'static,
196 C::Resp: From<GenericWebSocketError>,
197 W: WebSocketContext + Unpin + Send + Sized + 'static,
198 W::Incoming: Send,
199 W::Outgoing: Send,
200 W::Responder: Send,
201 W::Options: IntoWebSocketContext<W, Channel = RawItemSender<W::Incoming, W::Outgoing>>,
202 W::Options: From<(WebSocketChannel<S, W::Incoming, W::Outgoing>, E)>,
203{
204 let session_context = session_context_init();
205
206 let web_socket_channel = WebSocketChannel::with_stream(web_socket_stream);
207 let mut connection = WebSocketServer::<W>::with_options::<B>(
208 (web_socket_channel, server_context_extensions).into(),
209 );
210 let mut tasks = JoinSet::new();
211 while let Some(item) = connection.next().await {
212 let _ = tasks.spawn(D::delegate(
213 item,
214 global_context.clone(),
215 session_context.clone(),
216 ));
217 }
218 let _ = tasks.join_all().await;
219
220 Ok(())
221}
222
223#[derive(Debug, thiserror::Error)]
224pub enum RespondToItemError {
225 #[error("decode: {0}")]
226 Decode(#[from] DecodeError),
227 #[error("reponse failure")]
228 ResponseFailure,
229 #[error("unsupported")]
230 Unsupported,
231}
232
233pub trait RequestHandler<C>
234where
235 C: CommunicableTypes,
236 C::Resp: ErrorSubset,
237 Self: MatchResponse<C::Req>,
238{
239 type GlobalContext;
240 type SessionContext;
241
242 fn handle(
243 request: Self,
244 global_context: Arc<Self::GlobalContext>,
245 session_context: Arc<RwLock<Self::SessionContext>>,
246 ) -> impl Future<
247 Output = Result<
248 <Self as MatchResponse<C::Req>>::Resp,
249 impl Into<<C::Resp as ErrorSubset>::Error>,
250 >,
251 > + Send;
252}
253
254pub trait RequestDelegator<C>
255where
256 C: CommunicableTypes,
257 C::Resp: From<GenericWebSocketError>,
258{
259 type GlobalContext;
260 type SessionContext;
261
262 fn delegate(
263 item: RawWebSocketItem<C::Req, C::Resp>,
264 global_context: Arc<Self::GlobalContext>,
265 session_context: Arc<RwLock<Self::SessionContext>>,
266 ) -> impl Future<Output = Result<(), RespondToItemError>> + Send;
267}
268
269pub async fn respond_to_item<C, Request>(
270 item: RawWebSocketItem<C::Req, C::Resp>,
271 global_context: Arc<Request::GlobalContext>,
272 session_context: Arc<RwLock<Request::SessionContext>>,
273) -> Result<(), RespondToItemError>
274where
275 C: CommunicableTypes,
276 C::Resp: From<GenericWebSocketError> + ErrorSubset,
277 Request: Communicable<C::Req> + MatchResponse<C::Req> + RequestHandler<C>,
278 Request::Resp: Communicable<C::Resp>,
279{
280 let (request, responder) = item.decode::<Request>()?.split();
281
282 match Request::handle(request, global_context, session_context).await {
283 Ok(response) => responder
284 .send(response)
285 .map_err(|_| RespondToItemError::ResponseFailure),
286 Err(error) => responder
287 .send_error(error.into())
288 .map_err(|_| RespondToItemError::ResponseFailure),
289 }
290}