umsh_crypto/
lib.rs

1#![allow(async_fn_in_trait)]
2#![cfg_attr(not(feature = "std"), no_std)]
3
4//! Cryptographic traits and UMSH-specific key/packet operations.
5//!
6//! This crate separates algorithm providers from protocol logic. The low-level
7//! traits such as [`AesProvider`] and [`Sha256Provider`] can be backed either by
8//! software implementations or hardware accelerators, while [`CryptoEngine`]
9//! implements the UMSH-specific derivation and packet-authentication rules.
10//!
11//! # Example
12//!
13//! ```rust
14//! use umsh_crypto::software::{SoftwareAes, SoftwareIdentity, SoftwareSha256};
15//! use umsh_crypto::{CryptoEngine, NodeIdentity};
16//!
17//! let alice = SoftwareIdentity::from_secret_bytes(&[0x11; 32]);
18//! let bob = SoftwareIdentity::from_secret_bytes(&[0x22; 32]);
19//! let shared = alice.shared_secret_with(bob.public_key()).unwrap();
20//! let engine = CryptoEngine::new(SoftwareAes, SoftwareSha256);
21//! let keys = engine.derive_pairwise_keys(&shared);
22//!
23//! assert_ne!(keys.k_enc, [0u8; 16]);
24//! assert_ne!(keys.k_mic, [0u8; 16]);
25//! ```
26
27use core::ops::Range;
28
29use umsh_core::{
30    ChannelId, ChannelKey, PacketHeader, PacketType, PublicKey, SourceAddrRef, UnsealedPacket,
31    feed_aad,
32};
33use zeroize::{Zeroize, ZeroizeOnDrop};
34
35/// AES block-cipher instance used by the protocol engine.
36pub trait AesCipher {
37    /// Encrypt one 16-byte block in place.
38    fn encrypt_block(&self, block: &mut [u8; 16]);
39    /// Decrypt one 16-byte block in place.
40    fn decrypt_block(&self, block: &mut [u8; 16]);
41}
42
43/// Factory for keyed AES cipher instances.
44pub trait AesProvider {
45    /// Concrete cipher type returned by [`new_cipher`](Self::new_cipher).
46    type Cipher: AesCipher;
47
48    /// Create a new AES-128 cipher using `key`.
49    fn new_cipher(&self, key: &[u8; 16]) -> Self::Cipher;
50}
51
52/// SHA-256 and HMAC-SHA-256 provider.
53pub trait Sha256Provider {
54    /// Hash a list of borrowed byte slices as one concatenated message.
55    fn hash(&self, data: &[&[u8]]) -> [u8; 32];
56    /// Compute HMAC-SHA-256 over a list of borrowed byte slices.
57    fn hmac(&self, key: &[u8], data: &[&[u8]]) -> [u8; 32];
58}
59
60/// Raw X25519 shared secret.
61#[derive(Clone, Zeroize, ZeroizeOnDrop)]
62pub struct SharedSecret(pub [u8; 32]);
63
64/// Node identity capable of signing and key agreement.
65pub trait NodeIdentity {
66    type Error;
67
68    /// Return the long-term Ed25519 public key for this identity.
69    fn public_key(&self) -> &PublicKey;
70
71    /// Return the three-byte node hint derived from [`public_key`](Self::public_key).
72    fn hint(&self) -> umsh_core::NodeHint {
73        self.public_key().hint()
74    }
75
76    /// Sign an arbitrary message.
77    async fn sign(&self, message: &[u8]) -> Result<[u8; 64], Self::Error>;
78    /// Perform X25519-style key agreement with a peer public key.
79    async fn agree(&self, peer: &PublicKey) -> Result<SharedSecret, Self::Error>;
80}
81
82/// Protocol-level crypto failures.
83#[derive(Clone, Debug, PartialEq, Eq)]
84pub enum CryptoError {
85    InvalidPublicKey,
86    InvalidSharedSecret,
87    InvalidPacket,
88    AuthenticationFailed,
89}
90
91/// Derived pairwise transport keys.
92#[derive(Clone, Zeroize, ZeroizeOnDrop)]
93pub struct PairwiseKeys {
94    pub k_enc: [u8; 16],
95    pub k_mic: [u8; 16],
96}
97
98/// Derived multicast or channel transport keys.
99#[derive(Clone)]
100pub struct DerivedChannelKeys {
101    pub k_enc: [u8; 16],
102    pub k_mic: [u8; 16],
103    pub channel_id: ChannelId,
104}
105
106impl Zeroize for DerivedChannelKeys {
107    fn zeroize(&mut self) {
108        self.k_enc.zeroize();
109        self.k_mic.zeroize();
110    }
111}
112
113impl Drop for DerivedChannelKeys {
114    fn drop(&mut self) {
115        self.zeroize();
116    }
117}
118
119/// Incremental AES-CMAC state.
120pub struct CmacState<C: AesCipher> {
121    cipher: C,
122    state: [u8; 16],
123    buffer: [u8; 16],
124    pos: usize,
125    k1: [u8; 16],
126    k2: [u8; 16],
127}
128
129impl<C: AesCipher> CmacState<C> {
130    /// Initialize a new incremental CMAC state.
131    pub fn new(cipher: C) -> Self {
132        let mut l = [0u8; 16];
133        cipher.encrypt_block(&mut l);
134        let k1 = dbl(&l);
135        let k2 = dbl(&k1);
136        Self {
137            cipher,
138            state: [0u8; 16],
139            buffer: [0u8; 16],
140            pos: 0,
141            k1,
142            k2,
143        }
144    }
145
146    /// Feed additional bytes into the MAC state.
147    pub fn update(&mut self, mut data: &[u8]) {
148        while !data.is_empty() {
149            let space = 16 - self.pos;
150            let take = space.min(data.len());
151            self.buffer[self.pos..self.pos + take].copy_from_slice(&data[..take]);
152            self.pos += take;
153            data = &data[take..];
154            if self.pos == 16 && !data.is_empty() {
155                self.process_buffer();
156            }
157        }
158    }
159
160    /// Finalize and return the full 16-byte CMAC value.
161    pub fn finalize(self) -> [u8; 16] {
162        let this = self;
163        let mut last = [0u8; 16];
164        if this.pos == 16 {
165            last.copy_from_slice(&this.buffer);
166            xor_in_place(&mut last, &this.k1);
167        } else {
168            last[..this.pos].copy_from_slice(&this.buffer[..this.pos]);
169            last[this.pos] = 0x80;
170            xor_in_place(&mut last, &this.k2);
171        }
172
173        xor_in_place(&mut last, &this.state);
174        this.cipher.encrypt_block(&mut last);
175        last
176    }
177
178    fn process_buffer(&mut self) {
179        let mut block = self.buffer;
180        xor_in_place(&mut block, &self.state);
181        self.cipher.encrypt_block(&mut block);
182        self.state = block;
183        self.buffer = [0u8; 16];
184        self.pos = 0;
185    }
186}
187
188/// UMSH protocol crypto engine.
189pub struct CryptoEngine<A: AesProvider, S: Sha256Provider> {
190    aes: A,
191    sha: S,
192}
193
194impl<A: AesProvider, S: Sha256Provider> CryptoEngine<A, S> {
195    /// Create a new engine from algorithm providers.
196    pub fn new(aes: A, sha: S) -> Self {
197        Self { aes, sha }
198    }
199
200    /// Derive stable pairwise encryption and MIC keys from a shared secret.
201    pub fn derive_pairwise_keys(&self, shared_secret: &SharedSecret) -> PairwiseKeys {
202        let mut okm = [0u8; 32];
203        self.hkdf(
204            &shared_secret.0,
205            b"UMSH-PAIRWISE-SALT",
206            b"UMSH-UNICAST-V1",
207            &mut okm,
208        );
209        let mut keys = PairwiseKeys {
210            k_enc: [0u8; 16],
211            k_mic: [0u8; 16],
212        };
213        keys.k_enc.copy_from_slice(&okm[..16]);
214        keys.k_mic.copy_from_slice(&okm[16..32]);
215        okm.zeroize();
216        keys
217    }
218
219    /// Derive the channel identifier from a raw channel key.
220    pub fn derive_channel_id(&self, channel_key: &ChannelKey) -> ChannelId {
221        let mut out = [0u8; 2];
222        self.hkdf(&channel_key.0, b"UMSH-CHAN-ID", b"", &mut out);
223        ChannelId(out)
224    }
225
226    /// Derive multicast transport keys and the channel identifier.
227    pub fn derive_channel_keys(&self, channel_key: &ChannelKey) -> DerivedChannelKeys {
228        let channel_id = self.derive_channel_id(channel_key);
229        let mut info = [0u8; 15];
230        info[..13].copy_from_slice(b"UMSH-MCAST-V1");
231        info[13..15].copy_from_slice(&channel_id.0);
232        let mut okm = [0u8; 32];
233        self.hkdf(&channel_key.0, b"UMSH-MCAST-SALT", &info, &mut okm);
234        let mut derived = DerivedChannelKeys {
235            k_enc: [0u8; 16],
236            k_mic: [0u8; 16],
237            channel_id,
238        };
239        derived.k_enc.copy_from_slice(&okm[..16]);
240        derived.k_mic.copy_from_slice(&okm[16..32]);
241        okm.zeroize();
242        derived
243    }
244
245    /// Combine pairwise and channel keys for blind-unicast payload protection.
246    pub fn derive_blind_keys(
247        &self,
248        pairwise: &PairwiseKeys,
249        channel: &DerivedChannelKeys,
250    ) -> PairwiseKeys {
251        let mut keys = PairwiseKeys {
252            k_enc: [0u8; 16],
253            k_mic: [0u8; 16],
254        };
255        for (dst, (left, right)) in keys
256            .k_enc
257            .iter_mut()
258            .zip(pairwise.k_enc.iter().zip(channel.k_enc.iter()))
259        {
260            *dst = left ^ right;
261        }
262        for (dst, (left, right)) in keys
263            .k_mic
264            .iter_mut()
265            .zip(pairwise.k_mic.iter().zip(channel.k_mic.iter()))
266        {
267            *dst = left ^ right;
268        }
269        keys
270    }
271
272    /// Derive a channel key from a human-readable channel name.
273    pub fn derive_named_channel_key(&self, name: &str) -> ChannelKey {
274        ChannelKey(self.sha.hmac(b"UMSH-CHANNEL-V1", &[name.as_bytes()]))
275    }
276
277    /// Seal a unicast or multicast packet in place.
278    pub fn seal_packet(
279        &self,
280        packet: &mut UnsealedPacket<'_>,
281        keys: &PairwiseKeys,
282    ) -> Result<usize, CryptoError> {
283        let header = packet.header().map_err(|_| CryptoError::InvalidPacket)?;
284        let sec_info = header.sec_info.ok_or(CryptoError::InvalidPacket)?;
285        let full_mac = {
286            let bytes = packet.as_bytes();
287            let mut cmac = self.cmac_state(&keys.k_mic);
288            feed_aad(&header, bytes, |chunk| cmac.update(chunk));
289            cmac.update(packet.body());
290            cmac.finalize()
291        };
292
293        let mic_len = sec_info
294            .scf
295            .mic_size()
296            .map_err(|_| CryptoError::InvalidPacket)?
297            .byte_len();
298        packet.mic_slot()[..mic_len].copy_from_slice(&full_mac[..mic_len]);
299
300        if sec_info.scf.encrypted() {
301            let iv = self.build_ctr_iv(
302                &full_mac[..mic_len],
303                &packet.as_bytes()[packet.sec_info_range()],
304            );
305            self.aes_ctr(&keys.k_enc, &iv, packet.body_mut());
306        }
307
308        Ok(mic_len)
309    }
310
311    /// Seal a blind-unicast packet, including its hidden address block.
312    pub fn seal_blind_packet(
313        &self,
314        packet: &mut UnsealedPacket<'_>,
315        blind_keys: &PairwiseKeys,
316        channel_keys: &DerivedChannelKeys,
317    ) -> Result<usize, CryptoError> {
318        let header = packet.header().map_err(|_| CryptoError::InvalidPacket)?;
319        match header.packet_type() {
320            PacketType::BlindUnicast | PacketType::BlindUnicastAckReq => {}
321            _ => return Err(CryptoError::InvalidPacket),
322        }
323
324        let sec_info = header.sec_info.ok_or(CryptoError::InvalidPacket)?;
325        let blind_addr_range = packet
326            .blind_addr_range()
327            .ok_or(CryptoError::InvalidPacket)?;
328        let full_mac = {
329            let bytes = packet.as_bytes();
330            let mut cmac = self.cmac_state(&blind_keys.k_mic);
331            feed_aad(&header, bytes, |chunk| cmac.update(chunk));
332            cmac.update(packet.body());
333            cmac.finalize()
334        };
335
336        let mic_len = sec_info
337            .scf
338            .mic_size()
339            .map_err(|_| CryptoError::InvalidPacket)?
340            .byte_len();
341        let iv = self.build_ctr_iv(
342            &full_mac[..mic_len],
343            &packet.as_bytes()[packet.sec_info_range()],
344        );
345        packet.mic_slot()[..mic_len].copy_from_slice(&full_mac[..mic_len]);
346        if sec_info.scf.encrypted() {
347            self.aes_ctr(&blind_keys.k_enc, &iv, packet.body_mut());
348            self.aes_ctr(
349                &channel_keys.k_enc,
350                &iv,
351                &mut packet.as_bytes_mut()[blind_addr_range],
352            );
353        }
354        Ok(mic_len)
355    }
356
357    /// Verify and, if needed, decrypt a received secure packet in place.
358    pub fn open_packet(
359        &self,
360        buf: &mut [u8],
361        header: &PacketHeader,
362        keys: &PairwiseKeys,
363    ) -> Result<Range<usize>, CryptoError> {
364        let sec_info = header.sec_info.ok_or(CryptoError::InvalidPacket)?;
365        let mut mic = [0u8; 16];
366        let mic_len = header.mic_range.end - header.mic_range.start;
367        mic[..mic_len].copy_from_slice(&buf[header.mic_range.clone()]);
368        if sec_info.scf.encrypted() {
369            let sec_info_range =
370                header.body_range.start - sec_info.wire_len()..header.body_range.start;
371            let iv = self.build_ctr_iv(&mic[..mic_len], &buf[sec_info_range]);
372            self.aes_ctr(&keys.k_enc, &iv, &mut buf[header.body_range.clone()]);
373        }
374
375        let full_mac = {
376            let mut cmac = self.cmac_state(&keys.k_mic);
377            feed_aad(header, buf, |chunk| cmac.update(chunk));
378            cmac.update(&buf[header.body_range.clone()]);
379            cmac.finalize()
380        };
381        if !constant_time_eq(&mic[..mic_len], &full_mac[..mic_len]) {
382            return Err(CryptoError::AuthenticationFailed);
383        }
384
385        let body_range = match (header.packet_type(), header.source) {
386            (PacketType::Multicast, SourceAddrRef::Encrypted { len, .. }) => {
387                (header.body_range.start + len)..header.body_range.end
388            }
389            _ => header.body_range.clone(),
390        };
391        Ok(body_range)
392    }
393
394    /// Decrypt the blinded destination/source address block of a blind unicast.
395    pub fn decrypt_blind_addr(
396        &self,
397        buf: &mut [u8],
398        header: &PacketHeader,
399        channel_keys: &DerivedChannelKeys,
400    ) -> Result<(umsh_core::NodeHint, SourceAddrRef), CryptoError> {
401        match header.source {
402            SourceAddrRef::Encrypted { offset, len } => {
403                let addr_start = offset.checked_sub(3).ok_or(CryptoError::InvalidPacket)?;
404                let addr_end = addr_start + 3 + len;
405                let sec_info = header.sec_info.ok_or(CryptoError::InvalidPacket)?;
406                let sec_info_range =
407                    header.body_range.start - sec_info.wire_len()..header.body_range.start;
408                let iv = self.build_ctr_iv(&buf[header.mic_range.clone()], &buf[sec_info_range]);
409                self.aes_ctr(&channel_keys.k_enc, &iv, &mut buf[addr_start..addr_end]);
410                let dst = umsh_core::NodeHint([
411                    buf[addr_start],
412                    buf[addr_start + 1],
413                    buf[addr_start + 2],
414                ]);
415                let source = if len == 3 {
416                    SourceAddrRef::Hint(umsh_core::NodeHint([
417                        buf[addr_start + 3],
418                        buf[addr_start + 4],
419                        buf[addr_start + 5],
420                    ]))
421                } else if len == 32 {
422                    SourceAddrRef::FullKeyAt {
423                        offset: addr_start + 3,
424                    }
425                } else {
426                    return Err(CryptoError::InvalidPacket);
427                };
428                Ok((dst, source))
429            }
430            SourceAddrRef::Hint(hint) => {
431                let dst = header.dst.ok_or(CryptoError::InvalidPacket)?;
432                Ok((dst, SourceAddrRef::Hint(hint)))
433            }
434            SourceAddrRef::FullKeyAt { offset } => {
435                let dst = header.dst.ok_or(CryptoError::InvalidPacket)?;
436                let _ = buf
437                    .get(offset..offset + 32)
438                    .ok_or(CryptoError::InvalidPacket)?;
439                Ok((dst, SourceAddrRef::FullKeyAt { offset }))
440            }
441            SourceAddrRef::None => Err(CryptoError::InvalidPacket),
442        }
443    }
444
445    /// Compute the 8-byte transport ACK tag from a full CMAC and `k_enc`.
446    pub fn compute_ack_tag(&self, full_cmac: &[u8; 16], k_enc: &[u8; 16]) -> [u8; 8] {
447        let cipher = self.aes.new_cipher(k_enc);
448        let mut block = *full_cmac;
449        cipher.encrypt_block(&mut block);
450        let mut ack = [0u8; 8];
451        ack.copy_from_slice(&block[..8]);
452        ack
453    }
454
455    /// Create a reusable incremental CMAC state.
456    pub fn cmac_state(&self, key: &[u8; 16]) -> CmacState<A::Cipher> {
457        CmacState::new(self.aes.new_cipher(key))
458    }
459
460    /// Convenience wrapper for AES-CMAC over concatenated slices.
461    pub fn aes_cmac(&self, key: &[u8; 16], data: &[&[u8]]) -> [u8; 16] {
462        let mut state = self.cmac_state(key);
463        for chunk in data {
464            state.update(chunk);
465        }
466        state.finalize()
467    }
468
469    /// Apply AES-CTR using `iv` as the initial counter block.
470    pub fn aes_ctr(&self, key: &[u8; 16], iv: &[u8; 16], data: &mut [u8]) {
471        let cipher = self.aes.new_cipher(key);
472        let mut counter = *iv;
473        for chunk in data.chunks_mut(16) {
474            let mut stream = counter;
475            cipher.encrypt_block(&mut stream);
476            for (dst, src) in chunk.iter_mut().zip(stream.iter()) {
477                *dst ^= *src;
478            }
479            increment_counter(&mut counter);
480        }
481    }
482
483    /// Construct the CTR IV from MIC bytes and SECINFO bytes.
484    pub fn build_ctr_iv(&self, mic: &[u8], sec_info_bytes: &[u8]) -> [u8; 16] {
485        let mut iv = [0u8; 16];
486        let mut written = 0usize;
487        for byte in mic.iter().chain(sec_info_bytes.iter()).take(16) {
488            iv[written] = *byte;
489            written += 1;
490        }
491        iv
492    }
493
494    /// Run HKDF-SHA256 and write the output into `okm`.
495    pub fn hkdf(&self, ikm: &[u8], salt: &[u8], info: &[u8], okm: &mut [u8]) {
496        let prk = self.sha.hmac(salt, &[ikm]);
497        let mut previous = [0u8; 32];
498        let mut previous_len = 0usize;
499        let mut written = 0usize;
500        let mut counter = 1u8;
501        while written < okm.len() {
502            let next = self
503                .sha
504                .hmac(&prk, &[&previous[..previous_len], info, &[counter]]);
505            let take = (okm.len() - written).min(next.len());
506            okm[written..written + take].copy_from_slice(&next[..take]);
507            previous = next;
508            previous_len = 32;
509            written += take;
510            counter = counter.wrapping_add(1);
511        }
512    }
513}
514
515fn xor_in_place(dst: &mut [u8; 16], rhs: &[u8; 16]) {
516    for (left, right) in dst.iter_mut().zip(rhs.iter()) {
517        *left ^= *right;
518    }
519}
520
521fn dbl(block: &[u8; 16]) -> [u8; 16] {
522    let mut out = [0u8; 16];
523    let mut carry = 0u8;
524    for (index, byte) in block.iter().enumerate().rev() {
525        out[index] = (byte << 1) | carry;
526        carry = byte >> 7;
527    }
528    if block[0] & 0x80 != 0 {
529        out[15] ^= 0x87;
530    }
531    out
532}
533
534fn increment_counter(counter: &mut [u8; 16]) {
535    for byte in counter.iter_mut().rev() {
536        let (next, carry) = byte.overflowing_add(1);
537        *byte = next;
538        if !carry {
539            break;
540        }
541    }
542}
543
544fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
545    let mut diff = left.len() ^ right.len();
546    let max_len = left.len().max(right.len());
547    for index in 0..max_len {
548        let lhs = left.get(index).copied().unwrap_or(0);
549        let rhs = right.get(index).copied().unwrap_or(0);
550        diff |= usize::from(lhs ^ rhs);
551    }
552    diff == 0
553}
554
555#[cfg(feature = "software-crypto")]
556pub mod software {
557    use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit, generic_array::GenericArray};
558    use curve25519_dalek::{edwards::CompressedEdwardsY, montgomery::MontgomeryPoint};
559    use ed25519_dalek::{Signer, SigningKey};
560    use rand_core::CryptoRngCore;
561    use sha2::{Digest, Sha256, Sha512};
562
563    use super::*;
564
565    pub struct SoftwareAes;
566
567    pub struct SoftwareAesCipher(aes::Aes128);
568
569    impl AesCipher for SoftwareAesCipher {
570        fn encrypt_block(&self, block: &mut [u8; 16]) {
571            self.0.encrypt_block(GenericArray::from_mut_slice(block));
572        }
573
574        fn decrypt_block(&self, block: &mut [u8; 16]) {
575            self.0.decrypt_block(GenericArray::from_mut_slice(block));
576        }
577    }
578
579    impl AesProvider for SoftwareAes {
580        type Cipher = SoftwareAesCipher;
581
582        fn new_cipher(&self, key: &[u8; 16]) -> Self::Cipher {
583            SoftwareAesCipher(aes::Aes128::new(GenericArray::from_slice(key)))
584        }
585    }
586
587    pub struct SoftwareSha256;
588
589    impl Sha256Provider for SoftwareSha256 {
590        fn hash(&self, data: &[&[u8]]) -> [u8; 32] {
591            let mut hasher = Sha256::new();
592            for chunk in data {
593                hasher.update(chunk);
594            }
595            hasher.finalize().into()
596        }
597
598        fn hmac(&self, key: &[u8], data: &[&[u8]]) -> [u8; 32] {
599            const BLOCK_LEN: usize = 64;
600            let mut key_block = [0u8; BLOCK_LEN];
601            if key.len() > BLOCK_LEN {
602                key_block[..32].copy_from_slice(&self.hash(&[key]));
603            } else {
604                key_block[..key.len()].copy_from_slice(key);
605            }
606
607            let mut ipad = [0x36u8; BLOCK_LEN];
608            let mut opad = [0x5Cu8; BLOCK_LEN];
609            for index in 0..BLOCK_LEN {
610                ipad[index] ^= key_block[index];
611                opad[index] ^= key_block[index];
612            }
613
614            let mut inner = Sha256::new();
615            inner.update(ipad);
616            for chunk in data {
617                inner.update(chunk);
618            }
619            let inner_hash = inner.finalize();
620
621            let mut outer = Sha256::new();
622            outer.update(opad);
623            outer.update(inner_hash);
624            outer.finalize().into()
625        }
626    }
627
628    pub struct SoftwareIdentity {
629        secret: SigningKey,
630        public: PublicKey,
631    }
632
633    impl SoftwareIdentity {
634        pub fn generate(rng: &mut impl CryptoRngCore) -> Self {
635            let secret = SigningKey::generate(rng);
636            let public = PublicKey(secret.verifying_key().to_bytes());
637            Self { secret, public }
638        }
639
640        pub fn from_secret_bytes(bytes: &[u8; 32]) -> Self {
641            let secret = SigningKey::from_bytes(bytes);
642            let public = PublicKey(secret.verifying_key().to_bytes());
643            Self { secret, public }
644        }
645
646        pub fn shared_secret_with(&self, peer: &PublicKey) -> Result<SharedSecret, CryptoError> {
647            let local = signing_key_to_x25519(&self.secret);
648            let remote = public_key_to_x25519(peer)?;
649            let shared = local.diffie_hellman(&remote);
650            let bytes = shared.to_bytes();
651            if bytes.iter().all(|byte| *byte == 0) {
652                return Err(CryptoError::InvalidSharedSecret);
653            }
654            Ok(SharedSecret(bytes))
655        }
656    }
657
658    impl NodeIdentity for SoftwareIdentity {
659        type Error = CryptoError;
660
661        fn public_key(&self) -> &PublicKey {
662            &self.public
663        }
664
665        async fn sign(&self, message: &[u8]) -> Result<[u8; 64], Self::Error> {
666            Ok(self.secret.sign(message).to_bytes())
667        }
668
669        async fn agree(&self, peer: &PublicKey) -> Result<SharedSecret, Self::Error> {
670            self.shared_secret_with(peer)
671        }
672    }
673
674    pub type SoftwareCryptoEngine = CryptoEngine<SoftwareAes, SoftwareSha256>;
675
676    fn signing_key_to_x25519(secret: &SigningKey) -> x25519_dalek::StaticSecret {
677        let digest = Sha512::digest(secret.to_bytes());
678        let mut scalar = [0u8; 32];
679        scalar.copy_from_slice(&digest[..32]);
680        scalar[0] &= 248;
681        scalar[31] &= 127;
682        scalar[31] |= 64;
683        x25519_dalek::StaticSecret::from(scalar)
684    }
685
686    fn public_key_to_x25519(public: &PublicKey) -> Result<x25519_dalek::PublicKey, CryptoError> {
687        let compressed = CompressedEdwardsY(public.0);
688        let edwards = compressed
689            .decompress()
690            .ok_or(CryptoError::InvalidPublicKey)?;
691        let mont: MontgomeryPoint = edwards.to_montgomery();
692        Ok(x25519_dalek::PublicKey::from(mont.to_bytes()))
693    }
694}
695
696#[cfg(feature = "software-crypto")]
697pub use software::*;
698
699#[cfg(test)]
700mod tests {
701    use crate::{constant_time_eq, dbl};
702
703    #[cfg(feature = "software-crypto")]
704    use super::{software::*, *};
705    #[cfg(feature = "software-crypto")]
706    use umsh_core::{MicSize, NodeHint, PacketBuilder, PacketHeader, PublicKey};
707
708    #[cfg(feature = "software-crypto")]
709    #[test]
710    fn pairwise_hkdf_is_stable() {
711        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
712        let keys = engine.derive_pairwise_keys(&SharedSecret([7u8; 32]));
713        assert_ne!(keys.k_enc, [0u8; 16]);
714        assert_ne!(keys.k_mic, [0u8; 16]);
715    }
716
717    #[cfg(feature = "software-crypto")]
718    #[test]
719    fn aes_cmac_matches_rfc4493_example_2() {
720        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
721        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
722        let msg = hex_vec("6bc1bee22e409f96e93d7e117393172a");
723        let expected = hex_16("070a16b46b4d4144f79bdd9dd04a287c");
724        assert_eq!(engine.aes_cmac(&key, &[&msg]), expected);
725    }
726
727    #[cfg(feature = "software-crypto")]
728    #[test]
729    fn aes_cmac_matches_rfc4493_example_1_empty() {
730        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
731        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
732        let expected = hex_16("bb1d6929e95937287fa37d129b756746");
733        assert_eq!(engine.aes_cmac(&key, &[&[]]), expected);
734    }
735
736    #[cfg(feature = "software-crypto")]
737    #[test]
738    fn aes_cmac_matches_rfc4493_example_3_64b() {
739        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
740        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
741        let msg = hex_vec(
742            "6bc1bee22e409f96e93d7e117393172a\
743             ae2d8a571e03ac9c9eb76fac45af8e51\
744             30c81c46a35ce411e5fbc1191a0a52ef\
745             f69f2445df4f9b17ad2b417be66c3710",
746        );
747        let expected = hex_16("51f0bebf7e3b9d92fc49741779363cfe");
748        assert_eq!(engine.aes_cmac(&key, &[&msg]), expected);
749    }
750
751    #[cfg(feature = "software-crypto")]
752    #[test]
753    fn aes_cmac_matches_rfc4493_example_4_40b() {
754        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
755        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
756        let msg = hex_vec(
757            "6bc1bee22e409f96e93d7e117393172a\
758             ae2d8a571e03ac9c9eb76fac45af8e51\
759             30c81c46a35ce411",
760        );
761        let expected = hex_16("dfa66747de9ae63030ca32611497c827");
762        assert_eq!(engine.aes_cmac(&key, &[&msg]), expected);
763    }
764
765    /// NIST SP 800-38A Section F.5.1 — AES-128 CTR mode, single block.
766    #[cfg(feature = "software-crypto")]
767    #[test]
768    fn aes_ctr_matches_nist_sp800_38a_block_1() {
769        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
770        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
771        let iv = hex_16("f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff");
772        let mut data = hex_vec("6bc1bee22e409f96e93d7e117393172a");
773        engine.aes_ctr(&key, &iv, &mut data);
774        assert_eq!(data, hex_vec("874d6191b620e3261bef6864990db6ce"));
775    }
776
777    /// NIST SP 800-38A Section F.5.1 — AES-128 CTR mode, 4 blocks.
778    /// Verifies counter increment across multiple blocks.
779    #[cfg(feature = "software-crypto")]
780    #[test]
781    fn aes_ctr_matches_nist_sp800_38a_4_blocks() {
782        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
783        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
784        let iv = hex_16("f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff");
785        let mut data = hex_vec(
786            "6bc1bee22e409f96e93d7e117393172a\
787             ae2d8a571e03ac9c9eb76fac45af8e51\
788             30c81c46a35ce411e5fbc1191a0a52ef\
789             f69f2445df4f9b17ad2b417be66c3710",
790        );
791        let expected = hex_vec(
792            "874d6191b620e3261bef6864990db6ce\
793             9806f66b7970fdff8617187bb9fffdff\
794             5ae4df3edbd5d35e5b4f09020db03eab\
795             1e031dda2fbe03d1792170a0f3009cee",
796        );
797        engine.aes_ctr(&key, &iv, &mut data);
798        assert_eq!(data, expected);
799    }
800
801    /// NIST SP 800-38A — CTR decrypt (symmetric operation).
802    #[cfg(feature = "software-crypto")]
803    #[test]
804    fn aes_ctr_decrypt_matches_nist_sp800_38a() {
805        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
806        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
807        let iv = hex_16("f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff");
808        let mut data = hex_vec(
809            "874d6191b620e3261bef6864990db6ce\
810             9806f66b7970fdff8617187bb9fffdff\
811             5ae4df3edbd5d35e5b4f09020db03eab\
812             1e031dda2fbe03d1792170a0f3009cee",
813        );
814        let expected = hex_vec(
815            "6bc1bee22e409f96e93d7e117393172a\
816             ae2d8a571e03ac9c9eb76fac45af8e51\
817             30c81c46a35ce411e5fbc1191a0a52ef\
818             f69f2445df4f9b17ad2b417be66c3710",
819        );
820        engine.aes_ctr(&key, &iv, &mut data);
821        assert_eq!(data, expected);
822    }
823
824    /// AES-CTR with partial final block (non-aligned length).
825    #[cfg(feature = "software-crypto")]
826    #[test]
827    fn aes_ctr_partial_block() {
828        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
829        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
830        let iv = hex_16("f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff");
831        // Encrypt 5 bytes (less than one block)
832        let mut data = hex_vec("6bc1bee22e");
833        engine.aes_ctr(&key, &iv, &mut data);
834        // Should match first 5 bytes of full block 1 ciphertext
835        assert_eq!(data, hex_vec("874d6191b6"));
836        // Decrypt back
837        engine.aes_ctr(&key, &iv, &mut data);
838        assert_eq!(data, hex_vec("6bc1bee22e"));
839    }
840
841    #[cfg(feature = "software-crypto")]
842    #[test]
843    fn hmac_sha256_matches_rfc4231_case_1() {
844        let sha = SoftwareSha256;
845        let key = [0x0bu8; 20];
846        let expected = hex_32("b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7");
847        assert_eq!(sha.hmac(&key, &[b"Hi There"]), expected);
848    }
849
850    #[cfg(feature = "software-crypto")]
851    #[test]
852    fn hkdf_matches_rfc5869_case_1() {
853        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
854        let ikm = [0x0bu8; 22];
855        let salt = hex_vec("000102030405060708090a0b0c");
856        let info = hex_vec("f0f1f2f3f4f5f6f7f8f9");
857        let mut okm = [0u8; 42];
858        engine.hkdf(&ikm, &salt, &info, &mut okm);
859        assert_eq!(
860            okm.to_vec(),
861            hex_vec(
862                "3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf34007208d5b887185865"
863            )
864        );
865    }
866
867    #[test]
868    fn cmac_doubling_matches_rfc4493_subkey_generation() {
869        let l = hex_16("7df76b0c1ab899b33e42f047b91b546f");
870        let expected_k1 = hex_16("fbeed618357133667c85e08f7236a8de");
871        let expected_k2 = hex_16("f7ddac306ae266ccf90bc11ee46d513b");
872        let k1 = dbl(&l);
873        let k2 = dbl(&k1);
874        assert_eq!(k1, expected_k1);
875        assert_eq!(k2, expected_k2);
876    }
877
878    #[test]
879    fn constant_time_eq_rejects_mismatch() {
880        assert!(constant_time_eq(&[1, 2, 3], &[1, 2, 3]));
881        assert!(!constant_time_eq(&[1, 2, 3], &[1, 2, 4]));
882        assert!(!constant_time_eq(&[1, 2, 3], &[1, 2, 3, 4]));
883    }
884
885    #[cfg(feature = "software-crypto")]
886    #[test]
887    fn unicast_seal_and_open_round_trip() {
888        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
889        let keys = engine.derive_pairwise_keys(&SharedSecret([9u8; 32]));
890        let src = PublicKey([0xA1; 32]);
891        let dst = NodeHint([0xC3, 0xD4, 0x25]);
892        let mut buf = [0u8; 128];
893        let mut packet = PacketBuilder::new(&mut buf)
894            .unicast(dst)
895            .source_full(&src)
896            .frame_counter(1)
897            .encrypted()
898            .mic_size(MicSize::Mic16)
899            .payload(b"hello")
900            .build()
901            .unwrap();
902
903        engine.seal_packet(&mut packet, &keys).unwrap();
904        let header = PacketHeader::parse(packet.as_bytes()).unwrap();
905        let mut wire = packet.as_bytes().to_vec();
906        let range = engine.open_packet(&mut wire, &header, &keys).unwrap();
907        assert_eq!(&wire[range], b"hello");
908    }
909
910    /// Verify compute_ack_tag produces a deterministic 8-byte value and that
911    /// the same CMAC+key always yields the same tag.
912    #[cfg(feature = "software-crypto")]
913    #[test]
914    fn compute_ack_tag_is_deterministic() {
915        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
916        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
917        let cmac = hex_16("070a16b46b4d4144f79bdd9dd04a287c"); // RFC 4493 example 2
918        let tag1 = engine.compute_ack_tag(&cmac, &key);
919        let tag2 = engine.compute_ack_tag(&cmac, &key);
920        assert_eq!(tag1, tag2);
921        assert_eq!(tag1.len(), 8);
922        // Tag should be the first 8 bytes of AES-ECB(key, cmac)
923        let cipher = SoftwareAes.new_cipher(&key);
924        let mut block = cmac;
925        cipher.encrypt_block(&mut block);
926        assert_eq!(tag1, block[..8]);
927    }
928
929    /// Verify that compute_ack_tag with different CMACs produces different tags,
930    /// and that the tag is the first 8 bytes of AES-ECB(k_enc, cmac).
931    #[cfg(feature = "software-crypto")]
932    #[test]
933    fn compute_ack_tag_differs_for_different_cmacs() {
934        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
935        let key = hex_16("2b7e151628aed2a6abf7158809cf4f3c");
936        let cmac_a = hex_16("070a16b46b4d4144f79bdd9dd04a287c");
937        let cmac_b = hex_16("51f0bebf7e3b9d92fc49741779363cfe");
938        let tag_a = engine.compute_ack_tag(&cmac_a, &key);
939        let tag_b = engine.compute_ack_tag(&cmac_b, &key);
940        assert_ne!(tag_a, tag_b);
941        assert_ne!(tag_a, [0u8; 8]);
942        assert_ne!(tag_b, [0u8; 8]);
943    }
944
945    #[cfg(feature = "software-crypto")]
946    #[test]
947    fn encrypted_multicast_round_trip_preserves_source_prefix() {
948        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
949        let channel_key = ChannelKey([0x5Au8; 32]);
950        let derived = engine.derive_channel_keys(&channel_key);
951        let src = PublicKey([0xA1; 32]);
952        let mut buf = [0u8; 160];
953        let mut packet = PacketBuilder::new(&mut buf)
954            .multicast(derived.channel_id)
955            .source_hint(src.hint())
956            .frame_counter(5)
957            .encrypted()
958            .mic_size(MicSize::Mic16)
959            .payload(b"hello")
960            .build()
961            .unwrap();
962
963        let multicast_keys = PairwiseKeys {
964            k_enc: derived.k_enc,
965            k_mic: derived.k_mic,
966        };
967        engine.seal_packet(&mut packet, &multicast_keys).unwrap();
968        let header = PacketHeader::parse(packet.as_bytes()).unwrap();
969        let mut wire = packet.as_bytes().to_vec();
970        let range = engine
971            .open_packet(&mut wire, &header, &multicast_keys)
972            .unwrap();
973        assert_eq!(
974            &wire[header.body_range.start..header.body_range.start + 3],
975            &src.hint().0
976        );
977        assert_eq!(&wire[range], b"hello");
978    }
979
980    #[cfg(feature = "software-crypto")]
981    #[test]
982    fn blind_unicast_round_trip_recovers_addresses_and_payload() {
983        let engine = SoftwareCryptoEngine::new(SoftwareAes, SoftwareSha256);
984        let shared = SharedSecret([0x33u8; 32]);
985        let pairwise = engine.derive_pairwise_keys(&shared);
986        let channel_key = ChannelKey([0x5Au8; 32]);
987        let channel = engine.derive_channel_keys(&channel_key);
988        let blind_keys = engine.derive_blind_keys(&pairwise, &channel);
989        let src = PublicKey([
990            0xA1, 0xB2, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
991            0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C,
992            0x1D, 0x1E, 0x1F, 0x20,
993        ]);
994        let dst = NodeHint([0xC3, 0xD4, 0x25]);
995        let mut buf = [0u8; 160];
996        let mut packet = PacketBuilder::new(&mut buf)
997            .blind_unicast(channel.channel_id, dst)
998            .source_full(&src)
999            .frame_counter(5)
1000            .mic_size(MicSize::Mic16)
1001            .payload(b"hello")
1002            .build()
1003            .unwrap();
1004
1005        engine
1006            .seal_blind_packet(&mut packet, &blind_keys, &channel)
1007            .unwrap();
1008        let header = PacketHeader::parse(packet.as_bytes()).unwrap();
1009        let mut wire = packet.as_bytes().to_vec();
1010        let (decoded_dst, decoded_src) = engine
1011            .decrypt_blind_addr(&mut wire, &header, &channel)
1012            .unwrap();
1013        let range = engine.open_packet(&mut wire, &header, &blind_keys).unwrap();
1014
1015        assert_eq!(decoded_dst, dst);
1016        assert_eq!(
1017            decoded_src,
1018            SourceAddrRef::FullKeyAt {
1019                offset: header.body_range.start - 32
1020            }
1021        );
1022        assert_eq!(
1023            &wire[header.body_range.start - 32..header.body_range.start],
1024            &src.0
1025        );
1026        assert_eq!(&wire[range], b"hello");
1027    }
1028
1029    #[cfg(feature = "software-crypto")]
1030    #[test]
1031    fn software_identity_agreement_is_symmetric() {
1032        let alice = SoftwareIdentity::from_secret_bytes(&[
1033            0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E,
1034            0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C,
1035            0x2D, 0x2E, 0x2F, 0x30,
1036        ]);
1037        let bob = SoftwareIdentity::from_secret_bytes(&[
1038            0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E,
1039            0x3F, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, 0x4B, 0x4C,
1040            0x4D, 0x4E, 0x4F, 0x50,
1041        ]);
1042
1043        let ab = alice.shared_secret_with(bob.public_key()).unwrap();
1044        let ba = bob.shared_secret_with(alice.public_key()).unwrap();
1045        assert_eq!(ab.0, ba.0);
1046    }
1047
1048    fn hex_vec(input: &str) -> std::vec::Vec<u8> {
1049        assert_eq!(input.len() % 2, 0);
1050        let mut out = std::vec::Vec::with_capacity(input.len() / 2);
1051        let bytes = input.as_bytes();
1052        for index in (0..bytes.len()).step_by(2) {
1053            out.push((decode_hex(bytes[index]) << 4) | decode_hex(bytes[index + 1]));
1054        }
1055        out
1056    }
1057
1058    fn hex_16(input: &str) -> [u8; 16] {
1059        let bytes = hex_vec(input);
1060        let mut out = [0u8; 16];
1061        out.copy_from_slice(&bytes);
1062        out
1063    }
1064
1065    #[cfg(feature = "software-crypto")]
1066    fn hex_32(input: &str) -> [u8; 32] {
1067        let bytes = hex_vec(input);
1068        let mut out = [0u8; 32];
1069        out.copy_from_slice(&bytes);
1070        out
1071    }
1072
1073    fn decode_hex(byte: u8) -> u8 {
1074        match byte {
1075            b'0'..=b'9' => byte - b'0',
1076            b'a'..=b'f' => byte - b'a' + 10,
1077            b'A'..=b'F' => byte - b'A' + 10,
1078            _ => panic!("invalid hex"),
1079        }
1080    }
1081}