1use heapless::{LinearMap, Vec};
2use umsh_core::{ChannelId, ChannelKey, NodeHint, PublicKey, RouterHint};
3use umsh_crypto::{DerivedChannelKeys, PairwiseKeys};
4
5use crate::{CapacityError, cache::ReplayWindow};
6
7#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
9pub struct PeerId(pub u8);
10
11#[derive(Clone, Debug, PartialEq, Eq)]
13pub enum CachedRoute {
14 Source(Vec<RouterHint, 15>),
16 Flood { hops: u8 },
18}
19
20#[derive(Clone, Debug, PartialEq, Eq)]
22pub struct PeerInfo {
23 pub public_key: PublicKey,
25 pub pinned: bool,
27 pub route: Option<CachedRoute>,
29 pub last_seen_ms: u64,
31}
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub struct AutoPeerUpdate {
36 pub peer_id: PeerId,
38 pub evicted_key: Option<PublicKey>,
40}
41
42#[derive(Clone, Debug)]
44pub struct PeerRegistry<const N: usize> {
45 peers: Vec<PeerInfo, N>,
46}
47
48impl<const N: usize> Default for PeerRegistry<N> {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl<const N: usize> PeerRegistry<N> {
55 pub fn new() -> Self {
57 Self { peers: Vec::new() }
58 }
59
60 pub fn lookup_by_hint(&self, hint: &NodeHint) -> impl Iterator<Item = (PeerId, &PeerInfo)> {
62 self.peers
63 .iter()
64 .enumerate()
65 .filter(move |(_, peer)| peer.public_key.hint() == *hint)
66 .map(|(index, peer)| (PeerId(index as u8), peer))
67 }
68
69 pub fn lookup_by_key(&self, key: &PublicKey) -> Option<(PeerId, &PeerInfo)> {
71 self.peers
72 .iter()
73 .enumerate()
74 .find(|(_, peer)| peer.public_key == *key)
75 .map(|(index, peer)| (PeerId(index as u8), peer))
76 }
77
78 pub fn get(&self, id: PeerId) -> Option<&PeerInfo> {
80 self.peers.get(id.0 as usize)
81 }
82
83 pub fn get_mut(&mut self, id: PeerId) -> Option<&mut PeerInfo> {
85 self.peers.get_mut(id.0 as usize)
86 }
87
88 pub fn try_insert_or_update(&mut self, key: PublicKey) -> Result<PeerId, CapacityError> {
90 if let Some((id, peer)) = self
91 .peers
92 .iter_mut()
93 .enumerate()
94 .find(|(_, peer)| peer.public_key == key)
95 {
96 peer.public_key = key;
97 peer.pinned = true;
98 return Ok(PeerId(id as u8));
99 }
100
101 self.peers
102 .push(PeerInfo {
103 public_key: key,
104 pinned: true,
105 route: None,
106 last_seen_ms: 0,
107 })
108 .map_err(|_| CapacityError)?;
109 Ok(PeerId((self.peers.len() - 1) as u8))
110 }
111
112 pub fn try_insert_or_update_auto(
117 &mut self,
118 key: PublicKey,
119 now_ms: u64,
120 ) -> Result<AutoPeerUpdate, CapacityError> {
121 if let Some((id, peer)) = self
122 .peers
123 .iter_mut()
124 .enumerate()
125 .find(|(_, peer)| peer.public_key == key)
126 {
127 peer.last_seen_ms = now_ms;
128 return Ok(AutoPeerUpdate {
129 peer_id: PeerId(id as u8),
130 evicted_key: None,
131 });
132 }
133
134 if self.peers.len() < N {
135 self.peers
136 .push(PeerInfo {
137 public_key: key,
138 pinned: false,
139 route: None,
140 last_seen_ms: now_ms,
141 })
142 .map_err(|_| CapacityError)?;
143 return Ok(AutoPeerUpdate {
144 peer_id: PeerId((self.peers.len() - 1) as u8),
145 evicted_key: None,
146 });
147 }
148
149 let Some((index, oldest)) = self
150 .peers
151 .iter()
152 .enumerate()
153 .filter(|(_, peer)| !peer.pinned)
154 .min_by_key(|(_, peer)| peer.last_seen_ms)
155 else {
156 return Err(CapacityError);
157 };
158
159 let evicted_key = oldest.public_key;
160 self.peers[index] = PeerInfo {
161 public_key: key,
162 pinned: false,
163 route: None,
164 last_seen_ms: now_ms,
165 };
166 Ok(AutoPeerUpdate {
167 peer_id: PeerId(index as u8),
168 evicted_key: Some(evicted_key),
169 })
170 }
171
172 pub fn update_route(&mut self, id: PeerId, route: CachedRoute) {
174 if let Some(peer) = self.get_mut(id) {
175 peer.route = Some(route);
176 }
177 }
178
179 pub fn touch(&mut self, id: PeerId, now_ms: u64) {
181 if let Some(peer) = self.get_mut(id) {
182 peer.last_seen_ms = now_ms;
183 }
184 }
185}
186
187#[derive(Clone)]
189pub struct PeerCryptoState {
190 pub pairwise_keys: PairwiseKeys,
192 pub replay_window: ReplayWindow,
194}
195
196#[derive(Clone)]
198pub struct PeerCryptoMap<const N: usize> {
199 entries: LinearMap<PeerId, PeerCryptoState, N>,
200}
201
202impl<const N: usize> Default for PeerCryptoMap<N> {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl<const N: usize> PeerCryptoMap<N> {
209 pub fn new() -> Self {
211 Self {
212 entries: LinearMap::new(),
213 }
214 }
215
216 pub fn get(&self, id: &PeerId) -> Option<&PeerCryptoState> {
218 self.entries.get(id)
219 }
220
221 pub fn get_mut(&mut self, id: &PeerId) -> Option<&mut PeerCryptoState> {
223 self.entries.get_mut(id)
224 }
225
226 pub fn insert(
228 &mut self,
229 id: PeerId,
230 state: PeerCryptoState,
231 ) -> Result<Option<PeerCryptoState>, CapacityError> {
232 self.entries.insert(id, state).map_err(|_| CapacityError)
233 }
234
235 pub fn remove(&mut self, id: &PeerId) -> Option<PeerCryptoState> {
237 self.entries.remove(id)
238 }
239}
240
241#[derive(Clone)]
243pub struct HintReplayState {
244 pub window: ReplayWindow,
246 pub last_seen_ms: u64,
248}
249
250#[derive(Clone)]
252pub struct ChannelState<const RN: usize = 8, const HN: usize = 8> {
253 pub channel_key: ChannelKey,
255 pub derived: DerivedChannelKeys,
257 pub replay: LinearMap<PeerId, ReplayWindow, RN>,
259 pub hint_replay: LinearMap<NodeHint, HintReplayState, HN>,
261}
262
263impl<const RN: usize, const HN: usize> ChannelState<RN, HN> {
264 pub fn new(channel_key: ChannelKey, derived: DerivedChannelKeys) -> Self {
266 Self {
267 channel_key,
268 derived,
269 replay: LinearMap::new(),
270 hint_replay: LinearMap::new(),
271 }
272 }
273}
274
275#[derive(Clone)]
277pub struct ChannelTable<const N: usize, const RN: usize = 8, const HN: usize = 8> {
278 channels: Vec<ChannelState<RN, HN>, N>,
279}
280
281impl<const N: usize, const RN: usize, const HN: usize> Default for ChannelTable<N, RN, HN> {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287impl<const N: usize, const RN: usize, const HN: usize> ChannelTable<N, RN, HN> {
288 pub fn new() -> Self {
290 Self {
291 channels: Vec::new(),
292 }
293 }
294
295 pub fn len(&self) -> usize {
297 self.channels.len()
298 }
299
300 pub fn is_empty(&self) -> bool {
302 self.channels.is_empty()
303 }
304
305 pub fn lookup_by_id(&self, id: &ChannelId) -> impl Iterator<Item = &ChannelState<RN, HN>> {
307 self.channels
308 .iter()
309 .filter(move |channel| channel.derived.channel_id == *id)
310 }
311
312 pub fn get_mut_by_id(&mut self, id: &ChannelId) -> Option<&mut ChannelState<RN, HN>> {
314 self.channels
315 .iter_mut()
316 .find(|channel| channel.derived.channel_id == *id)
317 }
318
319 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut ChannelState<RN, HN>> {
321 self.channels.iter_mut()
322 }
323
324 pub fn try_add(
326 &mut self,
327 key: ChannelKey,
328 derived: DerivedChannelKeys,
329 ) -> Result<(), CapacityError> {
330 if let Some(channel) = self.get_mut_by_id(&derived.channel_id) {
331 channel.channel_key = key;
332 channel.derived = derived;
333 return Ok(());
334 }
335
336 self.channels
337 .push(ChannelState::new(key, derived))
338 .map_err(|_| CapacityError)
339 }
340}