1use std::{
5 marker::PhantomData,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use futures::{FutureExt, channel::oneshot};
11
12use reach_core::{
13 communication::{
14 Communicable, CommunicableType, Communication, ErrorSubset, IntoAugmentation, MatchResponse,
15 },
16 error::{self, GenericWebSocketError},
17};
18
19#[derive(Debug)]
20pub struct WebSocketItem<Item, Responder> {
21 pub item: Item,
22 responder: Responder,
23}
24
25pub type RawWebSocketItem<I, R> = WebSocketItem<Communication<I>, RawResponder<R>>;
26pub(crate) type RawResponder<R> = oneshot::Sender<Communication<R>>;
27pub(crate) type DecodedWebSocketItem<Item, I, R> =
28 WebSocketItem<Item, EncodingResponder<Item, I, R>>;
29
30impl<I, R> RawWebSocketItem<I, R>
31where
32 I: CommunicableType,
33 R: CommunicableType + From<GenericWebSocketError>,
34{
35 pub fn channel(item: Communication<I>) -> (Self, WebSocketItemReceiver<R>) {
36 Self::tagged_channel(item, ())
37 }
38
39 pub fn tagged_channel<Tag: IntoAugmentation>(
40 item: Communication<I>,
41 tag: Tag,
42 ) -> (Self, WebSocketItemReceiver<R, Tag>) {
43 let (responder, receiver) = oneshot::channel();
44 (
45 Self::with_raw_responder(item, responder),
46 WebSocketItemReceiver::with_tag_and_receiver(tag, receiver),
47 )
48 }
49
50 pub fn with_raw_responder(item: Communication<I>, responder: RawResponder<R>) -> Self {
51 Self { item, responder }
52 }
53
54 pub fn decode<Item>(self) -> Result<DecodedWebSocketItem<Item, I, R>, error::DecodeError>
55 where
56 Item: Communicable<I>,
57 {
58 match Item::try_from_communication(&self.item) {
59 Ok(item) => Ok(WebSocketItem {
60 item,
61 responder: self.responder.into(),
62 }),
63 Err(e) => {
64 let _ = self
65 .responder
66 .send(GenericWebSocketError::Decode.into_communication());
67 Err(e)
68 }
69 }
70 }
71
72 pub fn unsupported(self) {
73 let _ = self
74 .responder
75 .send(GenericWebSocketError::Unsupported.into_communication());
76 }
77}
78
79impl<I> WebSocketItem<Communication<I>, ()>
80where
81 I: CommunicableType,
82{
83 pub fn without_responder(item: Communication<I>) -> Self {
84 Self {
85 item,
86 responder: (),
87 }
88 }
89}
90
91impl<Item, I, R> DecodedWebSocketItem<Item, I, R>
92where
93 Item: Communicable<I> + MatchResponse<I>,
94 Item::Resp: Communicable<R>,
95 I: CommunicableType,
96 R: CommunicableType,
97{
98 pub fn split(self) -> (Item, EncodingResponder<Item, I, R>) {
99 (self.item, self.responder)
100 }
101}
102
103#[derive(Debug)]
104pub struct EncodingResponder<Item, I, R> {
105 responder: RawResponder<R>,
106 item: PhantomData<(Item, I)>,
107}
108
109impl<Item, I, R> From<RawResponder<R>> for EncodingResponder<Item, I, R>
110where
111 R: CommunicableType,
112{
113 fn from(responder: oneshot::Sender<Communication<R>>) -> Self {
114 Self {
115 responder,
116 item: PhantomData,
117 }
118 }
119}
120
121impl<Item, I, R> EncodingResponder<Item, I, R>
122where
123 Item: Communicable<I> + MatchResponse<I>,
124 Item::Resp: Communicable<R>,
125 I: CommunicableType,
126 R: CommunicableType,
127{
128 pub fn send(self, response: Item::Resp) -> Result<(), Item::Resp> {
129 self.responder
130 .send(response.to_communication())
131 .map_err(|_| response)
132 }
133}
134
135impl<Item, I, R> EncodingResponder<Item, I, R>
136where
137 R: CommunicableType + From<GenericWebSocketError>,
138{
139 pub fn send_generic_error(
140 self,
141 error: GenericWebSocketError,
142 ) -> Result<(), GenericWebSocketError> {
143 self.responder
144 .send(error.into_communication::<R>())
145 .map_err(|_| error)
146 }
147}
148
149impl<Item, I, R> EncodingResponder<Item, I, R>
150where
151 R: CommunicableType + ErrorSubset,
152{
153 pub fn send_error(self, error: <R as ErrorSubset>::Error) -> Result<(), oneshot::Canceled> {
154 self.responder
155 .send(Communication::new(R::from(error)))
156 .map_err(|_| oneshot::Canceled)
157 }
158}
159
160#[derive(Debug)]
161pub struct WebSocketItemReceiver<T, Tag = ()> {
162 pub receiver: oneshot::Receiver<Communication<T>>,
163 pub tag: Tag,
164}
165
166impl<T, Tag: IntoAugmentation> WebSocketItemReceiver<T, Tag> {
167 pub fn with_tag_and_receiver(tag: Tag, receiver: oneshot::Receiver<Communication<T>>) -> Self {
168 Self { receiver, tag }
169 }
170}
171
172impl<T> WebSocketItemReceiver<T> {
173 pub fn with_receiver(receiver: oneshot::Receiver<Communication<T>>) -> Self {
174 Self { receiver, tag: () }
175 }
176}
177
178pub type TaggedWebSocketItemReceiver<T> = WebSocketItemReceiver<T, Vec<u8>>;
179
180impl<T, Tag> Future for WebSocketItemReceiver<T, Tag>
181where
182 T: CommunicableType + From<GenericWebSocketError>,
183 Tag: IntoAugmentation + Default,
184 Self: Unpin,
185{
186 type Output = Communication<T>;
187
188 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189 self.receiver.poll_unpin(cx).map(|inner| {
190 let communication = match inner {
191 Ok(response) => response,
192 Err(_) => GenericWebSocketError::Internal.into_communication(),
193 };
194 let tag = std::mem::take(&mut self.tag);
195 communication.with_augmentation(tag.into_augmentation())
196 })
197 }
198}