Skip to main content

ecdh_omr/
take_the.rs

1// SPDX-FileCopyrightText: 2024-2026 eaon <eaon@posteo.net>
2// SPDX-License-Identifier: EUPL-1.2
3
4use aead::{Aead, KeyInit, Payload};
5#[cfg(feature = "rustcrypto-ec")]
6use elliptic_curve::{
7    Curve, CurveArithmetic,
8    point::PointCompression,
9    sec1::{CompressedPoint, FromSec1Point, ModulusSize, ToSec1Point},
10};
11
12#[cfg(feature = "dalek-x25519")]
13use crate::DalekX25519;
14#[cfg(feature = "rustcrypto-ec")]
15use crate::EllipticCurve;
16#[cfg(feature = "dalek-ristretto255")]
17use crate::{DalekRistretto255, dalek_ristretto255};
18use crate::{error::*, *};
19
20/// Decrypt [`Hint`] and [`Hints`]. Also a silly pun.
21pub trait TakeTheHint<K: KeyPair> {
22    /// Trial decryption of an individual [`Hint`].
23    fn take_the<A: Aead + KeyInit, L: ArraySize>(
24        &self,
25        hint: &Hint<K, A, L>,
26        context: &[u8],
27    ) -> Result<Array<u8, L>>
28    where
29        Hint<K, A, L>: HintSized<K, A, L>;
30
31    /// Trial decryption for a batch of [`Hints`].
32    fn take_all_the<A: Aead + KeyInit, L: ArraySize, const S: usize>(
33        &self,
34        hints: &Hints<Hint<K, A, L>, S>,
35        context: &[u8],
36    ) -> Vec<Array<u8, L>>
37    where
38        Hint<K, A, L>: HintSized<K, A, L> + Hinting<K, L>,
39    {
40        hints
41            .as_slice()
42            .iter()
43            .filter_map(|hint| self.take_the(hint, context).ok())
44            .collect()
45    }
46}
47
48fn decrypt<A: Aead + KeyInit, L: ArraySize>(
49    raw_shared_secret: impl AsRef<[u8]>,
50    blinded_blinding_factor: impl AsRef<[u8]>,
51    context: &[u8],
52    ciphertext: &[u8],
53) -> Result<Array<u8, L>> {
54    let shared_secret = shared_secret(
55        raw_shared_secret.as_ref(),
56        blinded_blinding_factor.as_ref(),
57        context,
58    );
59
60    let cipher = cipher_from_shared_secret::<A>(&shared_secret)?;
61    let nonce = nonce::<A>(blinded_blinding_factor.as_ref())?;
62
63    cipher
64        .decrypt(
65            &nonce,
66            Payload {
67                msg: ciphertext,
68                aad: blinded_blinding_factor.as_ref(),
69            },
70        )
71        .map(|m| m.into_iter().collect())
72        .map_err(Error::from)
73}
74
75#[cfg(any(feature = "dalek-ristretto255", feature = "dalek-x25519"))]
76macro_rules! dalek_take_the_hint {
77    ($curve:ty, $secret:path) => {
78        impl TakeTheHint<$curve> for $secret {
79            fn take_the<A: Aead + KeyInit, L: ArraySize>(
80                &self,
81                hint: &Hint<$curve, A, L>,
82                context: &[u8],
83            ) -> Result<Array<u8, L>>
84            where
85                Hint<$curve, A, L>: HintSized<$curve, A, L>,
86            {
87                let raw_shared_secret = self.diffie_hellman(&hint.blinded_blinding_factor);
88
89                decrypt::<A, L>(
90                    raw_shared_secret.to_bytes(),
91                    &hint.blinded_blinding_factor.to_bytes(),
92                    context,
93                    hint.ciphertext.as_slice(),
94                )
95            }
96        }
97    };
98}
99
100#[cfg(feature = "dalek-ristretto255")]
101dalek_take_the_hint!(DalekRistretto255, dalek_ristretto255::StaticSecret);
102
103#[cfg(feature = "dalek-x25519")]
104dalek_take_the_hint!(DalekX25519, x25519_dalek::StaticSecret);
105
106#[cfg(feature = "rustcrypto-ec")]
107impl<C: CurveArithmetic> TakeTheHint<EllipticCurve<C>> for elliptic_curve::SecretKey<C>
108where
109    C: CurveArithmetic + PointCompression,
110    <C as Curve>::FieldBytesSize: ModulusSize,
111    <C as CurveArithmetic>::AffinePoint: ToSec1Point<C> + FromSec1Point<C>,
112{
113    fn take_the<A: Aead + KeyInit, L: ArraySize>(
114        &self,
115        hint: &Hint<EllipticCurve<C>, A, L>,
116        context: &[u8],
117    ) -> Result<Array<u8, L>>
118    where
119        Hint<EllipticCurve<C>, A, L>: HintSized<EllipticCurve<C>, A, L>,
120    {
121        let raw_shared_secret = elliptic_curve::ecdh::diffie_hellman(
122            self.to_nonzero_scalar(),
123            hint.blinded_blinding_factor.as_affine(),
124        );
125        let blinded_blinding_factor_cp = CompressedPoint::<C>::from(&hint.blinded_blinding_factor);
126
127        decrypt::<A, L>(
128            raw_shared_secret.raw_secret_bytes(),
129            blinded_blinding_factor_cp,
130            context,
131            hint.ciphertext.as_slice(),
132        )
133    }
134}