darkfi/net/
message_publisher.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{any::Any, collections::HashMap, sync::Arc, time::Duration};
20
21use async_trait::async_trait;
22use futures::stream::{FuturesUnordered, StreamExt};
23use log::{debug, error};
24use rand::{rngs::OsRng, Rng};
25use smol::{io::AsyncReadExt, lock::Mutex};
26
27use super::message::Message;
28use crate::{
29    net::{metering::MeteringQueue, transport::PtStream},
30    system::{msleep, timeout::timeout},
31    Error, Result,
32};
33use darkfi_serial::{AsyncDecodable, VarInt};
34
35/// 64-bit identifier for message subscription.
36pub type MessageSubscriptionId = u64;
37type MessageResult<M> = Result<Arc<M>>;
38
39/// Dispatcher subscriptions HashMap type.
40type DispatcherSubscriptionsMap<M> =
41    Mutex<HashMap<MessageSubscriptionId, smol::channel::Sender<(MessageResult<M>, Option<u64>)>>>;
42
43/// A dispatcher that is unique to every [`Message`].
44///
45/// Maintains a list of subscriptions to a unique Message
46/// type and handles sending messages across these
47/// subscriptions.
48///
49/// Additionally, holds a `MeteringQueue` using the
50/// [`Message`] configuration to perform rate limiting
51/// of propagation towards the subscriptions.
52#[derive(Debug)]
53struct MessageDispatcher<M: Message> {
54    subs: DispatcherSubscriptionsMap<M>,
55    metering_queue: Mutex<MeteringQueue>,
56}
57
58impl<M: Message> MessageDispatcher<M> {
59    /// Create a new message dispatcher
60    fn new() -> Self {
61        Self {
62            subs: Mutex::new(HashMap::new()),
63            metering_queue: Mutex::new(MeteringQueue::new(M::METERING_CONFIGURATION)),
64        }
65    }
66
67    /// Create a random ID.
68    fn random_id() -> MessageSubscriptionId {
69        OsRng.gen()
70    }
71
72    /// Subscribe to a channel.
73    /// Assigns a new ID and adds it to the list of subscriptions.
74    pub async fn subscribe(self: Arc<Self>) -> MessageSubscription<M> {
75        let (sender, recv_queue) = smol::channel::unbounded();
76        // Guard against overwriting
77        let mut id = Self::random_id();
78        let mut subs = self.subs.lock().await;
79        loop {
80            if subs.contains_key(&id) {
81                id = Self::random_id();
82                continue
83            }
84
85            subs.insert(id, sender);
86            break
87        }
88
89        drop(subs);
90        MessageSubscription { id, recv_queue, parent: self }
91    }
92
93    /// Unsubscribe from a channel.
94    /// Removes the associated ID from the subscriber list.
95    async fn unsubscribe(&self, sub_id: MessageSubscriptionId) {
96        self.subs.lock().await.remove(&sub_id);
97    }
98
99    /// Private function to concurrently transmit a message to all subscriber channels.
100    /// Automatically clear all inactive channels. Strictly used internally.
101    async fn _trigger_all(&self, message: MessageResult<M>) {
102        let mut subs = self.subs.lock().await;
103
104        let msg_result_type = if message.is_ok() { "Ok" } else { "Err" };
105        debug!(
106            target: "net::message_publisher::_trigger_all()", "START msg={msg_result_type}({}), subs={}",
107            M::NAME, subs.len()
108        );
109
110        // Insert metering information and grab potential sleep time
111        let mut queue = self.metering_queue.lock().await;
112        queue.push(&M::METERING_SCORE);
113        let sleep_time = queue.sleep_time();
114        drop(queue);
115
116        let mut futures = FuturesUnordered::new();
117        let mut garbage_ids = vec![];
118
119        // Prep the futures for concurrent execution
120        for (sub_id, sub) in &*subs {
121            let sub_id = *sub_id;
122            let sub = sub.clone();
123            let message = message.clone();
124            futures.push(async move {
125                match sub.send((message, sleep_time)).await {
126                    Ok(res) => Ok((sub_id, res)),
127                    Err(err) => Err((sub_id, err)),
128                }
129            });
130        }
131
132        // Start polling
133        while let Some(r) = futures.next().await {
134            if let Err((sub_id, _err)) = r {
135                garbage_ids.push(sub_id);
136            }
137        }
138
139        // Garbage cleanup
140        for sub_id in garbage_ids {
141            subs.remove(&sub_id);
142        }
143
144        debug!(
145            target: "net::message_publisher::_trigger_all()", "END msg={msg_result_type}({}), subs={}",
146            M::NAME, subs.len(),
147        );
148    }
149}
150
151/// Handles message subscriptions through a subscription ID and
152/// a receiver channel.
153#[derive(Debug)]
154pub struct MessageSubscription<M: Message> {
155    id: MessageSubscriptionId,
156    recv_queue: smol::channel::Receiver<(MessageResult<M>, Option<u64>)>,
157    parent: Arc<MessageDispatcher<M>>,
158}
159
160impl<M: Message> MessageSubscription<M> {
161    /// Start receiving messages.
162    /// Sender also provides with a sleep time,
163    /// in case rate limit has started.
164    pub async fn receive(&self) -> MessageResult<M> {
165        let (message, sleep_time) = match self.recv_queue.recv().await {
166            Ok(pair) => pair,
167            Err(e) => panic!("MessageSubscription::receive(): recv_queue failed! {e}"),
168        };
169
170        // Check if we need to sleep
171        if message.is_ok() {
172            if let Some(sleep_time) = sleep_time {
173                msleep(sleep_time).await;
174            }
175        }
176
177        message
178    }
179
180    /// Start receiving messages with timeout.
181    pub async fn receive_with_timeout(&self, seconds: u64) -> MessageResult<M> {
182        let dur = Duration::from_secs(seconds);
183        let Ok(res) = timeout(dur, self.recv_queue.recv()).await else {
184            return Err(Error::ConnectTimeout)
185        };
186
187        let (message, sleep_time) = match res {
188            Ok(pair) => pair,
189            Err(e) => {
190                panic!("MessageSubscription::receive_with_timeout(): recv_queue failed! {e}")
191            }
192        };
193
194        // Check if we need to sleep
195        if message.is_ok() {
196            if let Some(sleep_time) = sleep_time {
197                msleep(sleep_time).await;
198            }
199        }
200
201        message
202    }
203
204    /// Cleans existing items from the receiver channel.
205    pub async fn clean(&self) -> Result<()> {
206        loop {
207            match self.recv_queue.try_recv() {
208                Ok(_) => continue,
209                Err(smol::channel::TryRecvError::Empty) => return Ok(()),
210                Err(e) => panic!("MessageSubscription::receive(): recv_queue failed! {e}"),
211            }
212        }
213    }
214
215    /// Unsubscribe from a message subscription. Must be called manually.
216    pub async fn unsubscribe(&self) {
217        self.parent.unsubscribe(self.id).await
218    }
219}
220
221/// Generic interface for the message dispatcher.
222#[async_trait]
223trait MessageDispatcherInterface: Send + Sync {
224    async fn trigger(
225        &self,
226        stream: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
227    ) -> Result<()>;
228
229    async fn trigger_error(&self, err: Error);
230
231    async fn metering_score(&self) -> u64;
232
233    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
234}
235
236/// Local implementation of the Message Dispatcher Interface
237#[async_trait]
238impl<M: Message> MessageDispatcherInterface for MessageDispatcher<M> {
239    /// Internal function to deserialize data into a message type
240    /// and dispatch it across subscriber channels. Reads directly
241    /// from an inbound stream.
242    ///
243    /// We extract the message length from the stream and use `take()`
244    /// to allocate an appropriately sized buffer as a basic DDOS protection.
245    async fn trigger(
246        &self,
247        stream: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
248    ) -> Result<()> {
249        // Parse message length
250        let length = match VarInt::decode_async(stream).await {
251            Ok(int) => int.0,
252            Err(err) => {
253                error!(
254                    target: "net::message_publisher::trigger()",
255                    "Unable to decode VarInt. Dropping...: {err}"
256                );
257                return Err(Error::MessageInvalid)
258            }
259        };
260
261        // Check the message length does not exceed set limit
262        if M::MAX_BYTES > 0 && length > M::MAX_BYTES {
263            error!(
264                target: "net::message_publisher::trigger()",
265                "Message length ({length}) exceeds configured limit ({}). Dropping...",
266                M::MAX_BYTES
267            );
268            return Err(Error::MessageInvalid)
269        }
270
271        // Deserialize stream into type
272        let mut take = stream.take(length);
273        let message = match M::decode_async(&mut take).await {
274            Ok(payload) => Ok(Arc::new(payload)),
275            Err(err) => {
276                error!(
277                    target: "net::message_publisher::trigger()",
278                    "Unable to decode data. Dropping...: {err}"
279                );
280                return Err(Error::MessageInvalid)
281            }
282        };
283
284        // Send down the pipes
285        self._trigger_all(message).await;
286        Ok(())
287    }
288
289    /// Internal function that sends an error message to all subscriber channels.
290    async fn trigger_error(&self, err: Error) {
291        self._trigger_all(Err(err)).await;
292    }
293
294    /// Internal function to retrieve metering queue current total score,
295    /// after prunning expired metering information.
296    async fn metering_score(&self) -> u64 {
297        let mut lock = self.metering_queue.lock().await;
298        lock.clean();
299        lock.total()
300    }
301
302    /// Converts to `Any` trait. Enables the dynamic modification of static types.
303    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
304        self
305    }
306}
307
308/// Generic publish/subscribe class that maintains a list of dispatchers.
309///
310/// Dispatchers transmit messages to subscribers and are specific to one
311/// message type.
312///
313/// Additionally, holds a global metering limit, which is the sum of each
314/// dispatcher `MeteringQueue` threshold, to drop the connection if passed.
315#[derive(Default)]
316pub struct MessageSubsystem {
317    dispatchers: Mutex<HashMap<&'static str, Arc<dyn MessageDispatcherInterface>>>,
318    metering_limit: Mutex<u64>,
319}
320
321impl MessageSubsystem {
322    /// Create a new message subsystem.
323    pub fn new() -> Self {
324        Self { dispatchers: Mutex::new(HashMap::new()), metering_limit: Mutex::new(0) }
325    }
326
327    /// Add a new dispatcher for specified [`Message`].
328    pub async fn add_dispatch<M: Message>(&self) {
329        // First lock the dispatchers
330        let mut lock = self.dispatchers.lock().await;
331
332        // Update the metering limit
333        *self.metering_limit.lock().await += M::METERING_CONFIGURATION.threshold;
334
335        // Insert the new dispatcher
336        lock.insert(M::NAME, Arc::new(MessageDispatcher::<M>::new()));
337    }
338
339    /// Subscribes to a [`Message`]. Using the Message name, the method
340    /// returns the associated `MessageDispatcher` from the list of
341    /// dispatchers and calls `subscribe()`.
342    pub async fn subscribe<M: Message>(&self) -> Result<MessageSubscription<M>> {
343        let dispatcher = self.dispatchers.lock().await.get(M::NAME).cloned();
344
345        let sub = match dispatcher {
346            Some(dispatcher) => {
347                let dispatcher: Arc<MessageDispatcher<M>> = dispatcher
348                    .as_any()
349                    .downcast::<MessageDispatcher<M>>()
350                    .expect("Multiple messages registered with different names");
351
352                dispatcher.subscribe().await
353            }
354
355            None => {
356                // Normal return failure here
357                return Err(Error::NetworkOperationFailed)
358            }
359        };
360
361        Ok(sub)
362    }
363
364    /// Transmits a payload to a dispatcher.
365    /// Returns an error if the payload fails to transmit.
366    pub async fn notify(
367        &self,
368        command: &str,
369        reader: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
370    ) -> Result<()> {
371        // Iterate over dispatchers and keep track of their current
372        // metering score
373        let mut found = false;
374        let mut total_score = 0;
375        for (name, dispatcher) in self.dispatchers.lock().await.iter() {
376            // If dispatcher is the command one, trasmit the message
377            if name == &command {
378                dispatcher.trigger(reader).await?;
379                found = true;
380            }
381
382            // Grab its total score
383            total_score += dispatcher.metering_score().await;
384        }
385
386        // Check if dispatcher was found
387        if !found {
388            return Err(Error::MissingDispatcher)
389        }
390
391        // Check if we are over the global metering limit
392        if total_score > *self.metering_limit.lock().await {
393            return Err(Error::MeteringLimitExceeded)
394        }
395
396        Ok(())
397    }
398
399    /// Concurrently transmits an error message across dispatchers.
400    pub async fn trigger_error(&self, err: Error) {
401        let mut futures = FuturesUnordered::new();
402
403        let dispatchers = self.dispatchers.lock().await;
404
405        for dispatcher in dispatchers.values() {
406            let dispatcher = dispatcher.clone();
407            let error = err.clone();
408            futures.push(async move { dispatcher.trigger_error(error).await });
409        }
410
411        drop(dispatchers);
412
413        while let Some(_r) = futures.next().await {}
414    }
415}