1
- #![ feature( type_alias_impl_trait) ]
2
-
3
1
use std:: {
4
2
io,
5
3
future:: Future ,
@@ -10,8 +8,9 @@ use std::{
10
8
} ;
11
9
12
10
use bytes:: { Buf , BufMut } ;
13
- use futures:: future:: TryFutureExt ;
14
- use rustls:: ClientConfig ;
11
+ use futures:: future:: { FutureExt , TryFutureExt } ;
12
+ use ring:: digest;
13
+ use rustls:: { ClientConfig , Session } ;
15
14
use tokio:: io:: { AsyncRead , AsyncWrite } ;
16
15
use tokio_postgres:: tls:: { ChannelBinding , MakeTlsConnect , TlsConnect } ;
17
16
use tokio_rustls:: { client:: TlsStream , TlsConnector } ;
@@ -30,13 +29,13 @@ impl MakeRustlsConnect {
30
29
31
30
impl < S > MakeTlsConnect < S > for MakeRustlsConnect
32
31
where
33
- S : AsyncRead + AsyncWrite + Unpin ,
32
+ S : AsyncRead + AsyncWrite + Unpin + Send + ' static ,
34
33
{
35
34
type Stream = RustlsStream < S > ;
36
35
type TlsConnect = RustlsConnect ;
37
- type Error = std :: io:: Error ;
36
+ type Error = io:: Error ;
38
37
39
- fn make_tls_connect ( & mut self , hostname : & str ) -> std :: io:: Result < RustlsConnect > {
38
+ fn make_tls_connect ( & mut self , hostname : & str ) -> io:: Result < RustlsConnect > {
40
39
DNSNameRef :: try_from_ascii_str ( hostname)
41
40
. map ( |dns_name| RustlsConnect {
42
41
hostname : dns_name. to_owned ( ) ,
@@ -53,15 +52,16 @@ pub struct RustlsConnect {
53
52
54
53
impl < S > TlsConnect < S > for RustlsConnect
55
54
where
56
- S : AsyncRead + AsyncWrite + Unpin ,
55
+ S : AsyncRead + AsyncWrite + Unpin + Send + ' static ,
57
56
{
58
57
type Stream = RustlsStream < S > ;
59
- type Error = std :: io:: Error ;
60
- type Future = impl Future < Output = std :: io:: Result < RustlsStream < S > > > ;
58
+ type Error = io:: Error ;
59
+ type Future = Pin < Box < dyn Future < Output = io:: Result < RustlsStream < S > > > > > ;
61
60
62
61
fn connect ( self , stream : S ) -> Self :: Future {
63
62
self . connector . connect ( self . hostname . as_ref ( ) , stream)
64
63
. map_ok ( |s| RustlsStream ( Box :: pin ( s) ) )
64
+ . boxed ( )
65
65
}
66
66
}
67
67
72
72
S : AsyncRead + AsyncWrite + Unpin ,
73
73
{
74
74
fn channel_binding ( & self ) -> ChannelBinding {
75
- ChannelBinding :: none ( ) // TODO
75
+ let ( _, session) = self . 0 . get_ref ( ) ;
76
+ match session. get_peer_certificates ( ) {
77
+ Some ( certs) if certs. len ( ) > 0 => {
78
+ let sha256 = digest:: digest ( & digest:: SHA256 , certs[ 0 ] . as_ref ( ) ) ;
79
+ ChannelBinding :: tls_server_end_point ( sha256. as_ref ( ) . into ( ) )
80
+ } ,
81
+ _ => ChannelBinding :: none ( ) ,
82
+ }
76
83
}
77
84
}
78
85
0 commit comments