reachable_node/server/
net.rs

1// SPDX-FileCopyrightText: 2024-2025 eaon <eaon@posteo.net>
2// SPDX-FileCopyrightText: 2024 Michael Goldenberg <m@mgoldenberg.net>
3// SPDX-License-Identifier: EUPL-1.2
4
5use std::{path::PathBuf, sync::Arc};
6
7use axum::response::Response as AxumResponse;
8use redb::{Database, ReadableTable};
9use tokio::sync::RwLock;
10
11use reach_core::{
12    AuditAuthenticationAssurance, RandomFromRng, error,
13    storage::{ReachablePublicKeyRing, Storable},
14    wire::{
15        AttestantVerifyingKeys, AuthenticationAssurance, AuthenticationChallenge, ErrorResponse,
16        Init, InitAuthenticatedSession, ReachCommunication, ReachOk, ReachableVerifyingKeys,
17        Request, Response, Salts,
18    },
19};
20use reach_proc_macros::request_handler_macro;
21use reach_signatures::{Digestible, Verifiable};
22use reach_websocket::{
23    RemoteServerContext, RemoteServerContextExtensions, WebSocketServer, WebSocketUpgrade,
24    macros::*, on_upgrade as core_on_upgrade,
25};
26
27use crate::db::{REACHABLE_PUBLIC_KEYS, REACHABLE_VERIFYING_KEYS};
28
29use super::memory::rng;
30
31pub struct ReachGlobalContext {
32    pub db: Database,
33    pub attestant_verifying_keys: AttestantVerifyingKeys,
34    pub salts: Salts,
35}
36
37impl ReachGlobalContext {
38    pub fn new(db: Database, config_path: PathBuf) -> Result<Self, Error> {
39        let attestant_verifying_keys =
40            AttestantVerifyingKeys::load(&config_path.join("attestant_verifying_keys"))?;
41        let salts = Salts::load(&config_path.join("salts"))?;
42
43        Ok(Self {
44            db,
45            attestant_verifying_keys,
46            salts,
47        })
48    }
49}
50
51pub struct ReachSessionContext {
52    pub(crate) authentication: Authentication,
53}
54
55impl Default for ReachSessionContext {
56    fn default() -> Self {
57        Self {
58            authentication: Authentication::Unauthenticated,
59        }
60    }
61}
62
63impl ReachSessionContext {
64    pub fn new() -> Self {
65        Self::default()
66    }
67}
68
69#[derive(Clone)]
70pub(crate) enum Authentication {
71    Unauthenticated,
72    Challenged(AuthenticationChallenge),
73    AuthenticatedAttestant,
74    Authenticated(Box<ReachableVerifyingKeys>),
75}
76
77impl Authentication {
78    pub fn is_authenticated(&self) -> bool {
79        matches!(
80            self,
81            Authentication::Authenticated(_) | Authentication::AuthenticatedAttestant
82        )
83    }
84}
85
86pub type ReachConnectionContext = RemoteServerContext<Request, Response>;
87pub type ReachConnection = WebSocketServer<ReachConnectionContext>;
88
89#[derive(thiserror::Error, Debug)]
90pub enum Error {
91    #[error("unsupported protocol version: {0}")]
92    UnsupportedProtocolVersion(u8),
93    #[error("rng error: {0}")]
94    RngError(#[from] rand::Error),
95    #[error("storage error: {0}")]
96    StorageError(#[from] error::StorageError),
97    #[error("bad request")]
98    BadRequest,
99    #[error("database commit: {0}")]
100    DatabaseCommit(#[from] redb::CommitError),
101    #[error("database storage: {0}")]
102    DatabaseStorage(#[from] redb::StorageError),
103    #[error("database transaction: {0}")]
104    DatabaseTransaction(#[from] redb::TransactionError),
105    #[error("database table: {0}")]
106    DatabaseTable(#[from] redb::TableError),
107    #[error("authentication failed")]
108    AuthenticationFailed,
109}
110
111request_wrapper!(ReachRequest, Request);
112
113request_delegator!(
114    ReachDelegator,
115    ReachCommunication,
116    ReachGlobalContext,
117    ReachSessionContext,
118    ReachRequest,
119    (
120        Init,
121        InitAuthenticatedSession,
122        AuthenticationAssurance,
123        ReachableVerifyingKeys,
124        ReachablePublicKeyRing,
125    ),
126);
127
128request_handler_macro!(
129    request_handler,
130    ReachCommunication,
131    ReachRequest,
132    ReachGlobalContext,
133    ReachSessionContext,
134);
135
136impl From<Error> for ErrorResponse {
137    fn from(_: Error) -> Self {
138        Self::Internal
139    }
140}
141
142pub fn on_upgrade<const B: usize>(
143    upgrade: WebSocketUpgrade<RemoteServerContextExtensions, ReachDelegator, ReachCommunication>,
144    session_context_init: impl FnOnce() -> Arc<RwLock<ReachSessionContext>> + Clone + Send + 'static,
145) -> AxumResponse {
146    core_on_upgrade::<
147        B,
148        RemoteServerContextExtensions,
149        ReachDelegator,
150        ReachCommunication,
151        ReachConnectionContext,
152    >(upgrade, session_context_init)
153}
154
155request_handler!(
156    async fn handle(
157        request: Init,
158        _global_context: _,
159        _session_context: _,
160    ) -> Result<ReachOk, Error> {
161        let version = request.0;
162        println!("Init version: {}", version);
163
164        match version {
165            0 => Ok(ReachOk),
166            version => Err(Error::UnsupportedProtocolVersion(version)),
167        }
168    }
169);
170
171request_handler!(
172    async fn handle(
173        request: InitAuthenticatedSession,
174        _global_context: _,
175        session_context: _,
176    ) -> Result<AuthenticationChallenge, Error> {
177        let version = request.0;
178        println!("Init Authenticated Session version: {}", version);
179
180        match version {
181            0 => {
182                let mut csprng = rng()?;
183
184                let challenge = AuthenticationChallenge::random_from_rng(&mut csprng);
185                let mut ctx = session_context.write().await;
186
187                ctx.authentication = Authentication::Challenged(challenge);
188                Ok(challenge)
189            }
190            version => Err(Error::UnsupportedProtocolVersion(version)),
191        }
192    }
193);
194
195request_handler!(
196    async fn handle(
197        request: AuthenticationAssurance,
198        global_context: _,
199        session_context: _,
200    ) -> Result<ReachOk, Error> {
201        let challenge = {
202            let ctx = session_context.read().await;
203
204            match ctx.authentication {
205                Authentication::Challenged(challenge) => challenge,
206                _ => return Err(Error::BadRequest),
207            }
208        };
209
210        let read_txn = global_context.db.begin_read()?;
211        let table = read_txn.open_table(REACHABLE_VERIFYING_KEYS).ok();
212
213        let authentication = match table.map_or(Ok(None), |t| t.get(request.authenticator_id)) {
214            Ok(Some(rvk_ag)) => {
215                let rvk = rvk_ag.value();
216                rvk.audit_authentication_assurance(&request, &challenge)
217                    .then_some(Authentication::Authenticated(Box::new(rvk)))
218            }
219            Ok(None) => global_context
220                .attestant_verifying_keys
221                .audit_authentication_assurance(&request, &challenge)
222                .then_some(Authentication::AuthenticatedAttestant),
223            Err(_) => Some(Authentication::Unauthenticated),
224        };
225
226        let mut ctx = session_context.write().await;
227        ctx.authentication = authentication
228            .and_then(|a| a.is_authenticated().then_some(a))
229            .ok_or(Error::AuthenticationFailed)?;
230
231        Ok(ReachOk)
232    }
233);
234
235#[cfg(feature = "server")]
236request_handler!(
237    async fn handle(
238        request: ReachableVerifyingKeys,
239        global_context: _,
240        session_context: _,
241    ) -> Result<ReachOk, Error> {
242        {
243            let ctx = session_context.read().await;
244            if !matches!(ctx.authentication, Authentication::AuthenticatedAttestant) {
245                return Err(Error::BadRequest);
246            }
247        }
248
249        let reachable_id = request.finalized_digest();
250        let reachable_verifying_keys = request.inner;
251
252        let write_txn = global_context.db.begin_write()?;
253        {
254            let mut table = write_txn.open_table(REACHABLE_VERIFYING_KEYS)?;
255            table.insert(&reachable_id, reachable_verifying_keys)?;
256        }
257        write_txn.commit()?;
258
259        Ok(ReachOk)
260    }
261);
262
263macro_rules! verify_all {
264    ($reachable_public_keys:tt, $reachable_verifying_keys:expr) => {
265        $reachable_public_keys
266            .iter()
267            .all(|rpk| rpk.verify($reachable_verifying_keys))
268    };
269}
270
271#[cfg(feature = "server")]
272request_handler!(
273    async fn handle(
274        request: ReachablePublicKeyRing,
275        global_context: _,
276        session_context: _,
277    ) -> Result<ReachOk, Error> {
278        let rpks = &request.reachable_public_keys;
279
280        let reachable_verifying_keys = {
281            let ctx = session_context.read().await;
282
283            match &ctx.authentication {
284                Authentication::Authenticated(rvk) => {
285                    verify_all!(rpks, rvk.as_ref()).then_some(*rvk.clone())
286                }
287                Authentication::AuthenticatedAttestant => {
288                    let read_txn = global_context.db.begin_read()?;
289                    let table = read_txn.open_table(REACHABLE_VERIFYING_KEYS)?;
290                    table
291                        .iter()?
292                        .filter_map(Result::ok)
293                        .map(|(_, ag)| ag.value())
294                        .find_map(|rvk| verify_all!(rpks, &rvk).then_some(rvk))
295                }
296                _ => return Err(Error::BadRequest),
297            }
298        };
299
300        reachable_verifying_keys.map_or(Err(Error::BadRequest), |rvk| {
301            let write_txn = global_context.db.begin_write()?;
302            {
303                let mut table = write_txn.open_multimap_table(REACHABLE_PUBLIC_KEYS)?;
304                for rpk in rpks {
305                    table.insert(&rvk, rpk)?;
306                }
307            }
308            write_txn.commit()?;
309
310            Ok(ReachOk)
311        })
312    }
313);
314
315#[derive(Debug, Clone, Copy, PartialEq, Eq)]
316pub enum SessionState {
317    Init,
318    Ready,
319    PendingAddEnvelope,
320    PendingAuthentication,
321    AuthenticatedAttestant,
322    AuthenticatedReachable,
323    Done,
324}
325
326impl SessionState {
327    pub fn check_transition(
328        &mut self,
329        to: Self,
330        input: Request,
331        output: Response,
332    ) -> Result<Self, error::WireError> {
333        use Request as I;
334        use Response as O;
335        use SessionState as S;
336
337        match (*self, input, output, to) {
338            (S::Init, I::Init, O::Reach, S::Ready) => {}
339            (S::Ready, I::EnvelopeId, O::Envelope | O::ErrorNotFound, S::Ready) => {}
340            (S::Ready, I::MessageVaultId, O::MessageVault | O::ErrorNotFound, S::Ready) => {}
341            (S::Ready, I::RemoveEnvelopeIdHint, O::Ok | O::ErrorNotFound, S::Done) => {}
342            (S::Init, I::InitAuthenticatedSession, O::Ok, S::PendingAuthentication) => {}
343            (
344                S::PendingAuthentication,
345                I::AuthenticationAssurance,
346                O::Ok,
347                S::AuthenticatedAttestant | S::AuthenticatedReachable,
348            ) => {}
349            (
350                S::PendingAuthentication,
351                I::AuthenticationAssurance,
352                O::ErrorVerification,
353                S::PendingAuthentication,
354            ) => {}
355            (S::Init, I::Init | I::InitAuthenticatedSession, O::ErrorUnsupported, S::Init) => {}
356            (S::Ready, I::AddMessageVault, O::SealedMessageVaultId, S::PendingAddEnvelope) => {}
357            (S::PendingAddEnvelope, I::AddEnvelope, O::SealedEnvelopeId, S::Done) => {}
358            _ => return Err(error::WireError::Unsupported),
359        };
360
361        Ok(to)
362    }
363}