1use std::{path::PathBuf, sync::Arc};
6
7use axum::response::Response as AxumResponse;
8use redb::{Database, ReadableTable};
9use tokio::sync::RwLock;
10
11use reach_core::{
12 AuditAuthenticationAssurance, RandomFromRng, error,
13 storage::{ReachablePublicKeyRing, Storable},
14 wire::{
15 AttestantVerifyingKeys, AuthenticationAssurance, AuthenticationChallenge, ErrorResponse,
16 Init, InitAuthenticatedSession, ReachCommunication, ReachOk, ReachableVerifyingKeys,
17 Request, Response, Salts,
18 },
19};
20use reach_proc_macros::request_handler_macro;
21use reach_signatures::{Digestible, Verifiable};
22use reach_websocket::{
23 RemoteServerContext, RemoteServerContextExtensions, WebSocketServer, WebSocketUpgrade,
24 macros::*, on_upgrade as core_on_upgrade,
25};
26
27use crate::db::{REACHABLE_PUBLIC_KEYS, REACHABLE_VERIFYING_KEYS};
28
29use super::memory::rng;
30
31pub struct ReachGlobalContext {
32 pub db: Database,
33 pub attestant_verifying_keys: AttestantVerifyingKeys,
34 pub salts: Salts,
35}
36
37impl ReachGlobalContext {
38 pub fn new(db: Database, config_path: PathBuf) -> Result<Self, Error> {
39 let attestant_verifying_keys =
40 AttestantVerifyingKeys::load(&config_path.join("attestant_verifying_keys"))?;
41 let salts = Salts::load(&config_path.join("salts"))?;
42
43 Ok(Self {
44 db,
45 attestant_verifying_keys,
46 salts,
47 })
48 }
49}
50
51pub struct ReachSessionContext {
52 pub(crate) authentication: Authentication,
53}
54
55impl Default for ReachSessionContext {
56 fn default() -> Self {
57 Self {
58 authentication: Authentication::Unauthenticated,
59 }
60 }
61}
62
63impl ReachSessionContext {
64 pub fn new() -> Self {
65 Self::default()
66 }
67}
68
69#[derive(Clone)]
70pub(crate) enum Authentication {
71 Unauthenticated,
72 Challenged(AuthenticationChallenge),
73 AuthenticatedAttestant,
74 Authenticated(Box<ReachableVerifyingKeys>),
75}
76
77impl Authentication {
78 pub fn is_authenticated(&self) -> bool {
79 matches!(
80 self,
81 Authentication::Authenticated(_) | Authentication::AuthenticatedAttestant
82 )
83 }
84}
85
86pub type ReachConnectionContext = RemoteServerContext<Request, Response>;
87pub type ReachConnection = WebSocketServer<ReachConnectionContext>;
88
89#[derive(thiserror::Error, Debug)]
90pub enum Error {
91 #[error("unsupported protocol version: {0}")]
92 UnsupportedProtocolVersion(u8),
93 #[error("rng error: {0}")]
94 RngError(#[from] rand::Error),
95 #[error("storage error: {0}")]
96 StorageError(#[from] error::StorageError),
97 #[error("bad request")]
98 BadRequest,
99 #[error("database commit: {0}")]
100 DatabaseCommit(#[from] redb::CommitError),
101 #[error("database storage: {0}")]
102 DatabaseStorage(#[from] redb::StorageError),
103 #[error("database transaction: {0}")]
104 DatabaseTransaction(#[from] redb::TransactionError),
105 #[error("database table: {0}")]
106 DatabaseTable(#[from] redb::TableError),
107 #[error("authentication failed")]
108 AuthenticationFailed,
109}
110
111request_wrapper!(ReachRequest, Request);
112
113request_delegator!(
114 ReachDelegator,
115 ReachCommunication,
116 ReachGlobalContext,
117 ReachSessionContext,
118 ReachRequest,
119 (
120 Init,
121 InitAuthenticatedSession,
122 AuthenticationAssurance,
123 ReachableVerifyingKeys,
124 ReachablePublicKeyRing,
125 ),
126);
127
128request_handler_macro!(
129 request_handler,
130 ReachCommunication,
131 ReachRequest,
132 ReachGlobalContext,
133 ReachSessionContext,
134);
135
136impl From<Error> for ErrorResponse {
137 fn from(_: Error) -> Self {
138 Self::Internal
139 }
140}
141
142pub fn on_upgrade<const B: usize>(
143 upgrade: WebSocketUpgrade<RemoteServerContextExtensions, ReachDelegator, ReachCommunication>,
144 session_context_init: impl FnOnce() -> Arc<RwLock<ReachSessionContext>> + Clone + Send + 'static,
145) -> AxumResponse {
146 core_on_upgrade::<
147 B,
148 RemoteServerContextExtensions,
149 ReachDelegator,
150 ReachCommunication,
151 ReachConnectionContext,
152 >(upgrade, session_context_init)
153}
154
155request_handler!(
156 async fn handle(
157 request: Init,
158 _global_context: _,
159 _session_context: _,
160 ) -> Result<ReachOk, Error> {
161 let version = request.0;
162 println!("Init version: {}", version);
163
164 match version {
165 0 => Ok(ReachOk),
166 version => Err(Error::UnsupportedProtocolVersion(version)),
167 }
168 }
169);
170
171request_handler!(
172 async fn handle(
173 request: InitAuthenticatedSession,
174 _global_context: _,
175 session_context: _,
176 ) -> Result<AuthenticationChallenge, Error> {
177 let version = request.0;
178 println!("Init Authenticated Session version: {}", version);
179
180 match version {
181 0 => {
182 let mut csprng = rng()?;
183
184 let challenge = AuthenticationChallenge::random_from_rng(&mut csprng);
185 let mut ctx = session_context.write().await;
186
187 ctx.authentication = Authentication::Challenged(challenge);
188 Ok(challenge)
189 }
190 version => Err(Error::UnsupportedProtocolVersion(version)),
191 }
192 }
193);
194
195request_handler!(
196 async fn handle(
197 request: AuthenticationAssurance,
198 global_context: _,
199 session_context: _,
200 ) -> Result<ReachOk, Error> {
201 let challenge = {
202 let ctx = session_context.read().await;
203
204 match ctx.authentication {
205 Authentication::Challenged(challenge) => challenge,
206 _ => return Err(Error::BadRequest),
207 }
208 };
209
210 let read_txn = global_context.db.begin_read()?;
211 let table = read_txn.open_table(REACHABLE_VERIFYING_KEYS).ok();
212
213 let authentication = match table.map_or(Ok(None), |t| t.get(request.authenticator_id)) {
214 Ok(Some(rvk_ag)) => {
215 let rvk = rvk_ag.value();
216 rvk.audit_authentication_assurance(&request, &challenge)
217 .then_some(Authentication::Authenticated(Box::new(rvk)))
218 }
219 Ok(None) => global_context
220 .attestant_verifying_keys
221 .audit_authentication_assurance(&request, &challenge)
222 .then_some(Authentication::AuthenticatedAttestant),
223 Err(_) => Some(Authentication::Unauthenticated),
224 };
225
226 let mut ctx = session_context.write().await;
227 ctx.authentication = authentication
228 .and_then(|a| a.is_authenticated().then_some(a))
229 .ok_or(Error::AuthenticationFailed)?;
230
231 Ok(ReachOk)
232 }
233);
234
235#[cfg(feature = "server")]
236request_handler!(
237 async fn handle(
238 request: ReachableVerifyingKeys,
239 global_context: _,
240 session_context: _,
241 ) -> Result<ReachOk, Error> {
242 {
243 let ctx = session_context.read().await;
244 if !matches!(ctx.authentication, Authentication::AuthenticatedAttestant) {
245 return Err(Error::BadRequest);
246 }
247 }
248
249 let reachable_id = request.finalized_digest();
250 let reachable_verifying_keys = request.inner;
251
252 let write_txn = global_context.db.begin_write()?;
253 {
254 let mut table = write_txn.open_table(REACHABLE_VERIFYING_KEYS)?;
255 table.insert(&reachable_id, reachable_verifying_keys)?;
256 }
257 write_txn.commit()?;
258
259 Ok(ReachOk)
260 }
261);
262
263macro_rules! verify_all {
264 ($reachable_public_keys:tt, $reachable_verifying_keys:expr) => {
265 $reachable_public_keys
266 .iter()
267 .all(|rpk| rpk.verify($reachable_verifying_keys))
268 };
269}
270
271#[cfg(feature = "server")]
272request_handler!(
273 async fn handle(
274 request: ReachablePublicKeyRing,
275 global_context: _,
276 session_context: _,
277 ) -> Result<ReachOk, Error> {
278 let rpks = &request.reachable_public_keys;
279
280 let reachable_verifying_keys = {
281 let ctx = session_context.read().await;
282
283 match &ctx.authentication {
284 Authentication::Authenticated(rvk) => {
285 verify_all!(rpks, rvk.as_ref()).then_some(*rvk.clone())
286 }
287 Authentication::AuthenticatedAttestant => {
288 let read_txn = global_context.db.begin_read()?;
289 let table = read_txn.open_table(REACHABLE_VERIFYING_KEYS)?;
290 table
291 .iter()?
292 .filter_map(Result::ok)
293 .map(|(_, ag)| ag.value())
294 .find_map(|rvk| verify_all!(rpks, &rvk).then_some(rvk))
295 }
296 _ => return Err(Error::BadRequest),
297 }
298 };
299
300 reachable_verifying_keys.map_or(Err(Error::BadRequest), |rvk| {
301 let write_txn = global_context.db.begin_write()?;
302 {
303 let mut table = write_txn.open_multimap_table(REACHABLE_PUBLIC_KEYS)?;
304 for rpk in rpks {
305 table.insert(&rvk, rpk)?;
306 }
307 }
308 write_txn.commit()?;
309
310 Ok(ReachOk)
311 })
312 }
313);
314
315#[derive(Debug, Clone, Copy, PartialEq, Eq)]
316pub enum SessionState {
317 Init,
318 Ready,
319 PendingAddEnvelope,
320 PendingAuthentication,
321 AuthenticatedAttestant,
322 AuthenticatedReachable,
323 Done,
324}
325
326impl SessionState {
327 pub fn check_transition(
328 &mut self,
329 to: Self,
330 input: Request,
331 output: Response,
332 ) -> Result<Self, error::WireError> {
333 use Request as I;
334 use Response as O;
335 use SessionState as S;
336
337 match (*self, input, output, to) {
338 (S::Init, I::Init, O::Reach, S::Ready) => {}
339 (S::Ready, I::EnvelopeId, O::Envelope | O::ErrorNotFound, S::Ready) => {}
340 (S::Ready, I::MessageVaultId, O::MessageVault | O::ErrorNotFound, S::Ready) => {}
341 (S::Ready, I::RemoveEnvelopeIdHint, O::Ok | O::ErrorNotFound, S::Done) => {}
342 (S::Init, I::InitAuthenticatedSession, O::Ok, S::PendingAuthentication) => {}
343 (
344 S::PendingAuthentication,
345 I::AuthenticationAssurance,
346 O::Ok,
347 S::AuthenticatedAttestant | S::AuthenticatedReachable,
348 ) => {}
349 (
350 S::PendingAuthentication,
351 I::AuthenticationAssurance,
352 O::ErrorVerification,
353 S::PendingAuthentication,
354 ) => {}
355 (S::Init, I::Init | I::InitAuthenticatedSession, O::ErrorUnsupported, S::Init) => {}
356 (S::Ready, I::AddMessageVault, O::SealedMessageVaultId, S::PendingAddEnvelope) => {}
357 (S::PendingAddEnvelope, I::AddEnvelope, O::SealedEnvelopeId, S::Done) => {}
358 _ => return Err(error::WireError::Unsupported),
359 };
360
361 Ok(to)
362 }
363}