darkfi/net/transport/
tls.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2025 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{io, sync::Arc};
20
21use futures_rustls::{
22    rustls::{
23        self,
24        client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
25        pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime},
26        server::danger::{ClientCertVerified, ClientCertVerifier},
27        version::TLS13,
28        ClientConfig, DigitallySignedStruct, DistinguishedName, ServerConfig, SignatureScheme,
29    },
30    TlsAcceptor, TlsConnector, TlsStream,
31};
32use log::error;
33use rustls_pemfile::pkcs8_private_keys;
34use x509_parser::{
35    parse_x509_certificate,
36    prelude::{GeneralName, ParsedExtension, X509Certificate},
37};
38
39/// Validate certificate DNSName.
40fn validate_dnsname(cert: &X509Certificate) -> std::result::Result<(), rustls::Error> {
41    #[rustfmt::skip]
42        let oid = x509_parser::oid_registry::asn1_rs::oid!(2.5.29.17);
43    let Ok(Some(extension)) = cert.get_extension_unique(&oid) else {
44        return Err(rustls::CertificateError::BadEncoding.into())
45    };
46
47    let dns_name = match extension.parsed_extension() {
48        ParsedExtension::SubjectAlternativeName(altname) => {
49            if altname.general_names.len() != 1 {
50                return Err(rustls::CertificateError::BadEncoding.into())
51            }
52
53            match altname.general_names[0] {
54                GeneralName::DNSName(dns_name) => dns_name,
55                _ => return Err(rustls::CertificateError::BadEncoding.into()),
56            }
57        }
58
59        _ => return Err(rustls::CertificateError::BadEncoding.into()),
60    };
61
62    if dns_name != "dark.fi" {
63        return Err(rustls::CertificateError::BadEncoding.into())
64    }
65
66    Ok(())
67}
68
69#[derive(Debug)]
70struct ServerCertificateVerifier;
71impl ServerCertVerifier for ServerCertificateVerifier {
72    fn verify_server_cert(
73        &self,
74        end_entity: &CertificateDer,
75        _intermediates: &[CertificateDer],
76        _server_name: &ServerName,
77        _ocsp_response: &[u8],
78        _now: UnixTime,
79    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
80        // Read the DER-encoded certificate into a buffer
81        let mut buf = Vec::with_capacity(end_entity.len());
82        for byte in end_entity.iter() {
83            buf.push(*byte);
84        }
85
86        // Parse the certificate
87        let Ok((_, cert)) = parse_x509_certificate(&buf) else {
88            error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
89            return Err(rustls::CertificateError::BadEncoding.into())
90        };
91
92        // Validate DNSName
93        validate_dnsname(&cert)?;
94
95        Ok(ServerCertVerified::assertion())
96    }
97
98    fn verify_tls12_signature(
99        &self,
100        _message: &[u8],
101        _cert: &CertificateDer,
102        _dss: &DigitallySignedStruct,
103    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
104        unreachable!()
105    }
106
107    fn verify_tls13_signature(
108        &self,
109        message: &[u8],
110        cert: &CertificateDer,
111        dss: &DigitallySignedStruct,
112    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
113        // Verify we're using the correct signature scheme
114        if dss.scheme != SignatureScheme::ED25519 {
115            return Err(rustls::CertificateError::BadSignature.into())
116        }
117
118        // Read the DER-encoded certificate into a buffer
119        let mut buf = Vec::with_capacity(cert.len());
120        for byte in cert.iter() {
121            buf.push(*byte);
122        }
123
124        // Parse the certificate and extract the public key
125        let Ok((_, cert)) = parse_x509_certificate(&buf) else {
126            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server TLS certificate");
127            return Err(rustls::CertificateError::BadEncoding.into())
128        };
129
130        let Ok(public_key) = ed25519_compact::PublicKey::from_der(cert.public_key().raw) else {
131            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server public key");
132            return Err(rustls::CertificateError::BadEncoding.into())
133        };
134
135        // Verify the signature
136        let Ok(signature) = ed25519_compact::Signature::from_slice(dss.signature()) else {
137            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature");
138            return Err(rustls::CertificateError::BadSignature.into())
139        };
140
141        if let Err(e) = public_key.verify(message, &signature) {
142            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature: {e}");
143            return Err(rustls::CertificateError::BadSignature.into())
144        }
145
146        Ok(HandshakeSignatureValid::assertion())
147    }
148
149    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
150        vec![SignatureScheme::ED25519]
151    }
152}
153
154#[derive(Debug)]
155struct ClientCertificateVerifier;
156impl ClientCertVerifier for ClientCertificateVerifier {
157    fn offer_client_auth(&self) -> bool {
158        true
159    }
160
161    fn client_auth_mandatory(&self) -> bool {
162        true
163    }
164
165    fn root_hint_subjects(&self) -> &[DistinguishedName] {
166        &[]
167    }
168
169    fn verify_client_cert(
170        &self,
171        end_entity: &CertificateDer,
172        _intermediates: &[CertificateDer],
173        _now: UnixTime,
174    ) -> std::result::Result<ClientCertVerified, rustls::Error> {
175        // Read the DER-encoded certificate into a buffer
176        let mut cert = Vec::with_capacity(end_entity.len());
177        for byte in end_entity.iter() {
178            cert.push(*byte);
179        }
180
181        // Parse the certificate
182        let Ok((_, cert)) = parse_x509_certificate(&cert) else {
183            error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
184            return Err(rustls::CertificateError::BadEncoding.into())
185        };
186
187        // Validate DNSName
188        validate_dnsname(&cert)?;
189
190        Ok(ClientCertVerified::assertion())
191    }
192
193    fn verify_tls12_signature(
194        &self,
195        _message: &[u8],
196        _cert: &CertificateDer,
197        _dss: &DigitallySignedStruct,
198    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
199        unreachable!()
200    }
201
202    fn verify_tls13_signature(
203        &self,
204        message: &[u8],
205        cert: &CertificateDer,
206        dss: &DigitallySignedStruct,
207    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
208        // Verify we're using the correct signature scheme
209        if dss.scheme != SignatureScheme::ED25519 {
210            return Err(rustls::CertificateError::BadSignature.into())
211        }
212
213        // Read the DER-encoded certificate into a buffer
214        let mut buf = Vec::with_capacity(cert.len());
215        for byte in cert.iter() {
216            buf.push(*byte);
217        }
218
219        // Parse the certificate and extract the public key
220        let Ok((_, cert)) = parse_x509_certificate(&buf) else {
221            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server TLS certificate");
222            return Err(rustls::CertificateError::BadEncoding.into())
223        };
224
225        let Ok(public_key) = ed25519_compact::PublicKey::from_der(cert.public_key().raw) else {
226            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server public key");
227            return Err(rustls::CertificateError::BadEncoding.into())
228        };
229
230        // Verify the signature
231        let Ok(signature) = ed25519_compact::Signature::from_slice(dss.signature()) else {
232            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature");
233            return Err(rustls::CertificateError::BadSignature.into())
234        };
235
236        if let Err(e) = public_key.verify(message, &signature) {
237            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature: {e}");
238            return Err(rustls::CertificateError::BadSignature.into())
239        }
240
241        Ok(HandshakeSignatureValid::assertion())
242    }
243
244    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
245        vec![SignatureScheme::ED25519]
246    }
247}
248
249pub struct TlsUpgrade {
250    /// TLS server configuration
251    server_config: Arc<ServerConfig>,
252    /// TLS client configuration
253    client_config: Arc<ClientConfig>,
254}
255
256impl TlsUpgrade {
257    pub async fn new() -> Self {
258        // On each instantiation, generate a new keypair and certificate
259        let keypair_pem = ed25519_compact::KeyPair::generate().to_pem();
260        let secret_key = pkcs8_private_keys(&mut keypair_pem.as_bytes()).next().unwrap().unwrap();
261        let secret_key = PrivateKeyDer::Pkcs8(secret_key);
262
263        let mut cert_params = rcgen::CertificateParams::new(&[]);
264        cert_params.alg = &rcgen::PKCS_ED25519;
265        cert_params.key_pair = Some(rcgen::KeyPair::from_pem(&keypair_pem).unwrap());
266        cert_params.subject_alt_names = vec![rcgen::SanType::DnsName("dark.fi".to_string())];
267        cert_params.extended_key_usages = vec![
268            rcgen::ExtendedKeyUsagePurpose::ClientAuth,
269            rcgen::ExtendedKeyUsagePurpose::ServerAuth,
270        ];
271
272        let certificate = rcgen::Certificate::from_params(cert_params).unwrap();
273        let certificate = certificate.serialize_der().unwrap();
274
275        // Server-side config
276        let client_cert_verifier = Arc::new(ClientCertificateVerifier {});
277        let server_config = Arc::new(
278            ServerConfig::builder_with_protocol_versions(&[&TLS13])
279                .with_client_cert_verifier(client_cert_verifier)
280                .with_single_cert(vec![certificate.clone().into()], secret_key.clone_key())
281                .unwrap(),
282        );
283
284        // Client-side config
285        let server_cert_verifier = Arc::new(ServerCertificateVerifier {});
286        let client_config = Arc::new(
287            ClientConfig::builder_with_protocol_versions(&[&TLS13])
288                .dangerous()
289                .with_custom_certificate_verifier(server_cert_verifier)
290                .with_client_auth_cert(vec![certificate.into()], secret_key)
291                .unwrap(),
292        );
293
294        Self { server_config, client_config }
295    }
296
297    pub async fn upgrade_dialer_tls<IO>(self, stream: IO) -> io::Result<TlsStream<IO>>
298    where
299        IO: super::PtStream,
300    {
301        let server_name = ServerName::try_from("dark.fi").unwrap();
302        let connector = TlsConnector::from(self.client_config);
303        let stream = connector.connect(server_name, stream).await?;
304        Ok(TlsStream::Client(stream))
305    }
306
307    // TODO: Try to find a transparent way for this instead of implementing
308    // the function separately for every transport type.
309    pub async fn upgrade_listener_tcp_tls(
310        self,
311        listener: smol::net::TcpListener,
312    ) -> io::Result<(TlsAcceptor, smol::net::TcpListener)> {
313        Ok((TlsAcceptor::from(self.server_config), listener))
314    }
315}