umsh_text/
lib.rs

1#![allow(async_fn_in_trait)]
2#![cfg_attr(not(feature = "std"), no_std)]
3
4//! Text-message support for UMSH.
5//!
6//! This crate owns the text-message payload types/codecs together with the
7//! node-layer convenience wrappers for plain chat-style applications.
8//!
9//! Applications that want raw packet access can stay at the `umsh-node`
10//! layer, while callers that want text-specific ergonomics can build on the
11//! wrappers here.
12
13extern crate alloc;
14
15use alloc::vec::Vec;
16mod error;
17mod text;
18
19use umsh_core::{PacketType, PayloadType};
20use umsh_mac::SendOptions;
21use umsh_node::{LocalNode, Subscription, Transport};
22
23#[cfg(feature = "software-crypto")]
24use umsh_node::{BoundChannel, MacBackend};
25
26pub use error::{EncodeError, ParseError, TextSendError};
27pub use text::{
28    Fragment, MessageSequence, MessageType, OwnedTextMessage, Regarding, TextMessage,
29    encode as encode_text_message, parse as parse_text_message,
30};
31
32/// Reason a received packet did not become a text message callback.
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum TextReceiveIssue {
35    WrongPayloadType(PayloadType),
36    Parse(ParseError),
37}
38
39/// Parse a typed text payload in the context of the enclosing packet type.
40pub fn parse_text_payload(
41    packet_type: PacketType,
42    payload: &[u8],
43) -> Result<TextMessage<'_>, ParseError> {
44    parse_text_message(expect_payload_type(
45        packet_type,
46        payload,
47        PayloadType::TextMessage,
48    )?)
49}
50
51/// Namespace for raw text-message codec functions.
52pub mod text_message {
53    pub use crate::text::{encode, parse};
54}
55
56use umsh_node::PeerConnection;
57use umsh_node::SendProgressTicket;
58
59/// Thin convenience wrapper for plain unicast text chat over a peer connection.
60#[derive(Clone)]
61pub struct UnicastTextChatWrapper<T: Transport + Clone> {
62    peer: PeerConnection<T>,
63}
64
65impl<T: Transport + Clone> UnicastTextChatWrapper<T> {
66    pub fn new(peer: PeerConnection<T>) -> Self {
67        Self { peer }
68    }
69
70    pub fn from_peer(peer: &PeerConnection<T>) -> Self {
71        Self { peer: peer.clone() }
72    }
73
74    pub fn peer_connection(&self) -> &PeerConnection<T> {
75        &self.peer
76    }
77
78    pub fn peer(&self) -> &umsh_core::PublicKey {
79        self.peer.peer()
80    }
81
82    pub async fn send_message(
83        &self,
84        message: &TextMessage<'_>,
85        options: &SendOptions,
86    ) -> Result<SendProgressTicket, TextSendError<T::Error>> {
87        let payload = encode_text_payload(message)?;
88        self.peer
89            .send(&payload, options)
90            .await
91            .map_err(TextSendError::Transport)
92    }
93
94    pub async fn send_owned_message(
95        &self,
96        message: &OwnedTextMessage,
97        options: &SendOptions,
98    ) -> Result<SendProgressTicket, TextSendError<T::Error>> {
99        self.send_message(&message.as_borrowed(), options).await
100    }
101
102    pub async fn send_text(
103        &self,
104        body: &str,
105        options: &SendOptions,
106    ) -> Result<SendProgressTicket, TextSendError<T::Error>> {
107        let message = TextMessage {
108            message_type: MessageType::Basic,
109            sender_handle: None,
110            sequence: None,
111            sequence_reset: false,
112            regarding: None,
113            editing: None,
114            bg_color: None,
115            text_color: None,
116            body,
117        };
118        self.send_message(&message, options).await
119    }
120}
121
122impl<M: umsh_node::MacBackend> UnicastTextChatWrapper<LocalNode<M>> {
123    pub fn on_text<F>(&self, handler: F) -> Subscription
124    where
125        F: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextMessage<'_>) + 'static,
126    {
127        self.on_text_with_diagnostics(handler, |_, _| {})
128    }
129
130    pub fn on_text_with_diagnostics<F, D>(&self, mut handler: F, mut diagnostics: D) -> Subscription
131    where
132        F: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextMessage<'_>) + 'static,
133        D: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextReceiveIssue) + 'static,
134    {
135        self.peer.on_receive(move |packet| {
136            if packet.payload_type() != PayloadType::TextMessage {
137                diagnostics(
138                    packet,
139                    TextReceiveIssue::WrongPayloadType(packet.payload_type()),
140                );
141                return false;
142            }
143            let message = match parse_text_message(packet.payload()) {
144                Ok(message) => message,
145                Err(error) => {
146                    diagnostics(packet, TextReceiveIssue::Parse(error));
147                    return false;
148                }
149            };
150            handler(packet, message);
151            true
152        })
153    }
154}
155
156#[cfg(feature = "software-crypto")]
157#[derive(Clone)]
158pub struct MulticastTextChatWrapper<M: MacBackend> {
159    channel: BoundChannel<M>,
160}
161
162#[cfg(feature = "software-crypto")]
163impl<M: MacBackend> MulticastTextChatWrapper<M> {
164    pub fn new(channel: BoundChannel<M>) -> Self {
165        Self { channel }
166    }
167
168    pub fn from_channel(channel: &BoundChannel<M>) -> Self {
169        Self {
170            channel: channel.clone(),
171        }
172    }
173
174    pub fn bound_channel(&self) -> &BoundChannel<M> {
175        &self.channel
176    }
177
178    pub async fn send_message(
179        &self,
180        message: &TextMessage<'_>,
181        options: &SendOptions,
182    ) -> Result<SendProgressTicket, TextSendError<umsh_node::NodeError<M>>> {
183        let payload = encode_text_payload(message)?;
184        self.channel
185            .send_all(&payload, options)
186            .await
187            .map_err(TextSendError::Transport)
188    }
189
190    pub async fn send_owned_message(
191        &self,
192        message: &OwnedTextMessage,
193        options: &SendOptions,
194    ) -> Result<SendProgressTicket, TextSendError<umsh_node::NodeError<M>>> {
195        self.send_message(&message.as_borrowed(), options).await
196    }
197
198    pub async fn send_text(
199        &self,
200        body: &str,
201        options: &SendOptions,
202    ) -> Result<SendProgressTicket, TextSendError<umsh_node::NodeError<M>>> {
203        let message = TextMessage {
204            message_type: MessageType::Basic,
205            sender_handle: None,
206            sequence: None,
207            sequence_reset: false,
208            regarding: None,
209            editing: None,
210            bg_color: None,
211            text_color: None,
212            body,
213        };
214        self.send_message(&message, options).await
215    }
216
217    pub fn on_text<F>(&self, handler: F) -> Subscription
218    where
219        F: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextMessage<'_>) + 'static,
220    {
221        self.on_text_with_diagnostics(handler, |_, _| {})
222    }
223
224    pub fn on_text_with_diagnostics<F, D>(&self, mut handler: F, mut diagnostics: D) -> Subscription
225    where
226        F: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextMessage<'_>) + 'static,
227        D: FnMut(&umsh_node::ReceivedPacketRef<'_>, TextReceiveIssue) + 'static,
228    {
229        let channel_id = *self.channel.channel().channel_id();
230        self.channel
231            .node()
232            .on_receive(move |packet: &umsh_node::ReceivedPacketRef<'_>| {
233                let Some(channel) = packet.channel() else {
234                    return false;
235                };
236                if channel.id() != channel_id {
237                    return false;
238                }
239                if packet.payload_type() != PayloadType::TextMessage {
240                    diagnostics(
241                        packet,
242                        TextReceiveIssue::WrongPayloadType(packet.payload_type()),
243                    );
244                    return false;
245                }
246                let message = match parse_text_message(packet.payload()) {
247                    Ok(message) => message,
248                    Err(error) => {
249                        diagnostics(packet, TextReceiveIssue::Parse(error));
250                        return false;
251                    }
252                };
253                handler(packet, message);
254                true
255            })
256    }
257}
258
259fn encode_text_payload(message: &TextMessage<'_>) -> Result<Vec<u8>, EncodeError> {
260    let mut body = [0u8; 512];
261    let len = text_message::encode(message, &mut body)?;
262    let mut payload = Vec::with_capacity(len + 1);
263    payload.push(PayloadType::TextMessage as u8);
264    payload.extend_from_slice(&body[..len]);
265    Ok(payload)
266}
267
268/// Split a typed application payload into its type byte and body.
269fn split_payload_type(payload: &[u8]) -> Result<(PayloadType, &[u8]), ParseError> {
270    if payload.is_empty() {
271        return Ok((PayloadType::Empty, &[]));
272    }
273    if let Some(payload_type) = PayloadType::from_byte(payload[0]) {
274        Ok((payload_type, &payload[1..]))
275    } else {
276        Ok((PayloadType::Empty, payload))
277    }
278}
279
280/// Validate that a packet carries the expected typed text payload.
281fn expect_payload_type(
282    packet_type: PacketType,
283    payload: &[u8],
284    expected: PayloadType,
285) -> Result<&[u8], ParseError> {
286    let (payload_type, body) = split_payload_type(payload)?;
287    if !payload_type.allowed_for(packet_type) {
288        return Err(ParseError::PayloadTypeNotAllowed {
289            payload_type: payload_type as u8,
290            packet_type,
291        });
292    }
293    if payload_type != expected {
294        return Err(ParseError::InvalidPayloadType(payload_type as u8));
295    }
296    Ok(body)
297}