Skip to content

extmod/modtls_mbedtls: Add support for TLS PSK #17074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extmod/mbedtls/mbedtls_config_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#define MBEDTLS_KEY_EXCHANGE_RSA_ENABLED
#define MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
#define MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED
#define MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
#define MBEDTLS_CAN_ECDH
#define MBEDTLS_PK_CAN_ECDSA_SIGN
#define MBEDTLS_PKCS1_V15
Expand Down
121 changes: 121 additions & 0 deletions extmod/modtls_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#endif
#include "mbedtls/debug.h"
#include "mbedtls/error.h"
#include "mbedtls/ssl_ciphersuites.h"
#if MBEDTLS_VERSION_NUMBER >= 0x03000000
#include "mbedtls/build_info.h"
#else
Expand Down Expand Up @@ -92,6 +93,10 @@
#if MICROPY_PY_SSL_ECDSA_SIGN_ALT
mp_obj_t ecdsa_sign_callback;
#endif

mp_obj_t psk_identity; // PSK identity (string)
mp_obj_t psk_key; // PSK key (bytes)
bool use_psk; // Flag to indicate if PSK should be used
} mp_obj_ssl_context_t;

// This corresponds to an SSLSocket object.
Expand Down Expand Up @@ -285,6 +290,11 @@
self->ecdsa_sign_callback = mp_const_none;
#endif

// Initialize PSK fields
self->psk_identity = mp_const_none;
self->psk_key = mp_const_none;
self->use_psk = false;

#ifdef MBEDTLS_DEBUG_C
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
mbedtls_debug_set_threshold(3);
Expand Down Expand Up @@ -382,10 +392,87 @@
}
static MP_DEFINE_CONST_FUN_OBJ_1(ssl_context_get_ciphers_obj, ssl_context_get_ciphers);

// Helper function to set PSK ciphersuites
static void set_psk_ciphersuites(mbedtls_ssl_config *conf) {
// Create a list of PSK ciphersuites
static int *psk_ciphersuites = NULL;

if (psk_ciphersuites == NULL) {
// Define known PSK ciphersuites
// These are common PSK ciphersuites supported by mbedtls
static const int known_psk_ciphersuites[] = {
MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA256,
MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA,
MBEDTLS_TLS_PSK_WITH_AES_256_CBC_SHA,
MBEDTLS_TLS_PSK_WITH_AES_128_GCM_SHA256,
MBEDTLS_TLS_PSK_WITH_AES_256_GCM_SHA384,
0 // Terminating zero
};

// Count available PSK ciphersuites
int count = 0;
for (int i = 0; known_psk_ciphersuites[i] != 0; i++) {
count++;
}

// Allocate memory for PSK ciphersuites
psk_ciphersuites = m_new(int, count + 1);
if (psk_ciphersuites == NULL) {
mp_raise_OSError(MP_ENOMEM);

Check warning on line 421 in extmod/modtls_mbedtls.c

View check run for this annotation

Codecov / codecov/patch

extmod/modtls_mbedtls.c#L421

Added line #L421 was not covered by tests
}

// Copy the PSK ciphersuites
for (int i = 0; i <= count; i++) { // Include terminating zero
psk_ciphersuites[i] = known_psk_ciphersuites[i];
}
}

// Set PSK ciphersuites
mbedtls_ssl_conf_ciphersuites(conf, psk_ciphersuites);
}

// SSLContext.set_ciphers(ciphersuite)
static mp_obj_t ssl_context_set_ciphers(mp_obj_t self_in, mp_obj_t ciphersuite) {
mp_obj_ssl_context_t *ssl_context = MP_OBJ_TO_PTR(self_in);

// Check if ciphersuite is a string
if (mp_obj_is_str(ciphersuite)) {
const char *ciphername = mp_obj_str_get_str(ciphersuite);

// Check for generic "PSK" mode
if (strcmp(ciphername, "PSK") == 0) {
ssl_context->use_psk = true;
set_psk_ciphersuites(&ssl_context->conf);
return mp_const_none;
}

// Check if this is a PSK ciphersuite name
if (strncmp(ciphername, "PSK-", 4) == 0 ||
strncmp(ciphername, "TLS-PSK-", 8) == 0 ||
strncmp(ciphername, "TLS_PSK_", 8) == 0) {

// Try to look up the ciphersuite ID
const int id = mbedtls_ssl_get_ciphersuite_id(ciphername);
if (id != 0) {
// Enable PSK mode
ssl_context->use_psk = true;

// Create a ciphersuite array with just this one ciphersuite
ssl_context->ciphersuites = m_new(int, 2);
if (ssl_context->ciphersuites == NULL) {
mp_raise_OSError(MP_ENOMEM);

Check warning on line 463 in extmod/modtls_mbedtls.c

View check run for this annotation

Codecov / codecov/patch

extmod/modtls_mbedtls.c#L463

Added line #L463 was not covered by tests
}
ssl_context->ciphersuites[0] = id;
ssl_context->ciphersuites[1] = 0; // Terminating zero

// Configure the ciphersuite
mbedtls_ssl_conf_ciphersuites(&ssl_context->conf, (const int *)ssl_context->ciphersuites);
return mp_const_none;
}
}
}

// Original implementation for non-PSK ciphersuites
// Check that ciphersuite is a list or tuple.
size_t len = 0;
mp_obj_t *ciphers;
Expand Down Expand Up @@ -467,6 +554,22 @@
}
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_load_verify_locations_obj, ssl_context_load_verify_locations);

// SSLContext.set_psk_identity(identity)
static mp_obj_t ssl_context_set_psk_identity(mp_obj_t self_in, mp_obj_t identity) {
mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(self_in);
self->psk_identity = identity;
return mp_const_none;
}
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_set_psk_identity_obj, ssl_context_set_psk_identity);

// SSLContext.set_psk_key(key)
static mp_obj_t ssl_context_set_psk_key(mp_obj_t self_in, mp_obj_t key) {
mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(self_in);
self->psk_key = key;
return mp_const_none;
}
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_set_psk_key_obj, ssl_context_set_psk_key);

static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname };
static const mp_arg_t allowed_args[] = {
Expand Down Expand Up @@ -496,6 +599,8 @@
{ MP_ROM_QSTR(MP_QSTR_load_cert_chain), MP_ROM_PTR(&ssl_context_load_cert_chain_obj)},
{ MP_ROM_QSTR(MP_QSTR_load_verify_locations), MP_ROM_PTR(&ssl_context_load_verify_locations_obj)},
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) },
{ MP_ROM_QSTR(MP_QSTR_set_psk_identity), MP_ROM_PTR(&ssl_context_set_psk_identity_obj) },
{ MP_ROM_QSTR(MP_QSTR_set_psk_key), MP_ROM_PTR(&ssl_context_set_psk_key_obj) },
};
static MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table);

Expand Down Expand Up @@ -603,6 +708,22 @@

mbedtls_ssl_init(&o->ssl);

// Configure PSK if enabled
if (ssl_context->use_psk && ssl_context->psk_identity != mp_const_none && ssl_context->psk_key != mp_const_none) {
// Get PSK identity and key
size_t psk_identity_len;
const byte *psk_identity = (const byte *)mp_obj_str_get_data(ssl_context->psk_identity, &psk_identity_len);

size_t psk_key_len;
const byte *psk_key = (const byte *)mp_obj_str_get_data(ssl_context->psk_key, &psk_key_len);

// Configure PSK
ret = mbedtls_ssl_conf_psk(&ssl_context->conf, psk_key, psk_key_len, psk_identity, psk_identity_len);
if (ret != 0) {
goto cleanup;

Check warning on line 723 in extmod/modtls_mbedtls.c

View check run for this annotation

Codecov / codecov/patch

extmod/modtls_mbedtls.c#L723

Added line #L723 was not covered by tests
}
}

ret = mbedtls_ssl_setup(&o->ssl, &ssl_context->conf);
#if !MICROPY_MBEDTLS_CONFIG_BARE_METAL
if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) {
Expand Down
52 changes: 52 additions & 0 deletions tests/multi_net/sslcontext_server_client_psk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Test TCP server and client with TLS-PSK, using set_psk_identity(),
# set_psk_key(), and set_ciphers("PSK").

try:
import socket
import tls
except ImportError:
print("SKIP")
raise SystemExit

PORT = 8000

PSK_ID = "PSK-Identity-1"
PSK_KEY = "c0ffee"
PSK_CIPHER = "PSK"


# Server
def instance0():
multitest.globals(IP=multitest.get_network_ip())
s = socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1])
s.listen(1)
multitest.next()
s2, _ = s.accept()
server_ctx = tls.SSLContext(tls.PROTOCOL_TLS_SERVER)
# Configure PSK
server_ctx.set_psk_identity(PSK_ID)
server_ctx.set_psk_key(bytes.fromhex(PSK_KEY))
server_ctx.set_ciphers(PSK_CIPHER)
s2 = server_ctx.wrap_socket(s2, server_side=True)
print(s2.read(16))
s2.write(b"server to client")
s2.close()
s.close()


# Client
def instance1():
multitest.next()
s = socket.socket()
s.connect(socket.getaddrinfo(IP, PORT)[0][-1])
client_ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT)
# Configure PSK
client_ctx.set_psk_identity(PSK_ID)
client_ctx.set_psk_key(bytes.fromhex(PSK_KEY))
client_ctx.set_ciphers(PSK_CIPHER)
s = client_ctx.wrap_socket(s, server_hostname="micropython.local")
s.write(b"client to server")
print(s.read(16))
s.close()
4 changes: 4 additions & 0 deletions tests/multi_net/sslcontext_server_client_psk.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
--- instance0 ---
b'client to server'
--- instance1 ---
b'server to client'
52 changes: 52 additions & 0 deletions tests/multi_net/sslcontext_server_client_psk_cipher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Test TCP server and client with TLS-PSK, using set_psk_identity(),
# set_psk_key(), and set_ciphers("TLS-PSK-WITH-AES-128-CBC-SHA256").

try:
import socket
import tls
except ImportError:
print("SKIP")
raise SystemExit

PORT = 8000

PSK_ID = "PSK-Identity-1"
PSK_KEY = "c0ffee"
PSK_CIPHER = "TLS-PSK-WITH-AES-128-CBC-SHA256"


# Server
def instance0():
multitest.globals(IP=multitest.get_network_ip())
s = socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1])
s.listen(1)
multitest.next()
s2, _ = s.accept()
server_ctx = tls.SSLContext(tls.PROTOCOL_TLS_SERVER)
# Configure PSK with specific ciphersuite
server_ctx.set_psk_identity(PSK_ID)
server_ctx.set_psk_key(bytes.fromhex(PSK_KEY))
server_ctx.set_ciphers(PSK_CIPHER)
s2 = server_ctx.wrap_socket(s2, server_side=True)
print(s2.read(16))
s2.write(b"server to client")
s2.close()
s.close()


# Client
def instance1():
multitest.next()
s = socket.socket()
s.connect(socket.getaddrinfo(IP, PORT)[0][-1])
client_ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT)
# Configure PSK with specific ciphersuite
client_ctx.set_psk_identity(PSK_ID)
client_ctx.set_psk_key(bytes.fromhex(PSK_KEY))
client_ctx.set_ciphers(PSK_CIPHER)
s = client_ctx.wrap_socket(s, server_hostname="micropython.local")
s.write(b"client to server")
print(s.read(16))
s.close()
4 changes: 4 additions & 0 deletions tests/multi_net/sslcontext_server_client_psk_cipher.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
--- instance0 ---
b'client to server'
--- instance1 ---
b'server to client'
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy