1#![allow(async_fn_in_trait)]
2#![cfg_attr(not(feature = "std"), no_std)]
3
4extern crate alloc;
14
15use alloc::vec::Vec;
16mod error;
17mod text;
18
19use umsh_core::{PacketType, PayloadType};
20use umsh_mac::SendOptions;
21use umsh_node::{LocalNode, Subscription, Transport};
22
23#[cfg(feature = "software-crypto")]
24use umsh_node::{BoundChannel, MacBackend};
25
26pub use error::{EncodeError, ParseError, TextSendError};
27pub use text::{
28 Fragment, MessageSequence, MessageType, OwnedTextMessage, Regarding, TextMessage,
29 encode as encode_text_message, parse as parse_text_message,
30};
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum TextReceiveIssue {
35 WrongPayloadType(PayloadType),
36 Parse(ParseError),
37}
38
39pub fn parse_text_payload(
41 packet_type: PacketType,
42 payload: &[u8],
43) -> Result<TextMessage<'_>, ParseError> {
44 parse_text_message(expect_payload_type(
45 packet_type,
46 payload,
47 PayloadType::TextMessage,
48 )?)
49}
50
51pub mod text_message {
53 pub use crate::text::{encode, parse};
54}
55
56use umsh_node::PeerConnection;
57use umsh_node::SendProgressTicket;
58
59#[derive(Clone)]
61pub struct UnicastTextChatWrapper<T: Transport + Clone> {
62 peer: PeerConnection<T>,
63}
64
65impl<T: Transport + Clone> UnicastTextChatWrapper<T> {
66 pub fn new(peer: PeerConnection<T>) -> Self {
67 Self { peer }
68 }
69
70 pub fn from_peer(peer: &PeerConnection<T>) -> Self {
71 Self { peer: peer.clone() }
72 }
73
74 pub fn peer_connection(&self) -> &PeerConnection<T> {
75 &self.peer
76 }
77
78 pub fn peer(&self) -> &umsh_core::PublicKey {
79 self.peer.peer()
80 }
81
82 pub async fn send_message(
83 &self,
84 message: &TextMessage<'_>,
85 options: &SendOptions,
86 ) -> Result<SendProgressTicket, TextSendError<T::Error>> {
87 let payload = encode_text_payload(message)?;
88 self.peer
89 .send(&payload, options)
90 .await
91 .map_err(TextSendError::Transport)
92 }
93
94 pub async fn send_owned_message(
95 &self,
96 message: &OwnedTextMessage,
97 options: &SendOptions,
98 ) -> Result<SendProgressTicket, TextSendError<T::Error>> {
99 self.send_message(&message.as_borrowed(), options).await
100 }
101
102 pub async fn send_text(
103 &self,
104 body: &str,
105 options: &SendOptions,
106 ) -> Result<SendProgressTicket, TextSendError<T::Error>> {
107 let message = TextMessage {
108 message_type: MessageType::Basic,
109 sender_handle: None,
110 sequence: None,
111 sequence_reset: false,
112 regarding: None,
113 editing: None,
114 bg_color: None,
115 text_color: None,
116 body,
117 };
118 self.send_message(&message, options).await
119 }
120}
121
122impl<M: umsh_node::MacBackend> UnicastTextChatWrapper<LocalNode<M>> {
123 pub fn on_text<F>(&self, handler: F) -> Subscription
124 where
125 F: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextMessage<'_>) + 'static,
126 {
127 self.on_text_with_diagnostics(handler, |_, _| {})
128 }
129
130 pub fn on_text_with_diagnostics<F, D>(&self, mut handler: F, mut diagnostics: D) -> Subscription
131 where
132 F: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextMessage<'_>) + 'static,
133 D: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextReceiveIssue) + 'static,
134 {
135 self.peer.on_receive(move |packet| {
136 if packet.payload_type() != PayloadType::TextMessage {
137 diagnostics(
138 packet,
139 TextReceiveIssue::WrongPayloadType(packet.payload_type()),
140 );
141 return false;
142 }
143 let message = match parse_text_message(packet.payload()) {
144 Ok(message) => message,
145 Err(error) => {
146 diagnostics(packet, TextReceiveIssue::Parse(error));
147 return false;
148 }
149 };
150 handler(packet, message);
151 true
152 })
153 }
154}
155
156#[cfg(feature = "software-crypto")]
157#[derive(Clone)]
158pub struct MulticastTextChatWrapper<M: MacBackend> {
159 channel: BoundChannel<M>,
160}
161
162#[cfg(feature = "software-crypto")]
163impl<M: MacBackend> MulticastTextChatWrapper<M> {
164 pub fn new(channel: BoundChannel<M>) -> Self {
165 Self { channel }
166 }
167
168 pub fn from_channel(channel: &BoundChannel<M>) -> Self {
169 Self {
170 channel: channel.clone(),
171 }
172 }
173
174 pub fn bound_channel(&self) -> &BoundChannel<M> {
175 &self.channel
176 }
177
178 pub async fn send_message(
179 &self,
180 message: &TextMessage<'_>,
181 options: &SendOptions,
182 ) -> Result<SendProgressTicket, TextSendError<umsh_node::NodeError<M>>> {
183 let payload = encode_text_payload(message)?;
184 self.channel
185 .send_all(&payload, options)
186 .await
187 .map_err(TextSendError::Transport)
188 }
189
190 pub async fn send_owned_message(
191 &self,
192 message: &OwnedTextMessage,
193 options: &SendOptions,
194 ) -> Result<SendProgressTicket, TextSendError<umsh_node::NodeError<M>>> {
195 self.send_message(&message.as_borrowed(), options).await
196 }
197
198 pub async fn send_text(
199 &self,
200 body: &str,
201 options: &SendOptions,
202 ) -> Result<SendProgressTicket, TextSendError<umsh_node::NodeError<M>>> {
203 let message = TextMessage {
204 message_type: MessageType::Basic,
205 sender_handle: None,
206 sequence: None,
207 sequence_reset: false,
208 regarding: None,
209 editing: None,
210 bg_color: None,
211 text_color: None,
212 body,
213 };
214 self.send_message(&message, options).await
215 }
216
217 pub fn on_text<F>(&self, handler: F) -> Subscription
218 where
219 F: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextMessage<'_>) + 'static,
220 {
221 self.on_text_with_diagnostics(handler, |_, _| {})
222 }
223
224 pub fn on_text_with_diagnostics<F, D>(&self, mut handler: F, mut diagnostics: D) -> Subscription
225 where
226 F: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextMessage<'_>) + 'static,
227 D: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextReceiveIssue) + 'static,
228 {
229 let channel_id = *self.channel.channel().channel_id();
230 self.channel
231 .node()
232 .on_receive(move |packet: &umsh_node::ReceivedPacketRef<'_>| {
233 let Some(channel) = packet.channel() else {
234 return false;
235 };
236 if channel.id() != channel_id {
237 return false;
238 }
239 if packet.payload_type() != PayloadType::TextMessage {
240 diagnostics(
241 packet,
242 TextReceiveIssue::WrongPayloadType(packet.payload_type()),
243 );
244 return false;
245 }
246 let message = match parse_text_message(packet.payload()) {
247 Ok(message) => message,
248 Err(error) => {
249 diagnostics(packet, TextReceiveIssue::Parse(error));
250 return false;
251 }
252 };
253 handler(packet, message);
254 true
255 })
256 }
257}
258
259fn encode_text_payload(message: &TextMessage<'_>) -> Result<Vec<u8>, EncodeError> {
260 let mut body = [0u8; 512];
261 let len = text_message::encode(message, &mut body)?;
262 let mut payload = Vec::with_capacity(len + 1);
263 payload.push(PayloadType::TextMessage as u8);
264 payload.extend_from_slice(&body[..len]);
265 Ok(payload)
266}
267
268fn split_payload_type(payload: &[u8]) -> Result<(PayloadType, &[u8]), ParseError> {
270 if payload.is_empty() {
271 return Ok((PayloadType::Empty, &[]));
272 }
273 if let Some(payload_type) = PayloadType::from_byte(payload[0]) {
274 Ok((payload_type, &payload[1..]))
275 } else {
276 Ok((PayloadType::Empty, payload))
277 }
278}
279
280fn expect_payload_type(
282 packet_type: PacketType,
283 payload: &[u8],
284 expected: PayloadType,
285) -> Result<&[u8], ParseError> {
286 let (payload_type, body) = split_payload_type(payload)?;
287 if !payload_type.allowed_for(packet_type) {
288 return Err(ParseError::PayloadTypeNotAllowed {
289 payload_type: payload_type as u8,
290 packet_type,
291 });
292 }
293 if payload_type != expected {
294 return Err(ParseError::InvalidPayloadType(payload_type as u8));
295 }
296 Ok(body)
297}