darkfi/net/
message_publisher.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{any::Any, collections::HashMap, sync::Arc, time::Duration};
20
21use async_trait::async_trait;
22use futures::stream::{FuturesUnordered, StreamExt};
23use log::{debug, error};
24use rand::{rngs::OsRng, Rng};
25use smol::{io::AsyncReadExt, lock::Mutex};
26
27use super::message::Message;
28use crate::{
29    net::{metering::MeteringQueue, transport::PtStream},
30    system::{msleep, timeout::timeout},
31    Error, Result,
32};
33use darkfi_serial::{AsyncDecodable, VarInt};
34
35/// 64-bit identifier for message subscription.
36pub type MessageSubscriptionId = u64;
37type MessageResult<M> = Result<Arc<M>>;
38
39/// Dispatcher subscriptions HashMap type.
40type DispatcherSubscriptionsMap<M> =
41    Mutex<HashMap<MessageSubscriptionId, smol::channel::Sender<(MessageResult<M>, Option<u64>)>>>;
42
43/// A dispatcher that is unique to every [`Message`].
44///
45/// Maintains a list of subscriptions to a unique Message
46/// type and handles sending messages across these
47/// subscriptions.
48///
49/// Additionally, holds a `MeteringQueue` using the
50/// [`Message`] configuration to perform rate limiting
51/// of propagation towards the subscriptions.
52#[derive(Debug)]
53struct MessageDispatcher<M: Message> {
54    subs: DispatcherSubscriptionsMap<M>,
55    metering_queue: Mutex<MeteringQueue>,
56}
57
58impl<M: Message> MessageDispatcher<M> {
59    /// Create a new message dispatcher
60    fn new() -> Self {
61        Self {
62            subs: Mutex::new(HashMap::new()),
63            metering_queue: Mutex::new(MeteringQueue::new(M::METERING_CONFIGURATION)),
64        }
65    }
66
67    /// Create a random ID.
68    fn random_id() -> MessageSubscriptionId {
69        OsRng.gen()
70    }
71
72    /// Subscribe to a channel.
73    /// Assigns a new ID and adds it to the list of subscriptions.
74    pub async fn subscribe(self: Arc<Self>) -> MessageSubscription<M> {
75        let (sender, recv_queue) = smol::channel::unbounded();
76        // Guard against overwriting
77        let mut id = Self::random_id();
78        let mut subs = self.subs.lock().await;
79        loop {
80            if subs.contains_key(&id) {
81                id = Self::random_id();
82                continue
83            }
84
85            subs.insert(id, sender);
86            break
87        }
88
89        drop(subs);
90        MessageSubscription { id, recv_queue, parent: self }
91    }
92
93    /// Unsubscribe from a channel.
94    /// Removes the associated ID from the subscriber list.
95    async fn unsubscribe(&self, sub_id: MessageSubscriptionId) {
96        self.subs.lock().await.remove(&sub_id);
97    }
98
99    /// Private function to concurrently transmit a message to all subscriber channels.
100    /// Automatically clear all inactive channels. Strictly used internally.
101    async fn _trigger_all(&self, message: MessageResult<M>) {
102        let mut subs = self.subs.lock().await;
103
104        let msg_result_type = if message.is_ok() { "Ok" } else { "Err" };
105        debug!(
106            target: "net::message_publisher::_trigger_all()", "START msg={}({}), subs={}",
107            msg_result_type,
108            M::NAME, subs.len(),
109        );
110
111        // Insert metering information and grab potential sleep time
112        let mut queue = self.metering_queue.lock().await;
113        queue.push(&M::METERING_SCORE);
114        let sleep_time = queue.sleep_time();
115        drop(queue);
116
117        let mut futures = FuturesUnordered::new();
118        let mut garbage_ids = vec![];
119
120        // Prep the futures for concurrent execution
121        for (sub_id, sub) in &*subs {
122            let sub_id = *sub_id;
123            let sub = sub.clone();
124            let message = message.clone();
125            futures.push(async move {
126                match sub.send((message, sleep_time)).await {
127                    Ok(res) => Ok((sub_id, res)),
128                    Err(err) => Err((sub_id, err)),
129                }
130            });
131        }
132
133        // Start polling
134        while let Some(r) = futures.next().await {
135            if let Err((sub_id, _err)) = r {
136                garbage_ids.push(sub_id);
137            }
138        }
139
140        // Garbage cleanup
141        for sub_id in garbage_ids {
142            subs.remove(&sub_id);
143        }
144
145        debug!(
146            target: "net::message_publisher::_trigger_all()", "END msg={}({}), subs={}",
147            msg_result_type,
148            M::NAME, subs.len(),
149        );
150    }
151}
152
153/// Handles message subscriptions through a subscription ID and
154/// a receiver channel.
155#[derive(Debug)]
156pub struct MessageSubscription<M: Message> {
157    id: MessageSubscriptionId,
158    recv_queue: smol::channel::Receiver<(MessageResult<M>, Option<u64>)>,
159    parent: Arc<MessageDispatcher<M>>,
160}
161
162impl<M: Message> MessageSubscription<M> {
163    /// Start receiving messages.
164    /// Sender also provides with a sleep time,
165    /// in case rate limit has started.
166    pub async fn receive(&self) -> MessageResult<M> {
167        let (message, sleep_time) = match self.recv_queue.recv().await {
168            Ok(pair) => pair,
169            Err(e) => panic!("MessageSubscription::receive(): recv_queue failed! {}", e),
170        };
171
172        // Check if we need to sleep
173        if message.is_ok() {
174            if let Some(sleep_time) = sleep_time {
175                msleep(sleep_time).await;
176            }
177        }
178
179        message
180    }
181
182    /// Start receiving messages with timeout.
183    pub async fn receive_with_timeout(&self, seconds: u64) -> MessageResult<M> {
184        let dur = Duration::from_secs(seconds);
185        let Ok(res) = timeout(dur, self.recv_queue.recv()).await else {
186            return Err(Error::ConnectTimeout)
187        };
188
189        let (message, sleep_time) = match res {
190            Ok(pair) => pair,
191            Err(e) => {
192                panic!("MessageSubscription::receive_with_timeout(): recv_queue failed! {}", e)
193            }
194        };
195
196        // Check if we need to sleep
197        if message.is_ok() {
198            if let Some(sleep_time) = sleep_time {
199                msleep(sleep_time).await;
200            }
201        }
202
203        message
204    }
205
206    /// Cleans existing items from the receiver channel.
207    pub async fn clean(&self) -> Result<()> {
208        loop {
209            match self.recv_queue.try_recv() {
210                Ok(_) => continue,
211                Err(smol::channel::TryRecvError::Empty) => return Ok(()),
212                Err(e) => panic!("MessageSubscription::receive(): recv_queue failed! {}", e),
213            }
214        }
215    }
216
217    /// Unsubscribe from a message subscription. Must be called manually.
218    pub async fn unsubscribe(&self) {
219        self.parent.unsubscribe(self.id).await
220    }
221}
222
223/// Generic interface for the message dispatcher.
224#[async_trait]
225trait MessageDispatcherInterface: Send + Sync {
226    async fn trigger(
227        &self,
228        stream: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
229    ) -> Result<()>;
230
231    async fn trigger_error(&self, err: Error);
232
233    async fn metering_score(&self) -> u64;
234
235    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
236}
237
238/// Local implementation of the Message Dispatcher Interface
239#[async_trait]
240impl<M: Message> MessageDispatcherInterface for MessageDispatcher<M> {
241    /// Internal function to deserialize data into a message type
242    /// and dispatch it across subscriber channels. Reads directly
243    /// from an inbound stream.
244    ///
245    /// We extract the message length from the stream and use `take()`
246    /// to allocate an appropiately sized buffer as a basic DDOS protection.
247    async fn trigger(
248        &self,
249        stream: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
250    ) -> Result<()> {
251        // Parse message length
252        let length = match VarInt::decode_async(stream).await {
253            Ok(int) => int.0,
254            Err(err) => {
255                error!(
256                    target: "net::message_publisher::trigger()",
257                    "Unable to decode VarInt. Dropping...: {}",
258                    err,
259                );
260                return Err(Error::MessageInvalid)
261            }
262        };
263
264        // Check the message length does not exceed set limit
265        if M::MAX_BYTES > 0 && length > M::MAX_BYTES {
266            error!(
267                target: "net::message_publisher::trigger()",
268                "Message length ({}) exceeds configured limit ({}). Dropping...",
269                length, M::MAX_BYTES,
270            );
271            return Err(Error::MessageInvalid)
272        }
273
274        // Deserialize stream into type
275        let mut take = stream.take(length);
276        let message = match M::decode_async(&mut take).await {
277            Ok(payload) => Ok(Arc::new(payload)),
278            Err(err) => {
279                error!(
280                    target: "net::message_publisher::trigger()",
281                    "Unable to decode data. Dropping...: {}",
282                    err,
283                );
284                return Err(Error::MessageInvalid)
285            }
286        };
287
288        // Send down the pipes
289        self._trigger_all(message).await;
290        Ok(())
291    }
292
293    /// Internal function that sends an error message to all subscriber channels.
294    async fn trigger_error(&self, err: Error) {
295        self._trigger_all(Err(err)).await;
296    }
297
298    /// Internal function to retrieve metering queue current total score,
299    /// after prunning expired metering information.
300    async fn metering_score(&self) -> u64 {
301        let mut lock = self.metering_queue.lock().await;
302        lock.clean();
303        lock.total()
304    }
305
306    /// Converts to `Any` trait. Enables the dynamic modification of static types.
307    fn as_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
308        self
309    }
310}
311
312/// Generic publish/subscribe class that maintains a list of dispatchers.
313///
314/// Dispatchers transmit messages to subscribers and are specific to one
315/// message type.
316///
317/// Additionally, holds a global metering limit, which is the sum of each
318/// dispatcher `MeteringQueue` threshold, to drop the connection if passed.
319#[derive(Default)]
320pub struct MessageSubsystem {
321    dispatchers: Mutex<HashMap<&'static str, Arc<dyn MessageDispatcherInterface>>>,
322    metering_limit: Mutex<u64>,
323}
324
325impl MessageSubsystem {
326    /// Create a new message subsystem.
327    pub fn new() -> Self {
328        Self { dispatchers: Mutex::new(HashMap::new()), metering_limit: Mutex::new(0) }
329    }
330
331    /// Add a new dispatcher for specified [`Message`].
332    pub async fn add_dispatch<M: Message>(&self) {
333        // First lock the dispatchers
334        let mut lock = self.dispatchers.lock().await;
335
336        // Update the metering limit
337        *self.metering_limit.lock().await += M::METERING_CONFIGURATION.threshold;
338
339        // Insert the new dispatcher
340        lock.insert(M::NAME, Arc::new(MessageDispatcher::<M>::new()));
341    }
342
343    /// Subscribes to a [`Message`]. Using the Message name, the method
344    /// returns the associated `MessageDispatcher` from the list of
345    /// dispatchers and calls `subscribe()`.
346    pub async fn subscribe<M: Message>(&self) -> Result<MessageSubscription<M>> {
347        let dispatcher = self.dispatchers.lock().await.get(M::NAME).cloned();
348
349        let sub = match dispatcher {
350            Some(dispatcher) => {
351                let dispatcher: Arc<MessageDispatcher<M>> = dispatcher
352                    .as_any()
353                    .downcast::<MessageDispatcher<M>>()
354                    .expect("Multiple messages registered with different names");
355
356                dispatcher.subscribe().await
357            }
358
359            None => {
360                // Normal return failure here
361                return Err(Error::NetworkOperationFailed)
362            }
363        };
364
365        Ok(sub)
366    }
367
368    /// Transmits a payload to a dispatcher.
369    /// Returns an error if the payload fails to transmit.
370    pub async fn notify(
371        &self,
372        command: &str,
373        reader: &mut smol::io::ReadHalf<Box<dyn PtStream + 'static>>,
374    ) -> Result<()> {
375        // Iterate over dispatchers and keep track of their current
376        // metering score
377        let mut found = false;
378        let mut total_score = 0;
379        for (name, dispatcher) in self.dispatchers.lock().await.iter() {
380            // If dispatcher is the command one, trasmit the message
381            if name == &command {
382                dispatcher.trigger(reader).await?;
383                found = true;
384            }
385
386            // Grab its total score
387            total_score += dispatcher.metering_score().await;
388        }
389
390        // Check if dispatcher was found
391        if !found {
392            return Err(Error::MissingDispatcher)
393        }
394
395        // Check if we are over the global metering limit
396        if total_score > *self.metering_limit.lock().await {
397            return Err(Error::MeteringLimitExceeded)
398        }
399
400        Ok(())
401    }
402
403    /// Concurrently transmits an error message across dispatchers.
404    pub async fn trigger_error(&self, err: Error) {
405        let mut futures = FuturesUnordered::new();
406
407        let dispatchers = self.dispatchers.lock().await;
408
409        for dispatcher in dispatchers.values() {
410            let dispatcher = dispatcher.clone();
411            let error = err.clone();
412            futures.push(async move { dispatcher.trigger_error(error).await });
413        }
414
415        drop(dispatchers);
416
417        while let Some(_r) = futures.next().await {}
418    }
419}