@@ -204,6 +204,58 @@ STATIC int _mbedtls_timing_get_delay(void *ctx) {
204
204
}
205
205
#endif
206
206
207
+ STATIC void set_ciphersuites (mbedtls_ssl_config * conf , int is_psk ) {
208
+ static int initialized = 0 ;
209
+ static int * psk_ciphers ;
210
+ static int * pki_ciphers ;
211
+
212
+ if (!initialized ) {
213
+ const int * ciphersuites = mbedtls_ssl_list_ciphersuites ();
214
+
215
+ int count_psk = 0 ;
216
+ int count_pki = 0 ;
217
+ int i = 0 ;
218
+ const mbedtls_ssl_ciphersuite_t * ciphersuite ;
219
+
220
+ for (i = 0 ; ciphersuites [i ] != 0 ; i ++ ) {
221
+ ciphersuite = mbedtls_ssl_ciphersuite_from_id (ciphersuites [i ]);
222
+ if (mbedtls_ssl_ciphersuite_uses_psk (ciphersuite )) {
223
+ count_psk ++ ;
224
+ } else {
225
+ count_pki ++ ;
226
+ }
227
+ }
228
+
229
+ psk_ciphers = malloc (sizeof (int ) * (count_psk + 1 ));
230
+ pki_ciphers = malloc (sizeof (int ) * (count_pki + 1 ));
231
+ if (psk_ciphers == NULL || pki_ciphers == NULL ) {
232
+ free (psk_ciphers );
233
+ free (pki_ciphers );
234
+ mp_raise_OSError (MP_ENOMEM );
235
+ }
236
+
237
+ int * psk_pos = psk_ciphers ;
238
+ int * pki_pos = pki_ciphers ;
239
+
240
+ for (i = 0 ; ciphersuites [i ] != 0 ; i ++ ) {
241
+ ciphersuite = mbedtls_ssl_ciphersuite_from_id (ciphersuites [i ]);
242
+ if (mbedtls_ssl_ciphersuite_uses_psk (ciphersuite )) {
243
+ * psk_pos = ciphersuites [i ];
244
+ psk_pos ++ ;
245
+ } else {
246
+ * pki_pos = ciphersuites [i ];
247
+ pki_pos ++ ;
248
+ }
249
+ }
250
+
251
+ * psk_pos = 0 ;
252
+ * pki_pos = 0 ;
253
+ initialized = 1 ;
254
+ }
255
+
256
+ mbedtls_ssl_conf_ciphersuites (conf , is_psk ? psk_ciphers : pki_ciphers );
257
+ }
258
+
207
259
STATIC mp_obj_ssl_socket_t * socket_new (mp_obj_t sock , struct ssl_args * args ) {
208
260
// Verify the socket object has the full stream protocol
209
261
mp_get_stream_raise (sock , MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL );
@@ -251,8 +303,8 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
251
303
#endif
252
304
253
305
#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED ) || defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED )
254
- // banana() ;
255
- if (args -> psk_identity . u_obj != mp_const_none && args -> psk_key . u_obj != mp_const_none ) {
306
+ int is_psk = args -> psk_identity . u_obj != mp_const_none && args -> psk_key . u_obj != mp_const_none ;
307
+ if (is_psk ) {
256
308
size_t psk_identity_len ;
257
309
size_t psk_key_len ;
258
310
const byte * psk_identity = (const byte * )mp_obj_str_get_data (args -> psk_identity .u_obj , & psk_identity_len );
@@ -264,6 +316,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
264
316
goto cleanup ;
265
317
}
266
318
}
319
+ set_ciphersuites (& o -> conf , is_psk );
267
320
#endif
268
321
269
322
ret = mbedtls_ssl_setup (& o -> ssl , & o -> conf );
0 commit comments