Skip to content

moduasyncio: Add SSL support #5840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions docs/library/ussl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,24 @@ facilities for network sockets, both client-side and server-side.
Functions
---------

.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, do_handshake=True)
.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, server_hostname=None, do_handshake=True)

Takes a `stream` *sock* (usually usocket.socket instance of ``SOCK_STREAM`` type),
and returns an instance of ssl.SSLSocket, which wraps the underlying stream in
an SSL context. Returned object has the usual `stream` interface methods like
an SSL context. The returned object has the usual `stream` interface methods like
``read()``, ``write()``, etc.
A server-side SSL socket should be created from a normal socket returned from
:meth:`~usocket.socket.accept()` on a non-SSL listening server socket.

- *do_handshake* determines whether the handshake is done as part of the ``wrap_socket``
Parameters:

- ``server_side``: creates a server connection if True, else client connection. A
server connection requires a ``keyfile`` and a ``certfile``.
- ``cert_reqs``: specifies the level of certificate checking to be performed.
- ``ca_certs``: root certificates to use for certificate checking.
- ``server_hostname``: specifies the hostname of the server for verification purposes
as well for SNI (Server Name Identification).
- ``do_handshake``: determines whether the handshake is done as part of the ``wrap_socket``
or whether it is deferred to be done as part of the initial reads or writes
(there is no ``do_handshake`` method as in CPython).
For blocking sockets doing the handshake immediately is standard. For non-blocking
Expand Down Expand Up @@ -58,3 +67,11 @@ Constants
ussl.CERT_REQUIRED

Supported values for *cert_reqs* parameter.

- CERT_NONE: in client mode accept just about any cert, in server mode do not
request a cert from the client.
- CERT_OPTIONAL: in client mode behaves the same as CERT_REQUIRED and in server
mode requests an optional cert from the client for authentication.
- CERT_REQUIRED: in client mode validates the server's cert and
in server mode requires the client to send a cert for authentication. Note that
ussl does not actually support client authentication.
113 changes: 77 additions & 36 deletions extmod/modussl_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@
#include "mbedtls/debug.h"
#include "mbedtls/error.h"

// flags for _mp_obj_ssl_socket_t.poll_flag that control the poll ioctl
// the issue is that when using ipoll we may be polling only for reading, and the socket may never
// become readable because mbedtls needs to write soemthing (like a handshake or renegotiation) and
// so poll never returns "it's readable" or "it's writable" and so nothing ever makes progress.
// See also the commit message for
// https://github.com/micropython/micropython/commit/9c7c082396f717a8a8eb845a0af407e78d38165f
#define READ_NEEDS_WRITE 0x1 // mbedtls_ssl_read said "I need a write"
#define WRITE_NEEDS_READ 0x2 // mbedtls_ssl_write said "I need a read"

typedef struct _mp_obj_ssl_socket_t {
mp_obj_base_t base;
mp_obj_t sock;
Expand All @@ -56,6 +65,8 @@ typedef struct _mp_obj_ssl_socket_t {
mbedtls_x509_crt cacert;
mbedtls_x509_crt cert;
mbedtls_pk_context pkey;
uint8_t poll_flag;
uint8_t poll_by_read; // true: at next poll try to read first
} mp_obj_ssl_socket_t;

struct ssl_args {
Expand All @@ -76,46 +87,29 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons
}
#endif

STATIC NORETURN void mbedtls_raise_error(int err) {
// _mbedtls_ssl_send and _mbedtls_ssl_recv (below) turn positive error codes from the
// underlying socket into negative codes to pass them through mbedtls. Here we turn them
// positive again so they get interpreted as the OSError they really are. The
// cut-off of -256 is a bit hacky, sigh.
if (err < 0 && err > -256) {
mp_raise_OSError(-err);
}

#if defined(MBEDTLS_ERROR_C)
// Including mbedtls_strerror takes about 1.5KB due to the error strings.
// MBEDTLS_ERROR_C is the define used by mbedtls to conditionally include mbedtls_strerror.
// It is set/unset in the MBEDTLS_CONFIG_FILE which is defined in the Makefile.

// Try to allocate memory for the message
#define ERR_STR_MAX 80 // mbedtls_strerror truncates if it doesn't fit
mp_obj_str_t *o_str = m_new_obj_maybe(mp_obj_str_t);
byte *o_str_buf = m_new_maybe(byte, ERR_STR_MAX);
if (o_str == NULL || o_str_buf == NULL) {
mp_raise_OSError(err);
// mod_ssl_errstr returns the error string corresponding to the error code found in an OSError,
// such as returned by read/write.
STATIC mp_obj_t mod_ssl_errstr(mp_obj_t err_in) {
size_t err = mp_obj_get_int(err_in);
vstr_t vstr;
vstr_init_len(&vstr, 80);

// Including mbedtls_strerror takes about 16KB on the esp32 due to all the strings
#if 1
vstr.buf[0] = 0;
mbedtls_strerror(err, vstr.buf, vstr.alloc);
vstr.len = strlen(vstr.buf);
if (vstr.len == 0) {
return MP_OBJ_NULL;
}

// print the error message into the allocated buffer
mbedtls_strerror(err, (char *)o_str_buf, ERR_STR_MAX);
size_t len = strlen((char *)o_str_buf);

// Put the exception object together
o_str->base.type = &mp_type_str;
o_str->data = o_str_buf;
o_str->len = len;
o_str->hash = qstr_compute_hash(o_str->data, o_str->len);
// raise
mp_obj_t args[2] = { MP_OBJ_NEW_SMALL_INT(err), MP_OBJ_FROM_PTR(o_str)};
nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args));
#else
// mbedtls is compiled without error strings so we simply return the err number
mp_raise_OSError(err); // err is typically a large negative number
vstr_printf(vstr, "mbedtls error -0x%x\n", -err);
#endif
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_ssl_errstr_obj, mod_ssl_errstr);

// _mbedtls_ssl_send is called by mbedtls to send bytes onto the underlying socket
STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
mp_obj_t sock = *(mp_obj_t *)ctx;

Expand Down Expand Up @@ -237,6 +231,8 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
}
}

o->poll_flag = 0;
o->poll_by_read = 0;
if (args->do_handshake.u_bool) {
while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
Expand All @@ -263,7 +259,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
} else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA) {
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
} else {
mbedtls_raise_error(ret);
mp_raise_OSError(-ret);
}
}

Expand All @@ -289,12 +285,16 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin
STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);

o->poll_flag &= ~READ_NEEDS_WRITE; // clear flag
int ret = mbedtls_ssl_read(&o->ssl, buf, size);
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
// end of stream
return 0;
}
if (ret >= 0) {
// if we got all we wanted, for the next poll try a read first 'cause
// there may be data in the mbedtls record buffer
o->poll_by_read = ret == size;
return ret;
}
if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
Expand All @@ -303,6 +303,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
// If handshake is not finished, read attempt may end up in protocol
// wanting to write next handshake message. The same may happen with
// renegotation.
o->poll_flag |= READ_NEEDS_WRITE; // set flag
ret = MP_EWOULDBLOCK;
}
*errcode = ret;
Expand All @@ -312,6 +313,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);

o->poll_flag &= ~WRITE_NEEDS_READ; // clear flag
int ret = mbedtls_ssl_write(&o->ssl, buf, size);
if (ret >= 0) {
return ret;
Expand All @@ -322,6 +324,7 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
// If handshake is not finished, write attempt may end up in protocol
// wanting to read next handshake message. The same may happen with
// renegotation.
o->poll_flag |= WRITE_NEEDS_READ; // set flag
ret = MP_EWOULDBLOCK;
}
*errcode = ret;
Expand All @@ -348,6 +351,43 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
mbedtls_ssl_config_free(&self->conf);
mbedtls_ctr_drbg_free(&self->ctr_drbg);
mbedtls_entropy_free(&self->entropy);
} else if (request == MP_STREAM_POLL) {
mp_uint_t ret = 0;
// If the last read returned everything asked for there may be more in the mbedtls buffer,
// so find out. (There doesn't seem to be an equivalent issue with writes.)
if ((arg & MP_STREAM_POLL_RD) && self->poll_by_read) {
size_t avail = mbedtls_ssl_get_bytes_avail(&self->ssl);
if (avail > 0) {
ret = MP_STREAM_POLL_RD;
}
}
// If we're polling to read but not write but mbedtls previously said it needs to write in
// order to be able to read then poll for both and if either is available pretend the socket
// is readable. When the app then performs a read, mbedtls is happy to perform the writes as
// well. Essentially, what we're ensuring is that one of mbedtls' read/write functions is
// called as soon as the socket can do something.
if ((arg & MP_STREAM_POLL_RD) && !(arg & MP_STREAM_POLL_WR) &&
self->poll_flag & READ_NEEDS_WRITE) {
arg |= MP_STREAM_POLL_WR;
ret |= mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
if (ret & MP_STREAM_POLL_WR) {
ret |= MP_STREAM_POLL_RD;
ret &= ~MP_STREAM_POLL_WR;
}
return ret;
// Now comes the same logic flipped around for write
} else if ((arg & MP_STREAM_POLL_WR) && !(arg & MP_STREAM_POLL_RD) &&
self->poll_flag & WRITE_NEEDS_READ) {
arg |= MP_STREAM_POLL_RD;
ret |= mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
if (ret & MP_STREAM_POLL_RD) {
ret |= MP_STREAM_POLL_WR;
ret &= ~MP_STREAM_POLL_RD;
}
return ret;
}
// Pass down to underlying socket
return ret | mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
}
// Pass all requests down to the underlying socket
return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
Expand Down Expand Up @@ -409,6 +449,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socke
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ussl) },
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
{ MP_ROM_QSTR(MP_QSTR_errstr), MP_ROM_PTR(&mod_ssl_errstr_obj) },
};

STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
Expand Down
27 changes: 24 additions & 3 deletions extmod/uasyncio/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

from . import core

try:
import ssl as modssl # module is used in function that has an ssl parameter
except:
modssl = None


class Stream:
def __init__(self, s, e={}):
Expand Down Expand Up @@ -71,20 +76,36 @@ async def drain(self):


# Create a TCP stream connection to a remote host
async def open_connection(host, port):
async def open_connection(host, port, ssl=None, server_hostname=None):
from uerrno import EINPROGRESS
import usocket as socket

ai = socket.getaddrinfo(host, port)[0] # TODO this is blocking!
s = socket.socket()
s.setblocking(False)
ss = Stream(s)
try:
s.connect(ai[-1])
except OSError as er:
if er.args[0] != EINPROGRESS:
raise er
yield core._io_queue.queue_write(s)
# wrap with SSL, if requested
if ssl:
if not modssl:
raise ValueError("SSL not supported")
if ssl is True:
ssl = {} # spec says to use ssl.create_default_context() but we don't have that
elif isinstance(ssl, dict):
# non-standard: accept dict with KW args suitable to call ssl.wrap_socket()
if server_hostname:
# spec: server_hostname sets or overrides the hostname that the target server’s
# certificate will be matched against.
ssl["server_hostname"] = server_hostname
else:
# spec says we should handle ssl.SSLContext object, but ain't got that
raise ValueError("invalid ssl param")
ssl["do_handshake"] = False # as non-blocking as possible
s = modssl.wrap_socket(s, **ssl)
ss = Stream(s)
return ss, ss


Expand Down
3 changes: 3 additions & 0 deletions ports/esp32/modsocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ STATIC mp_obj_t socket_connect(const mp_obj_t arg0, const mp_obj_t arg1) {
MP_THREAD_GIL_ENTER();
lwip_freeaddrinfo(res);
if (r != 0) {
// side-note: LwIP internally doesn't seem to have an error code for ECONNREFUSED and
// so refused connections show up as ECONNRESET. Could be band-aided for blocking connect,
// harder to do for nonblocking.
mp_raise_OSError(errno);
}

Expand Down
60 changes: 58 additions & 2 deletions tests/multi_net/ssl_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,58 @@
# Simple test creating an SSL connection and transferring some data
# This test won't run under CPython because it requires key/cert

import usocket as socket, ussl as ssl
try:
import usocket as socket, ussl as ssl, ubinascii as binascii, uselect as select
except ModuleNotFoundError:
import socket, ssl, binascii, select

PORT = 8000

# This self-signed key/cert pair is randomly generated and to be used for
# testing/demonstration only.
# openssl req -x509 -newkey rsa:1024 -keyout key.pem -out cert.pem -days 36500 -nodes
cert = """
-----BEGIN CERTIFICATE-----
MIICaDCCAdGgAwIBAgIUaYEwlY581HuPWHm2ndTWejuggAIwDQYJKoZIhvcNAQEL
BQAwRTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMDA0MTgxOTAwMDBaGA8yMTIw
MDMyNTE5MDAwMFowRTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCBnzANBgkqhkiG9w0B
AQEFAAOBjQAwgYkCgYEAxmACtMgGR2tTKVHzxG67Yx61pWNynXUE0q00yJ0a34AK
uQKzvyEdvkk5lL3snV4N5wKeRgWmS3/krl/YQO+Rk4eSJRqJc8INd3qSOFSNUgPg
W0VPP9vPox8au5Ngqn06jgtdD1F0a6Z+f+N3+JyRPAaetIWlFC9WEn+zzz0/cmkC
AwEAAaNTMFEwHQYDVR0OBBYEFBaI7GVj4GjxPWq+RO7A/4INOq2RMB8GA1UdIwQY
MBaAFBaI7GVj4GjxPWq+RO7A/4INOq2RMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZI
hvcNAQELBQADgYEAMpdYd8jkWxoXMxV+X2rpyx/BnPrPa+l2LehlulrU7lRh4QIU
t4f+W+yBvkFscPatpRfJoXXqregmhLxo8poKw08pjn7DNKBzcsPsxnmRIvFZuL2J
wYHGyP9HcMpsnx+UW2YjjQ4R1I0smRI7ZKiax8AJkN/P9eHH9Xku6ostXYk=
-----END CERTIFICATE-----
"""
key = """
-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBAMZgArTIBkdrUylR
88Ruu2MetaVjcp11BNKtNMidGt+ACrkCs78hHb5JOZS97J1eDecCnkYFpkt/5K5f
2EDvkZOHkiUaiXPCDXd6kjhUjVID4FtFTz/bz6MfGruTYKp9Oo4LXQ9RdGumfn/j
d/ickTwGnrSFpRQvVhJ/s889P3JpAgMBAAECgYBPkxnizM3//iRY0d/37vdKFnqF
AnRqhxNNM1+WDbdG6kTi3BugUrdsqlDnwpvUsHLhNOKqcf+4D3B7JkVIHxGEqLSl
YMbQrldodPwIP0ycf9hegzuhEvuYGkex22edmQ5brkdIt6QCv0QRtProYowJx4p6
CuM5423ORejs6Vw9gQJBAOF//1Ovmm5Q1d90ZzjFhZCwG3/z5uwqZMGBxJTaibSC
O5cci3n9Tcc4AebnMf5eyrXHovtSg1FfDxS+IUccXRECQQDhNM3R31YvYmRZwrTn
f71y+buXpUtMDUDhFK8FNZN1/zJ6dJVrWQ/MVj+TaNjLUYNdPmRPHQdt8+Fx65y9
95/ZAkEAqgmkdGwz3P9jZm4V778xqhrBgche1rJY63l4zG3F7LFPUfEaU1BoN9LJ
zF2FWzQLUutIwI5FqzQs4Q1FdqOyoQJBALAL1iUMwFO0R5v/X+lj6xXY8PM/jJf7
+E67G4In+okQIEanojJTYc0rUvGJ0YdGxjj6z/EkUS17qy2hsFq0GykCQQCiucp9
7kbPpzw/gW+ERfoLgtZKrP/+Au9C5sz2wxUpeKhYihVePF8pmytyD8mqt/3LIJhZ
NA2FEss2+KJUCjHc
-----END PRIVATE KEY-----
"""
chain = cert + key
# Produce cert/key for MicroPython
cert = cert[cert.index("M") : cert.index("=") + 2]
key = key[key.index("M") : key.rstrip().rindex("\n") + 1]
cert = binascii.a2b_base64(cert)
key = binascii.a2b_base64(key)


# Server
def instance0():
Expand All @@ -15,7 +63,15 @@ def instance0():
s.listen(1)
multitest.next()
s2, _ = s.accept()
s2 = ssl.wrap_socket(s2, server_side=True)
if hasattr(ssl, "SSLContext"):
fn = "/tmp/MP_test_cert.pem"
with open(fn, "w") as f:
f.write(chain)
ctx = ssl.SSLContext()
ctx.load_cert_chain(fn)
s2 = ctx.wrap_socket(s2, server_side=True)
else:
s2 = ssl.wrap_socket(s2, server_side=True, key=key, cert=cert)
print(s2.read(16))
s2.write(b"server to client")
s.close()
Expand Down
4 changes: 0 additions & 4 deletions tests/multi_net/ssl_data.py.exp

This file was deleted.

Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy