1use heapless::{LinearMap, Vec};
2use umsh_core::{ChannelId, ChannelKey, NodeHint, PublicKey, RouterHint};
3use umsh_crypto::{DerivedChannelKeys, PairwiseKeys};
4
5use crate::{CapacityError, cache::ReplayWindow};
6
7#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
9pub struct PeerId(pub u8);
10
11#[derive(Clone, Debug, PartialEq, Eq)]
13pub enum CachedRoute {
14 Direct,
19 Source(Vec<RouterHint, 15>),
21 Flood { hops: u8, regions: Vec<[u8; 2], 8> },
23}
24
25#[derive(Clone, Debug, PartialEq, Eq)]
27pub struct PeerInfo {
28 pub public_key: PublicKey,
30 pub pinned: bool,
32 pub route: Option<CachedRoute>,
34 pub last_seen_ms: u64,
36}
37
38#[derive(Clone, Copy, Debug, PartialEq, Eq)]
40pub struct AutoPeerUpdate {
41 pub peer_id: PeerId,
43 pub evicted_key: Option<PublicKey>,
45}
46
47#[derive(Clone, Debug)]
49pub struct PeerRegistry<const N: usize> {
50 peers: Vec<PeerInfo, N>,
51}
52
53impl<const N: usize> Default for PeerRegistry<N> {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl<const N: usize> PeerRegistry<N> {
60 pub fn new() -> Self {
62 Self { peers: Vec::new() }
63 }
64
65 pub fn lookup_by_hint(&self, hint: &NodeHint) -> impl Iterator<Item = (PeerId, &PeerInfo)> {
67 self.peers
68 .iter()
69 .enumerate()
70 .filter(move |(_, peer)| peer.public_key.hint() == *hint)
71 .map(|(index, peer)| (PeerId(index as u8), peer))
72 }
73
74 pub fn lookup_by_key(&self, key: &PublicKey) -> Option<(PeerId, &PeerInfo)> {
76 self.peers
77 .iter()
78 .enumerate()
79 .find(|(_, peer)| peer.public_key == *key)
80 .map(|(index, peer)| (PeerId(index as u8), peer))
81 }
82
83 pub fn get(&self, id: PeerId) -> Option<&PeerInfo> {
85 self.peers.get(id.0 as usize)
86 }
87
88 pub fn get_mut(&mut self, id: PeerId) -> Option<&mut PeerInfo> {
90 self.peers.get_mut(id.0 as usize)
91 }
92
93 pub fn try_insert_or_update(&mut self, key: PublicKey) -> Result<PeerId, CapacityError> {
95 if let Some((id, peer)) = self
96 .peers
97 .iter_mut()
98 .enumerate()
99 .find(|(_, peer)| peer.public_key == key)
100 {
101 peer.public_key = key;
102 peer.pinned = true;
103 return Ok(PeerId(id as u8));
104 }
105
106 self.peers
107 .push(PeerInfo {
108 public_key: key,
109 pinned: true,
110 route: None,
111 last_seen_ms: 0,
112 })
113 .map_err(|_| CapacityError)?;
114 Ok(PeerId((self.peers.len() - 1) as u8))
115 }
116
117 pub fn try_insert_or_update_auto(
122 &mut self,
123 key: PublicKey,
124 now_ms: u64,
125 ) -> Result<AutoPeerUpdate, CapacityError> {
126 if let Some((id, peer)) = self
127 .peers
128 .iter_mut()
129 .enumerate()
130 .find(|(_, peer)| peer.public_key == key)
131 {
132 peer.last_seen_ms = now_ms;
133 return Ok(AutoPeerUpdate {
134 peer_id: PeerId(id as u8),
135 evicted_key: None,
136 });
137 }
138
139 if self.peers.len() < N {
140 self.peers
141 .push(PeerInfo {
142 public_key: key,
143 pinned: false,
144 route: None,
145 last_seen_ms: now_ms,
146 })
147 .map_err(|_| CapacityError)?;
148 return Ok(AutoPeerUpdate {
149 peer_id: PeerId((self.peers.len() - 1) as u8),
150 evicted_key: None,
151 });
152 }
153
154 let Some((index, oldest)) = self
155 .peers
156 .iter()
157 .enumerate()
158 .filter(|(_, peer)| !peer.pinned)
159 .min_by_key(|(_, peer)| peer.last_seen_ms)
160 else {
161 return Err(CapacityError);
162 };
163
164 let evicted_key = oldest.public_key;
165 self.peers[index] = PeerInfo {
166 public_key: key,
167 pinned: false,
168 route: None,
169 last_seen_ms: now_ms,
170 };
171 Ok(AutoPeerUpdate {
172 peer_id: PeerId(index as u8),
173 evicted_key: Some(evicted_key),
174 })
175 }
176
177 pub fn update_route(&mut self, id: PeerId, route: CachedRoute) {
179 if let Some(peer) = self.get_mut(id) {
180 peer.route = Some(route);
181 }
182 }
183
184 pub fn touch(&mut self, id: PeerId, now_ms: u64) {
186 if let Some(peer) = self.get_mut(id) {
187 peer.last_seen_ms = now_ms;
188 }
189 }
190}
191
192#[derive(Clone)]
194pub struct PeerCryptoState {
195 pub pairwise_keys: PairwiseKeys,
197 pub replay_window: ReplayWindow,
199}
200
201#[derive(Clone)]
203pub struct PeerCryptoMap<const N: usize> {
204 entries: LinearMap<PeerId, PeerCryptoState, N>,
205}
206
207impl<const N: usize> Default for PeerCryptoMap<N> {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213impl<const N: usize> PeerCryptoMap<N> {
214 pub fn new() -> Self {
216 Self {
217 entries: LinearMap::new(),
218 }
219 }
220
221 pub fn get(&self, id: &PeerId) -> Option<&PeerCryptoState> {
223 self.entries.get(id)
224 }
225
226 pub fn get_mut(&mut self, id: &PeerId) -> Option<&mut PeerCryptoState> {
228 self.entries.get_mut(id)
229 }
230
231 pub fn insert(
233 &mut self,
234 id: PeerId,
235 state: PeerCryptoState,
236 ) -> Result<Option<PeerCryptoState>, CapacityError> {
237 self.entries.insert(id, state).map_err(|_| CapacityError)
238 }
239
240 pub fn remove(&mut self, id: &PeerId) -> Option<PeerCryptoState> {
242 self.entries.remove(id)
243 }
244}
245
246#[derive(Clone)]
248pub struct HintReplayState {
249 pub window: ReplayWindow,
251 pub last_seen_ms: u64,
253}
254
255#[derive(Clone)]
257pub struct ChannelState<const RN: usize = 8, const HN: usize = 8> {
258 pub channel_key: ChannelKey,
260 pub derived: DerivedChannelKeys,
262 pub replay: LinearMap<PeerId, ReplayWindow, RN>,
264 pub hint_replay: LinearMap<NodeHint, HintReplayState, HN>,
266}
267
268impl<const RN: usize, const HN: usize> ChannelState<RN, HN> {
269 pub fn new(channel_key: ChannelKey, derived: DerivedChannelKeys) -> Self {
271 Self {
272 channel_key,
273 derived,
274 replay: LinearMap::new(),
275 hint_replay: LinearMap::new(),
276 }
277 }
278}
279
280#[derive(Clone)]
282pub struct ChannelTable<const N: usize, const RN: usize = 8, const HN: usize = 8> {
283 channels: Vec<ChannelState<RN, HN>, N>,
284}
285
286impl<const N: usize, const RN: usize, const HN: usize> Default for ChannelTable<N, RN, HN> {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292impl<const N: usize, const RN: usize, const HN: usize> ChannelTable<N, RN, HN> {
293 pub fn new() -> Self {
295 Self {
296 channels: Vec::new(),
297 }
298 }
299
300 pub fn len(&self) -> usize {
302 self.channels.len()
303 }
304
305 pub fn is_empty(&self) -> bool {
307 self.channels.is_empty()
308 }
309
310 pub fn lookup_by_id(&self, id: &ChannelId) -> impl Iterator<Item = &ChannelState<RN, HN>> {
312 self.channels
313 .iter()
314 .filter(move |channel| channel.derived.channel_id == *id)
315 }
316
317 pub fn get_mut_by_id(&mut self, id: &ChannelId) -> Option<&mut ChannelState<RN, HN>> {
319 self.channels
320 .iter_mut()
321 .find(|channel| channel.derived.channel_id == *id)
322 }
323
324 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut ChannelState<RN, HN>> {
326 self.channels.iter_mut()
327 }
328
329 pub fn try_add(
331 &mut self,
332 key: ChannelKey,
333 derived: DerivedChannelKeys,
334 ) -> Result<(), CapacityError> {
335 if let Some(channel) = self.get_mut_by_id(&derived.channel_id) {
336 channel.channel_key = key;
337 channel.derived = derived;
338 return Ok(());
339 }
340
341 self.channels
342 .push(ChannelState::new(key, derived))
343 .map_err(|_| CapacityError)
344 }
345}