reach_proc_macros/
lib.rs

1// SPDX-FileCopyrightText: 2023—2025 eaon <eaon@posteo.net>
2// SPDX-FileCopyrightText: 2023 Sam Schlinkert <sschlinkert@gmail.com>
3// SPDX-License-Identifier: EUPL-1.2
4
5use std::{collections::HashMap, str::FromStr, sync::LazyLock};
6
7use proc_macro2::{Ident, Span, TokenStream, TokenTree};
8use quote::{quote, quote_spanned};
9use syn::{Data, DeriveInput, GenericArgument, PathArguments, Type, parse_macro_input};
10
11type ConvFn = fn(&Ident, &Ident) -> TokenStream;
12
13static CONV_FNS: LazyLock<HashMap<String, (ConvFn, ConvFn, ConvFn)>> = LazyLock::new(|| {
14    let map: Vec<(&str, ConvFn)> = vec![
15        ("nop", |_, _| quote! {}),
16        ("to_vec", |n, _| quote! { from.#n.to_vec() }),
17        ("inner to_vec", |n, _| quote! { from.#n.inner.to_vec() }),
18        ("as_bytes", |n, _| quote! { from.#n.as_bytes().to_vec() }),
19        ("to_bytes", |n, _| quote! { from.#n.to_bytes().to_vec() }),
20        ("move", |n, _| quote! { from.#n }),
21        (
22            "zeroizing",
23            |n, _| quote! { zeroize::Zeroizing::new(from.#n.clone()) },
24        ),
25        ("zeroizing deref clone", |n, t| {
26            quote! { <zeroize::Zeroizing<#t> as std::ops::Deref>::deref(&from.#n).clone() }
27        }),
28        ("clone", |n, _| quote! { from.#n.clone() }),
29        ("as u32", |n, _| quote! { from.#n as u32 }),
30        ("direct", |n, t| quote! { #t::try_from(from.#n)? }),
31        ("slice", |n, t| quote! { #t::try_from(from.#n.as_slice())? }),
32        (
33            "try_from slice",
34            |n, t| quote! { #t::try_from(from.#n.as_slice())? },
35        ),
36        (
37            "from_bytes",
38            |n, t| quote! { #t::from_bytes(from.#n.as_slice().try_into()?) },
39        ),
40        (
41            "from_bytes?",
42            |n, t| quote! { #t::from_bytes(from.#n.as_slice())? },
43        ),
44        (
45            "via [u8; 24]",
46            |n, t| quote! { #t::from(<[u8; 24]>::try_from(from.#n.as_slice())?) },
47        ),
48        (
49            "via [u8; 32]",
50            |n, t| quote! { #t::from(<[u8; 32]>::try_from(from.#n.as_slice())?) },
51        ),
52        (
53            "via &[u8; 64]",
54            |n, t| quote! { #t::from_bytes(&<[u8; 64]>::try_from(from.#n.as_slice())?) },
55        ),
56        ("Option Box try", |n, _| {
57            quote! {
58                match from.#n {
59                    Some(inner) => Some(Box::new(inner.try_into()?)),
60                    None => None,
61                }
62            }
63        }),
64        (
65            "Option Box as_ref",
66            |n, t| quote! { from.#n.as_ref().map(|inner| proto::#t::from(inner.deref())) },
67        ),
68        ("Option try", |n, _| {
69            quote! {
70                match from.#n {
71                    Some(inner) => Some(inner.try_into()?),
72                    None => None,
73                }
74            }
75        }),
76        (
77            "Option as_ref",
78            |n, t| quote! { from.#n.as_ref().map(proto::#t::from) },
79        ),
80        (
81            "Option map from",
82            |n, t| quote! { from.#n.map(proto::#t::from) },
83        ),
84        ("Option Result", |n, t| {
85            let core_path = core_path(n.span());
86            quote! { #t::try_from(from.#n.ok_or(#core_path::error::DecodeError)?)? }
87        }),
88        (
89            "Option Some",
90            |n, t| quote! { Some(proto::#t::from(&from.#n)) },
91        ),
92        (
93            "Option Some Owned",
94            |n, t| quote! { Some(proto::#t::from(from.#n)) },
95        ),
96        (
97            "into_iter try_from",
98            |n, t| quote! { from.#n.into_iter().map(#t::try_from).collect::<Result<_, _>>()? },
99        ),
100        ("into_iter from_bytes", |n, _| {
101            quote! { from.#n.into_iter().map(|item| {
102                item.as_slice().try_into()
103            }).collect::<Result<_, _>>()? }
104        }),
105        (
106            "iter",
107            |n, t| quote! { from.#n.iter().map(proto::#t::from).collect() },
108        ),
109        (
110            "into_iter from",
111            |n, t| quote! { from.#n.into_iter().map(proto::#t::from).collect() },
112        ),
113        (
114            "iter to_bytes BlindedPublicKey",
115            |n, _| quote! { from.#n.iter().map(|item| item.to_bytes().to_vec()).collect() },
116        ),
117        (
118            "proto decode",
119            |n, t| quote! { #t::decode(from.#n.as_slice())? },
120        ),
121        ("proto encode", |n, _| quote! { from.#n.encode_to_vec() }),
122        ("into_iter proto decode", |n, t| {
123            quote! {
124                from.#n.into_iter().map(#t::decode).collect::<Result<_, _>>()?
125            }
126        }),
127        (
128            "iter proto encode",
129            |n, _| quote! { from.#n.iter().map(|item| item.encode_to_vec()).collect() },
130        ),
131    ];
132
133    let map = map.into_iter().collect::<HashMap<&'static str, ConvFn>>();
134
135    [
136        ("BlindedPublicKey", "from_bytes", "to_bytes", "nop"),
137        ("Ed25519Signature", "via &[u8; 64]", "to_bytes", "to_bytes"),
138        ("Ed25519Signing", "slice", "to_bytes", "to_bytes"),
139        ("Ed25519Verifying", "slice", "as_bytes", "as_bytes"),
140        ("FnDsaVerifying", "slice", "to_vec", "to_vec"),
141        ("FnDsaSignature", "slice", "to_vec", "to_vec"),
142        ("FnDsaSigning", "slice", "inner to_vec", "inner to_vec"),
143        ("EnvelopeIdHints", "from_bytes?", "to_bytes", "to_bytes"),
144        ("MlKemCiphertext", "try_from slice", "to_vec", "to_vec"),
145        ("MlKemPublic", "from_bytes", "as_bytes", "as_bytes"),
146        ("MlKemSecret", "from_bytes", "as_bytes", "as_bytes"),
147        ("X25519Public", "via [u8; 32]", "as_bytes", "as_bytes"),
148        ("X25519Secret", "via [u8; 32]", "as_bytes", "as_bytes"),
149        ("XChaChaKey", "via [u8; 32]", "to_vec", "to_vec"),
150        ("XNonce", "via [u8; 24]", "to_vec", "to_vec"),
151        ("SixFour", "slice", "to_vec", "to_vec"),
152        ("String", "move", "clone", "move"),
153        ("ThreeTwo", "slice", "to_vec", "to_vec"),
154        ("TwoFour", "slice", "to_vec", "to_vec"),
155        ("OneSix", "slice", "to_vec", "to_vec"),
156        ("u16", "direct", "as u32", "as u32"),
157        ("Vec<u8>", "move", "clone", "move"),
158        (
159            "Vec<BlindedPublicKey>",
160            "into_iter from_bytes",
161            "iter to_bytes BlindedPublicKey",
162            "nop",
163        ),
164        ("Vec<_>", "into_iter try_from", "iter", "into_iter from"),
165        (
166            "Option<MessageMetadata>",
167            "Option try",
168            "nop",
169            "Option map from",
170        ),
171        (
172            "Option<Box<_>>",
173            "Option Box try",
174            "Option Box as_ref",
175            "nop",
176        ),
177        ("Option<T>", "move", "clone", "move"),
178        ("Option<_>", "Option try", "Option as_ref", "Option as_ref"),
179        ("_", "Option Result", "Option Some", "Option Some Owned"),
180        ("as ProstMessage", "proto decode", "proto encode", "nop"),
181        (
182            "Vec<as ProstMessage>",
183            "into_iter proto decode",
184            "iter proto encode",
185            "nop",
186        ),
187        (
188            "Zeroizing<_>",
189            "zeroizing",
190            "zeroizing deref clone",
191            "zeroizing",
192        ),
193    ]
194    .into_iter()
195    .map(|(key, de, en, eno)| (key.to_string(), (map[de], map[en], map[eno])))
196    .collect()
197});
198
199#[derive(Debug)]
200struct LocalField<'a> {
201    field_name: &'a Ident,
202    field_type: &'a Ident,
203    field_inner_types: Vec<&'a Ident>,
204}
205
206impl LocalField<'_> {
207    fn conv_material(&self) -> ((ConvFn, ConvFn, ConvFn), &Ident) {
208        let field_type = self.field_type.to_string();
209        let mut target_field_type = self.field_type;
210        let fits = &self.field_inner_types;
211
212        let key = if field_type.starts_with("Wire")
213            || field_type.starts_with("Memory")
214            || field_type.starts_with("Storage")
215        {
216            "as ProstMessage".to_string()
217        } else if fits.is_empty() {
218            field_type
219        } else {
220            match (field_type.as_str(), fits.first(), fits.get(1)) {
221                ("Option", Some(one), Some(two)) => {
222                    let inner_type = one.to_string();
223                    if inner_type == "Box" {
224                        target_field_type = two;
225                        "Option<Box<_>>"
226                    } else {
227                        "Option<T>"
228                    }
229                    .to_string()
230                }
231                ("Option", Some(one), _) => {
232                    let inner_type = one.to_string();
233                    if inner_type == "String" || inner_type == "u32" {
234                        "Option<T>"
235                    } else if inner_type == "MessageMetadata" {
236                        target_field_type = one;
237                        "Option<MessageMetadata>"
238                    } else {
239                        target_field_type = one;
240                        "Option<_>"
241                    }
242                    .to_string()
243                }
244                ("Vec", Some(one), _) => {
245                    let inner_type = one.to_string();
246                    if inner_type.starts_with("Wire")
247                        || inner_type.starts_with("Memory")
248                        || inner_type.starts_with("Storage")
249                    {
250                        target_field_type = one;
251                        "Vec<as ProstMessage>"
252                    } else if inner_type == "BlindedPublicKey" {
253                        "Vec<BlindedPublicKey>"
254                    } else if inner_type == "u8" {
255                        "Vec<u8>"
256                    } else {
257                        target_field_type = one;
258                        "Vec<_>"
259                    }
260                    .to_string()
261                }
262                ("Zeroizing", Some(one), _) => {
263                    target_field_type = one;
264                    "Zeroizing<_>".to_string()
265                }
266                _ => field_type,
267            }
268        };
269
270        (
271            match CONV_FNS.contains_key(&key) {
272                true => CONV_FNS[&key],
273                false => CONV_FNS["_"],
274            },
275            target_field_type,
276        )
277    }
278
279    fn convert(&self, conv: TokenStream) -> TokenStream {
280        let field_name = self.field_name;
281        quote_spanned! {field_name.span()=> #field_name: #conv, }
282    }
283
284    fn decode(&self) -> TokenStream {
285        let ((decode, _, _), target_field_type) = self.conv_material();
286        self.convert(decode(self.field_name, target_field_type))
287    }
288
289    fn encode(&self) -> TokenStream {
290        let ((_, encode, _), target_field_type) = self.conv_material();
291        self.convert(encode(self.field_name, target_field_type))
292    }
293
294    fn encode_owned(&self) -> TokenStream {
295        let ((_, _, encode_owned), target_field_type) = self.conv_material();
296        self.convert(encode_owned(self.field_name, target_field_type))
297    }
298}
299
300fn decode(data: &Data) -> Vec<TokenStream> {
301    parse_fields(data).iter().map(LocalField::decode).collect()
302}
303
304fn encode(data: &Data) -> Vec<TokenStream> {
305    parse_fields(data).iter().map(LocalField::encode).collect()
306}
307
308fn encode_owned(data: &Data) -> Vec<TokenStream> {
309    parse_fields(data)
310        .iter()
311        .map(LocalField::encode_owned)
312        .collect()
313}
314
315#[allow(clippy::panic)]
316fn prosted_derivations(
317    tt: &str,
318    input: &DeriveInput,
319    mod_base: &TokenTree,
320    prost_type_name: &TokenTree,
321    type_name: &Ident,
322) -> TokenStream {
323    let core_path = core_path(type_name.span());
324
325    if tt == "Decode" {
326        let decode_fields = decode(&input.data);
327
328        quote! {
329            impl TryFrom<#mod_base::#prost_type_name> for #type_name {
330                type Error = #core_path::error::DecodeError;
331
332                fn try_from(from: #mod_base::#prost_type_name) -> Result<Self, Self::Error> {
333                    Ok(#type_name {
334                        #(#decode_fields)*
335                    })
336                }
337            }
338        }
339    } else if tt == "Encode" {
340        let encode_fields = encode(&input.data);
341
342        quote! {
343            impl From<&#type_name> for #mod_base::#prost_type_name {
344                fn from(from: &#type_name) -> Self {
345                    Self {
346                        #(#encode_fields)*
347                    }
348                }
349            }
350        }
351    } else if tt == "EncodeOwned" {
352        let encode_owned_fields = encode_owned(&input.data);
353
354        quote! {
355            impl From<#type_name> for #mod_base::#prost_type_name {
356                fn from(from: #type_name) -> Self {
357                    Self {
358                        #(#encode_owned_fields)*
359                    }
360                }
361            }
362        }
363    } else if tt == "ProstTraits" {
364        quote! {
365            impl #core_path::ProstDecode for #type_name {
366                type EncodedType = #mod_base::#prost_type_name;
367            }
368
369            impl #core_path::ProstEncode for #type_name {
370                type EncodedType = #mod_base::#prost_type_name;
371            }
372        }
373    } else if tt == "ProstTraitsOwned" {
374        quote! {
375            impl #core_path::ProstDecode for #type_name {
376                type EncodedType = #mod_base::#prost_type_name;
377            }
378
379            impl #core_path::ProstEncodeOwned for #type_name {
380                type EncodedType = #mod_base::#prost_type_name;
381            }
382        }
383    } else {
384        panic!("Unsupported prosted derivation.");
385    }
386}
387
388#[proc_macro_attribute]
389pub fn prosted(
390    attr: proc_macro::TokenStream,
391    item: proc_macro::TokenStream,
392) -> proc_macro::TokenStream {
393    let structure = TokenStream::from(item.clone());
394    let input = parse_macro_input!(item as DeriveInput);
395    let type_name = &input.ident;
396
397    let attr = TokenStream::from(attr)
398        .into_iter()
399        .filter(|t| {
400            let tk = t.to_string();
401            let tk_s = tk.as_str();
402            tk_s != "," && tk_s != ":"
403        })
404        .collect::<Vec<_>>();
405
406    let mod_base = &attr[0];
407    let prost_type_name = &attr[1];
408
409    let impls = attr
410        .iter()
411        .skip(2)
412        .map(|tt| {
413            prosted_derivations(
414                tt.to_string().as_str(),
415                &input,
416                mod_base,
417                prost_type_name,
418                type_name,
419            )
420        })
421        .collect::<Vec<_>>();
422
423    quote! {
424        #structure
425
426        #(#impls)*
427    }
428    .into()
429}
430
431#[proc_macro_attribute]
432pub fn communicable(
433    attr: proc_macro::TokenStream,
434    item: proc_macro::TokenStream,
435) -> proc_macro::TokenStream {
436    let structure = TokenStream::from(item.clone());
437    let input = parse_macro_input!(item as DeriveInput);
438    let type_name = &input.ident;
439
440    let core_path = core_path(type_name.span());
441
442    let attr = TokenStream::from(attr).into_iter().collect::<Vec<_>>();
443    let base_type = &attr[0];
444    let variant = &attr[3];
445    let kind = attr[&attr.len() - 1].to_string(); // owned, marker
446
447    #[allow(clippy::unwrap_used)]
448    let (prost, trait_name) = match kind.as_str() {
449        "owned" => (true, "ProstCommunicableOwned"),
450        "marker" => (false, "Communicable"),
451        _ => (true, "ProstCommunicable"),
452    };
453    #[allow(clippy::unwrap_used)]
454    let trait_name = TokenStream::from_str(trait_name).unwrap();
455    if prost {
456        quote! {
457            #structure
458
459            impl #core_path::communication::#trait_name<#base_type> for #type_name {
460                const COMMUNICATION_VARIANT: #base_type = #base_type::#variant;
461            }
462        }
463    } else {
464        quote! {
465            #structure
466
467            impl #core_path::communication::#trait_name<#base_type> for #type_name {
468                fn to_communication(&self) -> #core_path::communication::Communication<#base_type> {
469                    #core_path::communication::Communication::from(#base_type::#variant, vec![], None)
470                }
471
472                fn try_from_communication(
473                    communication: &#core_path::communication::Communication<#base_type>,
474                ) -> Result<Self, #core_path::error::DecodeError> {
475                    if communication.r#type != #base_type::#variant {
476                        return Err(#core_path::error::DecodeError);
477                    }
478
479                    Ok(Self {})
480                }
481            }
482        }
483    }
484    .into()
485}
486
487fn parse_fields(data: &Data) -> Vec<LocalField<'_>> {
488    let mut fields = Vec::new();
489
490    match data {
491        Data::Struct(inner) => {
492            for field in &inner.fields {
493                let field_name = field.ident.as_ref().expect("Expected field name");
494                if let Type::Path(inner) = &field.ty {
495                    let field_type = &inner.path.segments[0].ident;
496                    let mut field_inner_types = vec![];
497                    all_anglebracketed(&inner.path.segments[0].arguments, &mut field_inner_types);
498
499                    fields.push(LocalField {
500                        field_name,
501                        field_type,
502                        field_inner_types,
503                    });
504                }
505            }
506        }
507        _ => unimplemented!(),
508    }
509
510    fields
511}
512
513fn all_anglebracketed<'a>(arguments: &'a PathArguments, idents: &mut Vec<&'a Ident>) {
514    if let PathArguments::AngleBracketed(inner) = arguments {
515        if let GenericArgument::Type(Type::Path(inner)) = &inner.args[0] {
516            idents.push(&inner.path.segments[0].ident);
517            all_anglebracketed(&inner.path.segments[0].arguments, idents);
518        }
519    }
520}
521
522fn core_path(span: Span) -> Ident {
523    Ident::new(
524        match std::env::var("CARGO_PKG_NAME")
525            .expect("Should use cargo")
526            .as_str()
527        {
528            "reach-core" => "crate",
529            _ => "reach_core",
530        },
531        span,
532    )
533}
534
535#[proc_macro_derive(ParticipantPublicKeys)]
536pub fn derive_participant_public_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
537    let input = parse_macro_input!(input as DeriveInput);
538    let type_name = &input.ident;
539    let participant_variant = if type_name.to_string().starts_with("Reaching") {
540        Ident::new("Reaching", type_name.span())
541    } else {
542        Ident::new("Reachable", type_name.span())
543    };
544
545    let core_path = core_path(type_name.span());
546
547    quote! {
548        impl #core_path::ParticipantPublicKeys for #type_name {
549            fn ec_public_key(&self) -> &reach_aliases::X25519Public {
550                &self.ec_public_key
551            }
552
553            fn pq_public_key(&self) -> &reach_aliases::MlKemPublic {
554                &self.pq_public_key
555            }
556
557            fn participant_type(&self) -> #core_path::ParticipantType {
558                #core_path::ParticipantType::#participant_variant
559            }
560        }
561    }
562    .into()
563}
564
565#[proc_macro_derive(ParticipantSecretKeys)]
566pub fn derive_participant_secret_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
567    let input = parse_macro_input!(input as DeriveInput);
568    let type_name = &input.ident;
569
570    quote! {
571        impl reach_encryption::ParticipantSecretKeys for #type_name {
572            fn ec_secret_key(&self) -> &reach_aliases::X25519Secret {
573                &self.ec_secret_key
574            }
575
576            fn pq_secret_key(&self) -> &reach_aliases::MlKemSecret {
577                &self.pq_secret_key
578            }
579        }
580    }
581    .into()
582}
583
584#[proc_macro_derive(SecretKeysRandomFromRng)]
585pub fn derive_secret_keys_random_from_rng(
586    input: proc_macro::TokenStream,
587) -> proc_macro::TokenStream {
588    let input = parse_macro_input!(input as DeriveInput);
589    let type_name = &input.ident;
590
591    let core_path = core_path(type_name.span());
592
593    quote! {
594        impl #core_path::RandomFromRng for #type_name {
595            fn random_from_rng(csprng: &mut impl rand_core::CryptoRngCore) -> Self {
596                let (pq_secret_key, pq_public_key) = reach_aliases::MlKem::generate(csprng);
597
598                Self {
599                    ec_secret_key: reach_aliases::X25519Secret::random_from_rng(csprng),
600                    pq_secret_key,
601                    pq_public_key,
602                }
603            }
604        }
605    }
606    .into()
607}
608
609#[proc_macro_derive(UnsignedPublicKeys)]
610pub fn derive_unsigned_public_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
611    let input = parse_macro_input!(input as DeriveInput);
612    let type_name = &input.ident;
613
614    let public_keys_type_name = Ident::new(
615        format!(
616            "Unsigned{}",
617            type_name.to_string().replace("SecretKeys", "PublicKeys")
618        )
619        .as_str(),
620        type_name.span(),
621    );
622
623    public_keys_tokenstream(type_name, &public_keys_type_name).into()
624}
625
626#[proc_macro_derive(PublicKeys)]
627pub fn derive_public_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
628    let input = parse_macro_input!(input as DeriveInput);
629    let type_name = &input.ident;
630
631    let public_keys_type_name = Ident::new(
632        type_name
633            .to_string()
634            .as_str()
635            .replace("SecretKeys", "PublicKeys")
636            .as_str(),
637        type_name.span(),
638    );
639
640    public_keys_tokenstream(type_name, &public_keys_type_name).into()
641}
642
643fn public_keys_tokenstream(type_name: &Ident, public_keys_type_name: &Ident) -> TokenStream {
644    quote! {
645        impl From<&#type_name> for #public_keys_type_name {
646            fn from(from: &#type_name) -> Self {
647                Self {
648                    ec_public_key: reach_aliases::X25519Public::from(&from.ec_secret_key),
649                    pq_public_key: from.pq_public_key.clone(),
650                }
651            }
652        }
653    }
654}
655
656#[proc_macro_derive(Verifier)]
657pub fn derive_verifier(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
658    let input = parse_macro_input!(input as DeriveInput);
659    let type_name = &input.ident;
660
661    quote! {
662        impl reach_signatures::Verifier for #type_name {
663            fn ec_verifying_key(&self) -> &reach_aliases::Ed25519Verifying {
664                &self.ec_verifying_key
665            }
666
667            fn pq_verifying_key(&self) -> &reach_aliases::FnDsaVerifying {
668                &self.pq_verifying_key
669            }
670        }
671    }
672    .into()
673}
674
675fn verifying_keys(type_name: &Ident, verifying_type_name: &Ident) -> proc_macro::TokenStream {
676    quote! {
677        impl From<&#type_name> for #verifying_type_name {
678            fn from(from: &#type_name) -> Self {
679                Self {
680                    ec_verifying_key: from.ec_signing_key.verifying_key(),
681                    pq_verifying_key: from.pq_verifying_key,
682                }
683            }
684        }
685    }
686    .into()
687}
688
689#[proc_macro_derive(UnsignedVerifyingKeys)]
690pub fn derive_unsigned_verifying_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
691    let input = parse_macro_input!(input as DeriveInput);
692    let type_name = &input.ident;
693    let verifying_type_name = Ident::new(
694        format!(
695            "Unsigned{}",
696            type_name.to_string().replace("Signing", "Verifying")
697        )
698        .as_str(),
699        type_name.span(),
700    );
701
702    verifying_keys(type_name, &verifying_type_name)
703}
704
705#[proc_macro_derive(VerifyingKeys)]
706pub fn derive_verifying_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
707    let input = parse_macro_input!(input as DeriveInput);
708    let type_name = &input.ident;
709    let verifying_type_name = Ident::new(
710        type_name
711            .to_string()
712            .replace("Signing", "Verifying")
713            .as_str(),
714        type_name.span(),
715    );
716
717    verifying_keys(type_name, &verifying_type_name)
718}
719
720#[proc_macro_derive(SigningKeys)]
721pub fn derive_signing_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
722    let input = parse_macro_input!(input as DeriveInput);
723    let type_name = &input.ident;
724
725    let core_path = core_path(type_name.span());
726
727    quote! {
728        impl #core_path::RandomFromRng for #type_name {
729            fn random_from_rng(csprng: &mut impl rand_core::CryptoRngCore) -> Self {
730                let mut kg = fn_dsa_kgen::KeyPairGeneratorStandard::default();
731                let mut pq_signing_key_inner = reach_signatures::FN_DSA_SIGNING_EMPTY;
732                let mut pq_verifying_key = reach_aliases::FN_DSA_VERIFYING_EMPTY;
733                kg.keygen(
734                    reach_aliases::FN_DSA_DEGREES,
735                    csprng,
736                    &mut pq_signing_key_inner,
737                    &mut pq_verifying_key
738                );
739
740                Self {
741                    ec_signing_key: reach_aliases::Ed25519Signing::generate(csprng),
742                    pq_signing_key: reach_aliases::FnDsaSigning {
743                        inner: pq_signing_key_inner,
744                    },
745                    pq_verifying_key,
746                }
747            }
748        }
749
750        impl reach_signatures::Sign for #type_name {
751            fn ec_signing_key(&self) -> &reach_aliases::Ed25519Signing {
752                &self.ec_signing_key
753            }
754
755            fn pq_signing_key(&self) -> &reach_aliases::FnDsaSigning {
756                &self.pq_signing_key
757            }
758        }
759    }
760    .into()
761}
762
763#[proc_macro]
764pub fn request_handler_macro(attr: proc_macro::TokenStream) -> proc_macro::TokenStream {
765    let attr = TokenStream::from(attr)
766        .into_iter()
767        .filter(|t| {
768            let tk = t.to_string();
769            let tk_s = tk.as_str();
770            tk_s != ","
771        })
772        .collect::<Vec<_>>();
773
774    let macro_name = &attr[0];
775    let communication_type = &attr[1];
776    let request_wrapper_type = &attr[2];
777    let global_context_type = &attr[3];
778    let session_context_type = &attr[4];
779
780    quote! {
781        macro_rules! #macro_name {
782            (
783                async fn handle(
784                    $request:tt: $request_type:ty,
785                    $global_context:tt: _,
786                    $session_context:tt: _$(,)?
787                ) -> $response_type:ty
788                $handler_body:block
789            ) => {
790                #macro_name!(
791                    async fn handle(
792                        $request: $request_type,
793                        $global_context: #global_context_type,
794                        $session_context: #session_context_type,
795                    ) -> $response_type $handler_body
796                );
797            };
798
799            (
800                async fn handle(
801                    $request:tt: $request_type:ty,
802                    $global_context:tt: $global_context_type:ty,
803                    $session_context:tt: _$(,)?
804                ) -> $response_type:ty
805                $handler_body:block
806            ) => {
807                #macro_name!(
808                    async fn handle(
809                        $request: $request_type,
810                        $global_context: $global_context_type,
811                        $session_context: #session_context_type,
812                    ) -> $response_type $handler_body
813                );
814            };
815
816            (
817                async fn handle(
818                    $request:tt: $request_type:ty,
819                    $global_context:tt: _,
820                    $session_context:tt: $session_context_type:ty$(,)?
821                ) -> $response_type:ty
822                $handler_body:block
823            ) => {
824                #macro_name!(
825                    async fn handle(
826                        $request: $request_type,
827                        $global_context: #global_context_type,
828                        $session_context: $session_context_type,
829                    ) -> $response_type $handler_body
830                );
831            };
832
833            (
834                async fn handle(
835                    $request:tt: $request_type:ty,
836                    $global_context:tt: $global_context_type:ty,
837                    $session_context:tt: $session_context_type:ty$(,)?
838                ) -> $response_type:ty
839                $handler_body:block
840            ) => {
841                impl reach_websocket::RequestHandler<#communication_type>
842                    for #request_wrapper_type<$request_type>
843                {
844                    type GlobalContext = $global_context_type;
845                    type SessionContext = $session_context_type;
846
847                    async fn handle(
848                        $request: Self,
849                        $global_context: std::sync::Arc<Self::GlobalContext>,
850                        $session_context: std::sync::Arc<tokio::sync::RwLock<Self::SessionContext>>,
851                    ) -> $response_type $handler_body
852                }
853            };
854        }
855
856        pub(crate) use #macro_name;
857    }
858    .into()
859}