ecdh_omr/
take_the.rs

1// SPDX-FileCopyrightText: 2024-2025 eaon <eaon@posteo.net>
2// SPDX-License-Identifier: EUPL-1.2
3
4use aead::{Aead, AeadCore, KeyInit, generic_array::typenum::marker_traits::Unsigned};
5#[cfg(feature = "rustcrypto-ec")]
6use elliptic_curve::{
7    Curve, CurveArithmetic,
8    point::PointCompression,
9    sec1::{CompressedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint},
10};
11use sha3::{Digest, Sha3_256};
12
13use crate::{
14    Hint, Hinting, Hints, cipher_from_shared_secret, curves,
15    curves::{KeyPair, sealed},
16    error::*,
17};
18#[cfg(feature = "rustcrypto-ec")]
19use curves::{EllipticCurve, rcec};
20#[cfg(feature = "dalek")]
21use curves::{X25519, dalek};
22
23/// Decrypt [`Hint`] and [`Hints`]. Also a silly pun.
24pub trait TakeTheHint<K: KeyPair> {
25    /// Trial decryption of an individual [`Hint`].
26    fn take_the<A: Aead + KeyInit, const L: usize>(
27        &self,
28        hint: &Hint<K, A, L>,
29        salt: &[u8],
30    ) -> Result<[u8; L]>;
31
32    /// Trial decryption for a batch of [`Hints`].
33    fn take_all_the<A: Aead + KeyInit, const L: usize, const S: usize>(
34        &self,
35        hints: &Hints<Hint<K, A, L>, S>,
36        salt: &[u8],
37    ) -> Vec<[u8; L]>
38    where
39        Hint<K, A, L>: Hinting<K, L>,
40        K::SecretKey: sealed::RandomSecretKey,
41    {
42        hints
43            .as_slice()
44            .iter()
45            .filter_map(|hint| self.take_the(hint, salt).ok())
46            .collect()
47    }
48}
49
50fn decrypt<A: Aead + KeyInit>(
51    nonce: &[u8],
52    shared_secret: impl AsRef<[u8]>,
53    ciphertext: &[u8],
54) -> Result<Vec<u8>> {
55    let cipher: A = cipher_from_shared_secret(shared_secret);
56    let nonce_size = <A as AeadCore>::NonceSize::to_usize();
57    let nonce = aead::Nonce::<A>::from_slice(&nonce[..nonce_size]);
58
59    Ok(cipher.decrypt(nonce, ciphertext)?)
60}
61
62#[cfg(feature = "dalek")]
63impl TakeTheHint<X25519> for dalek::StaticSecret {
64    fn take_the<A: Aead + KeyInit, const L: usize>(
65        &self,
66        hint: &Hint<X25519, A, L>,
67        salt: &[u8],
68    ) -> Result<[u8; L]> {
69        let raw_shared_secret = self.diffie_hellman(&hint.blinded_blinding_factor);
70
71        let mut hasher = <Sha3_256 as Digest>::new();
72        hasher.update(raw_shared_secret.as_bytes());
73        hasher.update(hint.blinded_blinding_factor.as_bytes());
74        hasher.update(salt);
75
76        let shared_secret = hasher.finalize();
77
78        <[u8; L]>::try_from(
79            decrypt::<A>(
80                hint.blinded_blinding_factor.as_bytes(),
81                shared_secret,
82                hint.ciphertext.as_slice(),
83            )?
84            .as_slice(),
85        )
86        .map_err(|_| Error::MessageLength)
87    }
88}
89
90#[cfg(feature = "rustcrypto-ec")]
91impl<C: CurveArithmetic> TakeTheHint<EllipticCurve<C>> for rcec::SecretKey<C>
92where
93    C: CurveArithmetic + PointCompression,
94    <C as Curve>::FieldBytesSize: ModulusSize,
95    <C as CurveArithmetic>::AffinePoint: ToEncodedPoint<C> + FromEncodedPoint<C>,
96{
97    fn take_the<A: Aead + KeyInit, const L: usize>(
98        &self,
99        hint: &Hint<EllipticCurve<C>, A, L>,
100        salt: &[u8],
101    ) -> Result<[u8; L]> {
102        let raw_shared_secret = elliptic_curve::ecdh::diffie_hellman(
103            self.to_nonzero_scalar(),
104            hint.blinded_blinding_factor.as_affine(),
105        );
106        let blinded_blinding_factor_cp = CompressedPoint::<C>::from(hint.blinded_blinding_factor);
107
108        let mut hasher = <Sha3_256 as Digest>::new();
109        hasher.update(raw_shared_secret.raw_secret_bytes());
110        hasher.update(blinded_blinding_factor_cp.as_slice());
111        hasher.update(salt);
112
113        let shared_secret = hasher.finalize();
114
115        <[u8; L]>::try_from(
116            decrypt::<A>(
117                blinded_blinding_factor_cp.as_slice(),
118                shared_secret,
119                hint.ciphertext.as_slice(),
120            )?
121            .as_slice(),
122        )
123        .map_err(|_| Error::MessageLength)
124    }
125}