reach_websocket/
item.rs

1// SPDX-FileCopyrightText: 2025 Michael Goldenberg <m@mgoldenberg.net>
2// SPDX-License-Identifier: EUPL-1.2
3
4use std::{
5    marker::PhantomData,
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use futures::{FutureExt, channel::oneshot};
11
12use reach_core::{
13    communication::{
14        Communicable, CommunicableType, Communication, ErrorSubset, IntoAugmentation, MatchResponse,
15    },
16    error::{self, GenericWebSocketError},
17};
18
19#[derive(Debug)]
20pub struct WebSocketItem<Item, Responder> {
21    pub item: Item,
22    responder: Responder,
23}
24
25pub type RawWebSocketItem<I, R> = WebSocketItem<Communication<I>, RawResponder<R>>;
26pub(crate) type RawResponder<R> = oneshot::Sender<Communication<R>>;
27pub(crate) type DecodedWebSocketItem<Item, I, R> =
28    WebSocketItem<Item, EncodingResponder<Item, I, R>>;
29
30impl<I, R> RawWebSocketItem<I, R>
31where
32    I: CommunicableType,
33    R: CommunicableType + From<GenericWebSocketError>,
34{
35    pub fn channel(item: Communication<I>) -> (Self, WebSocketItemReceiver<R>) {
36        Self::tagged_channel(item, ())
37    }
38
39    pub fn tagged_channel<Tag: IntoAugmentation>(
40        item: Communication<I>,
41        tag: Tag,
42    ) -> (Self, WebSocketItemReceiver<R, Tag>) {
43        let (responder, receiver) = oneshot::channel();
44        (
45            Self::with_raw_responder(item, responder),
46            WebSocketItemReceiver::with_tag_and_receiver(tag, receiver),
47        )
48    }
49
50    pub fn with_raw_responder(item: Communication<I>, responder: RawResponder<R>) -> Self {
51        Self { item, responder }
52    }
53
54    pub fn decode<Item>(self) -> Result<DecodedWebSocketItem<Item, I, R>, error::DecodeError>
55    where
56        Item: Communicable<I>,
57    {
58        match Item::try_from_communication(&self.item) {
59            Ok(item) => Ok(WebSocketItem {
60                item,
61                responder: self.responder.into(),
62            }),
63            Err(e) => {
64                let _ = self
65                    .responder
66                    .send(GenericWebSocketError::Decode.into_communication());
67                Err(e)
68            }
69        }
70    }
71
72    pub fn unsupported(self) {
73        let _ = self
74            .responder
75            .send(GenericWebSocketError::Unsupported.into_communication());
76    }
77}
78
79impl<I> WebSocketItem<Communication<I>, ()>
80where
81    I: CommunicableType,
82{
83    pub fn without_responder(item: Communication<I>) -> Self {
84        Self {
85            item,
86            responder: (),
87        }
88    }
89}
90
91impl<Item, I, R> DecodedWebSocketItem<Item, I, R>
92where
93    Item: Communicable<I> + MatchResponse<I>,
94    Item::Resp: Communicable<R>,
95    I: CommunicableType,
96    R: CommunicableType,
97{
98    pub fn split(self) -> (Item, EncodingResponder<Item, I, R>) {
99        (self.item, self.responder)
100    }
101}
102
103#[derive(Debug)]
104pub struct EncodingResponder<Item, I, R> {
105    responder: RawResponder<R>,
106    item: PhantomData<(Item, I)>,
107}
108
109impl<Item, I, R> From<RawResponder<R>> for EncodingResponder<Item, I, R>
110where
111    R: CommunicableType,
112{
113    fn from(responder: oneshot::Sender<Communication<R>>) -> Self {
114        Self {
115            responder,
116            item: PhantomData,
117        }
118    }
119}
120
121impl<Item, I, R> EncodingResponder<Item, I, R>
122where
123    Item: Communicable<I> + MatchResponse<I>,
124    Item::Resp: Communicable<R>,
125    I: CommunicableType,
126    R: CommunicableType,
127{
128    pub fn send(self, response: Item::Resp) -> Result<(), Item::Resp> {
129        self.responder
130            .send(response.to_communication())
131            .map_err(|_| response)
132    }
133}
134
135impl<Item, I, R> EncodingResponder<Item, I, R>
136where
137    R: CommunicableType + From<GenericWebSocketError>,
138{
139    pub fn send_generic_error(
140        self,
141        error: GenericWebSocketError,
142    ) -> Result<(), GenericWebSocketError> {
143        self.responder
144            .send(error.into_communication::<R>())
145            .map_err(|_| error)
146    }
147}
148
149impl<Item, I, R> EncodingResponder<Item, I, R>
150where
151    R: CommunicableType + ErrorSubset,
152{
153    pub fn send_error(self, error: <R as ErrorSubset>::Error) -> Result<(), oneshot::Canceled> {
154        self.responder
155            .send(Communication::new(R::from(error)))
156            .map_err(|_| oneshot::Canceled)
157    }
158}
159
160#[derive(Debug)]
161pub struct WebSocketItemReceiver<T, Tag = ()> {
162    pub receiver: oneshot::Receiver<Communication<T>>,
163    pub tag: Tag,
164}
165
166impl<T, Tag: IntoAugmentation> WebSocketItemReceiver<T, Tag> {
167    pub fn with_tag_and_receiver(tag: Tag, receiver: oneshot::Receiver<Communication<T>>) -> Self {
168        Self { receiver, tag }
169    }
170}
171
172impl<T> WebSocketItemReceiver<T> {
173    pub fn with_receiver(receiver: oneshot::Receiver<Communication<T>>) -> Self {
174        Self { receiver, tag: () }
175    }
176}
177
178pub type TaggedWebSocketItemReceiver<T> = WebSocketItemReceiver<T, Vec<u8>>;
179
180impl<T, Tag> Future for WebSocketItemReceiver<T, Tag>
181where
182    T: CommunicableType + From<GenericWebSocketError>,
183    Tag: IntoAugmentation + Default,
184    Self: Unpin,
185{
186    type Output = Communication<T>;
187
188    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189        self.receiver.poll_unpin(cx).map(|inner| {
190            let communication = match inner {
191                Ok(response) => response,
192                Err(_) => GenericWebSocketError::Internal.into_communication(),
193            };
194            let tag = std::mem::take(&mut self.tag);
195            communication.with_augmentation(tag.into_augmentation())
196        })
197    }
198}