1use std::{collections::HashMap, str::FromStr, sync::LazyLock};
6
7use proc_macro2::{Ident, Span, TokenStream, TokenTree};
8use quote::{quote, quote_spanned};
9use syn::{Data, DeriveInput, GenericArgument, PathArguments, Type, parse_macro_input};
10
11type ConvFn = fn(&Ident, &Ident) -> TokenStream;
12
13static CONV_FNS: LazyLock<HashMap<String, (ConvFn, ConvFn, ConvFn)>> = LazyLock::new(|| {
14 let map: Vec<(&str, ConvFn)> = vec![
15 ("nop", |_, _| quote! {}),
16 ("to_vec", |n, _| quote! { from.#n.to_vec() }),
17 ("inner to_vec", |n, _| quote! { from.#n.inner.to_vec() }),
18 ("as_bytes", |n, _| quote! { from.#n.as_bytes().to_vec() }),
19 ("to_bytes", |n, _| quote! { from.#n.to_bytes().to_vec() }),
20 ("move", |n, _| quote! { from.#n }),
21 (
22 "zeroizing",
23 |n, _| quote! { zeroize::Zeroizing::new(from.#n.clone()) },
24 ),
25 ("zeroizing deref clone", |n, t| {
26 quote! { <zeroize::Zeroizing<#t> as std::ops::Deref>::deref(&from.#n).clone() }
27 }),
28 ("clone", |n, _| quote! { from.#n.clone() }),
29 ("as u32", |n, _| quote! { from.#n as u32 }),
30 ("direct", |n, t| quote! { #t::try_from(from.#n)? }),
31 ("slice", |n, t| quote! { #t::try_from(from.#n.as_slice())? }),
32 (
33 "try_from slice",
34 |n, t| quote! { #t::try_from(from.#n.as_slice())? },
35 ),
36 (
37 "from_bytes",
38 |n, t| quote! { #t::from_bytes(from.#n.as_slice().try_into()?) },
39 ),
40 (
41 "from_bytes?",
42 |n, t| quote! { #t::from_bytes(from.#n.as_slice())? },
43 ),
44 (
45 "via [u8; 24]",
46 |n, t| quote! { #t::from(<[u8; 24]>::try_from(from.#n.as_slice())?) },
47 ),
48 (
49 "via [u8; 32]",
50 |n, t| quote! { #t::from(<[u8; 32]>::try_from(from.#n.as_slice())?) },
51 ),
52 (
53 "via &[u8; 64]",
54 |n, t| quote! { #t::from_bytes(&<[u8; 64]>::try_from(from.#n.as_slice())?) },
55 ),
56 ("Option Box try", |n, _| {
57 quote! {
58 match from.#n {
59 Some(inner) => Some(Box::new(inner.try_into()?)),
60 None => None,
61 }
62 }
63 }),
64 (
65 "Option Box as_ref",
66 |n, t| quote! { from.#n.as_ref().map(|inner| proto::#t::from(inner.deref())) },
67 ),
68 ("Option try", |n, _| {
69 quote! {
70 match from.#n {
71 Some(inner) => Some(inner.try_into()?),
72 None => None,
73 }
74 }
75 }),
76 (
77 "Option as_ref",
78 |n, t| quote! { from.#n.as_ref().map(proto::#t::from) },
79 ),
80 (
81 "Option map from",
82 |n, t| quote! { from.#n.map(proto::#t::from) },
83 ),
84 ("Option Result", |n, t| {
85 let core_path = core_path(n.span());
86 quote! { #t::try_from(from.#n.ok_or(#core_path::error::DecodeError)?)? }
87 }),
88 (
89 "Option Some",
90 |n, t| quote! { Some(proto::#t::from(&from.#n)) },
91 ),
92 (
93 "Option Some Owned",
94 |n, t| quote! { Some(proto::#t::from(from.#n)) },
95 ),
96 (
97 "into_iter try_from",
98 |n, t| quote! { from.#n.into_iter().map(#t::try_from).collect::<Result<_, _>>()? },
99 ),
100 ("into_iter from_bytes", |n, _| {
101 quote! { from.#n.into_iter().map(|item| {
102 item.as_slice().try_into()
103 }).collect::<Result<_, _>>()? }
104 }),
105 (
106 "iter",
107 |n, t| quote! { from.#n.iter().map(proto::#t::from).collect() },
108 ),
109 (
110 "into_iter from",
111 |n, t| quote! { from.#n.into_iter().map(proto::#t::from).collect() },
112 ),
113 (
114 "iter to_bytes BlindedPublicKey",
115 |n, _| quote! { from.#n.iter().map(|item| item.to_bytes().to_vec()).collect() },
116 ),
117 (
118 "proto decode",
119 |n, t| quote! { #t::decode(from.#n.as_slice())? },
120 ),
121 ("proto encode", |n, _| quote! { from.#n.encode_to_vec() }),
122 ("into_iter proto decode", |n, t| {
123 quote! {
124 from.#n.into_iter().map(#t::decode).collect::<Result<_, _>>()?
125 }
126 }),
127 (
128 "iter proto encode",
129 |n, _| quote! { from.#n.iter().map(|item| item.encode_to_vec()).collect() },
130 ),
131 ];
132
133 let map = map.into_iter().collect::<HashMap<&'static str, ConvFn>>();
134
135 [
136 ("BlindedPublicKey", "from_bytes", "to_bytes", "nop"),
137 ("Ed25519Signature", "via &[u8; 64]", "to_bytes", "to_bytes"),
138 ("Ed25519Signing", "slice", "to_bytes", "to_bytes"),
139 ("Ed25519Verifying", "slice", "as_bytes", "as_bytes"),
140 ("FnDsaVerifying", "slice", "to_vec", "to_vec"),
141 ("FnDsaSignature", "slice", "to_vec", "to_vec"),
142 ("FnDsaSigning", "slice", "inner to_vec", "inner to_vec"),
143 ("EnvelopeIdHints", "from_bytes?", "to_bytes", "to_bytes"),
144 ("MlKemCiphertext", "try_from slice", "to_vec", "to_vec"),
145 ("MlKemPublic", "from_bytes", "as_bytes", "as_bytes"),
146 ("MlKemSecret", "from_bytes", "as_bytes", "as_bytes"),
147 ("X25519Public", "via [u8; 32]", "as_bytes", "as_bytes"),
148 ("X25519Secret", "via [u8; 32]", "as_bytes", "as_bytes"),
149 ("XChaChaKey", "via [u8; 32]", "to_vec", "to_vec"),
150 ("XNonce", "via [u8; 24]", "to_vec", "to_vec"),
151 ("SixFour", "slice", "to_vec", "to_vec"),
152 ("String", "move", "clone", "move"),
153 ("ThreeTwo", "slice", "to_vec", "to_vec"),
154 ("TwoFour", "slice", "to_vec", "to_vec"),
155 ("OneSix", "slice", "to_vec", "to_vec"),
156 ("u16", "direct", "as u32", "as u32"),
157 ("Vec<u8>", "move", "clone", "move"),
158 (
159 "Vec<BlindedPublicKey>",
160 "into_iter from_bytes",
161 "iter to_bytes BlindedPublicKey",
162 "nop",
163 ),
164 ("Vec<_>", "into_iter try_from", "iter", "into_iter from"),
165 (
166 "Option<MessageMetadata>",
167 "Option try",
168 "nop",
169 "Option map from",
170 ),
171 (
172 "Option<Box<_>>",
173 "Option Box try",
174 "Option Box as_ref",
175 "nop",
176 ),
177 ("Option<T>", "move", "clone", "move"),
178 ("Option<_>", "Option try", "Option as_ref", "Option as_ref"),
179 ("_", "Option Result", "Option Some", "Option Some Owned"),
180 ("as ProstMessage", "proto decode", "proto encode", "nop"),
181 (
182 "Vec<as ProstMessage>",
183 "into_iter proto decode",
184 "iter proto encode",
185 "nop",
186 ),
187 (
188 "Zeroizing<_>",
189 "zeroizing",
190 "zeroizing deref clone",
191 "zeroizing",
192 ),
193 ]
194 .into_iter()
195 .map(|(key, de, en, eno)| (key.to_string(), (map[de], map[en], map[eno])))
196 .collect()
197});
198
199#[derive(Debug)]
200struct LocalField<'a> {
201 field_name: &'a Ident,
202 field_type: &'a Ident,
203 field_inner_types: Vec<&'a Ident>,
204}
205
206impl LocalField<'_> {
207 fn conv_material(&self) -> ((ConvFn, ConvFn, ConvFn), &Ident) {
208 let field_type = self.field_type.to_string();
209 let mut target_field_type = self.field_type;
210 let fits = &self.field_inner_types;
211
212 let key = if field_type.starts_with("Wire")
213 || field_type.starts_with("Memory")
214 || field_type.starts_with("Storage")
215 {
216 "as ProstMessage".to_string()
217 } else if fits.is_empty() {
218 field_type
219 } else {
220 match (field_type.as_str(), fits.first(), fits.get(1)) {
221 ("Option", Some(one), Some(two)) => {
222 let inner_type = one.to_string();
223 if inner_type == "Box" {
224 target_field_type = two;
225 "Option<Box<_>>"
226 } else {
227 "Option<T>"
228 }
229 .to_string()
230 }
231 ("Option", Some(one), _) => {
232 let inner_type = one.to_string();
233 if inner_type == "String" || inner_type == "u32" {
234 "Option<T>"
235 } else if inner_type == "MessageMetadata" {
236 target_field_type = one;
237 "Option<MessageMetadata>"
238 } else {
239 target_field_type = one;
240 "Option<_>"
241 }
242 .to_string()
243 }
244 ("Vec", Some(one), _) => {
245 let inner_type = one.to_string();
246 if inner_type.starts_with("Wire")
247 || inner_type.starts_with("Memory")
248 || inner_type.starts_with("Storage")
249 {
250 target_field_type = one;
251 "Vec<as ProstMessage>"
252 } else if inner_type == "BlindedPublicKey" {
253 "Vec<BlindedPublicKey>"
254 } else if inner_type == "u8" {
255 "Vec<u8>"
256 } else {
257 target_field_type = one;
258 "Vec<_>"
259 }
260 .to_string()
261 }
262 ("Zeroizing", Some(one), _) => {
263 target_field_type = one;
264 "Zeroizing<_>".to_string()
265 }
266 _ => field_type,
267 }
268 };
269
270 (
271 match CONV_FNS.contains_key(&key) {
272 true => CONV_FNS[&key],
273 false => CONV_FNS["_"],
274 },
275 target_field_type,
276 )
277 }
278
279 fn convert(&self, conv: TokenStream) -> TokenStream {
280 let field_name = self.field_name;
281 quote_spanned! {field_name.span()=> #field_name: #conv, }
282 }
283
284 fn decode(&self) -> TokenStream {
285 let ((decode, _, _), target_field_type) = self.conv_material();
286 self.convert(decode(self.field_name, target_field_type))
287 }
288
289 fn encode(&self) -> TokenStream {
290 let ((_, encode, _), target_field_type) = self.conv_material();
291 self.convert(encode(self.field_name, target_field_type))
292 }
293
294 fn encode_owned(&self) -> TokenStream {
295 let ((_, _, encode_owned), target_field_type) = self.conv_material();
296 self.convert(encode_owned(self.field_name, target_field_type))
297 }
298}
299
300fn decode(data: &Data) -> Vec<TokenStream> {
301 parse_fields(data).iter().map(LocalField::decode).collect()
302}
303
304fn encode(data: &Data) -> Vec<TokenStream> {
305 parse_fields(data).iter().map(LocalField::encode).collect()
306}
307
308fn encode_owned(data: &Data) -> Vec<TokenStream> {
309 parse_fields(data)
310 .iter()
311 .map(LocalField::encode_owned)
312 .collect()
313}
314
315#[allow(clippy::panic)]
316fn prosted_derivations(
317 tt: &str,
318 input: &DeriveInput,
319 mod_base: &TokenTree,
320 prost_type_name: &TokenTree,
321 type_name: &Ident,
322) -> TokenStream {
323 let core_path = core_path(type_name.span());
324
325 if tt == "Decode" {
326 let decode_fields = decode(&input.data);
327
328 quote! {
329 impl TryFrom<#mod_base::#prost_type_name> for #type_name {
330 type Error = #core_path::error::DecodeError;
331
332 fn try_from(from: #mod_base::#prost_type_name) -> Result<Self, Self::Error> {
333 Ok(#type_name {
334 #(#decode_fields)*
335 })
336 }
337 }
338 }
339 } else if tt == "Encode" {
340 let encode_fields = encode(&input.data);
341
342 quote! {
343 impl From<&#type_name> for #mod_base::#prost_type_name {
344 fn from(from: &#type_name) -> Self {
345 Self {
346 #(#encode_fields)*
347 }
348 }
349 }
350 }
351 } else if tt == "EncodeOwned" {
352 let encode_owned_fields = encode_owned(&input.data);
353
354 quote! {
355 impl From<#type_name> for #mod_base::#prost_type_name {
356 fn from(from: #type_name) -> Self {
357 Self {
358 #(#encode_owned_fields)*
359 }
360 }
361 }
362 }
363 } else if tt == "ProstTraits" {
364 quote! {
365 impl #core_path::ProstDecode for #type_name {
366 type EncodedType = #mod_base::#prost_type_name;
367 }
368
369 impl #core_path::ProstEncode for #type_name {
370 type EncodedType = #mod_base::#prost_type_name;
371 }
372 }
373 } else if tt == "ProstTraitsOwned" {
374 quote! {
375 impl #core_path::ProstDecode for #type_name {
376 type EncodedType = #mod_base::#prost_type_name;
377 }
378
379 impl #core_path::ProstEncodeOwned for #type_name {
380 type EncodedType = #mod_base::#prost_type_name;
381 }
382 }
383 } else {
384 panic!("Unsupported prosted derivation.");
385 }
386}
387
388#[proc_macro_attribute]
389pub fn prosted(
390 attr: proc_macro::TokenStream,
391 item: proc_macro::TokenStream,
392) -> proc_macro::TokenStream {
393 let structure = TokenStream::from(item.clone());
394 let input = parse_macro_input!(item as DeriveInput);
395 let type_name = &input.ident;
396
397 let attr = TokenStream::from(attr)
398 .into_iter()
399 .filter(|t| {
400 let tk = t.to_string();
401 let tk_s = tk.as_str();
402 tk_s != "," && tk_s != ":"
403 })
404 .collect::<Vec<_>>();
405
406 let mod_base = &attr[0];
407 let prost_type_name = &attr[1];
408
409 let impls = attr
410 .iter()
411 .skip(2)
412 .map(|tt| {
413 prosted_derivations(
414 tt.to_string().as_str(),
415 &input,
416 mod_base,
417 prost_type_name,
418 type_name,
419 )
420 })
421 .collect::<Vec<_>>();
422
423 quote! {
424 #structure
425
426 #(#impls)*
427 }
428 .into()
429}
430
431#[proc_macro_attribute]
432pub fn communicable(
433 attr: proc_macro::TokenStream,
434 item: proc_macro::TokenStream,
435) -> proc_macro::TokenStream {
436 let structure = TokenStream::from(item.clone());
437 let input = parse_macro_input!(item as DeriveInput);
438 let type_name = &input.ident;
439
440 let core_path = core_path(type_name.span());
441
442 let attr = TokenStream::from(attr).into_iter().collect::<Vec<_>>();
443 let base_type = &attr[0];
444 let variant = &attr[3];
445 let kind = attr[&attr.len() - 1].to_string(); #[allow(clippy::unwrap_used)]
448 let (prost, trait_name) = match kind.as_str() {
449 "owned" => (true, "ProstCommunicableOwned"),
450 "marker" => (false, "Communicable"),
451 _ => (true, "ProstCommunicable"),
452 };
453 #[allow(clippy::unwrap_used)]
454 let trait_name = TokenStream::from_str(trait_name).unwrap();
455 if prost {
456 quote! {
457 #structure
458
459 impl #core_path::communication::#trait_name<#base_type> for #type_name {
460 const COMMUNICATION_VARIANT: #base_type = #base_type::#variant;
461 }
462 }
463 } else {
464 quote! {
465 #structure
466
467 impl #core_path::communication::#trait_name<#base_type> for #type_name {
468 fn to_communication(&self) -> #core_path::communication::Communication<#base_type> {
469 #core_path::communication::Communication::from(#base_type::#variant, vec![], None)
470 }
471
472 fn try_from_communication(
473 communication: &#core_path::communication::Communication<#base_type>,
474 ) -> Result<Self, #core_path::error::DecodeError> {
475 if communication.r#type != #base_type::#variant {
476 return Err(#core_path::error::DecodeError);
477 }
478
479 Ok(Self {})
480 }
481 }
482 }
483 }
484 .into()
485}
486
487fn parse_fields(data: &Data) -> Vec<LocalField<'_>> {
488 let mut fields = Vec::new();
489
490 match data {
491 Data::Struct(inner) => {
492 for field in &inner.fields {
493 let field_name = field.ident.as_ref().expect("Expected field name");
494 if let Type::Path(inner) = &field.ty {
495 let field_type = &inner.path.segments[0].ident;
496 let mut field_inner_types = vec![];
497 all_anglebracketed(&inner.path.segments[0].arguments, &mut field_inner_types);
498
499 fields.push(LocalField {
500 field_name,
501 field_type,
502 field_inner_types,
503 });
504 }
505 }
506 }
507 _ => unimplemented!(),
508 }
509
510 fields
511}
512
513fn all_anglebracketed<'a>(arguments: &'a PathArguments, idents: &mut Vec<&'a Ident>) {
514 if let PathArguments::AngleBracketed(inner) = arguments {
515 if let GenericArgument::Type(Type::Path(inner)) = &inner.args[0] {
516 idents.push(&inner.path.segments[0].ident);
517 all_anglebracketed(&inner.path.segments[0].arguments, idents);
518 }
519 }
520}
521
522fn core_path(span: Span) -> Ident {
523 Ident::new(
524 match std::env::var("CARGO_PKG_NAME")
525 .expect("Should use cargo")
526 .as_str()
527 {
528 "reach-core" => "crate",
529 _ => "reach_core",
530 },
531 span,
532 )
533}
534
535#[proc_macro_derive(ParticipantPublicKeys)]
536pub fn derive_participant_public_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
537 let input = parse_macro_input!(input as DeriveInput);
538 let type_name = &input.ident;
539 let participant_variant = if type_name.to_string().starts_with("Reaching") {
540 Ident::new("Reaching", type_name.span())
541 } else {
542 Ident::new("Reachable", type_name.span())
543 };
544
545 let core_path = core_path(type_name.span());
546
547 quote! {
548 impl #core_path::ParticipantPublicKeys for #type_name {
549 fn ec_public_key(&self) -> &reach_aliases::X25519Public {
550 &self.ec_public_key
551 }
552
553 fn pq_public_key(&self) -> &reach_aliases::MlKemPublic {
554 &self.pq_public_key
555 }
556
557 fn participant_type(&self) -> #core_path::ParticipantType {
558 #core_path::ParticipantType::#participant_variant
559 }
560 }
561 }
562 .into()
563}
564
565#[proc_macro_derive(ParticipantSecretKeys)]
566pub fn derive_participant_secret_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
567 let input = parse_macro_input!(input as DeriveInput);
568 let type_name = &input.ident;
569
570 quote! {
571 impl reach_encryption::ParticipantSecretKeys for #type_name {
572 fn ec_secret_key(&self) -> &reach_aliases::X25519Secret {
573 &self.ec_secret_key
574 }
575
576 fn pq_secret_key(&self) -> &reach_aliases::MlKemSecret {
577 &self.pq_secret_key
578 }
579 }
580 }
581 .into()
582}
583
584#[proc_macro_derive(SecretKeysRandomFromRng)]
585pub fn derive_secret_keys_random_from_rng(
586 input: proc_macro::TokenStream,
587) -> proc_macro::TokenStream {
588 let input = parse_macro_input!(input as DeriveInput);
589 let type_name = &input.ident;
590
591 let core_path = core_path(type_name.span());
592
593 quote! {
594 impl #core_path::RandomFromRng for #type_name {
595 fn random_from_rng(csprng: &mut impl rand_core::CryptoRngCore) -> Self {
596 let (pq_secret_key, pq_public_key) = reach_aliases::MlKem::generate(csprng);
597
598 Self {
599 ec_secret_key: reach_aliases::X25519Secret::random_from_rng(csprng),
600 pq_secret_key,
601 pq_public_key,
602 }
603 }
604 }
605 }
606 .into()
607}
608
609#[proc_macro_derive(UnsignedPublicKeys)]
610pub fn derive_unsigned_public_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
611 let input = parse_macro_input!(input as DeriveInput);
612 let type_name = &input.ident;
613
614 let public_keys_type_name = Ident::new(
615 format!(
616 "Unsigned{}",
617 type_name.to_string().replace("SecretKeys", "PublicKeys")
618 )
619 .as_str(),
620 type_name.span(),
621 );
622
623 public_keys_tokenstream(type_name, &public_keys_type_name).into()
624}
625
626#[proc_macro_derive(PublicKeys)]
627pub fn derive_public_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
628 let input = parse_macro_input!(input as DeriveInput);
629 let type_name = &input.ident;
630
631 let public_keys_type_name = Ident::new(
632 type_name
633 .to_string()
634 .as_str()
635 .replace("SecretKeys", "PublicKeys")
636 .as_str(),
637 type_name.span(),
638 );
639
640 public_keys_tokenstream(type_name, &public_keys_type_name).into()
641}
642
643fn public_keys_tokenstream(type_name: &Ident, public_keys_type_name: &Ident) -> TokenStream {
644 quote! {
645 impl From<&#type_name> for #public_keys_type_name {
646 fn from(from: &#type_name) -> Self {
647 Self {
648 ec_public_key: reach_aliases::X25519Public::from(&from.ec_secret_key),
649 pq_public_key: from.pq_public_key.clone(),
650 }
651 }
652 }
653 }
654}
655
656#[proc_macro_derive(Verifier)]
657pub fn derive_verifier(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
658 let input = parse_macro_input!(input as DeriveInput);
659 let type_name = &input.ident;
660
661 quote! {
662 impl reach_signatures::Verifier for #type_name {
663 fn ec_verifying_key(&self) -> &reach_aliases::Ed25519Verifying {
664 &self.ec_verifying_key
665 }
666
667 fn pq_verifying_key(&self) -> &reach_aliases::FnDsaVerifying {
668 &self.pq_verifying_key
669 }
670 }
671 }
672 .into()
673}
674
675fn verifying_keys(type_name: &Ident, verifying_type_name: &Ident) -> proc_macro::TokenStream {
676 quote! {
677 impl From<&#type_name> for #verifying_type_name {
678 fn from(from: &#type_name) -> Self {
679 Self {
680 ec_verifying_key: from.ec_signing_key.verifying_key(),
681 pq_verifying_key: from.pq_verifying_key,
682 }
683 }
684 }
685 }
686 .into()
687}
688
689#[proc_macro_derive(UnsignedVerifyingKeys)]
690pub fn derive_unsigned_verifying_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
691 let input = parse_macro_input!(input as DeriveInput);
692 let type_name = &input.ident;
693 let verifying_type_name = Ident::new(
694 format!(
695 "Unsigned{}",
696 type_name.to_string().replace("Signing", "Verifying")
697 )
698 .as_str(),
699 type_name.span(),
700 );
701
702 verifying_keys(type_name, &verifying_type_name)
703}
704
705#[proc_macro_derive(VerifyingKeys)]
706pub fn derive_verifying_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
707 let input = parse_macro_input!(input as DeriveInput);
708 let type_name = &input.ident;
709 let verifying_type_name = Ident::new(
710 type_name
711 .to_string()
712 .replace("Signing", "Verifying")
713 .as_str(),
714 type_name.span(),
715 );
716
717 verifying_keys(type_name, &verifying_type_name)
718}
719
720#[proc_macro_derive(SigningKeys)]
721pub fn derive_signing_keys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
722 let input = parse_macro_input!(input as DeriveInput);
723 let type_name = &input.ident;
724
725 let core_path = core_path(type_name.span());
726
727 quote! {
728 impl #core_path::RandomFromRng for #type_name {
729 fn random_from_rng(csprng: &mut impl rand_core::CryptoRngCore) -> Self {
730 let mut kg = fn_dsa_kgen::KeyPairGeneratorStandard::default();
731 let mut pq_signing_key_inner = reach_signatures::FN_DSA_SIGNING_EMPTY;
732 let mut pq_verifying_key = reach_aliases::FN_DSA_VERIFYING_EMPTY;
733 kg.keygen(
734 reach_aliases::FN_DSA_DEGREES,
735 csprng,
736 &mut pq_signing_key_inner,
737 &mut pq_verifying_key
738 );
739
740 Self {
741 ec_signing_key: reach_aliases::Ed25519Signing::generate(csprng),
742 pq_signing_key: reach_aliases::FnDsaSigning {
743 inner: pq_signing_key_inner,
744 },
745 pq_verifying_key,
746 }
747 }
748 }
749
750 impl reach_signatures::Sign for #type_name {
751 fn ec_signing_key(&self) -> &reach_aliases::Ed25519Signing {
752 &self.ec_signing_key
753 }
754
755 fn pq_signing_key(&self) -> &reach_aliases::FnDsaSigning {
756 &self.pq_signing_key
757 }
758 }
759 }
760 .into()
761}
762
763#[proc_macro]
764pub fn request_handler_macro(attr: proc_macro::TokenStream) -> proc_macro::TokenStream {
765 let attr = TokenStream::from(attr)
766 .into_iter()
767 .filter(|t| {
768 let tk = t.to_string();
769 let tk_s = tk.as_str();
770 tk_s != ","
771 })
772 .collect::<Vec<_>>();
773
774 let macro_name = &attr[0];
775 let communication_type = &attr[1];
776 let request_wrapper_type = &attr[2];
777 let global_context_type = &attr[3];
778 let session_context_type = &attr[4];
779
780 quote! {
781 macro_rules! #macro_name {
782 (
783 async fn handle(
784 $request:tt: $request_type:ty,
785 $global_context:tt: _,
786 $session_context:tt: _$(,)?
787 ) -> $response_type:ty
788 $handler_body:block
789 ) => {
790 #macro_name!(
791 async fn handle(
792 $request: $request_type,
793 $global_context: #global_context_type,
794 $session_context: #session_context_type,
795 ) -> $response_type $handler_body
796 );
797 };
798
799 (
800 async fn handle(
801 $request:tt: $request_type:ty,
802 $global_context:tt: $global_context_type:ty,
803 $session_context:tt: _$(,)?
804 ) -> $response_type:ty
805 $handler_body:block
806 ) => {
807 #macro_name!(
808 async fn handle(
809 $request: $request_type,
810 $global_context: $global_context_type,
811 $session_context: #session_context_type,
812 ) -> $response_type $handler_body
813 );
814 };
815
816 (
817 async fn handle(
818 $request:tt: $request_type:ty,
819 $global_context:tt: _,
820 $session_context:tt: $session_context_type:ty$(,)?
821 ) -> $response_type:ty
822 $handler_body:block
823 ) => {
824 #macro_name!(
825 async fn handle(
826 $request: $request_type,
827 $global_context: #global_context_type,
828 $session_context: $session_context_type,
829 ) -> $response_type $handler_body
830 );
831 };
832
833 (
834 async fn handle(
835 $request:tt: $request_type:ty,
836 $global_context:tt: $global_context_type:ty,
837 $session_context:tt: $session_context_type:ty$(,)?
838 ) -> $response_type:ty
839 $handler_body:block
840 ) => {
841 impl reach_websocket::RequestHandler<#communication_type>
842 for #request_wrapper_type<$request_type>
843 {
844 type GlobalContext = $global_context_type;
845 type SessionContext = $session_context_type;
846
847 async fn handle(
848 $request: Self,
849 $global_context: std::sync::Arc<Self::GlobalContext>,
850 $session_context: std::sync::Arc<tokio::sync::RwLock<Self::SessionContext>>,
851 ) -> $response_type $handler_body
852 }
853 };
854 }
855
856 pub(crate) use #macro_name;
857 }
858 .into()
859}