Skip to content

Support setting principal and SASL extensions in oauth_cb, handle failures #1402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 97 additions & 6 deletions src/confluent_kafka/src/confluent_kafka.c
Original file line number Diff line number Diff line change
Expand Up @@ -1522,13 +1522,73 @@ static void log_cb (const rd_kafka_t *rk, int level,
CallState_resume(cs);
}

/**
* @brief Translate Python \p key and \p value to C types and set on
* provided \p extensions char* array at the provided index.
*
* @returns 1 on success or 0 if an exception was raised.
*/
static int py_extensions_to_c (char **extensions, Py_ssize_t idx,
PyObject *key, PyObject *value) {
PyObject *ks, *ks8, *vo8 = NULL;
const char *k;
const char *v;
Py_ssize_t ksize = 0;
Py_ssize_t vsize = 0;

if (!(ks = cfl_PyObject_Unistr(key))) {
PyErr_SetString(PyExc_TypeError,
"expected extension key to be unicode "
"string");
return 0;
}

k = cfl_PyUnistr_AsUTF8(ks, &ks8);
ksize = (Py_ssize_t)strlen(k);

if (cfl_PyUnistr(_Check(value))) {
/* Unicode string, translate to utf-8. */
v = cfl_PyUnistr_AsUTF8(value, &vo8);
if (!v) {
Py_DECREF(ks);
Py_XDECREF(ks8);
return 0;
}
vsize = (Py_ssize_t)strlen(v);
} else {
PyErr_Format(PyExc_TypeError,
"expected extension value to be "
"unicode string, not %s",
((PyTypeObject *)PyObject_Type(value))->
tp_name);
Py_DECREF(ks);
Py_XDECREF(ks8);
return 0;
}

extensions[idx] = (char*)malloc(ksize);
strcpy(extensions[idx], k);
extensions[idx + 1] = (char*)malloc(vsize);
strcpy(extensions[idx + 1], v);

Py_DECREF(ks);
Py_XDECREF(ks8);
Py_XDECREF(vo8);

return 1;
}

static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
void *opaque) {
Handle *h = opaque;
PyObject *eo, *result;
CallState *cs;
const char *token;
double expiry;
const char *principal = "";
PyObject *extensions = NULL;
char **rd_extensions = NULL;
Py_ssize_t rd_extensions_size = 0;
char err_msg[2048];
rd_kafka_resp_err_t err_code;

Expand All @@ -1539,26 +1599,57 @@ static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
Py_DECREF(eo);

if (!result) {
goto err;
goto fail;
}
if (!PyArg_ParseTuple(result, "sd", &token, &expiry)) {
if (!PyArg_ParseTuple(result, "sd|sO!", &token, &expiry, &principal, &PyDict_Type, &extensions)) {
Py_DECREF(result);
PyErr_Format(PyExc_TypeError,
PyErr_SetString(PyExc_TypeError,
"expect returned value from oauth_cb "
"to be (token_str, expiry_time) tuple");
goto err;
}

if (extensions) {
int len = (int)PyDict_Size(extensions);
rd_extensions = (char **)malloc(2 * len * sizeof(char *));
Py_ssize_t pos = 0;
PyObject *ko, *vo;
while (PyDict_Next(extensions, &pos, &ko, &vo)) {
if (!py_extensions_to_c(rd_extensions, rd_extensions_size, ko, vo)) {
Py_DECREF(result);
free(rd_extensions);
goto err;
}
rd_extensions_size = rd_extensions_size + 2;
}
}

err_code = rd_kafka_oauthbearer_set_token(h->rk, token,
(int64_t)(expiry * 1000),
"", NULL, 0, err_msg,
principal, (const char **)rd_extensions, rd_extensions_size, err_msg,
sizeof(err_msg));
Py_DECREF(result);
if (err_code) {
if (rd_extensions) {
for(int i = 0; i < rd_extensions_size; i++) {
free(rd_extensions[i]);
}
free(rd_extensions);
}

if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) {
PyErr_Format(PyExc_ValueError, "%s", err_msg);
goto err;
goto fail;
}
goto done;

fail:
err_code = rd_kafka_oauthbearer_set_token_failure(h->rk, "OAuth callback raised exception");
if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) {
PyErr_SetString(PyExc_ValueError, "Failed to set token failure");
goto err;
}
PyErr_Clear();
goto done;
err:
CallState_crash(cs);
rd_kafka_yield(h->rk);
Expand Down
99 changes: 67 additions & 32 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,49 +24,41 @@ def test_version():
assert confluent_kafka.version()[0] == confluent_kafka.__version__


# global variable for error_cb call back function
seen_error_cb = False


def test_error_cb():
""" Tests error_cb. """
seen_error_cb = False

def error_cb(error_msg):
global seen_error_cb
nonlocal seen_error_cb
seen_error_cb = True
acceptable_error_codes = (confluent_kafka.KafkaError._TRANSPORT, confluent_kafka.KafkaError._ALL_BROKERS_DOWN)
assert error_msg.code() in acceptable_error_codes

conf = {'bootstrap.servers': 'localhost:65531', # Purposely cause connection refused error
'group.id': 'test',
'socket.timeout.ms': '100',
'session.timeout.ms': 1000, # Avoid close() blocking too long
'error_cb': error_cb
}

kc = confluent_kafka.Consumer(**conf)
kc.subscribe(["test"])
while not seen_error_cb:
kc.poll(timeout=1)
kc.poll(timeout=0.1)

kc.close()


# global variable for stats_cb call back function
seen_stats_cb = False


def test_stats_cb():
""" Tests stats_cb. """
seen_stats_cb = False

def stats_cb(stats_json_str):
global seen_stats_cb
nonlocal seen_stats_cb
seen_stats_cb = True
stats_json = json.loads(stats_json_str)
assert len(stats_json['name']) > 0

conf = {'group.id': 'test',
'socket.timeout.ms': '100',
'session.timeout.ms': 1000, # Avoid close() blocking too long
'statistics.interval.ms': 200,
'stats_cb': stats_cb
Expand All @@ -76,22 +68,20 @@ def stats_cb(stats_json_str):

kc.subscribe(["test"])
while not seen_stats_cb:
kc.poll(timeout=1)
kc.poll(timeout=0.1)
kc.close()


seen_stats_cb_check_no_brokers = False


def test_conf_none():
""" Issue #133
Test that None can be passed for NULL by setting bootstrap.servers
to None. If None would be converted to a string then a broker would
show up in statistics. Verify that it doesnt. """
seen_stats_cb_check_no_brokers = False

def stats_cb_check_no_brokers(stats_json_str):
""" Make sure no brokers are reported in stats """
global seen_stats_cb_check_no_brokers
nonlocal seen_stats_cb_check_no_brokers
stats = json.loads(stats_json_str)
assert len(stats['brokers']) == 0, "expected no brokers in stats: %s" % stats_json_str
seen_stats_cb_check_no_brokers = True
Expand All @@ -101,9 +91,8 @@ def stats_cb_check_no_brokers(stats_json_str):
'stats_cb': stats_cb_check_no_brokers}

p = confluent_kafka.Producer(conf)
p.poll(timeout=1)
p.poll(timeout=0.1)

global seen_stats_cb_check_no_brokers
assert seen_stats_cb_check_no_brokers


Expand All @@ -130,23 +119,19 @@ def test_throttle_event_types():
assert str(throttle_event) == "broker/0 throttled for 10000 ms"


# global variable for oauth_cb call back function
seen_oauth_cb = False


def test_oauth_cb():
""" Tests oauth_cb. """
seen_oauth_cb = False

def oauth_cb(oauth_config):
global seen_oauth_cb
nonlocal seen_oauth_cb
seen_oauth_cb = True
assert oauth_config == 'oauth_cb'
return 'token', time.time() + 300.0

conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'socket.timeout.ms': '100',
'session.timeout.ms': 1000, # Avoid close() blocking too long
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb
Expand All @@ -155,7 +140,59 @@ def oauth_cb(oauth_config):
kc = confluent_kafka.Consumer(**conf)

while not seen_oauth_cb:
kc.poll(timeout=1)
kc.poll(timeout=0.1)
kc.close()


def test_oauth_cb_principal_sasl_extensions():
""" Tests oauth_cb. """
seen_oauth_cb = False

def oauth_cb(oauth_config):
nonlocal seen_oauth_cb
seen_oauth_cb = True
assert oauth_config == 'oauth_cb'
return 'token', time.time() + 300.0, oauth_config, {"extone": "extoneval", "exttwo": "exttwoval"}

conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 100, # Avoid close() blocking too long
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb
}

kc = confluent_kafka.Consumer(**conf)

while not seen_oauth_cb:
kc.poll(timeout=0.1)
kc.close()


def test_oauth_cb_failure():
""" Tests oauth_cb. """
oauth_cb_count = 0

def oauth_cb(oauth_config):
nonlocal oauth_cb_count
oauth_cb_count += 1
assert oauth_config == 'oauth_cb'
if oauth_cb_count == 2:
return 'token', time.time() + 100.0, oauth_config, {"extthree": "extthreeval"}
raise Exception

conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 1000, # Avoid close() blocking too long
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb
}

kc = confluent_kafka.Consumer(**conf)

while oauth_cb_count < 2:
kc.poll(timeout=0.1)
kc.close()


Expand Down Expand Up @@ -194,11 +231,9 @@ def test_unordered_dict(init_func):
client.poll(0)


# global variable for on_delivery call back function
seen_delivery_cb = False


def test_topic_config_update():
seen_delivery_cb = False

# *NOTE* default.topic.config has been deprecated.
# This example remains to ensure backward-compatibility until its removal.
confs = [{"message.timeout.ms": 600000, "default.topic.config": {"message.timeout.ms": 1000}},
Expand All @@ -207,7 +242,7 @@ def test_topic_config_update():

def on_delivery(err, msg):
# Since there is no broker, produced messages should time out.
global seen_delivery_cb
nonlocal seen_delivery_cb
seen_delivery_cb = True
assert err.code() == confluent_kafka.KafkaError._MSG_TIMED_OUT

Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy