tokio_postgres_rustls/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![deny(missing_docs, unsafe_code)]
4#![warn(clippy::all, clippy::pedantic)]
5
6use std::{convert::TryFrom, sync::Arc};
7
8use rustls::{pki_types::ServerName, ClientConfig};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_postgres::tls::MakeTlsConnect;
11
12mod private {
13    use std::{
14        future::Future,
15        io,
16        pin::Pin,
17        task::{Context, Poll},
18    };
19
20    use const_oid::db::{
21        rfc5912::{
22            ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
23            SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
24            SHA_512_WITH_RSA_ENCRYPTION,
25        },
26        rfc8410::ID_ED_25519,
27    };
28    use ring::digest;
29    use rustls::pki_types::ServerName;
30    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
31    use tokio_postgres::tls::{ChannelBinding, TlsConnect};
32    use tokio_rustls::{client::TlsStream, TlsConnector};
33    use x509_cert::{der::Decode, TbsCertificate};
34
35    pub struct TlsConnectFuture<S> {
36        pub inner: tokio_rustls::Connect<S>,
37    }
38
39    impl<S> Future for TlsConnectFuture<S>
40    where
41        S: AsyncRead + AsyncWrite + Unpin,
42    {
43        type Output = io::Result<RustlsStream<S>>;
44
45        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
46            // SAFETY: If `self` is pinned, so is `inner`.
47            #[allow(unsafe_code)]
48            let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) };
49            fut.poll(cx).map_ok(RustlsStream)
50        }
51    }
52
53    pub struct RustlsConnect(pub RustlsConnectData);
54
55    pub struct RustlsConnectData {
56        pub hostname: ServerName<'static>,
57        pub connector: TlsConnector,
58    }
59
60    impl<S> TlsConnect<S> for RustlsConnect
61    where
62        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
63    {
64        type Stream = RustlsStream<S>;
65        type Error = io::Error;
66        type Future = TlsConnectFuture<S>;
67
68        fn connect(self, stream: S) -> Self::Future {
69            TlsConnectFuture {
70                inner: self.0.connector.connect(self.0.hostname, stream),
71            }
72        }
73    }
74
75    pub struct RustlsStream<S>(TlsStream<S>);
76
77    impl<S> RustlsStream<S> {
78        pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream<S>> {
79            // SAFETY: When `Self` is pinned, so is the inner `TlsStream`.
80            #[allow(unsafe_code)]
81            unsafe {
82                self.map_unchecked_mut(|this| &mut this.0)
83            }
84        }
85    }
86
87    impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
88    where
89        S: AsyncRead + AsyncWrite + Unpin,
90    {
91        fn channel_binding(&self) -> ChannelBinding {
92            let (_, session) = self.0.get_ref();
93            match session.peer_certificates() {
94                Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0])
95                    .ok()
96                    .and_then(|cert| {
97                        let digest = match cert.signature.oid {
98                            // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
99                            ID_SHA_1
100                            | ID_SHA_256
101                            | SHA_1_WITH_RSA_ENCRYPTION
102                            | SHA_256_WITH_RSA_ENCRYPTION
103                            | ECDSA_WITH_SHA_256 => &digest::SHA256,
104                            ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
105                                &digest::SHA384
106                            }
107                            ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
108                                &digest::SHA512
109                            }
110                            _ => return None,
111                        };
112
113                        Some(digest)
114                    })
115                    .map_or_else(ChannelBinding::none, |algorithm| {
116                        let hash = digest::digest(algorithm, certs[0].as_ref());
117                        ChannelBinding::tls_server_end_point(hash.as_ref().into())
118                    }),
119                _ => ChannelBinding::none(),
120            }
121        }
122    }
123
124    impl<S> AsyncRead for RustlsStream<S>
125    where
126        S: AsyncRead + AsyncWrite + Unpin,
127    {
128        fn poll_read(
129            self: Pin<&mut Self>,
130            cx: &mut Context<'_>,
131            buf: &mut ReadBuf<'_>,
132        ) -> Poll<tokio::io::Result<()>> {
133            self.project_stream().poll_read(cx, buf)
134        }
135    }
136
137    impl<S> AsyncWrite for RustlsStream<S>
138    where
139        S: AsyncRead + AsyncWrite + Unpin,
140    {
141        fn poll_write(
142            self: Pin<&mut Self>,
143            cx: &mut Context<'_>,
144            buf: &[u8],
145        ) -> Poll<tokio::io::Result<usize>> {
146            self.project_stream().poll_write(cx, buf)
147        }
148
149        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
150            self.project_stream().poll_flush(cx)
151        }
152
153        fn poll_shutdown(
154            self: Pin<&mut Self>,
155            cx: &mut Context<'_>,
156        ) -> Poll<tokio::io::Result<()>> {
157            self.project_stream().poll_shutdown(cx)
158        }
159    }
160}
161
162/// A `MakeTlsConnect` implementation using `rustls`.
163///
164/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
165#[derive(Clone)]
166pub struct MakeRustlsConnect {
167    config: Arc<ClientConfig>,
168}
169
170impl MakeRustlsConnect {
171    /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
172    #[must_use]
173    pub fn new(config: ClientConfig) -> Self {
174        Self {
175            config: Arc::new(config),
176        }
177    }
178}
179
180impl<S> MakeTlsConnect<S> for MakeRustlsConnect
181where
182    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
183{
184    type Stream = private::RustlsStream<S>;
185    type TlsConnect = private::RustlsConnect;
186    type Error = rustls::pki_types::InvalidDnsNameError;
187
188    fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
189        ServerName::try_from(hostname).map(|dns_name| {
190            private::RustlsConnect(private::RustlsConnectData {
191                hostname: dns_name.to_owned(),
192                connector: Arc::clone(&self.config).into(),
193            })
194        })
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use rustls::pki_types::{CertificateDer, UnixTime};
202    use rustls::{
203        client::danger::ServerCertVerifier,
204        client::danger::{HandshakeSignatureValid, ServerCertVerified},
205        Error, SignatureScheme,
206    };
207
208    #[derive(Debug)]
209    struct AcceptAllVerifier {}
210    impl ServerCertVerifier for AcceptAllVerifier {
211        fn verify_server_cert(
212            &self,
213            _end_entity: &CertificateDer<'_>,
214            _intermediates: &[CertificateDer<'_>],
215            _server_name: &ServerName<'_>,
216            _ocsp_response: &[u8],
217            _now: UnixTime,
218        ) -> Result<ServerCertVerified, Error> {
219            Ok(ServerCertVerified::assertion())
220        }
221
222        fn verify_tls12_signature(
223            &self,
224            _message: &[u8],
225            _cert: &CertificateDer<'_>,
226            _dss: &rustls::DigitallySignedStruct,
227        ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
228            Ok(HandshakeSignatureValid::assertion())
229        }
230
231        fn verify_tls13_signature(
232            &self,
233            _message: &[u8],
234            _cert: &CertificateDer<'_>,
235            _dss: &rustls::DigitallySignedStruct,
236        ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
237            Ok(HandshakeSignatureValid::assertion())
238        }
239
240        fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
241            vec![
242                SignatureScheme::ECDSA_NISTP384_SHA384,
243                SignatureScheme::ECDSA_NISTP256_SHA256,
244                SignatureScheme::RSA_PSS_SHA512,
245                SignatureScheme::RSA_PSS_SHA384,
246                SignatureScheme::RSA_PSS_SHA256,
247                SignatureScheme::ED25519,
248            ]
249        }
250    }
251
252    #[tokio::test]
253    async fn it_works() {
254        env_logger::builder().is_test(true).try_init().unwrap();
255
256        let mut config = rustls::ClientConfig::builder()
257            .with_root_certificates(rustls::RootCertStore::empty())
258            .with_no_client_auth();
259        config
260            .dangerous()
261            .set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
262        let tls = super::MakeRustlsConnect::new(config);
263        let (client, conn) = tokio_postgres::connect(
264            "sslmode=require host=localhost port=5432 user=postgres",
265            tls,
266        )
267        .await
268        .expect("connect");
269        tokio::spawn(async move { conn.await.map_err(|e| panic!("{:?}", e)) });
270        let stmt = client.prepare("SELECT 1").await.expect("prepare");
271        let _ = client.query(&stmt, &[]).await.expect("query");
272    }
273}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy