1use std::{io, sync::Arc};
20
21use futures_rustls::{
22 rustls::{
23 self,
24 client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
25 pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime},
26 server::danger::{ClientCertVerified, ClientCertVerifier},
27 version::TLS13,
28 ClientConfig, DigitallySignedStruct, DistinguishedName, ServerConfig, SignatureScheme,
29 },
30 TlsAcceptor, TlsConnector, TlsStream,
31};
32use log::error;
33use rustls_pemfile::pkcs8_private_keys;
34use x509_parser::{
35 parse_x509_certificate,
36 prelude::{GeneralName, ParsedExtension, X509Certificate},
37};
38
39fn validate_dnsname(cert: &X509Certificate) -> std::result::Result<(), rustls::Error> {
41 #[rustfmt::skip]
42 let oid = x509_parser::oid_registry::asn1_rs::oid!(2.5.29.17);
43 let Ok(Some(extension)) = cert.get_extension_unique(&oid) else {
44 return Err(rustls::CertificateError::BadEncoding.into())
45 };
46
47 let dns_name = match extension.parsed_extension() {
48 ParsedExtension::SubjectAlternativeName(altname) => {
49 if altname.general_names.len() != 1 {
50 return Err(rustls::CertificateError::BadEncoding.into())
51 }
52
53 match altname.general_names[0] {
54 GeneralName::DNSName(dns_name) => dns_name,
55 _ => return Err(rustls::CertificateError::BadEncoding.into()),
56 }
57 }
58
59 _ => return Err(rustls::CertificateError::BadEncoding.into()),
60 };
61
62 if dns_name != "dark.fi" {
63 return Err(rustls::CertificateError::BadEncoding.into())
64 }
65
66 Ok(())
67}
68
69#[derive(Debug)]
70struct ServerCertificateVerifier;
71impl ServerCertVerifier for ServerCertificateVerifier {
72 fn verify_server_cert(
73 &self,
74 end_entity: &CertificateDer,
75 _intermediates: &[CertificateDer],
76 _server_name: &ServerName,
77 _ocsp_response: &[u8],
78 _now: UnixTime,
79 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
80 let mut buf = Vec::with_capacity(end_entity.len());
82 for byte in end_entity.iter() {
83 buf.push(*byte);
84 }
85
86 let Ok((_, cert)) = parse_x509_certificate(&buf) else {
88 error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
89 return Err(rustls::CertificateError::BadEncoding.into())
90 };
91
92 validate_dnsname(&cert)?;
94
95 Ok(ServerCertVerified::assertion())
96 }
97
98 fn verify_tls12_signature(
99 &self,
100 _message: &[u8],
101 _cert: &CertificateDer,
102 _dss: &DigitallySignedStruct,
103 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
104 unreachable!()
105 }
106
107 fn verify_tls13_signature(
108 &self,
109 message: &[u8],
110 cert: &CertificateDer,
111 dss: &DigitallySignedStruct,
112 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
113 if dss.scheme != SignatureScheme::ED25519 {
115 return Err(rustls::CertificateError::BadSignature.into())
116 }
117
118 let mut buf = Vec::with_capacity(cert.len());
120 for byte in cert.iter() {
121 buf.push(*byte);
122 }
123
124 let Ok((_, cert)) = parse_x509_certificate(&buf) else {
126 error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server TLS certificate");
127 return Err(rustls::CertificateError::BadEncoding.into())
128 };
129
130 let Ok(public_key) = ed25519_compact::PublicKey::from_der(cert.public_key().raw) else {
131 error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server public key");
132 return Err(rustls::CertificateError::BadEncoding.into())
133 };
134
135 let Ok(signature) = ed25519_compact::Signature::from_slice(dss.signature()) else {
137 error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature");
138 return Err(rustls::CertificateError::BadSignature.into())
139 };
140
141 if let Err(e) = public_key.verify(message, &signature) {
142 error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature: {e}");
143 return Err(rustls::CertificateError::BadSignature.into())
144 }
145
146 Ok(HandshakeSignatureValid::assertion())
147 }
148
149 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
150 vec![SignatureScheme::ED25519]
151 }
152}
153
154#[derive(Debug)]
155struct ClientCertificateVerifier;
156impl ClientCertVerifier for ClientCertificateVerifier {
157 fn offer_client_auth(&self) -> bool {
158 true
159 }
160
161 fn client_auth_mandatory(&self) -> bool {
162 true
163 }
164
165 fn root_hint_subjects(&self) -> &[DistinguishedName] {
166 &[]
167 }
168
169 fn verify_client_cert(
170 &self,
171 end_entity: &CertificateDer,
172 _intermediates: &[CertificateDer],
173 _now: UnixTime,
174 ) -> std::result::Result<ClientCertVerified, rustls::Error> {
175 let mut cert = Vec::with_capacity(end_entity.len());
177 for byte in end_entity.iter() {
178 cert.push(*byte);
179 }
180
181 let Ok((_, cert)) = parse_x509_certificate(&cert) else {
183 error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
184 return Err(rustls::CertificateError::BadEncoding.into())
185 };
186
187 validate_dnsname(&cert)?;
189
190 Ok(ClientCertVerified::assertion())
191 }
192
193 fn verify_tls12_signature(
194 &self,
195 _message: &[u8],
196 _cert: &CertificateDer,
197 _dss: &DigitallySignedStruct,
198 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
199 unreachable!()
200 }
201
202 fn verify_tls13_signature(
203 &self,
204 message: &[u8],
205 cert: &CertificateDer,
206 dss: &DigitallySignedStruct,
207 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
208 if dss.scheme != SignatureScheme::ED25519 {
210 return Err(rustls::CertificateError::BadSignature.into())
211 }
212
213 let mut buf = Vec::with_capacity(cert.len());
215 for byte in cert.iter() {
216 buf.push(*byte);
217 }
218
219 let Ok((_, cert)) = parse_x509_certificate(&buf) else {
221 error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server TLS certificate");
222 return Err(rustls::CertificateError::BadEncoding.into())
223 };
224
225 let Ok(public_key) = ed25519_compact::PublicKey::from_der(cert.public_key().raw) else {
226 error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server public key");
227 return Err(rustls::CertificateError::BadEncoding.into())
228 };
229
230 let Ok(signature) = ed25519_compact::Signature::from_slice(dss.signature()) else {
232 error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature");
233 return Err(rustls::CertificateError::BadSignature.into())
234 };
235
236 if let Err(e) = public_key.verify(message, &signature) {
237 error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature: {e}");
238 return Err(rustls::CertificateError::BadSignature.into())
239 }
240
241 Ok(HandshakeSignatureValid::assertion())
242 }
243
244 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
245 vec![SignatureScheme::ED25519]
246 }
247}
248
249pub struct TlsUpgrade {
250 server_config: Arc<ServerConfig>,
252 client_config: Arc<ClientConfig>,
254}
255
256impl TlsUpgrade {
257 pub async fn new() -> Self {
258 let keypair_pem = ed25519_compact::KeyPair::generate().to_pem();
260 let secret_key = pkcs8_private_keys(&mut keypair_pem.as_bytes()).next().unwrap().unwrap();
261 let secret_key = PrivateKeyDer::Pkcs8(secret_key);
262
263 let mut cert_params = rcgen::CertificateParams::new(&[]);
264 cert_params.alg = &rcgen::PKCS_ED25519;
265 cert_params.key_pair = Some(rcgen::KeyPair::from_pem(&keypair_pem).unwrap());
266 cert_params.subject_alt_names = vec![rcgen::SanType::DnsName("dark.fi".to_string())];
267 cert_params.extended_key_usages = vec![
268 rcgen::ExtendedKeyUsagePurpose::ClientAuth,
269 rcgen::ExtendedKeyUsagePurpose::ServerAuth,
270 ];
271
272 let certificate = rcgen::Certificate::from_params(cert_params).unwrap();
273 let certificate = certificate.serialize_der().unwrap();
274
275 let client_cert_verifier = Arc::new(ClientCertificateVerifier {});
277 let server_config = Arc::new(
278 ServerConfig::builder_with_protocol_versions(&[&TLS13])
279 .with_client_cert_verifier(client_cert_verifier)
280 .with_single_cert(vec![certificate.clone().into()], secret_key.clone_key())
281 .unwrap(),
282 );
283
284 let server_cert_verifier = Arc::new(ServerCertificateVerifier {});
286 let client_config = Arc::new(
287 ClientConfig::builder_with_protocol_versions(&[&TLS13])
288 .dangerous()
289 .with_custom_certificate_verifier(server_cert_verifier)
290 .with_client_auth_cert(vec![certificate.into()], secret_key)
291 .unwrap(),
292 );
293
294 Self { server_config, client_config }
295 }
296
297 pub async fn upgrade_dialer_tls<IO>(self, stream: IO) -> io::Result<TlsStream<IO>>
298 where
299 IO: super::PtStream,
300 {
301 let server_name = ServerName::try_from("dark.fi").unwrap();
302 let connector = TlsConnector::from(self.client_config);
303 let stream = connector.connect(server_name, stream).await?;
304 Ok(TlsStream::Client(stream))
305 }
306
307 pub async fn upgrade_listener_tcp_tls(
310 self,
311 listener: smol::net::TcpListener,
312 ) -> io::Result<(TlsAcceptor, smol::net::TcpListener)> {
313 Ok((TlsAcceptor::from(self.server_config), listener))
314 }
315}