46
46
#include "mbedtls/debug.h"
47
47
#include "mbedtls/error.h"
48
48
49
+ // flags for _mp_obj_ssl_socket_t.poll_flag that control the poll ioctl
50
+ // the issue is that when using ipoll we may be polling only for reading, and the socket may never
51
+ // become readable because mbedtls needs to write soemthing (like a handshake or renegotiation) and
52
+ // so poll never returns "it's readable" or "it's writable" and so nothing ever makes progress.
53
+ // See also the commit message for
54
+ // https://github.com/micropython/micropython/commit/9c7c082396f717a8a8eb845a0af407e78d38165f
55
+ #define READ_NEEDS_WRITE 0x1 // mbedtls_ssl_read said "I need a write"
56
+ #define WRITE_NEEDS_READ 0x2 // mbedtls_ssl_write said "I need a read"
57
+
49
58
typedef struct _mp_obj_ssl_socket_t {
50
59
mp_obj_base_t base ;
51
60
mp_obj_t sock ;
@@ -56,6 +65,8 @@ typedef struct _mp_obj_ssl_socket_t {
56
65
mbedtls_x509_crt cacert ;
57
66
mbedtls_x509_crt cert ;
58
67
mbedtls_pk_context pkey ;
68
+ uint8_t poll_flag ;
69
+ uint8_t poll_by_read ; // true: at next poll try to read first
59
70
} mp_obj_ssl_socket_t ;
60
71
61
72
struct ssl_args {
@@ -76,46 +87,29 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons
76
87
}
77
88
#endif
78
89
79
- STATIC NORETURN void mbedtls_raise_error (int err ) {
80
- // _mbedtls_ssl_send and _mbedtls_ssl_recv (below) turn positive error codes from the
81
- // underlying socket into negative codes to pass them through mbedtls. Here we turn them
82
- // positive again so they get interpreted as the OSError they really are. The
83
- // cut-off of -256 is a bit hacky, sigh.
84
- if (err < 0 && err > -256 ) {
85
- mp_raise_OSError (- err );
86
- }
87
-
88
- #if defined(MBEDTLS_ERROR_C )
89
- // Including mbedtls_strerror takes about 1.5KB due to the error strings.
90
- // MBEDTLS_ERROR_C is the define used by mbedtls to conditionally include mbedtls_strerror.
91
- // It is set/unset in the MBEDTLS_CONFIG_FILE which is defined in the Makefile.
92
-
93
- // Try to allocate memory for the message
94
- #define ERR_STR_MAX 80 // mbedtls_strerror truncates if it doesn't fit
95
- mp_obj_str_t * o_str = m_new_obj_maybe (mp_obj_str_t );
96
- byte * o_str_buf = m_new_maybe (byte , ERR_STR_MAX );
97
- if (o_str == NULL || o_str_buf == NULL ) {
98
- mp_raise_OSError (err );
90
+ // mod_ssl_errstr returns the error string corresponding to the error code found in an OSError,
91
+ // such as returned by read/write.
92
+ STATIC mp_obj_t mod_ssl_errstr (mp_obj_t err_in ) {
93
+ size_t err = mp_obj_get_int (err_in );
94
+ vstr_t vstr ;
95
+ vstr_init_len (& vstr , 80 );
96
+
97
+ // Including mbedtls_strerror takes about 16KB on the esp32 due to all the strings
98
+ #if 1
99
+ vstr .buf [0 ] = 0 ;
100
+ mbedtls_strerror (err , vstr .buf , vstr .alloc );
101
+ vstr .len = strlen (vstr .buf );
102
+ if (vstr .len == 0 ) {
103
+ return MP_OBJ_NULL ;
99
104
}
100
-
101
- // print the error message into the allocated buffer
102
- mbedtls_strerror (err , (char * )o_str_buf , ERR_STR_MAX );
103
- size_t len = strlen ((char * )o_str_buf );
104
-
105
- // Put the exception object together
106
- o_str -> base .type = & mp_type_str ;
107
- o_str -> data = o_str_buf ;
108
- o_str -> len = len ;
109
- o_str -> hash = qstr_compute_hash (o_str -> data , o_str -> len );
110
- // raise
111
- mp_obj_t args [2 ] = { MP_OBJ_NEW_SMALL_INT (err ), MP_OBJ_FROM_PTR (o_str )};
112
- nlr_raise (mp_obj_exception_make_new (& mp_type_OSError , 2 , 0 , args ));
113
105
#else
114
- // mbedtls is compiled without error strings so we simply return the err number
115
- mp_raise_OSError (err ); // err is typically a large negative number
106
+ vstr_printf (vstr , "mbedtls error -0x%x\n" , - err );
116
107
#endif
108
+ return mp_obj_new_str_from_vstr (& mp_type_bytes , & vstr );
117
109
}
110
+ STATIC MP_DEFINE_CONST_FUN_OBJ_1 (mod_ssl_errstr_obj , mod_ssl_errstr );
118
111
112
+ // _mbedtls_ssl_send is called by mbedtls to send bytes onto the underlying socket
119
113
STATIC int _mbedtls_ssl_send (void * ctx , const byte * buf , size_t len ) {
120
114
mp_obj_t sock = * (mp_obj_t * )ctx ;
121
115
@@ -237,6 +231,8 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
237
231
}
238
232
}
239
233
234
+ o -> poll_flag = 0 ;
235
+ o -> poll_by_read = 0 ;
240
236
if (args -> do_handshake .u_bool ) {
241
237
while ((ret = mbedtls_ssl_handshake (& o -> ssl )) != 0 ) {
242
238
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE ) {
@@ -263,7 +259,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
263
259
} else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA ) {
264
260
mp_raise_ValueError (MP_ERROR_TEXT ("invalid cert" ));
265
261
} else {
266
- mbedtls_raise_error ( ret );
262
+ mp_raise_OSError ( - ret );
267
263
}
268
264
}
269
265
@@ -289,12 +285,16 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin
289
285
STATIC mp_uint_t socket_read (mp_obj_t o_in , void * buf , mp_uint_t size , int * errcode ) {
290
286
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
291
287
288
+ o -> poll_flag &= ~READ_NEEDS_WRITE ; // clear flag
292
289
int ret = mbedtls_ssl_read (& o -> ssl , buf , size );
293
290
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ) {
294
291
// end of stream
295
292
return 0 ;
296
293
}
297
294
if (ret >= 0 ) {
295
+ // if we got all we wanted, for the next poll try a read first 'cause
296
+ // there may be data in the mbedtls record buffer
297
+ o -> poll_by_read = ret == size ;
298
298
return ret ;
299
299
}
300
300
if (ret == MBEDTLS_ERR_SSL_WANT_READ ) {
@@ -303,6 +303,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
303
303
// If handshake is not finished, read attempt may end up in protocol
304
304
// wanting to write next handshake message. The same may happen with
305
305
// renegotation.
306
+ o -> poll_flag |= READ_NEEDS_WRITE ; // set flag
306
307
ret = MP_EWOULDBLOCK ;
307
308
}
308
309
* errcode = ret ;
@@ -312,6 +313,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
312
313
STATIC mp_uint_t socket_write (mp_obj_t o_in , const void * buf , mp_uint_t size , int * errcode ) {
313
314
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
314
315
316
+ o -> poll_flag &= ~WRITE_NEEDS_READ ; // clear flag
315
317
int ret = mbedtls_ssl_write (& o -> ssl , buf , size );
316
318
if (ret >= 0 ) {
317
319
return ret ;
@@ -322,6 +324,7 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
322
324
// If handshake is not finished, write attempt may end up in protocol
323
325
// wanting to read next handshake message. The same may happen with
324
326
// renegotation.
327
+ o -> poll_flag |= WRITE_NEEDS_READ ; // set flag
325
328
ret = MP_EWOULDBLOCK ;
326
329
}
327
330
* errcode = ret ;
@@ -348,6 +351,43 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
348
351
mbedtls_ssl_config_free (& self -> conf );
349
352
mbedtls_ctr_drbg_free (& self -> ctr_drbg );
350
353
mbedtls_entropy_free (& self -> entropy );
354
+ } else if (request == MP_STREAM_POLL ) {
355
+ mp_uint_t ret = 0 ;
356
+ // If the last read returned everything asked for there may be more in the mbedtls buffer,
357
+ // so find out. (There doesn't seem to be an equivalent issue with writes.)
358
+ if ((arg & MP_STREAM_POLL_RD ) && self -> poll_by_read ) {
359
+ size_t avail = mbedtls_ssl_get_bytes_avail (& self -> ssl );
360
+ if (avail > 0 ) {
361
+ ret = MP_STREAM_POLL_RD ;
362
+ }
363
+ }
364
+ // If we're polling to read but not write but mbedtls previously said it needs to write in
365
+ // order to be able to read then poll for both and if either is available pretend the socket
366
+ // is readable. When the app then performs a read, mbedtls is happy to perform the writes as
367
+ // well. Essentially, what we're ensuring is that one of mbedtls' read/write functions is
368
+ // called as soon as the socket can do something.
369
+ if ((arg & MP_STREAM_POLL_RD ) && !(arg & MP_STREAM_POLL_WR ) &&
370
+ self -> poll_flag & READ_NEEDS_WRITE ) {
371
+ arg |= MP_STREAM_POLL_WR ;
372
+ ret |= mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
373
+ if (ret & MP_STREAM_POLL_WR ) {
374
+ ret |= MP_STREAM_POLL_RD ;
375
+ ret &= ~MP_STREAM_POLL_WR ;
376
+ }
377
+ return ret ;
378
+ // Now comes the same logic flipped around for write
379
+ } else if ((arg & MP_STREAM_POLL_WR ) && !(arg & MP_STREAM_POLL_RD ) &&
380
+ self -> poll_flag & WRITE_NEEDS_READ ) {
381
+ arg |= MP_STREAM_POLL_RD ;
382
+ ret |= mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
383
+ if (ret & MP_STREAM_POLL_RD ) {
384
+ ret |= MP_STREAM_POLL_WR ;
385
+ ret &= ~MP_STREAM_POLL_RD ;
386
+ }
387
+ return ret ;
388
+ }
389
+ // Pass down to underlying socket
390
+ return ret | mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
351
391
}
352
392
// Pass all requests down to the underlying socket
353
393
return mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
@@ -409,6 +449,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socke
409
449
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table [] = {
410
450
{ MP_ROM_QSTR (MP_QSTR___name__ ), MP_ROM_QSTR (MP_QSTR_ussl ) },
411
451
{ MP_ROM_QSTR (MP_QSTR_wrap_socket ), MP_ROM_PTR (& mod_ssl_wrap_socket_obj ) },
452
+ { MP_ROM_QSTR (MP_QSTR_errstr ), MP_ROM_PTR (& mod_ssl_errstr_obj ) },
412
453
};
413
454
414
455
STATIC MP_DEFINE_CONST_DICT (mp_module_ssl_globals , mp_module_ssl_globals_table );
0 commit comments