diff --git a/.clippy.toml b/.clippy.toml new file mode 100644 index 0000000..f357307 --- /dev/null +++ b/.clippy.toml @@ -0,0 +1 @@ +doc-valid-idents = ["PostgreSQL"] diff --git a/Cargo.toml b/Cargo.toml index 23b6ae7..8b83023 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,23 +1,31 @@ [package] name = "tokio-postgres-rustls" description = "Rustls integration for tokio-postgres" -version = "0.4.1" -authors = ["Jasper "] +version = "0.13.0" +authors = ["Jasper Hugo "] repository = "https://github.com/jbg/tokio-postgres-rustls" edition = "2018" license = "MIT" readme = "README.md" [dependencies] -bytes = "0.5.4" -futures = "0.3.4" -ring = "0.16.11" -rustls = "0.17.0" -tokio = "0.2.16" -tokio-postgres = "0.5.3" -tokio-rustls = "0.13.0" -webpki = "0.21.2" +const-oid = { version = "0.9.6", default-features = false, features = ["db"] } +ring = { version = "0.17", default-features = false } +rustls = { version = "0.23", default-features = false } +tokio = { version = "1", default-features = false } +tokio-postgres = { version = "0.7", default-features = false } +tokio-rustls = { version = "0.26", default-features = false } +x509-cert = { version = "0.2.5", default-features = false, features = ["std"] } [dev-dependencies] -env_logger = { version = "0.7.1", default-features = false } -tokio = { version = "0.2.16", features = ["macros"] } +env_logger = { version = "0.11", default-features = false } +tokio = { version = "1", default-features = false, features = ["macros", "rt"] } +tokio-postgres = { version = "0.7", default-features = false, features = [ + "runtime", +] } +rustls = { version = "0.23", default-features = false, features = [ + "std", + "logging", + "tls12", + "ring", +] } diff --git a/README.md b/README.md index 1202d64..e10efb5 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,9 @@ and the [tokio-postgres asynchronous PostgreSQL client library](https://github.c # Example ``` -let config = rustls::ClientConfig::new(); +let config = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config); let connect_fut = tokio_postgres::connect("sslmode=require host=localhost user=postgres", tls); // ... @@ -17,4 +19,3 @@ let connect_fut = tokio_postgres::connect("sslmode=require host=localhost user=p # License tokio-postgres-rustls is distributed under the MIT license. - diff --git a/src/lib.rs b/src/lib.rs index 218e7d6..eccd1ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,166 +1,272 @@ -use std::{ - future::Future, - io, - mem::MaybeUninit, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use bytes::{Buf, BufMut}; -use futures::future::{FutureExt, TryFutureExt}; -use ring::digest; -use rustls::{ClientConfig, Session}; +#![doc = include_str!("../README.md")] +#![forbid(rust_2018_idioms)] +#![deny(missing_docs, unsafe_code)] +#![warn(clippy::all, clippy::pedantic)] + +use std::{convert::TryFrom, sync::Arc}; + +use rustls::{pki_types::ServerName, ClientConfig}; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect}; -use tokio_rustls::{client::TlsStream, TlsConnector}; -use webpki::{DNSName, DNSNameRef}; +use tokio_postgres::tls::MakeTlsConnect; -#[derive(Clone)] -pub struct MakeRustlsConnect { - config: Arc, -} +mod private { + use std::{ + future::Future, + io, + pin::Pin, + task::{Context, Poll}, + }; -impl MakeRustlsConnect { - pub fn new(config: ClientConfig) -> Self { - Self { - config: Arc::new(config), + use const_oid::db::{ + rfc5912::{ + ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512, + SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION, + SHA_512_WITH_RSA_ENCRYPTION, + }, + rfc8410::ID_ED_25519, + }; + use ring::digest; + use rustls::pki_types::ServerName; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_postgres::tls::{ChannelBinding, TlsConnect}; + use tokio_rustls::{client::TlsStream, TlsConnector}; + use x509_cert::{der::Decode, TbsCertificate}; + + pub struct TlsConnectFuture { + pub inner: tokio_rustls::Connect, + } + + impl Future for TlsConnectFuture + where + S: AsyncRead + AsyncWrite + Unpin, + { + type Output = io::Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: If `self` is pinned, so is `inner`. + #[allow(unsafe_code)] + let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) }; + fut.poll(cx).map_ok(RustlsStream) } } -} -impl MakeTlsConnect for MakeRustlsConnect -where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - type Stream = RustlsStream; - type TlsConnect = RustlsConnect; - type Error = io::Error; + pub struct RustlsConnect(pub RustlsConnectData); - fn make_tls_connect(&mut self, hostname: &str) -> io::Result { - DNSNameRef::try_from_ascii_str(hostname) - .map(|dns_name| RustlsConnect { - hostname: dns_name.to_owned(), - connector: Arc::clone(&self.config).into(), - }) - .map_err(|_| io::ErrorKind::InvalidInput.into()) + pub struct RustlsConnectData { + pub hostname: ServerName<'static>, + pub connector: TlsConnector, } -} -pub struct RustlsConnect { - hostname: DNSName, - connector: TlsConnector, -} + impl TlsConnect for RustlsConnect + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Stream = RustlsStream; + type Error = io::Error; + type Future = TlsConnectFuture; -impl TlsConnect for RustlsConnect -where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - type Stream = RustlsStream; - type Error = io::Error; - type Future = Pin>> + Send>>; - - fn connect(self, stream: S) -> Self::Future { - self.connector - .connect(self.hostname.as_ref(), stream) - .map_ok(|s| RustlsStream(Box::pin(s))) - .boxed() + fn connect(self, stream: S) -> Self::Future { + TlsConnectFuture { + inner: self.0.connector.connect(self.0.hostname, stream), + } + } } -} -pub struct RustlsStream(Pin>>); + pub struct RustlsStream(TlsStream); -impl tokio_postgres::tls::TlsStream for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn channel_binding(&self) -> ChannelBinding { - let (_, session) = self.0.get_ref(); - match session.get_peer_certificates() { - Some(certs) if certs.len() > 0 => { - let sha256 = digest::digest(&digest::SHA256, certs[0].as_ref()); - ChannelBinding::tls_server_end_point(sha256.as_ref().into()) + impl RustlsStream { + pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream> { + // SAFETY: When `Self` is pinned, so is the inner `TlsStream`. + #[allow(unsafe_code)] + unsafe { + self.map_unchecked_mut(|this| &mut this.0) } - _ => ChannelBinding::none(), } } -} -impl AsyncRead for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { - self.0.as_mut().poll_read(cx, buf) + impl tokio_postgres::tls::TlsStream for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn channel_binding(&self) -> ChannelBinding { + let (_, session) = self.0.get_ref(); + match session.peer_certificates() { + Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0]) + .ok() + .and_then(|cert| { + let digest = match cert.signature.oid { + // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1 + ID_SHA_1 + | ID_SHA_256 + | SHA_1_WITH_RSA_ENCRYPTION + | SHA_256_WITH_RSA_ENCRYPTION + | ECDSA_WITH_SHA_256 => &digest::SHA256, + ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => { + &digest::SHA384 + } + ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => { + &digest::SHA512 + } + _ => return None, + }; + + Some(digest) + }) + .map_or_else(ChannelBinding::none, |algorithm| { + let hash = digest::digest(algorithm, certs[0].as_ref()); + ChannelBinding::tls_server_end_point(hash.as_ref().into()) + }), + _ => ChannelBinding::none(), + } + } } - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.0.prepare_uninitialized_buffer(buf) + impl AsyncRead for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project_stream().poll_read(cx, buf) + } } - fn poll_read_buf( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut B, - ) -> Poll> + impl AsyncWrite for RustlsStream where - Self: Sized, + S: AsyncRead + AsyncWrite + Unpin, { - self.0.as_mut().poll_read_buf(cx, buf) + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project_stream().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project_stream().poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project_stream().poll_shutdown(cx) + } } } -impl AsyncWrite for RustlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - self.0.as_mut().poll_write(cx, buf) - } +/// A `MakeTlsConnect` implementation using `rustls`. +/// +/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. +#[derive(Clone)] +pub struct MakeRustlsConnect { + config: Arc, +} - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.0.as_mut().poll_flush(cx) +impl MakeRustlsConnect { + /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. + #[must_use] + pub fn new(config: ClientConfig) -> Self { + Self { + config: Arc::new(config), + } } +} - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.0.as_mut().poll_shutdown(cx) - } +impl MakeTlsConnect for MakeRustlsConnect +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Stream = private::RustlsStream; + type TlsConnect = private::RustlsConnect; + type Error = rustls::pki_types::InvalidDnsNameError; - fn poll_write_buf( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut B, - ) -> Poll> - where - Self: Sized, - { - self.0.as_mut().poll_write_buf(cx, buf) + fn make_tls_connect(&mut self, hostname: &str) -> Result { + ServerName::try_from(hostname).map(|dns_name| { + private::RustlsConnect(private::RustlsConnectData { + hostname: dns_name.to_owned(), + connector: Arc::clone(&self.config).into(), + }) + }) } } #[cfg(test)] mod tests { - use futures::future::TryFutureExt; + use super::*; + use rustls::pki_types::{CertificateDer, UnixTime}; + use rustls::{ + client::danger::ServerCertVerifier, + client::danger::{HandshakeSignatureValid, ServerCertVerified}, + Error, SignatureScheme, + }; + + #[derive(Debug)] + struct AcceptAllVerifier {} + impl ServerCertVerifier for AcceptAllVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::ED25519, + ] + } + } #[tokio::test] async fn it_works() { env_logger::builder().is_test(true).try_init().unwrap(); - let config = rustls::ClientConfig::new(); + let mut config = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + config + .dangerous() + .set_certificate_verifier(Arc::new(AcceptAllVerifier {})); let tls = super::MakeRustlsConnect::new(config); - let (client, conn) = - tokio_postgres::connect("sslmode=require host=localhost user=postgres", tls) - .await - .expect("connect"); - tokio::spawn(conn.map_err(|e| panic!("{:?}", e))); + let (client, conn) = tokio_postgres::connect( + "sslmode=require host=localhost port=5432 user=postgres", + tls, + ) + .await + .expect("connect"); + tokio::spawn(async move { conn.await.map_err(|e| panic!("{:?}", e)) }); let stmt = client.prepare("SELECT 1").await.expect("prepare"); let _ = client.query(&stmt, &[]).await.expect("query"); } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy